Commit d8a7990f authored by Ben Burkert's avatar Ben Burkert Committed by Brad Fitzpatrick

net: support all PacketConn and Conn returned by Resolver.Dial

Allow the Resolver.Dial func to return instances of Conn other than
*TCPConn and *UDPConn. If the Conn is also a PacketConn, assume DNS
messages transmitted over the Conn adhere to section 4.2.1. "UDP usage".
Otherwise, follow section 4.2.2. "TCP usage".

Provides a hook mechanism so that DNS queries generated by the net
package may be answered or modified before being sent to over the
network.

Updates #19910

Change-Id: Ib089a28ad4a1848bbeaf624ae889f1e82d56655b
Reviewed-on: https://go-review.googlesource.com/45153
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent d55d7b93
...@@ -36,14 +36,14 @@ type dnsConn interface { ...@@ -36,14 +36,14 @@ type dnsConn interface {
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
} }
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { // dnsPacketConn implements the dnsConn interface for RFC 1035's
return dnsRoundTripUDP(c, query) // "UDP usage" transport mechanism. Conn is a packet-oriented connection,
// such as a *UDPConn.
type dnsPacketConn struct {
Conn
} }
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
// such as a *UDPConn.
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack() b, ok := query.Pack()
if !ok { if !ok {
return nil, errors.New("cannot marshal DNS message") return nil, errors.New("cannot marshal DNS message")
...@@ -69,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { ...@@ -69,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
} }
} }
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) { // dnsStreamConn implements the dnsConn interface for RFC 1035's
return dnsRoundTripTCP(c, out) // "TCP usage" transport mechanism. Conn is a stream-oriented connection,
// such as a *TCPConn.
type dnsStreamConn struct {
Conn
} }
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
// such as a *TCPConn.
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack() b, ok := query.Pack()
if !ok { if !ok {
return nil, errors.New("cannot marshal DNS message") return nil, errors.New("cannot marshal DNS message")
......
...@@ -8,6 +8,7 @@ package net ...@@ -8,6 +8,7 @@ package net
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"internal/poll" "internal/poll"
"io/ioutil" "io/ioutil"
...@@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct { ...@@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct {
func TestDNSTransportFallback(t *testing.T) { func TestDNSTransportFallback(t *testing.T) {
fake := fakeDNSServer{ fake := fakeDNSServer{
rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) { rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{ r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{ dnsMsgHdr: dnsMsgHdr{
rcode: dnsRcodeSuccess, id: q.id,
response: true,
rcode: dnsRcodeSuccess,
}, },
question: q.question,
} }
if n == "udp" { if n == "udp" {
r.truncated = true r.truncated = true
...@@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) { ...@@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) {
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{ r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{ dnsMsgHdr: dnsMsgHdr{
id: q.id, id: q.id,
response: true,
}, },
question: q.question,
} }
switch q.question[0].Name { switch q.question[0].Name {
...@@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{ r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{ dnsMsgHdr: dnsMsgHdr{
id: q.id, id: q.id,
response: true,
}, },
question: q.question,
} }
switch q.question[0].Name { switch q.question[0].Name {
...@@ -751,7 +759,7 @@ type fakeDNSServer struct { ...@@ -751,7 +759,7 @@ type fakeDNSServer struct {
} }
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
} }
type fakeDNSConn struct { type fakeDNSConn struct {
...@@ -759,6 +767,7 @@ type fakeDNSConn struct { ...@@ -759,6 +767,7 @@ type fakeDNSConn struct {
server *fakeDNSServer server *fakeDNSServer
n string n string
s string s string
q *dnsMsg
t time.Time t time.Time
} }
...@@ -766,15 +775,45 @@ func (f *fakeDNSConn) Close() error { ...@@ -766,15 +775,45 @@ func (f *fakeDNSConn) Close() error {
return nil return nil
} }
func (f *fakeDNSConn) Read(b []byte) (int, error) {
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
if err != nil {
return 0, err
}
bb, ok := resp.Pack()
if !ok {
return 0, errors.New("cannot marshal DNS message")
}
if len(b) < len(bb) {
return 0, errors.New("read would fragment DNS message")
}
copy(b, bb)
return len(bb), nil
}
func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
return 0, nil, nil
}
func (f *fakeDNSConn) Write(b []byte) (int, error) {
f.q = new(dnsMsg)
if !f.q.Unpack(b) {
return 0, errors.New("cannot unmarshal DNS message")
}
return len(b), nil
}
func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
return 0, nil
}
func (f *fakeDNSConn) SetDeadline(t time.Time) error { func (f *fakeDNSConn) SetDeadline(t time.Time) error {
f.t = t f.t = t
return nil return nil
} }
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
return f.server.rh(f.n, f.s, q, f.t)
}
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) { func TestIgnoreDNSForgeries(t *testing.T) {
c, s := Pipe() c, s := Pipe()
...@@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) { ...@@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
}, },
} }
resp, err := dnsRoundTripUDP(c, msg) dc := &dnsPacketConn{c}
resp, err := dc.dnsRoundTrip(msg)
if err != nil { if err != nil {
t.Fatalf("dnsRoundTripUDP failed: %v", err) t.Fatalf("dnsRoundTripUDP failed: %v", err)
} }
...@@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
case resolveOpError: case resolveOpError:
return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
case resolveServfail: case resolveServfail:
return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeServerFailure}}, nil return &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
rcode: dnsRcodeServerFailure,
},
question: q.question,
}, nil
case resolveTimeout: case resolveTimeout:
return nil, poll.ErrTimeout return nil, poll.ErrTimeout
default: default:
...@@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
switch q.question[0].Name { switch q.question[0].Name {
case searchX, name + ".": case searchX, name + ".":
// Return NXDOMAIN to utilize the search list. // Return NXDOMAIN to utilize the search list.
return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeNameError}}, nil return &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
rcode: dnsRcodeNameError,
},
question: q.question,
}, nil
case searchY: case searchY:
// Return records below. // Return records below.
default: default:
......
...@@ -111,9 +111,11 @@ type Resolver struct { ...@@ -111,9 +111,11 @@ type Resolver struct {
// Go's built-in DNS resolver to make TCP and UDP connections // Go's built-in DNS resolver to make TCP and UDP connections
// to DNS services. The provided addr will always be an IP // to DNS services. The provided addr will always be an IP
// address and not a hostname. // address and not a hostname.
// The Conn returned must be a *TCPConn or *UDPConn as // If the Conn returned is also a PacketConn, sent and received DNS
// requested by the network parameter. If nil, the default // messages must adhere to section 4.2.1. "UDP usage" of RFC 1035.
// dialer is used. // Otherwise, DNS messages transmitted over Conn must adhere to section
// 4.2.2. "TCP usage".
// If nil, the default dialer is used.
Dial func(ctx context.Context, network, addr string) (Conn, error) Dial func(ctx context.Context, network, addr string) (Conn, error)
// TODO(bradfitz): optional interface impl override hook // TODO(bradfitz): optional interface impl override hook
......
...@@ -8,8 +8,6 @@ package net ...@@ -8,8 +8,6 @@ package net
import ( import (
"context" "context"
"errors"
"reflect"
"sync" "sync"
) )
...@@ -70,12 +68,10 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e ...@@ -70,12 +68,10 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
if err != nil { if err != nil {
return nil, mapErr(err) return nil, mapErr(err)
} }
dc, ok := c.(dnsConn) if _, ok := c.(PacketConn); ok {
if !ok { return &dnsPacketConn{c}, nil
c.Close()
return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String())
} }
return dc, nil return &dnsStreamConn{c}, nil
} }
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment