Commit b6b4004d authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: context plumbing, add Dialer.DialContext

For #12580 (http.Transport tracing/analytics)
Updates #13021

Change-Id: I126e494a7bd872e42c388ecb58499ecbf0f014cc
Reviewed-on: https://go-review.googlesource.com/22101
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
Reviewed-by: default avatarMikio Hara <mikioh.mikioh@gmail.com>
parent 1d0977a1
...@@ -280,7 +280,9 @@ var pkgDeps = map[string][]string{ ...@@ -280,7 +280,9 @@ var pkgDeps = map[string][]string{
// Basic networking. // Basic networking.
// Because net must be used by any package that wants to // Because net must be used by any package that wants to
// do networking portably, it must have a small dependency set: just L0+basic os. // do networking portably, it must have a small dependency set: just L0+basic os.
"net": {"L0", "CGO", "math/rand", "os", "sort", "syscall", "time", "internal/syscall/windows", "internal/singleflight", "internal/race"}, "net": {"L0", "CGO",
"context", "math/rand", "os", "sort", "syscall", "time",
"internal/syscall/windows", "internal/singleflight", "internal/race"},
// NET enables use of basic network-related packages. // NET enables use of basic network-related packages.
"NET": { "NET": {
......
...@@ -7,7 +7,10 @@ ...@@ -7,7 +7,10 @@
package net package net
import "testing" import (
"context"
"testing"
)
func TestCgoLookupIP(t *testing.T) { func TestCgoLookupIP(t *testing.T) {
host := "localhost" host := "localhost"
...@@ -18,7 +21,7 @@ func TestCgoLookupIP(t *testing.T) { ...@@ -18,7 +21,7 @@ func TestCgoLookupIP(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if _, err := goLookupIP(host); err != nil { if _, err := goLookupIP(context.Background(), host); err != nil {
t.Error(err) t.Error(err)
} }
} }
This diff is collapsed.
...@@ -6,6 +6,7 @@ package net ...@@ -6,6 +6,7 @@ package net
import ( import (
"bufio" "bufio"
"context"
"internal/testenv" "internal/testenv"
"io" "io"
"net/internal/socktest" "net/internal/socktest"
...@@ -193,18 +194,11 @@ const ( ...@@ -193,18 +194,11 @@ const (
// In some environments, the slow IPs may be explicitly unreachable, and fail // In some environments, the slow IPs may be explicitly unreachable, and fail
// more quickly than expected. This test hook prevents dialTCP from returning // more quickly than expected. This test hook prevents dialTCP from returning
// before the deadline. // before the deadline.
func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { func slowDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
c, err := dialTCP(net, laddr, raddr, deadline, cancel) c, err := doDialTCP(ctx, net, laddr, raddr)
if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) { if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
// Wait for the deadline, or indefinitely if none exists. // Wait for the deadline, or indefinitely if none exists.
var wait <-chan time.Time <-ctx.Done()
if !deadline.IsZero() {
wait = time.After(deadline.Sub(time.Now()))
}
select {
case <-cancel:
case <-wait:
}
} }
return c, err return c, err
} }
...@@ -356,15 +350,14 @@ func TestDialParallel(t *testing.T) { ...@@ -356,15 +350,14 @@ func TestDialParallel(t *testing.T) {
d := Dialer{ d := Dialer{
FallbackDelay: fallbackDelay, FallbackDelay: fallbackDelay,
} }
ctx := &dialContext{
Dialer: d,
network: "tcp",
address: "?",
finalDeadline: d.deadline(time.Now()),
}
startTime := time.Now() startTime := time.Now()
c, err := dialParallel(ctx, primaries, fallbacks, nil) dp := &dialParam{
elapsed := time.Now().Sub(startTime) Dialer: d,
network: "tcp",
address: "?",
}
c, err := dialParallel(context.Background(), dp, primaries, fallbacks)
elapsed := time.Since(startTime)
if c != nil { if c != nil {
c.Close() c.Close()
...@@ -385,16 +378,16 @@ func TestDialParallel(t *testing.T) { ...@@ -385,16 +378,16 @@ func TestDialParallel(t *testing.T) {
} }
// Repeat each case, ensuring that it can be canceled quickly. // Repeat each case, ensuring that it can be canceled quickly.
cancel := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
close(cancel) cancel()
wg.Done() wg.Done()
}() }()
startTime = time.Now() startTime = time.Now()
c, err = dialParallel(ctx, primaries, fallbacks, cancel) c, err = dialParallel(ctx, dp, primaries, fallbacks)
if c != nil { if c != nil {
c.Close() c.Close()
} }
...@@ -406,7 +399,7 @@ func TestDialParallel(t *testing.T) { ...@@ -406,7 +399,7 @@ func TestDialParallel(t *testing.T) {
} }
} }
func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
switch host { switch host {
case "slow6loopback4": case "slow6loopback4":
// Returns a slow IPv6 address, and a local IPv4 address. // Returns a slow IPv6 address, and a local IPv4 address.
...@@ -415,7 +408,7 @@ func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, e ...@@ -415,7 +408,7 @@ func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, e
{IP: ParseIP("127.0.0.1")}, {IP: ParseIP("127.0.0.1")},
}, nil }, nil
default: default:
return fn(host) return fn(ctx, host)
} }
} }
...@@ -530,22 +523,24 @@ func TestDialParallelSpuriousConnection(t *testing.T) { ...@@ -530,22 +523,24 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
origTestHookDialTCP := testHookDialTCP origTestHookDialTCP := testHookDialTCP
defer func() { testHookDialTCP = origTestHookDialTCP }() defer func() { testHookDialTCP = origTestHookDialTCP }()
testHookDialTCP = func(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
// Sleep long enough for Happy Eyeballs to kick in, and inhibit cancelation. // Sleep long enough for Happy Eyeballs to kick in, and inhibit cancelation.
// This forces dialParallel to juggle two successful connections. // This forces dialParallel to juggle two successful connections.
time.Sleep(fallbackDelay * 2) time.Sleep(fallbackDelay * 2)
cancel = nil
return dialTCP(net, laddr, raddr, deadline, cancel) // Now ignore the provided context (which will be canceled) and use a
// different one to make sure this completes with a valid connection,
// which we hope to be closed below:
return doDialTCP(context.Background(), net, laddr, raddr)
} }
d := Dialer{ d := Dialer{
FallbackDelay: fallbackDelay, FallbackDelay: fallbackDelay,
} }
ctx := &dialContext{ dp := &dialParam{
Dialer: d, Dialer: d,
network: "tcp", network: "tcp",
address: "?", address: "?",
finalDeadline: d.deadline(time.Now()),
} }
makeAddr := func(ip string) addrList { makeAddr := func(ip string) addrList {
...@@ -557,7 +552,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) { ...@@ -557,7 +552,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
} }
// dialParallel returns one connection (and closes the other.) // dialParallel returns one connection (and closes the other.)
c, err := dialParallel(ctx, makeAddr("127.0.0.1"), makeAddr("::1"), nil) c, err := dialParallel(context.Background(), dp, makeAddr("127.0.0.1"), makeAddr("::1"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
package net package net
import ( import (
"context"
"errors" "errors"
"io" "io"
"math/rand" "math/rand"
...@@ -399,11 +400,11 @@ func (o hostLookupOrder) String() string { ...@@ -399,11 +400,11 @@ func (o hostLookupOrder) String() string {
// Normally we let cgo use the C library resolver instead of // Normally we let cgo use the C library resolver instead of
// depending on our lookup code, so that Go and C get the same // depending on our lookup code, so that Go and C get the same
// answers. // answers.
func goLookupHost(name string) (addrs []string, err error) { func goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
return goLookupHostOrder(name, hostLookupFilesDNS) return goLookupHostOrder(ctx, name, hostLookupFilesDNS)
} }
func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) { func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles { if order == hostLookupFilesDNS || order == hostLookupFiles {
// Use entries from /etc/hosts if they match. // Use entries from /etc/hosts if they match.
addrs = lookupStaticHost(name) addrs = lookupStaticHost(name)
...@@ -411,7 +412,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err ...@@ -411,7 +412,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err
return return
} }
} }
ips, err := goLookupIPOrder(name, order) ips, err := goLookupIPOrder(ctx, name, order)
if err != nil { if err != nil {
return return
} }
...@@ -437,11 +438,11 @@ func goLookupIPFiles(name string) (addrs []IPAddr) { ...@@ -437,11 +438,11 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
// goLookupIP is the native Go implementation of LookupIP. // goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go. // The libc versions are in cgo_*.go.
func goLookupIP(name string) (addrs []IPAddr, err error) { func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) {
return goLookupIPOrder(name, hostLookupFilesDNS) return goLookupIPOrder(ctx, name, hostLookupFilesDNS)
} }
func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) { func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles { if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name) addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles { if len(addrs) > 0 || order == hostLookupFiles {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
package net package net
import ( import (
"context"
"fmt" "fmt"
"internal/testenv" "internal/testenv"
"io/ioutil" "io/ioutil"
...@@ -133,7 +134,7 @@ func TestAvoidDNSName(t *testing.T) { ...@@ -133,7 +134,7 @@ func TestAvoidDNSName(t *testing.T) {
// Issue 13705: don't try to resolve onion addresses, etc // Issue 13705: don't try to resolve onion addresses, etc
func TestLookupTorOnion(t *testing.T) { func TestLookupTorOnion(t *testing.T) {
addrs, err := goLookupIP("foo.onion") addrs, err := goLookupIP(context.Background(), "foo.onion")
if len(addrs) > 0 { if len(addrs) > 0 {
t.Errorf("unexpected addresses: %v", addrs) t.Errorf("unexpected addresses: %v", addrs)
} }
...@@ -249,7 +250,7 @@ func TestUpdateResolvConf(t *testing.T) { ...@@ -249,7 +250,7 @@ func TestUpdateResolvConf(t *testing.T) {
for j := 0; j < N; j++ { for j := 0; j < N; j++ {
go func(name string) { go func(name string) {
defer wg.Done() defer wg.Done()
ips, err := goLookupIP(name) ips, err := goLookupIP(context.Background(), name)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
...@@ -397,7 +398,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { ...@@ -397,7 +398,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
t.Error(err) t.Error(err)
continue continue
} }
addrs, err := goLookupIP(tt.name) addrs, err := goLookupIP(context.Background(), tt.name)
if err != nil { if err != nil {
// This test uses external network connectivity. // This test uses external network connectivity.
// We need to take care with errors on both // We need to take care with errors on both
...@@ -447,14 +448,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { ...@@ -447,14 +448,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
name := fmt.Sprintf("order %v", order) name := fmt.Sprintf("order %v", order)
// First ensure that we get an error when contacting a non-existent host. // First ensure that we get an error when contacting a non-existent host.
_, err := goLookupIPOrder("notarealhost", order) _, err := goLookupIPOrder(context.Background(), "notarealhost", order)
if err == nil { if err == nil {
t.Errorf("%s: expected error while looking up name not in hosts file", name) t.Errorf("%s: expected error while looking up name not in hosts file", name)
continue continue
} }
// Now check that we get an address when the name appears in the hosts file. // Now check that we get an address when the name appears in the hosts file.
addrs, err := goLookupIPOrder("thor", order) // entry is in "testdata/hosts" addrs, err := goLookupIPOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
if err != nil { if err != nil {
t.Errorf("%s: expected to successfully lookup host entry", name) t.Errorf("%s: expected to successfully lookup host entry", name)
continue continue
...@@ -510,7 +511,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -510,7 +511,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
return r, nil return r, nil
} }
_, err = goLookupIP(fqdn) _, err = goLookupIP(context.Background(), fqdn)
if err == nil { if err == nil {
t.Fatal("expected an error") t.Fatal("expected an error")
} }
...@@ -523,17 +524,19 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -523,17 +524,19 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
func BenchmarkGoLookupIP(b *testing.B) { func BenchmarkGoLookupIP(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks) testHookUninstaller.Do(uninstallTestHooks)
ctx := context.Background()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
goLookupIP("www.example.com") goLookupIP(ctx, "www.example.com")
} }
} }
func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks) testHookUninstaller.Do(uninstallTestHooks)
ctx := context.Background()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
goLookupIP("some.nonexistent") goLookupIP(ctx, "some.nonexistent")
} }
} }
...@@ -553,9 +556,10 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { ...@@ -553,9 +556,10 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
if err := conf.writeAndUpdate(lines); err != nil { if err := conf.writeAndUpdate(lines); err != nil {
b.Fatal(err) b.Fatal(err)
} }
ctx := context.Background()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
goLookupIP("www.example.com") goLookupIP(ctx, "www.example.com")
} }
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package net package net
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -138,7 +139,7 @@ func TestDialError(t *testing.T) { ...@@ -138,7 +139,7 @@ func TestDialError(t *testing.T) {
origTestHookLookupIP := testHookLookupIP origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }() defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { testHookLookupIP = func(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true} return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true}
} }
sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) { sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
...@@ -283,7 +284,7 @@ func TestListenError(t *testing.T) { ...@@ -283,7 +284,7 @@ func TestListenError(t *testing.T) {
origTestHookLookupIP := testHookLookupIP origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }() defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true} return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
} }
sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) { sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) {
...@@ -343,7 +344,7 @@ func TestListenPacketError(t *testing.T) { ...@@ -343,7 +344,7 @@ func TestListenPacketError(t *testing.T) {
origTestHookLookupIP := testHookLookupIP origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }() defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true} return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
} }
......
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
package net package net
import ( import (
"context"
"io" "io"
"os" "os"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time"
) )
// Network file descriptor. // Network file descriptor.
...@@ -36,10 +36,6 @@ type netFD struct { ...@@ -36,10 +36,6 @@ type netFD struct {
func sysInit() { func sysInit() {
} }
func dial(network string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) {
return dialer(deadline)
}
func newFD(sysfd, family, sotype int, net string) (*netFD, error) { func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
} }
...@@ -68,15 +64,17 @@ func (fd *netFD) name() string { ...@@ -68,15 +64,17 @@ func (fd *netFD) name() string {
return fd.net + ":" + ls + "->" + rs return fd.net + ":" + ls + "->" + rs
} }
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error { func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
// Do not need to call fd.writeLock here, // Do not need to call fd.writeLock here,
// because fd is not yet accessible to user, // because fd is not yet accessible to user,
// so no concurrent operations are possible. // so no concurrent operations are possible.
switch err := connectFunc(fd.sysfd, ra); err { switch err := connectFunc(fd.sysfd, ra); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case nil, syscall.EISCONN: case nil, syscall.EISCONN:
if !deadline.IsZero() && deadline.Before(time.Now()) { select {
return errTimeout case <-ctx.Done():
return mapErr(ctx.Err())
default:
} }
if err := fd.init(); err != nil { if err := fd.init(); err != nil {
return err return err
...@@ -98,27 +96,27 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c ...@@ -98,27 +96,27 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
if err := fd.init(); err != nil { if err := fd.init(); err != nil {
return err return err
} }
if !deadline.IsZero() { if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
fd.setWriteDeadline(deadline) fd.setWriteDeadline(deadline)
defer fd.setWriteDeadline(noDeadline) defer fd.setWriteDeadline(noDeadline)
} }
if cancel != nil {
done := make(chan bool) // Wait for the goroutine converting context.Done into a write timeout
defer func() { // to exist, otherwise our caller might cancel the context and
// This is unbuffered; wait for the goroutine before returning. // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
done <- true done := make(chan bool) // must be unbuffered
}() defer func() { done <- true }()
go func() { go func() {
select { select {
case <-cancel: case <-ctx.Done():
// Force the runtime's poller to immediately give // Force the runtime's poller to immediately give
// up waiting for writability. // up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo) fd.setWriteDeadline(aLongTimeAgo)
<-done <-done
case <-done: case <-done:
} }
}() }()
}
for { for {
// Performing multiple connect system calls on a // Performing multiple connect system calls on a
// non-blocking socket under Unix variants does not // non-blocking socket under Unix variants does not
...@@ -130,8 +128,8 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c ...@@ -130,8 +128,8 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
// details. // details.
if err := fd.pd.waitWrite(); err != nil { if err := fd.pd.waitWrite(); err != nil {
select { select {
case <-cancel: case <-ctx.Done():
return errCanceled return mapErr(ctx.Err())
default: default:
} }
return err return err
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package net package net
import ( import (
"context"
"internal/race" "internal/race"
"os" "os"
"runtime" "runtime"
...@@ -320,14 +321,14 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { ...@@ -320,14 +321,14 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
runtime.SetFinalizer(fd, (*netFD).Close) runtime.SetFinalizer(fd, (*netFD).Close)
} }
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error { func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
// Do not need to call fd.writeLock here, // Do not need to call fd.writeLock here,
// because fd is not yet accessible to user, // because fd is not yet accessible to user,
// so no concurrent operations are possible. // so no concurrent operations are possible.
if err := fd.init(); err != nil { if err := fd.init(); err != nil {
return err return err
} }
if !deadline.IsZero() { if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
fd.setWriteDeadline(deadline) fd.setWriteDeadline(deadline)
defer fd.setWriteDeadline(noDeadline) defer fd.setWriteDeadline(noDeadline)
} }
...@@ -351,30 +352,30 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c ...@@ -351,30 +352,30 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-c
// Call ConnectEx API. // Call ConnectEx API.
o := &fd.wop o := &fd.wop
o.sa = ra o.sa = ra
if cancel != nil {
done := make(chan bool) // Wait for the goroutine converting context.Done into a write timeout
defer func() { // to exist, otherwise our caller might cancel the context and
// This is unbuffered; wait for the goroutine before returning. // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
done <- true done := make(chan bool) // must be unbuffered
}() defer func() { done <- true }()
go func() { go func() {
select { select {
case <-cancel: case <-ctx.Done():
// Force the runtime's poller to immediately give // Force the runtime's poller to immediately give
// up waiting for writability. // up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo) fd.setWriteDeadline(aLongTimeAgo)
<-done <-done
case <-done: case <-done:
} }
}() }()
}
_, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error { _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o) return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
}) })
if err != nil { if err != nil {
select { select {
case <-cancel: case <-ctx.Done():
return errCanceled return mapErr(ctx.Err())
default: default:
if _, ok := err.(syscall.Errno); ok { if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("connectex", err) err = os.NewSyscallError("connectex", err)
......
...@@ -4,9 +4,19 @@ ...@@ -4,9 +4,19 @@
package net package net
import "context"
var ( var (
testHookDialTCP = dialTCP // if non-nil, overrides dialTCP.
testHookHostsPath = "/etc/hosts" testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
testHookLookupIP = func(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { return fn(host) }
testHookHostsPath = "/etc/hosts"
testHookLookupIP = func(
ctx context.Context,
fn func(context.Context, string) ([]IPAddr, error),
host string,
) ([]IPAddr, error) {
return fn(ctx, host)
}
testHookSetKeepAlive = func() {} testHookSetKeepAlive = func() {}
) )
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
package net package net
import "syscall" import (
"context"
"syscall"
)
// IPAddr represents the address of an IP end point. // IPAddr represents the address of an IP end point.
type IPAddr struct { type IPAddr struct {
...@@ -56,7 +59,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { ...@@ -56,7 +59,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
default: default:
return nil, UnknownNetworkError(net) return nil, UnknownNetworkError(net)
} }
addrs, err := internetAddrList(afnet, addr, noDeadline) addrs, err := internetAddrList(context.Background(), afnet, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -171,7 +174,7 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} } ...@@ -171,7 +174,7 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
// netProto, which must be "ip", "ip4", or "ip6" followed by a colon // netProto, which must be "ip", "ip4", or "ip6" followed by a colon
// and a protocol number or name. // and a protocol number or name.
func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
c, err := dialIP(netProto, laddr, raddr, noDeadline) c, err := dialIP(context.Background(), netProto, laddr, raddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
} }
...@@ -183,7 +186,7 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { ...@@ -183,7 +186,7 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
// methods can be used to receive and send IP packets with per-packet // methods can be used to receive and send IP packets with per-packet
// addressing. // addressing.
func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
c, err := listenIP(netProto, laddr) c, err := listenIP(context.Background(), netProto, laddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err} return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err}
} }
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
package net package net
import ( import (
"context"
"syscall" "syscall"
"time"
) )
func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) { func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
...@@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) ...@@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return 0, 0, syscall.EPLAN9 return 0, 0, syscall.EPLAN9
} }
func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) { func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
package net package net
import ( import (
"context"
"syscall" "syscall"
"time"
) )
// BUG(mikio): On every POSIX platform, reads from the "ip4" network // BUG(mikio): On every POSIX platform, reads from the "ip4" network
...@@ -120,7 +120,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) ...@@ -120,7 +120,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return c.fd.writeMsg(b, oob, sa) return c.fd.writeMsg(b, oob, sa)
} }
func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(netProto) network, proto, err := parseNetwork(netProto)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -133,14 +133,14 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, ...@@ -133,14 +133,14 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn,
if raddr == nil { if raddr == nil {
return nil, errMissingAddress return nil, errMissingAddress
} }
fd, err := internetSocket(network, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel) fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newIPConn(fd), nil return newIPConn(fd), nil
} }
func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) { func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(netProto) network, proto, err := parseNetwork(netProto)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -150,7 +150,7 @@ func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) { ...@@ -150,7 +150,7 @@ func listenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
default: default:
return nil, UnknownNetworkError(netProto) return nil, UnknownNetworkError(netProto)
} }
fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel) fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
package net package net
import "time" import (
"context"
)
var ( var (
// supportsIPv4 reports whether the platform supports IPv4 // supportsIPv4 reports whether the platform supports IPv4
...@@ -188,7 +190,7 @@ func JoinHostPort(host, port string) string { ...@@ -188,7 +190,7 @@ func JoinHostPort(host, port string) string {
// address or a DNS name, and returns a list of internet protocol // address or a DNS name, and returns a list of internet protocol
// family addresses. The result contains at least one address when // family addresses. The result contains at least one address when
// error is nil. // error is nil.
func internetAddrList(net, addr string, deadline time.Time) (addrList, error) { func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
var ( var (
err error err error
host, port string host, port string
...@@ -236,7 +238,7 @@ func internetAddrList(net, addr string, deadline time.Time) (addrList, error) { ...@@ -236,7 +238,7 @@ func internetAddrList(net, addr string, deadline time.Time) (addrList, error) {
return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil
} }
// Try as a DNS name. // Try as a DNS name.
ips, err := lookupIPDeadline(host, deadline) ips, err := lookupIPContext(ctx, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
package net package net
import ( import (
"context"
"runtime" "runtime"
"syscall" "syscall"
"time"
) )
// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the // BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
...@@ -152,9 +152,10 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family ...@@ -152,9 +152,10 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
return syscall.AF_INET6, false return syscall.AF_INET6, false
} }
func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) { // Internet sockets (TCP, UDP, IP)
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel) return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
} }
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) { func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
package net package net
import ( import (
"context"
"internal/singleflight" "internal/singleflight"
"time"
) )
// protocols contains minimal mappings between internet protocol // protocols contains minimal mappings between internet protocol
...@@ -33,7 +33,7 @@ func LookupHost(host string) (addrs []string, err error) { ...@@ -33,7 +33,7 @@ func LookupHost(host string) (addrs []string, err error) {
if ip := ParseIP(host); ip != nil { if ip := ParseIP(host); ip != nil {
return []string{host}, nil return []string{host}, nil
} }
return lookupHost(host) return lookupHost(context.Background(), host)
} }
// LookupIP looks up host using the local resolver. // LookupIP looks up host using the local resolver.
...@@ -47,7 +47,7 @@ func LookupIP(host string) (ips []IP, err error) { ...@@ -47,7 +47,7 @@ func LookupIP(host string) (ips []IP, err error) {
if ip := ParseIP(host); ip != nil { if ip := ParseIP(host); ip != nil {
return []IP{ip}, nil return []IP{ip}, nil
} }
addrs, err := lookupIPMerge(host) addrs, err := lookupIPMerge(context.Background(), host)
if err != nil { if err != nil {
return return
} }
...@@ -63,9 +63,9 @@ var lookupGroup singleflight.Group ...@@ -63,9 +63,9 @@ var lookupGroup singleflight.Group
// lookupIPMerge wraps lookupIP, but makes sure that for any given // lookupIPMerge wraps lookupIP, but makes sure that for any given
// host, only one lookup is in-flight at a time. The returned memory // host, only one lookup is in-flight at a time. The returned memory
// is always owned by the caller. // is always owned by the caller.
func lookupIPMerge(host string) (addrs []IPAddr, err error) { func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) {
addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) { addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) {
return testHookLookupIP(lookupIP, host) return testHookLookupIP(ctx, lookupIP, host)
}) })
return lookupIPReturn(addrsi, err, shared) return lookupIPReturn(addrsi, err, shared)
} }
...@@ -85,37 +85,26 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error ...@@ -85,37 +85,26 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error
return addrs, nil return addrs, nil
} }
// lookupIPDeadline looks up a hostname with a deadline. // lookupIPContext looks up a hostname with a context.
func lookupIPDeadline(host string, deadline time.Time) (addrs []IPAddr, err error) { func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
if deadline.IsZero() { // TODO(bradfitz): when adding trace hooks later here, make
return lookupIPMerge(host) // sure the tracing is done outside of the singleflight
} // merging. Both callers should see the DNS lookup delay, even
// if it's only being done once. The r.Shared bit can be
// We could push the deadline down into the name resolution // included in the trace for callers who need it.
// functions. However, the most commonly used implementation
// calls getaddrinfo, which has no timeout.
timeout := deadline.Sub(time.Now())
if timeout <= 0 {
return nil, errTimeout
}
t := time.NewTimer(timeout)
defer t.Stop()
ch := lookupGroup.DoChan(host, func() (interface{}, error) { ch := lookupGroup.DoChan(host, func() (interface{}, error) {
return testHookLookupIP(lookupIP, host) return testHookLookupIP(ctx, lookupIP, host)
}) })
select { select {
case <-t.C: case <-ctx.Done():
// The DNS lookup timed out for some reason. Force // The DNS lookup timed out for some reason. Force
// future requests to start the DNS lookup again // future requests to start the DNS lookup again
// rather than waiting for the current lookup to // rather than waiting for the current lookup to
// complete. See issue 8602. // complete. See issue 8602.
lookupGroup.Forget(host) lookupGroup.Forget(host)
return nil, mapErr(ctx.Err())
return nil, errTimeout
case r := <-ch: case r := <-ch:
return lookupIPReturn(r.Val, r.Err, r.Shared) return lookupIPReturn(r.Val, r.Err, r.Shared)
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package net package net
import ( import (
"context"
"errors" "errors"
"os" "os"
) )
...@@ -115,7 +116,7 @@ func lookupProtocol(name string) (proto int, err error) { ...@@ -115,7 +116,7 @@ func lookupProtocol(name string) (proto int, err error) {
return 0, UnknownNetworkError(name) return 0, UnknownNetworkError(name)
} }
func lookupHost(host string) (addrs []string, err error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Use netdir/cs instead of netdir/dns because cs knows about // Use netdir/cs instead of netdir/dns because cs knows about
// host names in local network (e.g. from /lib/ndb/local) // host names in local network (e.g. from /lib/ndb/local)
lines, err := queryCS("net", host, "1") lines, err := queryCS("net", host, "1")
...@@ -146,7 +147,8 @@ loop: ...@@ -146,7 +147,8 @@ loop:
return return
} }
func lookupIP(host string) (addrs []IPAddr, err error) { func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
// TODO(bradfitz): push down ctx
lits, err := LookupHost(host) lits, err := LookupHost(host)
if err != nil { if err != nil {
return return
......
...@@ -6,17 +6,20 @@ ...@@ -6,17 +6,20 @@
package net package net
import "syscall" import (
"context"
"syscall"
)
func lookupProtocol(name string) (proto int, err error) { func lookupProtocol(name string) (proto int, err error) {
return 0, syscall.ENOPROTOOPT return 0, syscall.ENOPROTOOPT
} }
func lookupHost(host string) (addrs []string, err error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
return nil, syscall.ENOPROTOOPT return nil, syscall.ENOPROTOOPT
} }
func lookupIP(host string) (addrs []IPAddr, err error) { func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
return nil, syscall.ENOPROTOOPT return nil, syscall.ENOPROTOOPT
} }
......
...@@ -6,6 +6,7 @@ package net ...@@ -6,6 +6,7 @@ package net
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"internal/testenv" "internal/testenv"
"runtime" "runtime"
...@@ -14,7 +15,7 @@ import ( ...@@ -14,7 +15,7 @@ import (
"time" "time"
) )
func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
switch host { switch host {
case "localhost": case "localhost":
return []IPAddr{ return []IPAddr{
...@@ -22,7 +23,7 @@ func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, ...@@ -22,7 +23,7 @@ func lookupLocalhost(fn func(string) ([]IPAddr, error), host string) ([]IPAddr,
{IP: IPv6loopback}, {IP: IPv6loopback},
}, nil }, nil
default: default:
return fn(host) return fn(ctx, host)
} }
} }
...@@ -375,15 +376,20 @@ func TestLookupIPDeadline(t *testing.T) { ...@@ -375,15 +376,20 @@ func TestLookupIPDeadline(t *testing.T) {
const N = 5000 const N = 5000
const timeout = 3 * time.Second const timeout = 3 * time.Second
ctxHalfTimeout, cancel := context.WithTimeout(context.Background(), timeout/2)
defer cancel()
ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
c := make(chan error, 2*N) c := make(chan error, 2*N)
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
name := fmt.Sprintf("%d.net-test.golang.org", i) name := fmt.Sprintf("%d.net-test.golang.org", i)
go func() { go func() {
_, err := lookupIPDeadline(name, time.Now().Add(timeout/2)) _, err := lookupIPContext(ctxHalfTimeout, name)
c <- err c <- err
}() }()
go func() { go func() {
_, err := lookupIPDeadline(name, time.Now().Add(timeout)) _, err := lookupIPContext(ctxTimeout, name)
c <- err c <- err
}() }()
} }
......
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
package net package net
import "sync" import (
"context"
"sync"
)
var onceReadProtocols sync.Once var onceReadProtocols sync.Once
...@@ -49,7 +52,7 @@ func lookupProtocol(name string) (int, error) { ...@@ -49,7 +52,7 @@ func lookupProtocol(name string) (int, error) {
return proto, nil return proto, nil
} }
func lookupHost(host string) (addrs []string, err error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order := systemConf().hostLookupOrder(host) order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo { if order == hostLookupCgo {
if addrs, err, ok := cgoLookupHost(host); ok { if addrs, err, ok := cgoLookupHost(host); ok {
...@@ -58,19 +61,20 @@ func lookupHost(host string) (addrs []string, err error) { ...@@ -58,19 +61,20 @@ func lookupHost(host string) (addrs []string, err error) {
// cgo not available (or netgo); fall back to Go's DNS resolver // cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS order = hostLookupFilesDNS
} }
return goLookupHostOrder(host, order) return goLookupHostOrder(ctx, host, order)
} }
func lookupIP(host string) (addrs []IPAddr, err error) { func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
order := systemConf().hostLookupOrder(host) order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo { if order == hostLookupCgo {
// TODO(bradfitz): push down ctx, or at least its deadline to start
if addrs, err, ok := cgoLookupIP(host); ok { if addrs, err, ok := cgoLookupIP(host); ok {
return addrs, err return addrs, err
} }
// cgo not available (or netgo); fall back to Go's DNS resolver // cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS order = hostLookupFilesDNS
} }
return goLookupIPOrder(host, order) return goLookupIPOrder(ctx, host, order)
} }
func lookupPort(network, service string) (int, error) { func lookupPort(network, service string) (int, error) {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package net package net
import ( import (
"context"
"os" "os"
"runtime" "runtime"
"syscall" "syscall"
...@@ -51,8 +52,8 @@ func lookupProtocol(name string) (int, error) { ...@@ -51,8 +52,8 @@ func lookupProtocol(name string) (int, error) {
return r.proto, r.err return r.proto, r.err
} }
func lookupHost(name string) ([]string, error) { func lookupHost(ctx context.Context, name string) ([]string, error) {
ips, err := LookupIP(name) ips, err := lookupIP(ctx, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -83,59 +84,97 @@ func gethostbyname(name string) (addrs []IPAddr, err error) { ...@@ -83,59 +84,97 @@ func gethostbyname(name string) (addrs []IPAddr, err error) {
return addrs, nil return addrs, nil
} }
func oldLookupIP(name string) ([]IPAddr, error) { func oldLookupIP(ctx context.Context, name string) ([]IPAddr, error) {
// GetHostByName return value is stored in thread local storage. // GetHostByName return value is stored in thread local storage.
// Start new os thread before the call to prevent races. // Start new os thread before the call to prevent races.
type result struct { type ret struct {
addrs []IPAddr addrs []IPAddr
err error err error
} }
ch := make(chan result) ch := make(chan ret, 1)
go func() { go func() {
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
addrs, err := gethostbyname(name) addrs, err := gethostbyname(name)
ch <- result{addrs: addrs, err: err} ch <- ret{addrs: addrs, err: err}
}() }()
r := <-ch select {
if r.err != nil { case r := <-ch:
r.err = &DNSError{Err: r.err.Error(), Name: name} if r.err != nil {
r.err = &DNSError{Err: r.err.Error(), Name: name}
}
return r.addrs, r.err
case <-ctx.Done():
// TODO(bradfitz,brainman): cancel the ongoing
// gethostbyname? For now we just let it finish and
// write to the buffered channel.
return nil, &DNSError{
Name: name,
Err: ctx.Err().Error(),
IsTimeout: ctx.Err() == context.DeadlineExceeded,
}
} }
return r.addrs, r.err
} }
func newLookupIP(name string) ([]IPAddr, error) { func newLookupIP(ctx context.Context, name string) ([]IPAddr, error) {
acquireThread() // TODO(bradfitz,brainman): use ctx?
defer releaseThread()
hints := syscall.AddrinfoW{ type ret struct {
Family: syscall.AF_UNSPEC, addrs []IPAddr
Socktype: syscall.SOCK_STREAM, err error
Protocol: syscall.IPPROTO_IP,
}
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
if e != nil {
return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}
} }
defer syscall.FreeAddrInfoW(result) ch := make(chan ret, 1)
addrs := make([]IPAddr, 0, 5) go func() {
for ; result != nil; result = result.Next { acquireThread()
addr := unsafe.Pointer(result.Addr) defer releaseThread()
switch result.Family { hints := syscall.AddrinfoW{
case syscall.AF_INET: Family: syscall.AF_UNSPEC,
a := (*syscall.RawSockaddrInet4)(addr).Addr Socktype: syscall.SOCK_STREAM,
addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])}) Protocol: syscall.IPPROTO_IP,
case syscall.AF_INET6: }
a := (*syscall.RawSockaddrInet6)(addr).Addr var result *syscall.AddrinfoW
zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id)) e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone}) if e != nil {
default: ch <- ret{err: &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}}
return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name} }
defer syscall.FreeAddrInfoW(result)
addrs := make([]IPAddr, 0, 5)
for ; result != nil; result = result.Next {
addr := unsafe.Pointer(result.Addr)
switch result.Family {
case syscall.AF_INET:
a := (*syscall.RawSockaddrInet4)(addr).Addr
addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
case syscall.AF_INET6:
a := (*syscall.RawSockaddrInet6)(addr).Addr
zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
default:
ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}}
}
}
ch <- ret{addrs: addrs}
}()
select {
case r := <-ch:
return r.addrs, r.err
case <-ctx.Done():
// TODO(bradfitz,brainman): cancel the ongoing
// GetAddrInfoW? It would require conditionally using
// GetAddrInfoEx with lpOverlapped, which requires
// Windows 8 or newer. I guess we'll need oldLookupIP,
// newLookupIP, and newerLookUP.
//
// For now we just let it finish and write to the
// buffered channel.
return nil, &DNSError{
Name: name,
Err: ctx.Err().Error(),
IsTimeout: ctx.Err() == context.DeadlineExceeded,
} }
} }
return addrs, nil
} }
func getservbyname(network, service string) (int, error) { func getservbyname(network, service string) (int, error) {
......
...@@ -79,6 +79,7 @@ On Windows, the resolver always uses C library functions, such as GetAddrInfo an ...@@ -79,6 +79,7 @@ On Windows, the resolver always uses C library functions, such as GetAddrInfo an
package net package net
import ( import (
"context"
"errors" "errors"
"io" "io"
"os" "os"
...@@ -377,6 +378,22 @@ var ( ...@@ -377,6 +378,22 @@ var (
ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection") ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection")
) )
// mapErr maps from the context errors to the historical internal net
// error values.
//
// TODO(bradfitz): get rid of this after adjusting tests and making
// context.DeadlineExceeded implement net.Error?
func mapErr(err error) error {
switch err {
case context.Canceled:
return errCanceled
case context.DeadlineExceeded:
return errTimeout
default:
return err
}
}
// OpError is the error type usually returned by functions in the net // OpError is the error type usually returned by functions in the net
// package. It describes the operation, network type, and address of // package. It describes the operation, network type, and address of
// an error. // an error.
......
...@@ -7,7 +7,10 @@ ...@@ -7,7 +7,10 @@
package net package net
import "testing" import (
"context"
"testing"
)
func TestGoLookupIP(t *testing.T) { func TestGoLookupIP(t *testing.T) {
host := "localhost" host := "localhost"
...@@ -18,7 +21,7 @@ func TestGoLookupIP(t *testing.T) { ...@@ -18,7 +21,7 @@ func TestGoLookupIP(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if _, err := goLookupIP(host); err != nil { if _, err := goLookupIP(context.Background(), host); err != nil {
t.Error(err) t.Error(err)
} }
} }
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
package net package net
import ( import (
"context"
"os" "os"
"syscall" "syscall"
"time"
) )
// A sockaddr represents a TCP, UDP, IP or Unix network endpoint // A sockaddr represents a TCP, UDP, IP or Unix network endpoint
...@@ -34,7 +34,7 @@ type sockaddr interface { ...@@ -34,7 +34,7 @@ type sockaddr interface {
// socket returns a network file descriptor that is ready for // socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller. // asynchronous I/O using the network poller.
func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) { func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto) s, err := sysSocket(family, sotype, proto)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s ...@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
return fd, nil return fd, nil
} }
} }
if err := fd.dial(laddr, raddr, deadline, cancel); err != nil { if err := fd.dial(ctx, laddr, raddr); err != nil {
fd.Close() fd.Close()
return nil, err return nil, err
} }
...@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { ...@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
return func(syscall.Sockaddr) Addr { return nil } return func(syscall.Sockaddr) Addr { return nil }
} }
func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error { func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
var err error var err error
var lsa syscall.Sockaddr var lsa syscall.Sockaddr
if laddr != nil { if laddr != nil {
...@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan s ...@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan s
if rsa, err = raddr.sockaddr(fd.family); err != nil { if rsa, err = raddr.sockaddr(fd.family); err != nil {
return err return err
} }
if err := fd.connect(lsa, rsa, deadline, cancel); err != nil { if err := fd.connect(ctx, lsa, rsa); err != nil {
return err return err
} }
fd.isConnected = true fd.isConnected = true
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package net package net
import ( import (
"context"
"io" "io"
"os" "os"
"syscall" "syscall"
...@@ -60,7 +61,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { ...@@ -60,7 +61,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
default: default:
return nil, UnknownNetworkError(net) return nil, UnknownNetworkError(net)
} }
addrs, err := internetAddrList(net, addr, noDeadline) addrs, err := internetAddrList(context.Background(), net, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -186,7 +187,7 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { ...@@ -186,7 +187,7 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if raddr == nil { if raddr == nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
} }
c, err := dialTCP(net, laddr, raddr, noDeadline, noCancel) c, err := dialTCP(context.Background(), net, laddr, raddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
} }
...@@ -285,7 +286,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { ...@@ -285,7 +286,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
if laddr == nil { if laddr == nil {
laddr = &TCPAddr{} laddr = &TCPAddr{}
} }
ln, err := listenTCP(net, laddr) ln, err := listenTCP(context.Background(), net, laddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
} }
......
...@@ -5,17 +5,24 @@ ...@@ -5,17 +5,24 @@
package net package net
import ( import (
"context"
"io" "io"
"os" "os"
"time"
) )
func (c *TCPConn) readFrom(r io.Reader) (int64, error) { func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r) return genericReadFrom(c, r)
} }
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if !deadline.IsZero() { if testHookDialTCP != nil {
return testHookDialTCP(ctx, net, laddr, raddr)
}
return doDialTCP(ctx, net, laddr, raddr)
}
func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if d, _ := ctx.Deadline(); !d.IsZero() {
panic("net.dialTCP: deadline not implemented on Plan 9") panic("net.dialTCP: deadline not implemented on Plan 9")
} }
// TODO(bradfitz,0intro): also use the cancel channel. // TODO(bradfitz,0intro): also use the cancel channel.
...@@ -63,7 +70,7 @@ func (ln *TCPListener) file() (*os.File, error) { ...@@ -63,7 +70,7 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil return f, nil
} }
func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) { func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
fd, err := listenPlan9(network, laddr) fd, err := listenPlan9(network, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
package net package net
import ( import (
"context"
"io" "io"
"os" "os"
"syscall" "syscall"
"time"
) )
func sockaddrToTCP(sa syscall.Sockaddr) Addr { func sockaddrToTCP(sa syscall.Sockaddr) Addr {
...@@ -47,8 +47,15 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { ...@@ -47,8 +47,15 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r) return genericReadFrom(c, r)
} }
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel) if testHookDialTCP != nil {
return testHookDialTCP(ctx, net, laddr, raddr)
}
return doDialTCP(ctx, net, laddr, raddr)
}
func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
// TCP has a rarely used mechanism called a 'simultaneous connection' in // TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can // which Dial("tcp", addr1, addr2) run on the machine at addr1 can
...@@ -78,7 +85,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-cha ...@@ -78,7 +85,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-cha
if err == nil { if err == nil {
fd.Close() fd.Close()
} }
fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel) fd, err = internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
} }
if err != nil { if err != nil {
...@@ -141,8 +148,8 @@ func (ln *TCPListener) file() (*os.File, error) { ...@@ -141,8 +148,8 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil return f, nil
} }
func listenTCP(network string, laddr *TCPAddr) (*TCPListener, error) { func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel) fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_STREAM, 0, "listen")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
package net package net
import "syscall" import (
"context"
"syscall"
)
// UDPAddr represents the address of a UDP end point. // UDPAddr represents the address of a UDP end point.
type UDPAddr struct { type UDPAddr struct {
...@@ -55,7 +58,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { ...@@ -55,7 +58,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
default: default:
return nil, UnknownNetworkError(net) return nil, UnknownNetworkError(net)
} }
addrs, err := internetAddrList(net, addr, noDeadline) addrs, err := internetAddrList(context.Background(), net, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -181,7 +184,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { ...@@ -181,7 +184,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
if raddr == nil { if raddr == nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
} }
c, err := dialUDP(net, laddr, raddr, noDeadline) c, err := dialUDP(context.Background(), net, laddr, raddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
} }
...@@ -204,7 +207,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { ...@@ -204,7 +207,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
if laddr == nil { if laddr == nil {
laddr = &UDPAddr{} laddr = &UDPAddr{}
} }
c, err := listenUDP(net, laddr) c, err := listenUDP(context.Background(), net, laddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
} }
...@@ -231,7 +234,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon ...@@ -231,7 +234,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
if gaddr == nil || gaddr.IP == nil { if gaddr == nil || gaddr.IP == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress} return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
} }
c, err := listenMulticastUDP(network, ifi, gaddr) c, err := listenMulticastUDP(context.Background(), network, ifi, gaddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err} return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err}
} }
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
package net package net
import ( import (
"context"
"errors" "errors"
"os" "os"
"syscall" "syscall"
"time"
) )
func (c *UDPConn) readFrom(b []byte) (n int, addr *UDPAddr, err error) { func (c *UDPConn) readFrom(b []byte) (n int, addr *UDPAddr, err error) {
...@@ -55,8 +55,8 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error ...@@ -55,8 +55,8 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return 0, 0, syscall.EPLAN9 return 0, 0, syscall.EPLAN9
} }
func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
if !deadline.IsZero() { if deadline, _ := ctx.Deadline(); !deadline.IsZero() {
panic("net.dialUDP: deadline not implemented on Plan 9") panic("net.dialUDP: deadline not implemented on Plan 9")
} }
fd, err := dialPlan9(net, laddr, raddr) fd, err := dialPlan9(net, laddr, raddr)
...@@ -94,7 +94,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { ...@@ -94,7 +94,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
return h, b return h, b
} }
func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
l, err := listenPlan9(network, laddr) l, err := listenPlan9(network, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -111,6 +111,6 @@ func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { ...@@ -111,6 +111,6 @@ func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
return newUDPConn(fd), err return newUDPConn(fd), err
} }
func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
package net package net
import ( import (
"context"
"syscall" "syscall"
"time"
) )
func sockaddrToUDP(sa syscall.Sockaddr) Addr { func sockaddrToUDP(sa syscall.Sockaddr) Addr {
...@@ -90,24 +90,24 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error ...@@ -90,24 +90,24 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return c.fd.writeMsg(b, oob, sa) return c.fd.writeMsg(b, oob, sa)
} }
func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel) fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newUDPConn(fd), nil return newUDPConn(fd), nil
} }
func listenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(network, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel) fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newUDPConn(fd), nil return newUDPConn(fd), nil
} }
func listenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel) fd, err := internetSocket(ctx, network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package net package net
import ( import (
"context"
"os" "os"
"syscall" "syscall"
"time" "time"
...@@ -188,7 +189,7 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { ...@@ -188,7 +189,7 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
default: default:
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)} return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)}
} }
c, err := dialUnix(net, laddr, raddr, noDeadline) c, err := dialUnix(context.Background(), net, laddr, raddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
} }
...@@ -290,7 +291,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { ...@@ -290,7 +291,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
if laddr == nil { if laddr == nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress}
} }
ln, err := listenUnix(net, laddr) ln, err := listenUnix(context.Background(), net, laddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
} }
...@@ -310,7 +311,7 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { ...@@ -310,7 +311,7 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) {
if laddr == nil { if laddr == nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: errMissingAddress} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: errMissingAddress}
} }
c, err := listenUnixgram(net, laddr) c, err := listenUnixgram(context.Background(), net, laddr)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
} }
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
package net package net
import ( import (
"context"
"os" "os"
"syscall" "syscall"
"time"
) )
func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) { func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
...@@ -26,7 +26,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err ...@@ -26,7 +26,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
return 0, 0, syscall.EPLAN9 return 0, 0, syscall.EPLAN9
} }
func dialUnix(network string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { func dialUnix(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
...@@ -42,10 +42,10 @@ func (ln *UnixListener) file() (*os.File, error) { ...@@ -42,10 +42,10 @@ func (ln *UnixListener) file() (*os.File, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) { func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
...@@ -7,13 +7,13 @@ ...@@ -7,13 +7,13 @@
package net package net
import ( import (
"context"
"errors" "errors"
"os" "os"
"syscall" "syscall"
"time"
) )
func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Time) (*netFD, error) { func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) {
var sotype int var sotype int
switch net { switch net {
case "unix": case "unix":
...@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti ...@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti
return nil, errors.New("unknown mode: " + mode) return nil, errors.New("unknown mode: " + mode)
} }
fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel) fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -146,8 +146,8 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err ...@@ -146,8 +146,8 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
return c.fd.writeMsg(b, oob, sa) return c.fd.writeMsg(b, oob, sa)
} }
func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { func dialUnix(ctx context.Context, net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(net, laddr, raddr, "dial", deadline) fd, err := unixSocket(ctx, net, laddr, raddr, "dial")
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -187,16 +187,16 @@ func (ln *UnixListener) file() (*os.File, error) { ...@@ -187,16 +187,16 @@ func (ln *UnixListener) file() (*os.File, error) {
return f, nil return f, nil
} }
func listenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
fd, err := unixSocket(network, laddr, nil, "listen", noDeadline) fd, err := unixSocket(ctx, network, laddr, nil, "listen")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
} }
func listenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) { func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(network, laddr, nil, "listen", noDeadline) fd, err := unixSocket(ctx, network, laddr, nil, "listen")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment