Commit 0ad36867 authored by Russ Cox's avatar Russ Cox

context: use fewer goroutines in WithCancel/WithTimeout

If the parent context passed to WithCancel or WithTimeout
is a known context implementation (one created by this package),
we attach the child to the parent by editing data structures directly;
otherwise, for unknown parent implementations, we make a
goroutine that watches for the parent to finish and propagates
the cancellation.

A common problem with this scheme, before this CL, is that
users who write custom context implementations to manage
their value sets cause WithCancel/WithTimeout to start
goroutines that would have not been started before.

This CL changes the way we map a parent context back to the
underlying data structure. Instead of walking up through
known context implementations to reach the *cancelCtx,
we look up parent.Value(&cancelCtxKey) to return the
innermost *cancelCtx, which we use if it matches parent.Done().

This way, a custom context implementation wrapping a
*cancelCtx but not changing Done-ness (and not refusing
to return wrapped keys) will not require a goroutine anymore
in WithCancel/WithTimeout.

For #28728.

Change-Id: Idba2f435c81b19fe38d0dbf308458ca87c7381e9
Reviewed-on: https://go-review.googlesource.com/c/go/+/196521Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent ae024d9d
...@@ -51,6 +51,7 @@ import ( ...@@ -51,6 +51,7 @@ import (
"errors" "errors"
"internal/reflectlite" "internal/reflectlite"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
...@@ -239,11 +240,24 @@ func newCancelCtx(parent Context) cancelCtx { ...@@ -239,11 +240,24 @@ func newCancelCtx(parent Context) cancelCtx {
return cancelCtx{Context: parent} return cancelCtx{Context: parent}
} }
// goroutines counts the number of goroutines ever created; for testing.
var goroutines int32
// propagateCancel arranges for child to be canceled when parent is. // propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) { func propagateCancel(parent Context, child canceler) {
if parent.Done() == nil { done := parent.Done()
if done == nil {
return // parent is never canceled return // parent is never canceled
} }
select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err())
return
default:
}
if p, ok := parentCancelCtx(parent); ok { if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock() p.mu.Lock()
if p.err != nil { if p.err != nil {
...@@ -257,6 +271,7 @@ func propagateCancel(parent Context, child canceler) { ...@@ -257,6 +271,7 @@ func propagateCancel(parent Context, child canceler) {
} }
p.mu.Unlock() p.mu.Unlock()
} else { } else {
atomic.AddInt32(&goroutines, +1)
go func() { go func() {
select { select {
case <-parent.Done(): case <-parent.Done():
...@@ -267,22 +282,31 @@ func propagateCancel(parent Context, child canceler) { ...@@ -267,22 +282,31 @@ func propagateCancel(parent Context, child canceler) {
} }
} }
// parentCancelCtx follows a chain of parent references until it finds a // &cancelCtxKey is the key that a cancelCtx returns itself for.
// *cancelCtx. This function understands how each of the concrete types in this var cancelCtxKey int
// package represents its parent.
// parentCancelCtx returns the underlying *cancelCtx for parent.
// It does this by looking up parent.Value(&cancelCtxKey) to find
// the innermost enclosing *cancelCtx and then checking whether
// parent.Done() matches that *cancelCtx. (If not, the *cancelCtx
// has been wrapped in a custom implementation providing a
// different done channel, in which case we should not bypass it.)
func parentCancelCtx(parent Context) (*cancelCtx, bool) { func parentCancelCtx(parent Context) (*cancelCtx, bool) {
for { done := parent.Done()
switch c := parent.(type) { if done == closedchan || done == nil {
case *cancelCtx:
return c, true
case *timerCtx:
return &c.cancelCtx, true
case *valueCtx:
parent = c.Context
default:
return nil, false return nil, false
} }
p, ok := parent.Value(&cancelCtxKey).(*cancelCtx)
if !ok {
return nil, false
}
p.mu.Lock()
ok = p.done == done
p.mu.Unlock()
if !ok {
return nil, false
} }
return p, true
} }
// removeChild removes a context from its parent. // removeChild removes a context from its parent.
...@@ -323,6 +347,13 @@ type cancelCtx struct { ...@@ -323,6 +347,13 @@ type cancelCtx struct {
err error // set to non-nil by the first cancel call err error // set to non-nil by the first cancel call
} }
func (c *cancelCtx) Value(key interface{}) interface{} {
if key == &cancelCtxKey {
return c
}
return c.Context.Value(key)
}
func (c *cancelCtx) Done() <-chan struct{} { func (c *cancelCtx) Done() <-chan struct{} {
c.mu.Lock() c.mu.Lock()
if c.done == nil { if c.done == nil {
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
...@@ -21,6 +22,7 @@ type testingT interface { ...@@ -21,6 +22,7 @@ type testingT interface {
Failed() bool Failed() bool
Fatal(args ...interface{}) Fatal(args ...interface{})
Fatalf(format string, args ...interface{}) Fatalf(format string, args ...interface{})
Helper()
Log(args ...interface{}) Log(args ...interface{})
Logf(format string, args ...interface{}) Logf(format string, args ...interface{})
Name() string Name() string
...@@ -401,7 +403,7 @@ func XTestAllocs(t testingT, testingShort func() bool, testingAllocsPerRun func( ...@@ -401,7 +403,7 @@ func XTestAllocs(t testingT, testingShort func() bool, testingAllocsPerRun func(
c, _ := WithTimeout(bg, 15*time.Millisecond) c, _ := WithTimeout(bg, 15*time.Millisecond)
<-c.Done() <-c.Done()
}, },
limit: 8, limit: 12,
gccgoLimit: 15, gccgoLimit: 15,
}, },
{ {
...@@ -648,3 +650,81 @@ func XTestDeadlineExceededSupportsTimeout(t testingT) { ...@@ -648,3 +650,81 @@ func XTestDeadlineExceededSupportsTimeout(t testingT) {
t.Fatal("wrong value for timeout") t.Fatal("wrong value for timeout")
} }
} }
type myCtx struct {
Context
}
type myDoneCtx struct {
Context
}
func (d *myDoneCtx) Done() <-chan struct{} {
c := make(chan struct{})
return c
}
func XTestCustomContextGoroutines(t testingT) {
g := atomic.LoadInt32(&goroutines)
checkNoGoroutine := func() {
t.Helper()
now := atomic.LoadInt32(&goroutines)
if now != g {
t.Fatalf("%d goroutines created", now-g)
}
}
checkCreatedGoroutine := func() {
t.Helper()
now := atomic.LoadInt32(&goroutines)
if now != g+1 {
t.Fatalf("%d goroutines created, want 1", now-g)
}
g = now
}
_, cancel0 := WithCancel(&myDoneCtx{Background()})
cancel0()
checkCreatedGoroutine()
_, cancel0 = WithTimeout(&myDoneCtx{Background()}, 1*time.Hour)
cancel0()
checkCreatedGoroutine()
checkNoGoroutine()
defer checkNoGoroutine()
ctx1, cancel1 := WithCancel(Background())
defer cancel1()
checkNoGoroutine()
ctx2 := &myCtx{ctx1}
ctx3, cancel3 := WithCancel(ctx2)
defer cancel3()
checkNoGoroutine()
_, cancel3b := WithCancel(&myDoneCtx{ctx2})
defer cancel3b()
checkCreatedGoroutine() // ctx1 is not providing Done, must not be used
ctx4, cancel4 := WithTimeout(ctx3, 1*time.Hour)
defer cancel4()
checkNoGoroutine()
ctx5, cancel5 := WithCancel(ctx4)
defer cancel5()
checkNoGoroutine()
cancel5()
checkNoGoroutine()
_, cancel6 := WithTimeout(ctx5, 1*time.Hour)
defer cancel6()
checkNoGoroutine()
// Check applied to cancelled context.
cancel6()
cancel1()
_, cancel7 := WithCancel(ctx5)
defer cancel7()
checkNoGoroutine()
}
...@@ -27,3 +27,4 @@ func TestCancelRemoves(t *testing.T) { XTestCancelRemoves(t) } ...@@ -27,3 +27,4 @@ func TestCancelRemoves(t *testing.T) { XTestCancelRemoves(t) }
func TestWithCancelCanceledParent(t *testing.T) { XTestWithCancelCanceledParent(t) } func TestWithCancelCanceledParent(t *testing.T) { XTestWithCancelCanceledParent(t) }
func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey(t) } func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey(t) }
func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) } func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) }
func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) }
...@@ -252,7 +252,7 @@ var pkgDeps = map[string][]string{ ...@@ -252,7 +252,7 @@ var pkgDeps = map[string][]string{
"compress/gzip": {"L4", "compress/flate"}, "compress/gzip": {"L4", "compress/flate"},
"compress/lzw": {"L4"}, "compress/lzw": {"L4"},
"compress/zlib": {"L4", "compress/flate"}, "compress/zlib": {"L4", "compress/flate"},
"context": {"errors", "internal/reflectlite", "sync", "time"}, "context": {"errors", "internal/reflectlite", "sync", "sync/atomic", "time"},
"database/sql": {"L4", "container/list", "context", "database/sql/driver", "database/sql/internal"}, "database/sql": {"L4", "container/list", "context", "database/sql/driver", "database/sql/internal"},
"database/sql/driver": {"L4", "context", "time", "database/sql/internal"}, "database/sql/driver": {"L4", "context", "time", "database/sql/internal"},
"debug/dwarf": {"L4"}, "debug/dwarf": {"L4"},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment