Commit 03a48ebe authored by Dmitry Vyukov's avatar Dmitry Vyukov Committed by Russ Cox

sync: simplify WaitGroup

A comment in waitgroup.go describes the following scenario
as the reason to have dynamically created semaphores:

// G1: Add(1)
// G1: go G2()
// G1: Wait() // Context switch after Unlock() and before Semacquire().
// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet.
// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block.
// G3: Add(1) // Makes counter == 1, waiters == 0.
// G3: go G4()
// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug.

However, the scenario is incorrect:
G3: Add(1) happens concurrently with G1: Wait(),
and so there is no reasonable behavior of the program
(G1: Wait() may or may not wait for G3: Add(1) which
can't be the intended behavior).

With this conclusion we can:
1. Remove dynamic allocation of semaphores.
2. Remove the mutex entirely and instead pack counter and waiters
   into single uint64.

This makes the logic significantly simpler, both Add and Wait
do only a single atomic RMW to update the state.

benchmark                            old ns/op     new ns/op     delta
BenchmarkWaitGroupUncontended        30.6          32.7          +6.86%
BenchmarkWaitGroupActuallyWait       722           595           -17.59%
BenchmarkWaitGroupActuallyWait-2     396           319           -19.44%
BenchmarkWaitGroupActuallyWait-4     224           183           -18.30%
BenchmarkWaitGroupActuallyWait-8     134           106           -20.90%

benchmark                          old allocs     new allocs     delta
BenchmarkWaitGroupActuallyWait     2              1              -50.00%

benchmark                          old bytes     new bytes     delta
BenchmarkWaitGroupActuallyWait     48            16            -66.67%

Change-Id: I28911f3243aa16544e99ac8f1f5af31944c7ea3a
Reviewed-on: https://go-review.googlesource.com/4117
Run-TryBot: Dmitry Vyukov <dvyukov@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
Reviewed-by: default avatarRuss Cox <rsc@golang.org>
parent 450988b5
...@@ -15,23 +15,21 @@ import ( ...@@ -15,23 +15,21 @@ import (
// runs and calls Done when finished. At the same time, // runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished. // Wait can be used to block until all goroutines have finished.
type WaitGroup struct { type WaitGroup struct {
m Mutex // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
counter int32 // 64-bit atomic operations require 64-bit alignment, but 32-bit
waiters int32 // compilers do not ensure it. So we allocate 12 bytes and then use
sema *uint32 // the aligned 8 bytes in them as state.
state1 [12]byte
sema uint32
} }
// WaitGroup creates a new semaphore each time the old semaphore func (wg *WaitGroup) state() *uint64 {
// is released. This is to avoid the following race: if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
// return (*uint64)(unsafe.Pointer(&wg.state1))
// G1: Add(1) } else {
// G1: go G2() return (*uint64)(unsafe.Pointer(&wg.state1[4]))
// G1: Wait() // Context switch after Unlock() and before Semacquire(). }
// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet. }
// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block.
// G3: Add(1) // Makes counter == 1, waiters == 0.
// G3: go G4()
// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug.
// Add adds delta, which may be negative, to the WaitGroup counter. // Add adds delta, which may be negative, to the WaitGroup counter.
// If the counter becomes zero, all goroutines blocked on Wait are released. // If the counter becomes zero, all goroutines blocked on Wait are released.
...@@ -43,10 +41,13 @@ type WaitGroup struct { ...@@ -43,10 +41,13 @@ type WaitGroup struct {
// at any time. // at any time.
// Typically this means the calls to Add should execute before the statement // Typically this means the calls to Add should execute before the statement
// creating the goroutine or other event to be waited for. // creating the goroutine or other event to be waited for.
// If a WaitGroup is reused to wait for several independent sets of events,
// new Add calls must happen after all previous Wait calls have returned.
// See the WaitGroup example. // See the WaitGroup example.
func (wg *WaitGroup) Add(delta int) { func (wg *WaitGroup) Add(delta int) {
statep := wg.state()
if raceenabled { if raceenabled {
_ = wg.m.state // trigger nil deref early _ = *statep // trigger nil deref early
if delta < 0 { if delta < 0 {
// Synchronize decrements with Wait. // Synchronize decrements with Wait.
raceReleaseMerge(unsafe.Pointer(wg)) raceReleaseMerge(unsafe.Pointer(wg))
...@@ -54,7 +55,9 @@ func (wg *WaitGroup) Add(delta int) { ...@@ -54,7 +55,9 @@ func (wg *WaitGroup) Add(delta int) {
raceDisable() raceDisable()
defer raceEnable() defer raceEnable()
} }
v := atomic.AddInt32(&wg.counter, int32(delta)) state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32)
w := uint32(state)
if raceenabled { if raceenabled {
if delta > 0 && v == int32(delta) { if delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait. // The first increment must be synchronized with Wait.
...@@ -66,18 +69,25 @@ func (wg *WaitGroup) Add(delta int) { ...@@ -66,18 +69,25 @@ func (wg *WaitGroup) Add(delta int) {
if v < 0 { if v < 0 {
panic("sync: negative WaitGroup counter") panic("sync: negative WaitGroup counter")
} }
if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 { if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return return
} }
wg.m.Lock() // This goroutine has set counter to 0 when waiters > 0.
if atomic.LoadInt32(&wg.counter) == 0 { // Now there can't be concurrent mutations of state:
for i := int32(0); i < wg.waiters; i++ { // - Adds must not happen concurrently with Wait,
runtime_Semrelease(wg.sema) // - Wait does not increment waiters if it sees counter == 0.
} // Still do a cheap sanity check to detect WaitGroup misuse.
wg.waiters = 0 if *statep != state {
wg.sema = nil panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema)
} }
wg.m.Unlock()
} }
// Done decrements the WaitGroup counter. // Done decrements the WaitGroup counter.
...@@ -87,51 +97,41 @@ func (wg *WaitGroup) Done() { ...@@ -87,51 +97,41 @@ func (wg *WaitGroup) Done() {
// Wait blocks until the WaitGroup counter is zero. // Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() { func (wg *WaitGroup) Wait() {
statep := wg.state()
if raceenabled { if raceenabled {
_ = wg.m.state // trigger nil deref early _ = *statep // trigger nil deref early
raceDisable() raceDisable()
} }
if atomic.LoadInt32(&wg.counter) == 0 { for {
if raceenabled { state := atomic.LoadUint64(statep)
raceEnable() v := int32(state >> 32)
raceAcquire(unsafe.Pointer(wg)) w := uint32(state)
} if v == 0 {
return // Counter is 0, no need to wait.
} if raceenabled {
wg.m.Lock() raceEnable()
w := atomic.AddInt32(&wg.waiters, 1) raceAcquire(unsafe.Pointer(wg))
// This code is racing with the unlocked path in Add above. }
// The code above modifies counter and then reads waiters. return
// We must modify waiters and then read counter (the opposite order)
// to avoid missing an Add.
if atomic.LoadInt32(&wg.counter) == 0 {
atomic.AddInt32(&wg.waiters, -1)
if raceenabled {
raceEnable()
raceAcquire(unsafe.Pointer(wg))
raceDisable()
} }
wg.m.Unlock() // Increment waiters count.
if raceenabled { if atomic.CompareAndSwapUint64(statep, state, state+1) {
raceEnable() if raceenabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
raceWrite(unsafe.Pointer(&wg.sema))
}
runtime_Semacquire(&wg.sema)
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if raceenabled {
raceEnable()
raceAcquire(unsafe.Pointer(wg))
}
return
} }
return
}
if raceenabled && w == 1 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
raceWrite(unsafe.Pointer(&wg.sema))
}
if wg.sema == nil {
wg.sema = new(uint32)
}
s := wg.sema
wg.m.Unlock()
runtime_Semacquire(s)
if raceenabled {
raceEnable()
raceAcquire(unsafe.Pointer(wg))
} }
} }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package sync_test package sync_test
import ( import (
"runtime"
. "sync" . "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
...@@ -60,6 +61,90 @@ func TestWaitGroupMisuse(t *testing.T) { ...@@ -60,6 +61,90 @@ func TestWaitGroupMisuse(t *testing.T) {
t.Fatal("Should panic") t.Fatal("Should panic")
} }
func TestWaitGroupMisuse2(t *testing.T) {
if runtime.NumCPU() <= 2 {
t.Skip("NumCPU<=2, skipping: this test requires parallelism")
}
defer func() {
err := recover()
if err != "sync: negative WaitGroup counter" &&
err != "sync: WaitGroup misuse: Add called concurrently with Wait" &&
err != "sync: WaitGroup is reused before previous Wait has returned" {
t.Fatalf("Unexpected panic: %#v", err)
}
}()
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
done := make(chan interface{}, 2)
// The detection is opportunistically, so we want it to panic
// at least in one run out of a million.
for i := 0; i < 1e6; i++ {
var wg WaitGroup
wg.Add(1)
go func() {
defer func() {
done <- recover()
}()
wg.Wait()
}()
go func() {
defer func() {
done <- recover()
}()
wg.Add(1) // This is the bad guy.
wg.Done()
}()
wg.Done()
for j := 0; j < 2; j++ {
if err := <-done; err != nil {
panic(err)
}
}
}
t.Fatal("Should panic")
}
func TestWaitGroupMisuse3(t *testing.T) {
if runtime.NumCPU() <= 1 {
t.Skip("NumCPU==1, skipping: this test requires parallelism")
}
defer func() {
err := recover()
if err != "sync: negative WaitGroup counter" &&
err != "sync: WaitGroup misuse: Add called concurrently with Wait" &&
err != "sync: WaitGroup is reused before previous Wait has returned" {
t.Fatalf("Unexpected panic: %#v", err)
}
}()
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
done := make(chan interface{}, 1)
// The detection is opportunistically, so we want it to panic
// at least in one run out of a million.
for i := 0; i < 1e6; i++ {
var wg WaitGroup
wg.Add(1)
go func() {
wg.Done()
}()
go func() {
defer func() {
done <- recover()
}()
wg.Wait()
// Start reusing the wg before waiting for the Wait below to return.
wg.Add(1)
go func() {
wg.Done()
}()
wg.Wait()
}()
wg.Wait()
if err := <-done; err != nil {
panic(err)
}
}
t.Fatal("Should panic")
}
func TestWaitGroupRace(t *testing.T) { func TestWaitGroupRace(t *testing.T) {
// Run this test for about 1ms. // Run this test for about 1ms.
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
...@@ -85,6 +170,19 @@ func TestWaitGroupRace(t *testing.T) { ...@@ -85,6 +170,19 @@ func TestWaitGroupRace(t *testing.T) {
} }
} }
func TestWaitGroupAlign(t *testing.T) {
type X struct {
x byte
wg WaitGroup
}
var x X
x.wg.Add(1)
go func(x *X) {
x.wg.Done()
}(&x)
x.wg.Wait()
}
func BenchmarkWaitGroupUncontended(b *testing.B) { func BenchmarkWaitGroupUncontended(b *testing.B) {
type PaddedWaitGroup struct { type PaddedWaitGroup struct {
WaitGroup WaitGroup
...@@ -146,3 +244,17 @@ func BenchmarkWaitGroupWait(b *testing.B) { ...@@ -146,3 +244,17 @@ func BenchmarkWaitGroupWait(b *testing.B) {
func BenchmarkWaitGroupWaitWork(b *testing.B) { func BenchmarkWaitGroupWaitWork(b *testing.B) {
benchmarkWaitGroupWait(b, 100) benchmarkWaitGroupWait(b, 100)
} }
func BenchmarkWaitGroupActuallyWait(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var wg WaitGroup
wg.Add(1)
go func() {
wg.Done()
}()
wg.Wait()
}
})
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment