Commit 5d392600 authored by Emmanuel T Odeke's avatar Emmanuel T Odeke Committed by Emmanuel Odeke

net: preserve unexpired context values for LookupIPAddr

To avoid any cancelation of the parent context from affecting
lookupGroup operations, Resolver.LookupIPAddr previously used
an entirely new context created from context.Background().
However, this meant that all the values in the parent context
with which LookupIPAddr was invoked were dropped.

This change provides a custom context implementation
that only preserves values of the parent context by composing
context.Background() and the parent context. It only falls back
to the parent context to perform value lookups if the parent
context has not yet expired.
This context is never canceled, and has no deadlines.

Fixes #28600

Change-Id: If2f570caa26c65bad638b7102c35c79d5e429fea
Reviewed-on: https://go-review.googlesource.com/c/148698
Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 70e3b1df
......@@ -205,6 +205,33 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
return r.lookupIPAddr(ctx, "ip", host)
}
// onlyValuesCtx is a context that uses an underlying context
// for value lookup if the underlying context hasn't yet expired.
type onlyValuesCtx struct {
context.Context
lookupValues context.Context
}
var _ context.Context = (*onlyValuesCtx)(nil)
// Value performs a lookup if the original context hasn't expired.
func (ovc *onlyValuesCtx) Value(key interface{}) interface{} {
select {
case <-ovc.lookupValues.Done():
return nil
default:
return ovc.lookupValues.Value(key)
}
}
// withUnexpiredValuesPreserved returns a context.Context that only uses lookupCtx
// for its values, otherwise it is never canceled and has no deadline.
// If the lookup context expires, any looked up values will return nil.
// See Issue 28600.
func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context {
return &onlyValuesCtx{Context: context.Background(), lookupValues: lookupCtx}
}
// lookupIPAddr looks up host using the local resolver and particular network.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
......@@ -231,8 +258,9 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP
// We don't want a cancelation of ctx to affect the
// lookupGroup operation. Otherwise if our context gets
// canceled it might cause an error to be returned to a lookup
// using a completely different context.
lookupGroupCtx, lookupGroupCancel := context.WithCancel(context.Background())
// using a completely different context. However we need to preserve
// only the values in context. See Issue 28600.
lookupGroupCtx, lookupGroupCancel := context.WithCancel(withUnexpiredValuesPreserved(ctx))
dnsWaitGroup.Add(1)
ch, called := r.getLookupGroup().DoChan(host, func() (interface{}, error) {
......
......@@ -1034,3 +1034,78 @@ func TestIPVersion(t *testing.T) {
}
}
}
// Issue 28600: The context that is used to lookup ips should always
// preserve the values from the context that was passed into LookupIPAddr.
func TestLookupIPAddrPreservesContextValues(t *testing.T) {
origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
keyValues := []struct {
key, value interface{}
}{
{"key-1", 12},
{384, "value2"},
{new(float64), 137},
}
ctx := context.Background()
for _, kv := range keyValues {
ctx = context.WithValue(ctx, kv.key, kv.value)
}
wantIPs := []IPAddr{
{IP: IPv4(127, 0, 0, 1)},
{IP: IPv6loopback},
}
checkCtxValues := func(ctx_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
for _, kv := range keyValues {
g, w := ctx_.Value(kv.key), kv.value
if !reflect.DeepEqual(g, w) {
t.Errorf("Value lookup:\n\tGot: %v\n\tWant: %v", g, w)
}
}
return wantIPs, nil
}
testHookLookupIP = checkCtxValues
resolvers := []*Resolver{
nil,
new(Resolver),
}
for i, resolver := range resolvers {
gotIPs, err := resolver.LookupIPAddr(ctx, "golang.org")
if err != nil {
t.Errorf("Resolver #%d: unexpected error: %v", i, err)
}
if !reflect.DeepEqual(gotIPs, wantIPs) {
t.Errorf("#%d: mismatched IPAddr results\n\tGot: %v\n\tWant: %v", i, gotIPs, wantIPs)
}
}
}
func TestWithUnexpiredValuesPreserved(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Insert a value into it.
key, value := "key-1", 2
ctx = context.WithValue(ctx, key, value)
// Now use the "values preserving context" like
// we would for LookupIPAddr. See Issue 28600.
ctx = withUnexpiredValuesPreserved(ctx)
// Lookup before expiry.
if g, w := ctx.Value(key), value; g != w {
t.Errorf("Lookup before expiry: Got %v Want %v", g, w)
}
// Cancel the context.
cancel()
// Lookup after expiry should return nil
if g := ctx.Value(key); g != nil {
t.Errorf("Lookup after expiry: Got %v want nil", g)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment