Commit 380aa884 authored by Matt Harden's avatar Matt Harden Committed by Brad Fitzpatrick

net: allow Resolver to use a custom dialer

In some cases it is desirable to customize the way the DNS server is
contacted, for instance to use a specific LocalAddr. While most
operating-system level resolvers do not allow this, we have the
opportunity to do so with the Go resolver. Most of the code was
already in place to allow tests to override the dialer. This exposes
that functionality, and as a side effect eliminates the need for a
testing hook.

Fixes #17404

Change-Id: I1c5e570f8edbcf630090f8ec6feb52e379e3e5c0
Reviewed-on: https://go-review.googlesource.com/37260
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 3b5637ff
...@@ -304,7 +304,7 @@ var pkgDeps = map[string][]string{ ...@@ -304,7 +304,7 @@ var pkgDeps = map[string][]string{
// do networking portably, it must have a small dependency set: just L0+basic os. // do networking portably, it must have a small dependency set: just L0+basic os.
"net": { "net": {
"L0", "CGO", "L0", "CGO",
"context", "math/rand", "os", "sort", "syscall", "time", "context", "math/rand", "os", "reflect", "sort", "syscall", "time",
"internal/nettrace", "internal/poll", "internal/nettrace", "internal/poll",
"internal/syscall/windows", "internal/singleflight", "internal/race", "internal/syscall/windows", "internal/singleflight", "internal/race",
"golang_org/x/net/lif", "golang_org/x/net/route", "golang_org/x/net/lif", "golang_org/x/net/route",
......
...@@ -25,13 +25,6 @@ import ( ...@@ -25,13 +25,6 @@ import (
"time" "time"
) )
// A dnsDialer provides dialing suitable for DNS queries.
type dnsDialer interface {
dialDNS(ctx context.Context, network, addr string) (dnsConn, error)
}
var testHookDNSDialer = func() dnsDialer { return &Dialer{} }
// A dnsConn represents a DNS transport endpoint. // A dnsConn represents a DNS transport endpoint.
type dnsConn interface { type dnsConn interface {
io.Closer io.Closer
...@@ -116,33 +109,8 @@ func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { ...@@ -116,33 +109,8 @@ func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
return resp, nil return resp, nil
} }
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
switch network {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(network)
}
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
// already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
c, err := d.DialContext(ctx, network, server)
if err != nil {
return nil, mapErr(err)
}
switch network {
case "tcp", "tcp4", "tcp6":
return c.(*TCPConn), nil
case "udp", "udp4", "udp6":
return c.(*UDPConn), nil
}
panic("unreachable")
}
// exchange sends a query on the connection and hopes for a response. // exchange sends a query on the connection and hopes for a response.
func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
d := testHookDNSDialer()
out := dnsMsg{ out := dnsMsg{
dnsMsgHdr: dnsMsgHdr{ dnsMsgHdr: dnsMsgHdr{
recursion_desired: true, recursion_desired: true,
...@@ -158,7 +126,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti ...@@ -158,7 +126,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel() defer cancel()
c, err := d.dialDNS(ctx, network, server) c, err := r.dial(ctx, network, server)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -181,7 +149,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti ...@@ -181,7 +149,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti
// Do a lookup for a single name, which must be rooted // Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers). // (otherwise answer will not find the answers).
func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
var lastErr error var lastErr error
serverOffset := cfg.serverOffset() serverOffset := cfg.serverOffset()
sLen := uint32(len(cfg.servers)) sLen := uint32(len(cfg.servers))
...@@ -190,7 +158,7 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) ...@@ -190,7 +158,7 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16)
for j := uint32(0); j < sLen; j++ { for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen] server := cfg.servers[(serverOffset+j)%sLen]
msg, err := exchange(ctx, server, name, qtype, cfg.timeout) msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout)
if err != nil { if err != nil {
lastErr = &DNSError{ lastErr = &DNSError{
Err: err.Error(), Err: err.Error(),
...@@ -333,7 +301,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname ...@@ -333,7 +301,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname
conf := resolvConf.dnsConfig conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock() resolvConf.mu.RUnlock()
for _, fqdn := range conf.nameList(name) { for _, fqdn := range conf.nameList(name) {
cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype) cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype)
if err == nil { if err == nil {
break break
} }
...@@ -512,7 +480,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order ...@@ -512,7 +480,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
for _, fqdn := range conf.nameList(name) { for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes { for _, qtype := range qtypes {
go func(qtype uint16) { go func(qtype uint16) {
cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype) cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
lane <- racer{cname, rrs, err} lane <- racer{cname, rrs, err}
}(qtype) }(qtype)
} }
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"context" "context"
"fmt" "fmt"
"internal/poll" "internal/poll"
"internal/testenv"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
...@@ -21,6 +20,8 @@ import ( ...@@ -21,6 +20,8 @@ import (
"time" "time"
) )
var goResolver = Resolver{PreferGo: true}
// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation. // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
const TestAddr uint32 = 0xc0000201 const TestAddr uint32 = 0xc0000201
...@@ -41,18 +42,30 @@ var dnsTransportFallbackTests = []struct { ...@@ -41,18 +42,30 @@ var dnsTransportFallbackTests = []struct {
} }
func TestDNSTransportFallback(t *testing.T) { func TestDNSTransportFallback(t *testing.T) {
testenv.MustHaveExternalNetwork(t) fake := fakeDNSServer{
rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
rcode: dnsRcodeSuccess,
},
}
if n == "udp" {
r.truncated = true
}
return r, nil
},
}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
for _, tt := range dnsTransportFallbackTests { for _, tt := range dnsTransportFallbackTests {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
msg, err := exchange(ctx, tt.server, tt.name, tt.qtype, time.Second) msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
continue continue
} }
switch msg.rcode { switch msg.rcode {
case tt.rcode, dnsRcodeServerFailure: case tt.rcode:
default: default:
t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode) t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
continue continue
...@@ -82,13 +95,28 @@ var specialDomainNameTests = []struct { ...@@ -82,13 +95,28 @@ var specialDomainNameTests = []struct {
} }
func TestSpecialDomainName(t *testing.T) { func TestSpecialDomainName(t *testing.T) {
testenv.MustHaveExternalNetwork(t) fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
},
}
switch q.question[0].Name {
case "example.com.":
r.rcode = dnsRcodeSuccess
default:
r.rcode = dnsRcodeNameError
}
return r, nil
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
server := "8.8.8.8:53" server := "8.8.8.8:53"
for _, tt := range specialDomainNameTests { for _, tt := range specialDomainNameTests {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
msg, err := exchange(ctx, server, tt.name, tt.qtype, 3*time.Second) msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
continue continue
...@@ -143,15 +171,40 @@ func TestAvoidDNSName(t *testing.T) { ...@@ -143,15 +171,40 @@ func TestAvoidDNSName(t *testing.T) {
} }
} }
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
},
question: q.question,
}
if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
r.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
},
A: TestAddr,
},
}
}
return r, nil
}}
// Issue 13705: don't try to resolve onion addresses, etc // Issue 13705: don't try to resolve onion addresses, etc
func TestLookupTorOnion(t *testing.T) { func TestLookupTorOnion(t *testing.T) {
addrs, err := DefaultResolver.goLookupIP(context.Background(), "foo.onion") r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
if len(addrs) > 0 { addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
t.Errorf("unexpected addresses: %v", addrs)
}
if err != nil { if err != nil {
t.Fatalf("lookup = %v; want nil", err) t.Fatalf("lookup = %v; want nil", err)
} }
if len(addrs) > 0 {
t.Errorf("unexpected addresses: %v", addrs)
}
} }
type resolvConfTest struct { type resolvConfTest struct {
...@@ -241,7 +294,7 @@ var updateResolvConfTests = []struct { ...@@ -241,7 +294,7 @@ var updateResolvConfTests = []struct {
} }
func TestUpdateResolvConf(t *testing.T) { func TestUpdateResolvConf(t *testing.T) {
testenv.MustHaveExternalNetwork(t) r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
...@@ -261,7 +314,7 @@ func TestUpdateResolvConf(t *testing.T) { ...@@ -261,7 +314,7 @@ func TestUpdateResolvConf(t *testing.T) {
for j := 0; j < N; j++ { for j := 0; j < N; j++ {
go func(name string) { go func(name string) {
defer wg.Done() defer wg.Done()
ips, err := DefaultResolver.goLookupIP(context.Background(), name) ips, err := r.LookupIPAddr(context.Background(), name)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
...@@ -396,7 +449,60 @@ var goLookupIPWithResolverConfigTests = []struct { ...@@ -396,7 +449,60 @@ var goLookupIPWithResolverConfigTests = []struct {
} }
func TestGoLookupIPWithResolverConfig(t *testing.T) { func TestGoLookupIPWithResolverConfig(t *testing.T) {
testenv.MustHaveExternalNetwork(t) fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break
default:
time.Sleep(10 * time.Millisecond)
return nil, poll.ErrTimeout
}
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
},
question: q.question,
}
for _, question := range q.question {
switch question.Qtype {
case dnsTypeA:
switch question.Name {
case "hostname.as112.net.":
break
case "ipv4.google.com.":
r.answer = append(r.answer, &dnsRR_A{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
},
A: TestAddr,
})
default:
}
case dnsTypeAAAA:
switch question.Name {
case "hostname.as112.net.":
break
case "ipv6.google.com.":
r.answer = append(r.answer, &dnsRR_AAAA{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeAAAA,
Class: dnsClassINET,
Rdlength: 16,
},
AAAA: TestAddr6,
})
}
}
}
return r, nil
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
...@@ -409,14 +515,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { ...@@ -409,14 +515,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
t.Error(err) t.Error(err)
continue continue
} }
addrs, err := DefaultResolver.goLookupIP(context.Background(), tt.name) addrs, err := r.LookupIPAddr(context.Background(), tt.name)
if err != nil { if err != nil {
// This test uses external network connectivity.
// We need to take care with errors on both
// DNS message exchange layer and DNS
// transport layer because goLookupIP may fail
// when the IP connectivity on node under test
// gets lost during its run.
if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) { if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
t.Errorf("got %v; want %v", err, tt.error) t.Errorf("got %v; want %v", err, tt.error)
} }
...@@ -441,7 +541,17 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { ...@@ -441,7 +541,17 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
// Test that goLookupIPOrder falls back to the host file when no DNS servers are available. // Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
func TestGoLookupIPOrderFallbackToFile(t *testing.T) { func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
testenv.MustHaveExternalNetwork(t) fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
},
question: q.question,
}
return r, nil
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
// Add a config that simulates no dns servers being available. // Add a config that simulates no dns servers being available.
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
...@@ -459,14 +569,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { ...@@ -459,14 +569,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
name := fmt.Sprintf("order %v", order) name := fmt.Sprintf("order %v", order)
// First ensure that we get an error when contacting a non-existent host. // First ensure that we get an error when contacting a non-existent host.
_, _, err := DefaultResolver.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order) _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
if err == nil { if err == nil {
t.Errorf("%s: expected error while looking up name not in hosts file", name) t.Errorf("%s: expected error while looking up name not in hosts file", name)
continue continue
} }
// Now check that we get an address when the name appears in the hosts file. // Now check that we get an address when the name appears in the hosts file.
addrs, _, err := DefaultResolver.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
if err != nil { if err != nil {
t.Errorf("%s: expected to successfully lookup host entry", name) t.Errorf("%s: expected to successfully lookup host entry", name)
continue continue
...@@ -489,9 +599,6 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { ...@@ -489,9 +599,6 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
func TestErrorForOriginalNameWhenSearching(t *testing.T) { func TestErrorForOriginalNameWhenSearching(t *testing.T) {
const fqdn = "doesnotexist.domain" const fqdn = "doesnotexist.domain"
origTestHookDNSDialer := testHookDNSDialer
defer func() { testHookDNSDialer = origTestHookDNSDialer }()
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -502,10 +609,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -502,10 +609,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
d := &fakeDNSDialer{} fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
testHookDNSDialer = func() dnsDialer { return d }
d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{ r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{ dnsMsgHdr: dnsMsgHdr{
id: q.id, id: q.id,
...@@ -520,7 +624,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -520,7 +624,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
} }
return r, nil return r, nil
} }}
cases := []struct { cases := []struct {
strictErrors bool strictErrors bool
...@@ -530,8 +634,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -530,8 +634,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
{false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}}, {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}},
} }
for _, tt := range cases { for _, tt := range cases {
r := Resolver{StrictErrors: tt.strictErrors} r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
_, err = r.goLookupIP(context.Background(), fqdn) _, err = r.LookupIPAddr(context.Background(), fqdn)
if err == nil { if err == nil {
t.Fatal("expected an error") t.Fatal("expected an error")
} }
...@@ -545,9 +649,6 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -545,9 +649,6 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
// Issue 15434. If a name server gives a lame referral, continue to the next. // Issue 15434. If a name server gives a lame referral, continue to the next.
func TestIgnoreLameReferrals(t *testing.T) { func TestIgnoreLameReferrals(t *testing.T) {
origTestHookDNSDialer := testHookDNSDialer
defer func() { testHookDNSDialer = origTestHookDNSDialer }()
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -559,10 +660,7 @@ func TestIgnoreLameReferrals(t *testing.T) { ...@@ -559,10 +660,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
d := &fakeDNSDialer{} fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
testHookDNSDialer = func() dnsDialer { return d }
d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
t.Log(s, q) t.Log(s, q)
r := &dnsMsg{ r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{ dnsMsgHdr: dnsMsgHdr{
...@@ -590,9 +688,10 @@ func TestIgnoreLameReferrals(t *testing.T) { ...@@ -590,9 +688,10 @@ func TestIgnoreLameReferrals(t *testing.T) {
} }
return r, nil return r, nil
} }}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
addrs, err := DefaultResolver.goLookupIP(context.Background(), "www.golang.org") addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -611,7 +710,7 @@ func BenchmarkGoLookupIP(b *testing.B) { ...@@ -611,7 +710,7 @@ func BenchmarkGoLookupIP(b *testing.B) {
ctx := context.Background() ctx := context.Background()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
DefaultResolver.goLookupIP(ctx, "www.example.com") goResolver.LookupIPAddr(ctx, "www.example.com")
} }
} }
...@@ -620,7 +719,7 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { ...@@ -620,7 +719,7 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
ctx := context.Background() ctx := context.Background()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
DefaultResolver.goLookupIP(ctx, "some.nonexistent") goResolver.LookupIPAddr(ctx, "some.nonexistent")
} }
} }
...@@ -643,21 +742,22 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { ...@@ -643,21 +742,22 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
ctx := context.Background() ctx := context.Background()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
DefaultResolver.goLookupIP(ctx, "www.example.com") goResolver.LookupIPAddr(ctx, "www.example.com")
} }
} }
type fakeDNSDialer struct { type fakeDNSServer struct {
// reply handler rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
} }
func (f *fakeDNSDialer) dialDNS(_ context.Context, n, s string) (dnsConn, error) { func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
return &fakeDNSConn{f.rh, s, time.Time{}}, nil return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil
} }
type fakeDNSConn struct { type fakeDNSConn struct {
rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error) Conn
server *fakeDNSServer
n string
s string s string
t time.Time t time.Time
} }
...@@ -672,7 +772,7 @@ func (f *fakeDNSConn) SetDeadline(t time.Time) error { ...@@ -672,7 +772,7 @@ func (f *fakeDNSConn) SetDeadline(t time.Time) error {
} }
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) { func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
return f.rh(f.s, q, f.t) return f.server.rh(f.n, f.s, q, f.t)
} }
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
...@@ -749,9 +849,6 @@ func TestIgnoreDNSForgeries(t *testing.T) { ...@@ -749,9 +849,6 @@ func TestIgnoreDNSForgeries(t *testing.T) {
// Issue 16865. If a name server times out, continue to the next. // Issue 16865. If a name server times out, continue to the next.
func TestRetryTimeout(t *testing.T) { func TestRetryTimeout(t *testing.T) {
origTestHookDNSDialer := testHookDNSDialer
defer func() { testHookDNSDialer = origTestHookDNSDialer }()
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -766,12 +863,9 @@ func TestRetryTimeout(t *testing.T) { ...@@ -766,12 +863,9 @@ func TestRetryTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
d := &fakeDNSDialer{}
testHookDNSDialer = func() dnsDialer { return d }
var deadline0 time.Time var deadline0 time.Time
d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
t.Log(s, q, deadline) t.Log(s, q, deadline)
if deadline.IsZero() { if deadline.IsZero() {
...@@ -789,9 +883,10 @@ func TestRetryTimeout(t *testing.T) { ...@@ -789,9 +883,10 @@ func TestRetryTimeout(t *testing.T) {
} }
return mockTXTResponse(q), nil return mockTXTResponse(q), nil
} }}
r := &Resolver{PreferGo: true, Dial: fake.DialContext}
_, err = LookupTXT("www.golang.org") _, err = r.LookupTXT(context.Background(), "www.golang.org")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -810,9 +905,6 @@ func TestRotate(t *testing.T) { ...@@ -810,9 +905,6 @@ func TestRotate(t *testing.T) {
} }
func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
origTestHookDNSDialer := testHookDNSDialer
defer func() { testHookDNSDialer = origTestHookDNSDialer }()
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -831,18 +923,16 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { ...@@ -831,18 +923,16 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
t.Fatal(err) t.Fatal(err)
} }
d := &fakeDNSDialer{}
testHookDNSDialer = func() dnsDialer { return d }
var usedServers []string var usedServers []string
d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
usedServers = append(usedServers, s) usedServers = append(usedServers, s)
return mockTXTResponse(q), nil return mockTXTResponse(q), nil
} }}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
// len(nameservers) + 1 to allow rotation to get back to start // len(nameservers) + 1 to allow rotation to get back to start
for i := 0; i < len(nameservers)+1; i++ { for i := 0; i < len(nameservers)+1; i++ {
if _, err := LookupTXT("www.golang.org"); err != nil { if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
...@@ -878,9 +968,6 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg { ...@@ -878,9 +968,6 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg {
// Issue 17448. With StrictErrors enabled, temporary errors should make // Issue 17448. With StrictErrors enabled, temporary errors should make
// LookupIP fail rather than return a partial result. // LookupIP fail rather than return a partial result.
func TestStrictErrorsLookupIP(t *testing.T) { func TestStrictErrorsLookupIP(t *testing.T) {
origTestHookDNSDialer := testHookDNSDialer
defer func() { testHookDNSDialer = origTestHookDNSDialer }()
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1017,10 +1104,7 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1017,10 +1104,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
} }
for i, tt := range cases { for i, tt := range cases {
d := &fakeDNSDialer{} fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
testHookDNSDialer = func() dnsDialer { return d }
d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
t.Log(s, q) t.Log(s, q)
switch tt.resolveWhich(&q.question[0]) { switch tt.resolveWhich(&q.question[0]) {
...@@ -1082,11 +1166,11 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1082,11 +1166,11 @@ func TestStrictErrorsLookupIP(t *testing.T) {
return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype) return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
} }
return r, nil return r, nil
} }}
for _, strict := range []bool{true, false} { for _, strict := range []bool{true, false} {
r := Resolver{StrictErrors: strict} r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
ips, err := r.goLookupIP(context.Background(), name) ips, err := r.LookupIPAddr(context.Background(), name)
var wantErr error var wantErr error
if strict { if strict {
...@@ -1118,9 +1202,6 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1118,9 +1202,6 @@ func TestStrictErrorsLookupIP(t *testing.T) {
// Issue 17448. With StrictErrors enabled, temporary errors should make // Issue 17448. With StrictErrors enabled, temporary errors should make
// LookupTXT stop walking the search list. // LookupTXT stop walking the search list.
func TestStrictErrorsLookupTXT(t *testing.T) { func TestStrictErrorsLookupTXT(t *testing.T) {
origTestHookDNSDialer := testHookDNSDialer
defer func() { testHookDNSDialer = origTestHookDNSDialer }()
conf, err := newResolvConfTest() conf, err := newResolvConfTest()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1141,10 +1222,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1141,10 +1222,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org." const searchY = "test.y.golang.org."
const txt = "Hello World" const txt = "Hello World"
d := &fakeDNSDialer{} fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
testHookDNSDialer = func() dnsDialer { return d }
d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
t.Log(s, q) t.Log(s, q)
switch q.question[0].Name { switch q.question[0].Name {
...@@ -1155,10 +1233,10 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1155,10 +1233,10 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
default: default:
return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
} }
} }}
for _, strict := range []bool{true, false} { for _, strict := range []bool{true, false} {
r := Resolver{StrictErrors: strict} r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
_, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT) _, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
var wantErr error var wantErr error
var wantRRs int var wantRRs int
......
...@@ -107,6 +107,15 @@ type Resolver struct { ...@@ -107,6 +107,15 @@ type Resolver struct {
// with resolvers that process AAAA queries incorrectly. // with resolvers that process AAAA queries incorrectly.
StrictErrors bool StrictErrors bool
// Dial optionally specifies an alternate dialer for use by
// Go's built-in DNS resolver to make TCP and UDP connections
// to DNS services. The provided addr will always be an IP
// address and not a hostname.
// The Conn returned must be a *TCPConn or *UDPConn as
// requested by the network parameter. If nil, the default
// dialer is used.
Dial func(ctx context.Context, network, addr string) (Conn, error)
// TODO(bradfitz): optional interface impl override hook // TODO(bradfitz): optional interface impl override hook
// TODO(bradfitz): Timeout time.Duration? // TODO(bradfitz): Timeout time.Duration?
} }
......
...@@ -8,6 +8,8 @@ package net ...@@ -8,6 +8,8 @@ package net
import ( import (
"context" "context"
"errors"
"reflect"
"sync" "sync"
) )
...@@ -51,6 +53,31 @@ func lookupProtocol(_ context.Context, name string) (int, error) { ...@@ -51,6 +53,31 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
return lookupProtocolMap(name) return lookupProtocolMap(name)
} }
func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
// already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
var c Conn
var err error
if r.Dial != nil {
c, err = r.Dial(ctx, network, server)
} else {
var d Dialer
c, err = d.DialContext(ctx, network, server)
}
if err != nil {
return nil, mapErr(err)
}
dc, ok := c.(dnsConn)
if !ok {
c.Close()
return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String())
}
return dc, nil
}
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order := systemConf().hostLookupOrder(host) order := systemConf().hostLookupOrder(host)
if !r.PreferGo && order == hostLookupCgo { if !r.PreferGo && order == hostLookupCgo {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment