Commit 2bc5f125 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: add Resolver type, Dialer.Resolver, and DefaultResolver

The new Resolver type (a struct) has 9 Lookup methods, all taking a
context.Context.

There's now a new DefaultResolver global, like http's
DefaultTransport and DefaultClient.

net.Dialer now has an optional Resolver field to set the Resolver.

This also does finishes some resolver cleanup internally, deleting
lookupIPMerge and renaming lookupIPContext into Resolver.LookupIPAddr.

The Resolver currently doesn't let you tweak much, but it's a struct
specifically so we can add knobs in the future. Currently I just added
a bool to force the pure Go resolver. In the future we could let
people provide an interface to implement the methods, or add a Timeout
time.Duration, which would wrap all provided contexts in a
context.WithTimeout.

Fixes #16672

Change-Id: I7ba1f886704f06def7b6b5c4da9809db51bc1495
Reviewed-on: https://go-review.googlesource.com/29440
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
parent ea143c29
......@@ -59,6 +59,9 @@ type Dialer struct {
// that do not support keep-alives ignore this field.
KeepAlive time.Duration
// Resolver optionally specifies an alternate resolver to use.
Resolver *Resolver
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancelation.
......@@ -92,6 +95,13 @@ func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Tim
return minNonzeroTime(earliest, d.Deadline)
}
func (d *Dialer) resolver() *Resolver {
if d.Resolver != nil {
return d.Resolver
}
return DefaultResolver
}
// partialDeadline returns the deadline to use for a single address,
// when multiple addresses are pending.
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
......@@ -156,7 +166,7 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err
// resolverAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
afnet, _, err := parseNetwork(ctx, network)
if err != nil {
return nil, err
......@@ -166,7 +176,6 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
}
switch afnet {
case "unix", "unixgram", "unixpacket":
// TODO(bradfitz): push down context
addr, err := ResolveUnixAddr(afnet, addr)
if err != nil {
return nil, err
......@@ -176,7 +185,7 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
}
return addrList{addr}, nil
}
addrs, err := internetAddrList(ctx, afnet, addr)
addrs, err := r.internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
}
......@@ -326,7 +335,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
}
addrs, err := resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
......@@ -525,7 +534,7 @@ func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error)
// instead of just the interface with the given host address.
// See Dial for more details about address syntax.
func Listen(net, laddr string) (Listener, error) {
addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
......@@ -552,7 +561,7 @@ func Listen(net, laddr string) (Listener, error) {
// instead of just the interface with the given host address.
// See Dial for the syntax of laddr.
func ListenPacket(net, laddr string) (PacketConn, error) {
addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
......
......@@ -456,8 +456,9 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) {
return goLookupIPOrder(ctx, name, hostLookupFilesDNS)
func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
order := systemConf().hostLookupOrder(host)
return goLookupIPOrder(ctx, host, order)
}
func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) {
......
......@@ -65,7 +65,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
addrs, err := internetAddrList(context.Background(), afnet, addr)
addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, addr)
if err != nil {
return nil, err
}
......
......@@ -190,7 +190,7 @@ func JoinHostPort(host, port string) string {
// address or a DNS name, and returns a list of internet protocol
// family addresses. The result contains at least one address when
// error is nil.
func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
var (
err error
host, port string
......@@ -202,7 +202,7 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
if host, port, err = SplitHostPort(addr); err != nil {
return nil, err
}
if portnum, err = LookupPort(net, port); err != nil {
if portnum, err = r.LookupPort(ctx, net, port); err != nil {
return nil, err
}
}
......@@ -238,7 +238,7 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil
}
// Try as a DNS name.
ips, err := lookupIPContext(ctx, host)
ips, err := r.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
......
......@@ -84,86 +84,75 @@ func lookupPortMap(network, service string) (port int, error error) {
return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
}
// DefaultResolver is the resolver used by the package-level Lookup
// functions and by Dialers without a specified Resolver.
var DefaultResolver = &Resolver{}
// A Resolver looks up names and numbers.
//
// A nil *Resolver is equivalent to a zero Resolver.
type Resolver struct {
// PreferGo controls whether Go's built-in DNS resolver is preferred
// on platforms where it's available. It is equivalent to setting
// GODEBUG=netdns=go, but scoped to just this resolver.
PreferGo bool
// TODO(bradfitz): optional interface impl override hook
// TODO(bradfitz): Timeout time.Duration?
}
func (r *Resolver) lookupIPFunc() func(context.Context, string) ([]IPAddr, error) {
if r != nil && r.PreferGo {
return goLookupIP
}
return lookupIP
}
// LookupHost looks up the given host using the local resolver.
// It returns an array of that host's addresses.
// It returns a slice of that host's addresses.
func LookupHost(host string) (addrs []string, err error) {
// Make sure that no matter what we do later, host=="" is rejected.
// ParseIP, for example, does accept empty strings.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
}
if ip := ParseIP(host); ip != nil {
return []string{host}, nil
}
return lookupHost(context.Background(), host)
return DefaultResolver.LookupHost(context.Background(), host)
}
// LookupIP looks up host using the local resolver.
// It returns an array of that host's IPv4 and IPv6 addresses.
func LookupIP(host string) (ips []IP, err error) {
// LookupHost looks up the given host using the local resolver.
// It returns a slice of that host's addresses.
func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Make sure that no matter what we do later, host=="" is rejected.
// ParseIP, for example, does accept empty strings.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
}
if ip := ParseIP(host); ip != nil {
return []IP{ip}, nil
}
addrs, err := lookupIPMerge(context.Background(), host)
if err != nil {
return
}
ips = make([]IP, len(addrs))
for i, addr := range addrs {
ips[i] = addr.IP
return []string{host}, nil
}
return
return lookupHost(ctx, host)
}
var lookupGroup singleflight.Group
// lookupIPMerge wraps lookupIP, but makes sure that for any given
// host, only one lookup is in-flight at a time. The returned memory
// is always owned by the caller.
func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) {
addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) {
return testHookLookupIP(ctx, lookupIP, host)
})
return lookupIPReturn(addrsi, err, shared)
}
// lookupIPReturn turns the return values from singleflight.Do into
// the return values from LookupIP.
func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
// LookupIP looks up host using the local resolver.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func LookupIP(host string) ([]IP, error) {
addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host)
if err != nil {
return nil, err
}
addrs := addrsi.([]IPAddr)
if shared {
clone := make([]IPAddr, len(addrs))
copy(clone, addrs)
addrs = clone
ips := make([]IP, len(addrs))
for i, ia := range addrs {
ips[i] = ia.IP
}
return addrs, nil
return ips, nil
}
// ipAddrsEface returns an empty interface slice of addrs.
func ipAddrsEface(addrs []IPAddr) []interface{} {
s := make([]interface{}, len(addrs))
for i, v := range addrs {
s[i] = v
// LookupIPAddr looks up host using the local resolver.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
// Make sure that no matter what we do later, host=="" is rejected.
// ParseIP, for example, does accept empty strings.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
}
if ip := ParseIP(host); ip != nil {
return []IPAddr{{IP: ip}}, nil
}
return s
}
// lookupIPContext looks up a hostname with a context.
//
// TODO(bradfitz): rename this function. All the other
// build-tag-specific lookupIP funcs also take a context now, so this
// name is no longer great. Maybe make this lookupIPMerge and ditch
// the other one, making its callers call this instead with a
// context.Background().
func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil && trace.DNSStart != nil {
trace.DNSStart(host)
......@@ -171,7 +160,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
// The underlying resolver func is lookupIP by default but it
// can be overridden by tests. This is needed by net/http, so it
// uses a context key instead of unexported variables.
resolverFunc := lookupIP
resolverFunc := r.lookupIPFunc()
if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil {
resolverFunc = alt
}
......@@ -201,11 +190,46 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
}
}
// lookupGroup merges LookupIPAddr calls together for lookups
// for the same host. The lookupGroup key is is the LookupIPAddr.host
// argument.
// The return values are ([]IPAddr, error).
var lookupGroup singleflight.Group
// lookupIPReturn turns the return values from singleflight.Do into
// the return values from LookupIP.
func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
if err != nil {
return nil, err
}
addrs := addrsi.([]IPAddr)
if shared {
clone := make([]IPAddr, len(addrs))
copy(clone, addrs)
addrs = clone
}
return addrs, nil
}
// ipAddrsEface returns an empty interface slice of addrs.
func ipAddrsEface(addrs []IPAddr) []interface{} {
s := make([]interface{}, len(addrs))
for i, v := range addrs {
s[i] = v
}
return s
}
// LookupPort looks up the port for the given network and service.
func LookupPort(network, service string) (port int, err error) {
return DefaultResolver.LookupPort(context.Background(), network, service)
}
// LookupPort looks up the port for the given network and service.
func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
port, needsLookup := parsePort(service)
if needsLookup {
port, err = lookupPort(context.Background(), network, service)
port, err = lookupPort(ctx, network, service)
if err != nil {
return 0, err
}
......@@ -224,6 +248,14 @@ func LookupCNAME(name string) (cname string, err error) {
return lookupCNAME(context.Background(), name)
}
// LookupCNAME returns the canonical DNS host for the given name.
// Callers that do not care about the canonical name can call
// LookupHost or LookupIP directly; both take care of resolving
// the canonical name as part of the lookup.
func (r *Resolver) LookupCNAME(ctx context.Context, name string) (cname string, err error) {
return lookupCNAME(ctx, name)
}
// LookupSRV tries to resolve an SRV query of the given service,
// protocol, and domain name. The proto is "tcp" or "udp".
// The returned records are sorted by priority and randomized
......@@ -237,23 +269,57 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
return lookupSRV(context.Background(), service, proto, name)
}
// LookupSRV tries to resolve an SRV query of the given service,
// protocol, and domain name. The proto is "tcp" or "udp".
// The returned records are sorted by priority and randomized
// by weight within a priority.
//
// LookupSRV constructs the DNS name to look up following RFC 2782.
// That is, it looks up _service._proto.name. To accommodate services
// publishing SRV records under non-standard names, if both service
// and proto are empty strings, LookupSRV looks up name directly.
func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
return lookupSRV(ctx, service, proto, name)
}
// LookupMX returns the DNS MX records for the given domain name sorted by preference.
func LookupMX(name string) (mxs []*MX, err error) {
func LookupMX(name string) ([]*MX, error) {
return lookupMX(context.Background(), name)
}
// LookupMX returns the DNS MX records for the given domain name sorted by preference.
func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
return lookupMX(ctx, name)
}
// LookupNS returns the DNS NS records for the given domain name.
func LookupNS(name string) (nss []*NS, err error) {
func LookupNS(name string) ([]*NS, error) {
return lookupNS(context.Background(), name)
}
// LookupNS returns the DNS NS records for the given domain name.
func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
return lookupNS(ctx, name)
}
// LookupTXT returns the DNS TXT records for the given domain name.
func LookupTXT(name string) (txts []string, err error) {
func LookupTXT(name string) ([]string, error) {
return lookupTXT(context.Background(), name)
}
// LookupTXT returns the DNS TXT records for the given domain name.
func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
return lookupTXT(ctx, name)
}
// LookupAddr performs a reverse lookup for the given address, returning a list
// of names mapping to that address.
func LookupAddr(addr string) (names []string, err error) {
return lookupAddr(context.Background(), addr)
}
// LookupAddr performs a reverse lookup for the given address, returning a list
// of names mapping to that address.
func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) {
return lookupAddr(ctx, addr)
}
......@@ -19,6 +19,10 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
return nil, syscall.ENOPROTOOPT
}
func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
return nil, syscall.ENOPROTOOPT
}
func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
return nil, syscall.ENOPROTOOPT
}
......
......@@ -148,6 +148,8 @@ loop:
return
}
var goLookupIP = lookupIP
func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
lits, err := lookupHost(ctx, host)
if err != nil {
......
......@@ -398,11 +398,11 @@ func TestDNSFlood(t *testing.T) {
for i := 0; i < N; i++ {
name := fmt.Sprintf("%d.net-test.golang.org", i)
go func() {
_, err := lookupIPContext(ctxHalfTimeout, name)
_, err := DefaultResolver.LookupIPAddr(ctxHalfTimeout, name)
c <- err
}()
go func() {
_, err := lookupIPContext(ctxTimeout, name)
_, err := DefaultResolver.LookupIPAddr(ctxTimeout, name)
c <- err
}()
}
......
......@@ -66,8 +66,14 @@ func lookupHost(ctx context.Context, name string) ([]string, error) {
return addrs, nil
}
// goLookupIP isn't a Pure Go implementation on Windows.
// TODO(bradfitz): should it be? Not sure it can be. It's always used syscall.GetAddrInfoW.
func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
return lookupIP(ctx, host)
}
func lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
// TODO(bradfitz,brainman): use ctx?
// TODO(bradfitz,brainman): use ctx more. See TODO below.
type ret struct {
addrs []IPAddr
......
......@@ -64,7 +64,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
addrs, err := internetAddrList(context.Background(), net, addr)
addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr)
if err != nil {
return nil, err
}
......
......@@ -67,7 +67,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
addrs, err := internetAddrList(context.Background(), net, addr)
addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr)
if err != nil {
return nil, err
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment