Commit b6b4004d authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: context plumbing, add Dialer.DialContext

For #12580 (http.Transport tracing/analytics)
Updates #13021

Change-Id: I126e494a7bd872e42c388ecb58499ecbf0f014cc
Reviewed-on: https://go-review.googlesource.com/22101
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
Reviewed-by: default avatarMikio Hara <mikioh.mikioh@gmail.com>
parent 1d0977a1
......@@ -280,7 +280,9 @@ var pkgDeps = map[string][]string{
// Basic networking.
// Because net must be used by any package that wants to
// do networking portably, it must have a small dependency set: just L0+basic os.
"net": {"L0", "CGO", "math/rand", "os", "sort", "syscall", "time", "internal/syscall/windows", "internal/singleflight", "internal/race"},
"net": {"L0", "CGO",
"context", "math/rand", "os", "sort", "syscall", "time",
"internal/syscall/windows", "internal/singleflight", "internal/race"},
// NET enables use of basic network-related packages.
"NET": {
......
......@@ -7,7 +7,10 @@
package net
import "testing"
import (
"context"
"testing"
)
func TestCgoLookupIP(t *testing.T) {
host := "localhost"
......@@ -18,7 +21,7 @@ func TestCgoLookupIP(t *testing.T) {
if err != nil {
t.Error(err)
}
if _, err := goLookupIP(host); err != nil {
if _, err := goLookupIP(context.Background(), host); err != nil {
t.Error(err)
}
}
......@@ -5,6 +5,7 @@
package net
import (
"context"
"runtime"
"time"
)
......@@ -61,21 +62,34 @@ type Dialer struct {
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancelation.
//
// Deprecated: Use DialContext instead.
Cancel <-chan struct{}
}
// Return either now+Timeout or Deadline, whichever comes first.
// Or zero, if neither is set.
func (d *Dialer) deadline(now time.Time) time.Time {
if d.Timeout == 0 {
return d.Deadline
func minNonzeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
timeoutDeadline := now.Add(d.Timeout)
if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) {
return timeoutDeadline
} else {
return d.Deadline
if b.IsZero() || a.Before(b) {
return a
}
return b
}
// deadline returns the earliest of:
// - now+Timeout
// - d.Deadline
// - the context's deadline
// Or zero, if none of Timeout, Deadline, or context's deadline is set.
func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
if d.Timeout != 0 { // including negative, for historical reasons
earliest = now.Add(d.Timeout)
}
if d, ok := ctx.Deadline(); ok {
earliest = minNonzeroTime(earliest, d)
}
return minNonzeroTime(earliest, d.Deadline)
}
// partialDeadline returns the deadline to use for a single address,
......@@ -142,7 +156,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
// resolverAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) {
func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
afnet, _, err := parseNetwork(network)
if err != nil {
return nil, err
......@@ -152,6 +166,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a
}
switch afnet {
case "unix", "unixgram", "unixpacket":
// TODO(bradfitz): push down context
addr, err := ResolveUnixAddr(afnet, addr)
if err != nil {
return nil, err
......@@ -161,7 +176,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a
}
return addrList{addr}, nil
}
addrs, err := internetAddrList(afnet, addr, deadline)
addrs, err := internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
}
......@@ -253,11 +268,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
return d.Dial(network, address)
}
// dialContext holds common state for all dial operations.
type dialContext struct {
// dialParam contains a Dial's parameters and configuration.
type dialParam struct {
Dialer
network, address string
finalDeadline time.Time
}
// Dial connects to the address on the named network.
......@@ -265,17 +279,50 @@ type dialContext struct {
// See func Dial for a description of the network and address
// parameters.
func (d *Dialer) Dial(network, address string) (Conn, error) {
finalDeadline := d.deadline(time.Now())
addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline)
return d.DialContext(context.Background(), network, address)
}
// DialContext connects to the address on the named network using
// the provided context.
//
// The provided Context must be non-nil.
//
// See func Dial for a description of the network and address
// parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
if ctx == nil {
panic("nil context")
}
deadline := d.deadline(ctx, time.Now())
if !deadline.IsZero() {
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
subCtx, cancel := context.WithDeadline(ctx, deadline)
defer cancel()
ctx = subCtx
}
}
if oldCancel := d.Cancel; oldCancel != nil {
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-oldCancel:
cancel()
case <-subCtx.Done():
}
}()
ctx = subCtx
}
addrs, err := resolveAddrList(ctx, "dial", network, address, d.LocalAddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
ctx := &dialContext{
Dialer: *d,
network: network,
address: address,
finalDeadline: finalDeadline,
dp := &dialParam{
Dialer: *d,
network: network,
address: address,
}
// DualStack mode requires that dialTCP support cancelation. This is
......@@ -288,138 +335,128 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
}
var c Conn
if len(fallbacks) == 0 {
// dialParallel can accept an empty fallbacks list,
// but this shortcut avoids the goroutine/channel overhead.
c, err = dialSerial(ctx, primaries, ctx.Cancel)
if len(fallbacks) > 0 {
c, err = dialParallel(ctx, dp, primaries, fallbacks)
} else {
c, err = dialParallel(ctx, primaries, fallbacks, ctx.Cancel)
c, err = dialSerial(ctx, dp, primaries)
}
if err != nil {
return nil, err
}
if d.KeepAlive > 0 && err == nil {
if tc, ok := c.(*TCPConn); ok {
setKeepAlive(tc.fd, true)
setKeepAlivePeriod(tc.fd, d.KeepAlive)
testHookSetKeepAlive()
}
if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 {
setKeepAlive(tc.fd, true)
setKeepAlivePeriod(tc.fd, d.KeepAlive)
testHookSetKeepAlive()
}
return c, err
return c, nil
}
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
func dialParallel(ctx *dialContext, primaries, fallbacks addrList, userCancel <-chan struct{}) (Conn, error) {
results := make(chan dialResult, 2)
cancel := make(chan struct{})
func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
if len(fallbacks) == 0 {
return dialSerial(ctx, dp, primaries)
}
// Spawn the primary racer.
go dialSerialAsync(ctx, primaries, nil, cancel, results)
returned := make(chan struct{})
defer close(returned)
// Spawn the fallback racer.
fallbackTimer := time.NewTimer(ctx.fallbackDelay())
go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results)
type dialResult struct {
Conn
error
primary bool
done bool
}
results := make(chan dialResult) // unbuffered
// Wait for both racers to succeed or fail.
var primaryResult, fallbackResult dialResult
for !primaryResult.done || !fallbackResult.done {
startRacer := func(ctx context.Context, primary bool) {
ras := primaries
if !primary {
ras = fallbacks
}
c, err := dialSerial(ctx, dp, ras)
select {
case <-userCancel:
// Forward an external cancelation request.
if cancel != nil {
close(cancel)
cancel = nil
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
c.Close()
}
userCancel = nil
}
}
var primary, fallback dialResult
// Start the main racer.
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go startRacer(primaryCtx, true)
// Start the timer for the fallback racer.
fallbackTimer := time.NewTimer(dp.fallbackDelay())
defer fallbackTimer.Stop()
for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go startRacer(fallbackCtx, false)
case res := <-results:
// Drop the result into its assigned bucket.
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primaryResult = res
primary = res
} else {
fallbackResult = res
fallback = res
}
// On success, cancel the other racer (if one exists.)
if res.error == nil && cancel != nil {
close(cancel)
cancel = nil
if primary.done && fallback.done {
return nil, primary.error
}
// If the fallbackTimer was pending, then either we've canceled the
// fallback because we no longer want it, or we haven't canceled yet
// and therefore want it to wake up immediately.
if fallbackTimer.Stop() && cancel != nil {
if res.primary && fallbackTimer.Stop() {
// If we were able to stop the timer, that means it
// was running (hadn't yet started the fallback), but
// we just got an error on the primary path, so start
// the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
}
// Return, in order of preference:
// 1. The primary connection (but close the other if we got both.)
// 2. The fallback connection.
// 3. The primary error.
if primaryResult.error == nil {
if fallbackResult.error == nil {
fallbackResult.Conn.Close()
}
return primaryResult.Conn, nil
} else if fallbackResult.error == nil {
return fallbackResult.Conn, nil
} else {
return nil, primaryResult.error
}
}
type dialResult struct {
Conn
error
primary bool
done bool
}
// dialSerialAsync runs dialSerial after some delay, and returns the
// resulting connection through a channel. When racing two connections,
// the primary goroutine uses a nil timer to omit the delay.
func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) {
if timer != nil {
// We're in the fallback goroutine; sleep before connecting.
select {
case <-timer.C:
case <-cancel:
// dialSerial will immediately return errCanceled in this case.
}
}
c, err := dialSerial(ctx, ras, cancel)
results <- dialResult{Conn: c, error: err, primary: timer == nil, done: true}
}
// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) {
func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.
for i, ra := range ras {
select {
case <-cancel:
return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled}
case <-ctx.Done():
return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}
partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i)
deadline, _ := ctx.Deadline()
partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
if err != nil {
// Ran out of time.
if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err}
firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
}
break
}
// If this dial is canceled, the implementation is expected to complete
// quickly, but it's still possible that we could return a spurious Conn,
// which the caller must Close.
dialer := func(d time.Time) (Conn, error) {
return dialSingle(ctx, ra, d, cancel)
dialCtx := ctx
if partialDeadline.Before(deadline) {
var cancel context.CancelFunc
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
defer cancel()
}
c, err := dial(ctx.network, ra, dialer, partialDeadline)
c, err := dialSingle(dialCtx, dp, ra)
if err == nil {
return c, nil
}
......@@ -429,7 +466,7 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
}
if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress}
firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}
......@@ -437,26 +474,26 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
// dialSingle attempts to establish and returns a single connection to
// the destination address. This must be called through the OS-specific
// dial function, because some OSes don't implement the deadline feature.
func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) {
la := ctx.LocalAddr
func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
la := dp.LocalAddr
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
c, err = testHookDialTCP(ctx.network, la, ra, deadline, cancel)
c, err = dialTCP(ctx, dp.network, la, ra)
case *UDPAddr:
la, _ := la.(*UDPAddr)
c, err = dialUDP(ctx.network, la, ra, deadline)
c, err = dialUDP(ctx, dp.network, la, ra)
case *IPAddr:
la, _ := la.(*IPAddr)
c, err = dialIP(ctx.network, la, ra, deadline)
c, err = dialIP(ctx, dp.network, la, ra)
case *UnixAddr:
la, _ := la.(*UnixAddr)
c, err = dialUnix(ctx.network, la, ra, deadline)
c, err = dialUnix(ctx, dp.network, la, ra)
default:
return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}}
return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
}
if err != nil {
return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}
......@@ -469,7 +506,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str
// instead of just the interface with the given host address.
// See Dial for more details about address syntax.
func Listen(net, laddr string) (Listener, error) {
addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
......@@ -496,7 +533,7 @@ func Listen(net, laddr string) (Listener, error) {
// instead of just the interface with the given host address.
// See Dial for the syntax of laddr.
func ListenPacket(net, laddr string) (PacketConn, error) {
addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
......
......@@ -6,6 +6,7 @@ package net
import (
"bufio"
"context"
"internal/testenv"
"io"
"net/internal/socktest"
......@@ -193,18 +194,11 @@ const (
// In some environments, the slow IPs may be explicitly unreachable, and fail
// more quickly than expected. This test hook prevents dialTCP from returning
// before the deadline.
func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
c, err := dialTCP(net, laddr, raddr, deadline, cancel)
func slowDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
c, err := doDialTCP(ctx, net, laddr, raddr)
if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
// Wait for the deadline, or indefinitely if none exists.
var wait <-chan time.Time
if !deadline.IsZero() {
wait = time.After(deadline.Sub(time.Now()))
}
select {
case <-cancel:
case <-wait:
}
<-ctx.Done()
}
return c, err
}
......@@ -356,15 +350,14 @@ func TestDialParallel(t *testing.T) {
d := Dialer{
FallbackDelay: fallbackDelay,
}
ctx := &dialContext{
Dialer: d,
network: "tcp",
address: "?",
finalDeadline: d.deadline(time.Now()),
}
startTime := time.Now()
c, err := dialParallel(ctx, primaries, fallbacks, nil)
elapsed := time.Now().Sub(startTime)
dp := &dialParam{
Dialer: d,
network: "tcp",
address: "?",
}
c, err := dialParallel(context.Background(), dp, primaries, fallbacks)
elapsed := time.Since(startTime)
if c != nil {
c.Close()
......@@ -385,16 +378,16 @@ func TestDialParallel(t *testing.T) {
}
// Repeat each case, ensuring that it can be canceled quickly.
cancel := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
time.Sleep(5 * time.Millisecond)
close(cancel)
cancel()
wg.Done()
}()
startTime = time.Now()
c, err = dialParallel(ctx, primaries, fallbacks, cancel)
c, err = dialParallel(ctx, dp, primaries, fallbacks)
if c != nil {
c.Close()
}
......@@ -406,7 +399,7 @@ func TestDialParallel(t *testing.T) {
}
}
func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
switch host {
case "slow6loopback4":
// Returns a slow IPv6 address, and a local IPv4 address.
......@@ -415,7 +408,7 @@ func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, e
{IP: ParseIP("127.0.0.1")},
}, nil
default:
return fn(host)
return fn(ctx, host)
}
}
......@@ -530,22 +523,24 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
origTestHookDialTCP := testHookDialTCP
defer func() { testHookDialTCP = origTestHookDialTCP }()
testHookDialTCP = func(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
// Sleep long enough for Happy Eyeballs to kick in, and inhibit cancelation.
// This forces dialParallel to juggle two successful connections.
time.Sleep(fallbackDelay * 2)
cancel = nil
return dialTCP(net, laddr, raddr, deadline, cancel)
// Now ignore the provided context (which will be canceled) and use a
// different one to make sure this completes with a valid connection,
// which we hope to be closed below:
return doDialTCP(context.Background(), net, laddr, raddr)
}
d := Dialer{
FallbackDelay: fallbackDelay,
}
ctx := &dialContext{
Dialer: d,
network: "tcp",
address: "?",
finalDeadline: d.deadline(time.Now()),
dp := &dialParam{
Dialer: d,
network: "tcp",
address: "?",
}
makeAddr := func(ip string) addrList {
......@@ -557,7 +552,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
}
// dialParallel returns one connection (and closes the other.)
c, err := dialParallel(ctx, makeAddr("127.0.0.1"), makeAddr("::1"), nil)
c, err := dialParallel(context.Background(), dp, makeAddr("127.0.0.1"), makeAddr("::1"))
if err != nil {
t.Fatal(err)
}
......
......@@ -16,6 +16,7 @@
package net
import (
"context"
"errors"
"io"
"math/rand"
......@@ -399,11 +400,11 @@ func (o hostLookupOrder) String() string {
// Normally we let cgo use the C library resolver instead of
// depending on our lookup code, so that Go and C get the same
// answers.
func goLookupHost(name string) (addrs []string, err error) {
return goLookupHostOrder(name, hostLookupFilesDNS)
func goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
return goLookupHostOrder(ctx, name, hostLookupFilesDNS)
}
func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) {
func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
// Use entries from /etc/hosts if they match.
addrs = lookupStaticHost(name)
......@@ -411,7 +412,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err
return
}
}
ips, err := goLookupIPOrder(name, order)
ips, err := goLookupIPOrder(ctx, name, order)
if err != nil {
return
}
......@@ -437,11 +438,11 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
func goLookupIP(name string) (addrs []IPAddr, err error) {
return goLookupIPOrder(name, hostLookupFilesDNS)
func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) {
return goLookupIPOrder(ctx, name, hostLookupFilesDNS)
}
func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) {
func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles {
......
......@@ -7,6 +7,7 @@
package net
import (
"context"
"fmt"
"internal/testenv"
"io/ioutil"
......@@ -133,7 +134,7 @@ func TestAvoidDNSName(t *testing.T) {
// Issue 13705: don't try to resolve onion addresses, etc
func TestLookupTorOnion(t *testing.T) {
addrs, err := goLookupIP("foo.onion")
addrs, err := goLookupIP(context.Background(), "foo.onion")
if len(addrs) > 0 {
t.Errorf("unexpected addresses: %v", addrs)
}
......@@ -249,7 +250,7 @@ func TestUpdateResolvConf(t *testing.T) {
for j := 0; j < N; j++ {
go func(name string) {
defer wg.Done()
ips, err := goLookupIP(name)
ips, err := goLookupIP(context.Background(), name)
if err != nil {
t.Error(err)
return
......@@ -397,7 +398,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
t.Error(err)
continue
}
addrs, err := goLookupIP(tt.name)
addrs, err := goLookupIP(context.Background(), tt.name)
if err != nil {
// This test uses external network connectivity.
// We need to take care with errors on both
......@@ -447,14 +448,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
name := fmt.Sprintf("order %v", order)
// First ensure that we get an error when contacting a non-existent host.
_, err := goLookupIPOrder("notarealhost", order)
_, err := goLookupIPOrder(context.Background(), "notarealhost", order)
if err == nil {
t.Errorf("%s: expected error while looking up name not in hosts file", name)
continue
}
// Now check that we get an address when the name appears in the hosts file.
addrs, err := goLookupIPOrder("thor", order) // entry is in "testdata/hosts"
addrs, err := goLookupIPOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
if err != nil {
t.Errorf("%s: expected to successfully lookup host entry", name)
continue
......@@ -510,7 +511,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
return r, nil
}
_, err = goLookupIP(fqdn)
_, err = goLookupIP(context.Background(), fqdn)
if err == nil {
t.Fatal("expected an error")
}
......@@ -523,17 +524,19 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
func BenchmarkGoLookupIP(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
ctx := context.Background()
for i := 0; i < b.N; i++ {
goLookupIP("www.example.com")
goLookupIP(ctx, "www.example.com")
}
}
func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
ctx := context.Background()
for i := 0; i < b.N; i++ {
goLookupIP("some.nonexistent")
goLookupIP(ctx, "some.nonexistent")
}
}
......@@ -553,9 +556,10 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
if err := conf.writeAndUpdate(lines); err != nil {
b.Fatal(err)
}
ctx := context.Background()
for i := 0; i < b.N; i++ {
goLookupIP("www.example.com")
goLookupIP(ctx, "www.example.com")
}
}
......
......@@ -5,6 +5,7 @@
package net
import (
"context"
"fmt"
"io"
"io/ioutil"
......@@ -138,7 +139,7 @@ func TestDialError(t *testing.T) {
origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
testHookLookupIP = func(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true}
}
sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
......@@ -283,7 +284,7 @@ func TestListenError(t *testing.T) {
origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
}
sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) {
......@@ -343,7 +344,7 @@ func TestListenPacketError(t *testing.T) {
origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
}
......
......@@ -7,12 +7,12 @@
package net
import (
"context"
"io"
"os"
"runtime"
"sync/atomic"
"syscall"
"time"
)
// Network file descriptor.
......@@ -36,10 +36,6 @@ type netFD struct {
func sysInit() {
}
func dial(network string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) {
return dialer(deadline)
}
func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
}
......@@ -68,15 +64,17 @@ func (fd *netFD) name() string {
return fd.net + ":" + ls + "->" + rs
}
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
switch err := connectFunc(fd.sysfd, ra); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case nil, syscall.EISCONN:
if !deadline.IsZero() && deadline.Before(time.Now()) {
return errTimeout
select {
case <-ctx.Done():
return mapErr(ctx.Err())
default:
}
if err := fd.init(); err != nil {
return err
......@@ -98,27 +96,27 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
if err := fd.init(); err != nil {
return err
}
if !deadline.IsZero() {
if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
fd.setWriteDeadline(deadline)
defer fd.setWriteDeadline(noDeadline)
}
if cancel != nil {
done := make(chan bool)
defer func() {
// This is unbuffered; wait for the goroutine before returning.
done <- true
}()
go func() {
select {
case <-cancel:
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo)
<-done
case <-done:
}
}()
}
// Wait for the goroutine converting context.Done into a write timeout
// to exist, otherwise our caller might cancel the context and
// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
done := make(chan bool) // must be unbuffered
defer func() { done <- true }()
go func() {
select {
case <-ctx.Done():
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo)
<-done
case <-done:
}
}()
for {
// Performing multiple connect system calls on a
// non-blocking socket under Unix variants does not
......@@ -130,8 +128,8 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
// details.
if err := fd.pd.waitWrite(); err != nil {
select {
case <-cancel:
return errCanceled
case <-ctx.Done():
return mapErr(ctx.Err())
default:
}
return err
......
......@@ -5,6 +5,7 @@
package net
import (
"context"
"internal/race"
"os"
"runtime"
......@@ -320,14 +321,14 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
runtime.SetFinalizer(fd, (*netFD).Close)
}
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
if err := fd.init(); err != nil {
return err
}
if !deadline.IsZero() {
if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
fd.setWriteDeadline(deadline)
defer fd.setWriteDeadline(noDeadline)
}
......@@ -351,30 +352,30 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
// Call ConnectEx API.
o := &fd.wop
o.sa = ra
if cancel != nil {
done := make(chan bool)
defer func() {
// This is unbuffered; wait for the goroutine before returning.
done <- true
}()
go func() {
select {
case <-cancel:
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo)
<-done
case <-done:
}
}()
}
// Wait for the goroutine converting context.Done into a write timeout
// to exist, otherwise our caller might cancel the context and
// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
done := make(chan bool) // must be unbuffered
defer func() { done <- true }()
go func() {
select {
case <-ctx.Done():
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo)
<-done
case <-done:
}
}()
_, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
})
if err != nil {
select {
case <-cancel:
return errCanceled
case <-ctx.Done():
return mapErr(ctx.Err())
default:
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("connectex", err)
......
......@@ -4,9 +4,19 @@
package net
import "context"
var (
testHookDialTCP = dialTCP
testHookHostsPath = "/etc/hosts"
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { return fn(host) }
// if non-nil, overrides dialTCP.
testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
testHookHostsPath = "/etc/hosts"
testHookLookupIP = func(
ctx context.Context,
fn func(context.Context, string) ([]IPAddr, error),
host string,
) ([]IPAddr, error) {
return fn(ctx, host)
}
testHookSetKeepAlive = func() {}
)
......@@ -4,7 +4,10 @@
package net
import "syscall"
import (
"context"
"syscall"
)
// IPAddr represents the address of an IP end point.
type IPAddr struct {
......@@ -56,7 +59,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
addrs, err := internetAddrList(afnet, addr, noDeadline)
addrs, err := internetAddrList(context.Background(), afnet, addr)
if err != nil {
return nil, err
}
......@@ -171,7 +174,7 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
// netProto, which must be "ip", "ip4", or "ip6" followed by a colon
// and a protocol number or name.
func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
c, err := dialIP(netProto, laddr, raddr, noDeadline)
c, err := dialIP(context.Background(), netProto, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
......@@ -183,7 +186,7 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
// methods can be used to receive and send IP packets with per-packet
// addressing.
func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
c, err := listenIP(netProto, laddr)
c, err := listenIP(context.Background(), netProto, laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err}
}
......
......@@ -5,8 +5,8 @@
package net
import (
"context"
"syscall"
"time"
)
func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
......@@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return 0, 0, syscall.EPLAN9
}
func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) {
func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}
func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}
......@@ -7,8 +7,8 @@
package net
import (
"context"
"syscall"
"time"
)
// BUG(mikio): On every POSIX platform, reads from the "ip4" network
......@@ -120,7 +120,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return c.fd.writeMsg(b, oob, sa)
}
func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) {
func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(netProto)
if err != nil {
return nil, err
......@@ -133,14 +133,14 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn,
if raddr == nil {
return nil, errMissingAddress
}
fd, err := internetSocket(network, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel)
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
if err != nil {
return nil, err
}
return newIPConn(fd), nil
}
func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(netProto)
if err != nil {
return nil, err
......@@ -150,7 +150,7 @@ func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
default:
return nil, UnknownNetworkError(netProto)
}
fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel)
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
if err != nil {
return nil, err
}
......
......@@ -6,7 +6,9 @@
package net
import "time"
import (
"context"
)
var (
// supportsIPv4 reports whether the platform supports IPv4
......@@ -188,7 +190,7 @@ func JoinHostPort(host, port string) string {
// address or a DNS name, and returns a list of internet protocol
// family addresses. The result contains at least one address when
// error is nil.
func internetAddrList(net, addr string, deadline time.Time) (addrList, error) {
func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
var (
err error
host, port string
......@@ -236,7 +238,7 @@ func internetAddrList(net, addr string, deadline time.Time) (addrList, error) {
return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil
}
// Try as a DNS name.
ips, err := lookupIPDeadline(host, deadline)
ips, err := lookupIPContext(ctx, host)
if err != nil {
return nil, err
}
......
......@@ -7,9 +7,9 @@
package net
import (
"context"
"runtime"
"syscall"
"time"
)
// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
......@@ -152,9 +152,10 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
return syscall.AF_INET6, false
}
func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) {
// Internet sockets (TCP, UDP, IP)
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
}
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
......
......@@ -5,8 +5,8 @@
package net
import (
"context"
"internal/singleflight"
"time"
)
// protocols contains minimal mappings between internet protocol
......@@ -33,7 +33,7 @@ func LookupHost(host string) (addrs []string, err error) {
if ip := ParseIP(host); ip != nil {
return []string{host}, nil
}
return lookupHost(host)
return lookupHost(context.Background(), host)
}
// LookupIP looks up host using the local resolver.
......@@ -47,7 +47,7 @@ func LookupIP(host string) (ips []IP, err error) {
if ip := ParseIP(host); ip != nil {
return []IP{ip}, nil
}
addrs, err := lookupIPMerge(host)
addrs, err := lookupIPMerge(context.Background(), host)
if err != nil {
return
}
......@@ -63,9 +63,9 @@ var lookupGroup singleflight.Group
// lookupIPMerge wraps lookupIP, but makes sure that for any given
// host, only one lookup is in-flight at a time. The returned memory
// is always owned by the caller.
func lookupIPMerge(host string) (addrs []IPAddr, err error) {
func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) {
addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) {
return testHookLookupIP(lookupIP, host)
return testHookLookupIP(ctx, lookupIP, host)
})
return lookupIPReturn(addrsi, err, shared)
}
......@@ -85,37 +85,26 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error
return addrs, nil
}
// lookupIPDeadline looks up a hostname with a deadline.
func lookupIPDeadline(host string, deadline time.Time) (addrs []IPAddr, err error) {
if deadline.IsZero() {
return lookupIPMerge(host)
}
// We could push the deadline down into the name resolution
// functions. However, the most commonly used implementation
// calls getaddrinfo, which has no timeout.
timeout := deadline.Sub(time.Now())
if timeout <= 0 {
return nil, errTimeout
}
t := time.NewTimer(timeout)
defer t.Stop()
// lookupIPContext looks up a hostname with a context.
func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
// TODO(bradfitz): when adding trace hooks later here, make
// sure the tracing is done outside of the singleflight
// merging. Both callers should see the DNS lookup delay, even
// if it's only being done once. The r.Shared bit can be
// included in the trace for callers who need it.
ch := lookupGroup.DoChan(host, func() (interface{}, error) {
return testHookLookupIP(lookupIP, host)
return testHookLookupIP(ctx, lookupIP, host)
})
select {
case <-t.C:
case <-ctx.Done():
// The DNS lookup timed out for some reason. Force
// future requests to start the DNS lookup again
// rather than waiting for the current lookup to
// complete. See issue 8602.
lookupGroup.Forget(host)
return nil, errTimeout
return nil, mapErr(ctx.Err())
case r := <-ch:
return lookupIPReturn(r.Val, r.Err, r.Shared)
}
......
......@@ -5,6 +5,7 @@
package net
import (
"context"
"errors"
"os"
)
......@@ -115,7 +116,7 @@ func lookupProtocol(name string) (proto int, err error) {
return 0, UnknownNetworkError(name)
}
func lookupHost(host string) (addrs []string, err error) {
func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Use netdir/cs instead of netdir/dns because cs knows about
// host names in local network (e.g. from /lib/ndb/local)
lines, err := queryCS("net", host, "1")
......@@ -146,7 +147,8 @@ loop:
return
}
func lookupIP(host string) (addrs []IPAddr, err error) {
func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
// TODO(bradfitz): push down ctx
lits, err := LookupHost(host)
if err != nil {
return
......
......@@ -6,17 +6,20 @@
package net
import "syscall"
import (
"context"
"syscall"
)
func lookupProtocol(name string) (proto int, err error) {
return 0, syscall.ENOPROTOOPT
}
func lookupHost(host string) (addrs []string, err error) {
func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
return nil, syscall.ENOPROTOOPT
}
func lookupIP(host string) (addrs []IPAddr, err error) {
func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
return nil, syscall.ENOPROTOOPT
}
......
......@@ -6,6 +6,7 @@ package net
import (
"bytes"
"context"
"fmt"
"internal/testenv"
"runtime"
......@@ -14,7 +15,7 @@ import (
"time"
)
func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) {
func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
switch host {
case "localhost":
return []IPAddr{
......@@ -22,7 +23,7 @@ func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr,
{IP: IPv6loopback},
}, nil
default:
return fn(host)
return fn(ctx, host)
}
}
......@@ -375,15 +376,20 @@ func TestLookupIPDeadline(t *testing.T) {
const N = 5000
const timeout = 3 * time.Second
ctxHalfTimeout, cancel := context.WithTimeout(context.Background(), timeout/2)
defer cancel()
ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
c := make(chan error, 2*N)
for i := 0; i < N; i++ {
name := fmt.Sprintf("%d.net-test.golang.org", i)
go func() {
_, err := lookupIPDeadline(name, time.Now().Add(timeout/2))
_, err := lookupIPContext(ctxHalfTimeout, name)
c <- err
}()
go func() {
_, err := lookupIPDeadline(name, time.Now().Add(timeout))
_, err := lookupIPContext(ctxTimeout, name)
c <- err
}()
}
......
......@@ -6,7 +6,10 @@
package net
import "sync"
import (
"context"
"sync"
)
var onceReadProtocols sync.Once
......@@ -49,7 +52,7 @@ func lookupProtocol(name string) (int, error) {
return proto, nil
}
func lookupHost(host string) (addrs []string, err error) {
func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo {
if addrs, err, ok := cgoLookupHost(host); ok {
......@@ -58,19 +61,20 @@ func lookupHost(host string) (addrs []string, err error) {
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
return goLookupHostOrder(host, order)
return goLookupHostOrder(ctx, host, order)
}
func lookupIP(host string) (addrs []IPAddr, err error) {
func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo {
// TODO(bradfitz): push down ctx, or at least its deadline to start
if addrs, err, ok := cgoLookupIP(host); ok {
return addrs, err
}
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
return goLookupIPOrder(host, order)
return goLookupIPOrder(ctx, host, order)
}
func lookupPort(network, service string) (int, error) {
......
......@@ -5,6 +5,7 @@
package net
import (
"context"
"os"
"runtime"
"syscall"
......@@ -51,8 +52,8 @@ func lookupProtocol(name string) (int, error) {
return r.proto, r.err
}
func lookupHost(name string) ([]string, error) {
ips, err := LookupIP(name)
func lookupHost(ctx context.Context, name string) ([]string, error) {
ips, err := lookupIP(ctx, name)
if err != nil {
return nil, err
}
......@@ -83,59 +84,97 @@ func gethostbyname(name string) (addrs []IPAddr, err error) {
return addrs, nil
}
func oldLookupIP(name string) ([]IPAddr, error) {
func oldLookupIP(ctx context.Context, name string) ([]IPAddr, error) {
// GetHostByName return value is stored in thread local storage.
// Start new os thread before the call to prevent races.
type result struct {
type ret struct {
addrs []IPAddr
err error
}
ch := make(chan result)
ch := make(chan ret, 1)
go func() {
acquireThread()
defer releaseThread()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
addrs, err := gethostbyname(name)
ch <- result{addrs: addrs, err: err}
ch <- ret{addrs: addrs, err: err}
}()
r := <-ch
if r.err != nil {
r.err = &DNSError{Err: r.err.Error(), Name: name}
select {
case r := <-ch:
if r.err != nil {
r.err = &DNSError{Err: r.err.Error(), Name: name}
}
return r.addrs, r.err
case <-ctx.Done():
// TODO(bradfitz,brainman): cancel the ongoing
// gethostbyname? For now we just let it finish and
// write to the buffered channel.
return nil, &DNSError{
Name: name,
Err: ctx.Err().Error(),
IsTimeout: ctx.Err() == context.DeadlineExceeded,
}
}
return r.addrs, r.err
}
func newLookupIP(name string) ([]IPAddr, error) {
acquireThread()
defer releaseThread()
hints := syscall.AddrinfoW{
Family: syscall.AF_UNSPEC,
Socktype: syscall.SOCK_STREAM,
Protocol: syscall.IPPROTO_IP,
}
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
if e != nil {
return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}
func newLookupIP(ctx context.Context, name string) ([]IPAddr, error) {
// TODO(bradfitz,brainman): use ctx?
type ret struct {
addrs []IPAddr
err error
}
defer syscall.FreeAddrInfoW(result)
addrs := make([]IPAddr, 0, 5)
for ; result != nil; result = result.Next {
addr := unsafe.Pointer(result.Addr)
switch result.Family {
case syscall.AF_INET:
a := (*syscall.RawSockaddrInet4)(addr).Addr
addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
case syscall.AF_INET6:
a := (*syscall.RawSockaddrInet6)(addr).Addr
zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
default:
return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
ch := make(chan ret, 1)
go func() {
acquireThread()
defer releaseThread()
hints := syscall.AddrinfoW{
Family: syscall.AF_UNSPEC,
Socktype: syscall.SOCK_STREAM,
Protocol: syscall.IPPROTO_IP,
}
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
if e != nil {
ch <- ret{err: &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}}
}
defer syscall.FreeAddrInfoW(result)
addrs := make([]IPAddr, 0, 5)
for ; result != nil; result = result.Next {
addr := unsafe.Pointer(result.Addr)
switch result.Family {
case syscall.AF_INET:
a := (*syscall.RawSockaddrInet4)(addr).Addr
addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
case syscall.AF_INET6:
a := (*syscall.RawSockaddrInet6)(addr).Addr
zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
default:
ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}}
}
}
ch <- ret{addrs: addrs}
}()
select {
case r := <-ch:
return r.addrs, r.err
case <-ctx.Done():
// TODO(bradfitz,brainman): cancel the ongoing
// GetAddrInfoW? It would require conditionally using
// GetAddrInfoEx with lpOverlapped, which requires
// Windows 8 or newer. I guess we'll need oldLookupIP,
// newLookupIP, and newerLookUP.
//
// For now we just let it finish and write to the
// buffered channel.
return nil, &DNSError{
Name: name,
Err: ctx.Err().Error(),
IsTimeout: ctx.Err() == context.DeadlineExceeded,
}
}
return addrs, nil
}
func getservbyname(network, service string) (int, error) {
......
......@@ -79,6 +79,7 @@ On Windows, the resolver always uses C library functions, such as GetAddrInfo an
package net
import (
"context"
"errors"
"io"
"os"
......@@ -377,6 +378,22 @@ var (
ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection")
)
// mapErr maps from the context errors to the historical internal net
// error values.
//
// TODO(bradfitz): get rid of this after adjusting tests and making
// context.DeadlineExceeded implement net.Error?
func mapErr(err error) error {
switch err {
case context.Canceled:
return errCanceled
case context.DeadlineExceeded:
return errTimeout
default:
return err
}
}
// OpError is the error type usually returned by functions in the net
// package. It describes the operation, network type, and address of
// an error.
......
......@@ -7,7 +7,10 @@
package net
import "testing"
import (
"context"
"testing"
)
func TestGoLookupIP(t *testing.T) {
host := "localhost"
......@@ -18,7 +21,7 @@ func TestGoLookupIP(t *testing.T) {
if err != nil {
t.Error(err)
}
if _, err := goLookupIP(host); err != nil {
if _, err := goLookupIP(context.Background(), host); err != nil {
t.Error(err)
}
}
......@@ -7,9 +7,9 @@
package net
import (
"context"
"os"
"syscall"
"time"
)
// A sockaddr represents a TCP, UDP, IP or Unix network endpoint
......@@ -34,7 +34,7 @@ type sockaddr interface {
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) {
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
......@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
return fd, nil
}
}
if err := fd.dial(laddr, raddr, deadline, cancel); err != nil {
if err := fd.dial(ctx, laddr, raddr); err != nil {
fd.Close()
return nil, err
}
......@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
return func(syscall.Sockaddr) Addr { return nil }
}
func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error {
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
var err error
var lsa syscall.Sockaddr
if laddr != nil {
......@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan s
if rsa, err = raddr.sockaddr(fd.family); err != nil {
return err
}
if err := fd.connect(lsa, rsa, deadline, cancel); err != nil {
if err := fd.connect(ctx, lsa, rsa); err != nil {
return err
}
fd.isConnected = true
......
......@@ -5,6 +5,7 @@
package net
import (
"context"
"io"
"os"
"syscall"
......@@ -60,7 +61,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
addrs, err := internetAddrList(net, addr, noDeadline)
addrs, err := internetAddrList(context.Background(), net, addr)
if err != nil {
return nil, err
}
......@@ -186,7 +187,7 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
c, err := dialTCP(net, laddr, raddr, noDeadline, noCancel)
c, err := dialTCP(context.Background(), net, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
......@@ -285,7 +286,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
if laddr == nil {
laddr = &TCPAddr{}
}
ln, err := listenTCP(net, laddr)
ln, err := listenTCP(context.Background(), net, laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
}
......
......@@ -5,17 +5,24 @@
package net
import (
"context"
"io"
"os"
"time"
)
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
if !deadline.IsZero() {
func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil {
return testHookDialTCP(ctx, net, laddr, raddr)
}
return doDialTCP(ctx, net, laddr, raddr)
}
func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if d, _ := ctx.Deadline(); !d.IsZero() {
panic("net.dialTCP: deadline not implemented on Plan 9")
}
// TODO(bradfitz,0intro): also use the cancel channel.
......@@ -63,7 +70,7 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil
}
func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
fd, err := listenPlan9(network, laddr)
if err != nil {
return nil, err
......
......@@ -7,10 +7,10 @@
package net
import (
"context"
"io"
"os"
"syscall"
"time"
)
func sockaddrToTCP(sa syscall.Sockaddr) Addr {
......@@ -47,8 +47,15 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil {
return testHookDialTCP(ctx, net, laddr, raddr)
}
return doDialTCP(ctx, net, laddr, raddr)
}
func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
......@@ -78,7 +85,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-cha
if err == nil {
fd.Close()
}
fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
fd, err = internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
}
if err != nil {
......@@ -141,8 +148,8 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil
}
func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel)
func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_STREAM, 0, "listen")
if err != nil {
return nil, err
}
......
......@@ -4,7 +4,10 @@
package net
import "syscall"
import (
"context"
"syscall"
)
// UDPAddr represents the address of a UDP end point.
type UDPAddr struct {
......@@ -55,7 +58,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
addrs, err := internetAddrList(net, addr, noDeadline)
addrs, err := internetAddrList(context.Background(), net, addr)
if err != nil {
return nil, err
}
......@@ -181,7 +184,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
c, err := dialUDP(net, laddr, raddr, noDeadline)
c, err := dialUDP(context.Background(), net, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
......@@ -204,7 +207,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
if laddr == nil {
laddr = &UDPAddr{}
}
c, err := listenUDP(net, laddr)
c, err := listenUDP(context.Background(), net, laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
}
......@@ -231,7 +234,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
if gaddr == nil || gaddr.IP == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
}
c, err := listenMulticastUDP(network, ifi, gaddr)
c, err := listenMulticastUDP(context.Background(), network, ifi, gaddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err}
}
......
......@@ -5,10 +5,10 @@
package net
import (
"context"
"errors"
"os"
"syscall"
"time"
)
func (c *UDPConn) readFrom(b []byte) (n int, addr *UDPAddr, err error) {
......@@ -55,8 +55,8 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return 0, 0, syscall.EPLAN9
}
func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
if !deadline.IsZero() {
func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
panic("net.dialUDP: deadline not implemented on Plan 9")
}
fd, err := dialPlan9(net, laddr, raddr)
......@@ -94,7 +94,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
return h, b
}
func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
l, err := listenPlan9(network, laddr)
if err != nil {
return nil, err
......@@ -111,6 +111,6 @@ func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
return newUDPConn(fd), err
}
func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
return nil, syscall.EPLAN9
}
......@@ -7,8 +7,8 @@
package net
import (
"context"
"syscall"
"time"
)
func sockaddrToUDP(sa syscall.Sockaddr) Addr {
......@@ -90,24 +90,24 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return c.fd.writeMsg(b, oob, sa)
}
func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel)
func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial")
if err != nil {
return nil, err
}
return newUDPConn(fd), nil
}
func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen")
if err != nil {
return nil, err
}
return newUDPConn(fd), nil
}
func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen")
if err != nil {
return nil, err
}
......
......@@ -5,6 +5,7 @@
package net
import (
"context"
"os"
"syscall"
"time"
......@@ -188,7 +189,7 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
default:
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)}
}
c, err := dialUnix(net, laddr, raddr, noDeadline)
c, err := dialUnix(context.Background(), net, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
......@@ -290,7 +291,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
if laddr == nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress}
}
ln, err := listenUnix(net, laddr)
ln, err := listenUnix(context.Background(), net, laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
}
......@@ -310,7 +311,7 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) {
if laddr == nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: errMissingAddress}
}
c, err := listenUnixgram(net, laddr)
c, err := listenUnixgram(context.Background(), net, laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
}
......
......@@ -5,9 +5,9 @@
package net
import (
"context"
"os"
"syscall"
"time"
)
func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
......@@ -26,7 +26,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
return 0, 0, syscall.EPLAN9
}
func dialUnix(network string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) {
func dialUnix(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
return nil, syscall.EPLAN9
}
......@@ -42,10 +42,10 @@ func (ln *UnixListener) file() (*os.File, error) {
return nil, syscall.EPLAN9
}
func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
return nil, syscall.EPLAN9
}
func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
return nil, syscall.EPLAN9
}
......@@ -7,13 +7,13 @@
package net
import (
"context"
"errors"
"os"
"syscall"
"time"
)
func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Time) (*netFD, error) {
func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) {
var sotype int
switch net {
case "unix":
......@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti
return nil, errors.New("unknown mode: " + mode)
}
fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel)
fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr)
if err != nil {
return nil, err
}
......@@ -146,8 +146,8 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
return c.fd.writeMsg(b, oob, sa)
}
func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) {
fd, err := unixSocket(net, laddr, raddr, "dial", deadline)
func dialUnix(ctx context.Context, net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(ctx, net, laddr, raddr, "dial")
if err != nil {
return nil, err
}
......@@ -187,16 +187,16 @@ func (ln *UnixListener) file() (*os.File, error) {
return f, nil
}
func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
fd, err := unixSocket(network, laddr, nil, "listen", noDeadline)
func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
fd, err := unixSocket(ctx, network, laddr, nil, "listen")
if err != nil {
return nil, err
}
return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
}
func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(network, laddr, nil, "listen", noDeadline)
func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(ctx, network, laddr, nil, "listen")
if err != nil {
return nil, err
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment