Commit 672729eb authored by Ian Gudger's avatar Ian Gudger Committed by Brad Fitzpatrick

net: use golang.org/x/net/dns/dnsmessage for DNS resolution

Vendors golang.org/x/net/dns/dnsmessage from x/net git rev
892bf7b0c6e2f93b51166bf3882e50277fa5afc6

Updates #16218
Updates #21160

Change-Id: Ic4e8f3c3d83c2936354ec14c5be93b0d2b42dd91
Reviewed-on: https://go-review.googlesource.com/37879
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent c830e05a
......@@ -313,7 +313,7 @@ var pkgDeps = map[string][]string{
"context", "math/rand", "os", "reflect", "sort", "syscall", "time",
"internal/nettrace", "internal/poll",
"internal/syscall/windows", "internal/singleflight", "internal/race",
"golang_org/x/net/lif", "golang_org/x/net/route",
"golang_org/x/net/dns/dnsmessage", "golang_org/x/net/lif", "golang_org/x/net/route",
},
// NET enables use of basic network-related packages.
......
......@@ -7,6 +7,8 @@ package net
import (
"math/rand"
"sort"
"golang_org/x/net/dns/dnsmessage"
)
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
......@@ -35,71 +37,13 @@ func reverseaddr(addr string) (arpa string, err error) {
return string(buf), nil
}
// Find answer for name in dns message.
// On return, if err == nil, addrs != nil.
func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) {
addrs = make([]dnsRR, 0, len(dns.answer))
if dns.rcode == dnsRcodeNameError {
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
if dns.rcode != dnsRcodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
if dns.rcode == dnsRcodeServerFailure {
err.IsTemporary = true
}
return "", nil, err
}
// Look for the name.
// Presotto says it's okay to assume that servers listed in
// /etc/resolv.conf are recursive resolvers.
// We asked for recursion, so it should have included
// all the answers we need in this one packet.
Cname:
for cnameloop := 0; cnameloop < 10; cnameloop++ {
addrs = addrs[0:0]
for _, rr := range dns.answer {
if _, justHeader := rr.(*dnsRR_Header); justHeader {
// Corrupt record: we only have a
// header. That header might say it's
// of type qtype, but we don't
// actually have it. Skip.
continue
}
h := rr.Header()
if h.Class == dnsClassINET && equalASCIILabel(h.Name, name) {
switch h.Rrtype {
case qtype:
addrs = append(addrs, rr)
case dnsTypeCNAME:
// redirect to cname
name = rr.(*dnsRR_CNAME).Cname
continue Cname
}
}
}
if len(addrs) == 0 {
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
return name, addrs, nil
}
return "", nil, &DNSError{Err: "too many redirects", Name: name, Server: server}
}
func equalASCIILabel(x, y string) bool {
if len(x) != len(y) {
func equalASCIIName(x, y dnsmessage.Name) bool {
if x.Length != y.Length {
return false
}
for i := 0; i < len(x); i++ {
a := x[i]
b := y[i]
for i := 0; i < int(x.Length); i++ {
a := x.Data[i]
b := y.Data[i]
if 'A' <= a && a <= 'Z' {
a += 0x20
}
......
......@@ -67,51 +67,3 @@ func testWeighting(t *testing.T, margin float64) {
func TestWeighting(t *testing.T) {
testWeighting(t, 0.05)
}
// Issue 8434: verify that Temporary returns true on an error when rcode
// is SERVFAIL
func TestIssue8434(t *testing.T) {
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
rcode: dnsRcodeServerFailure,
},
}
_, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
if err == nil {
t.Fatal("expected an error")
}
if ne, ok := err.(Error); !ok {
t.Fatalf("err = %#v; wanted something supporting net.Error", err)
} else if !ne.Temporary() {
t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
}
if de, ok := err.(*DNSError); !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
} else if !de.IsTemporary {
t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
}
}
// Issue 12778: verify that NXDOMAIN without RA bit errors as
// "no such host" and not "server misbehaving"
func TestIssue12778(t *testing.T) {
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
rcode: dnsRcodeNameError,
recursion_available: false,
},
}
_, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
if err == nil {
t.Fatal("expected an error")
}
de, ok := err.(*DNSError)
if !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
}
if de.Err != errNoSuchHost.Error() {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -9,6 +9,8 @@ package net
import (
"context"
"sync"
"golang_org/x/net/dns/dnsmessage"
)
var onceReadProtocols sync.Once
......@@ -51,7 +53,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
return lookupProtocolMap(name)
}
func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) {
func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
......@@ -68,10 +70,7 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
if err != nil {
return nil, mapErr(err)
}
if _, ok := c.(PacketConn); ok {
return &dnsPacketConn{c}, nil
}
return &dnsStreamConn{c}, nil
return c, nil
}
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
......@@ -98,8 +97,8 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, e
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order)
return
ips, _, err := r.goLookupIPCNAMEOrder(ctx, host, order)
return ips, err
}
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
......@@ -134,53 +133,176 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
} else {
target = "_" + service + "._" + proto + "." + name
}
cname, rrs, err := r.lookup(ctx, target, dnsTypeSRV)
p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
if err != nil {
return "", nil, err
}
srvs := make([]*SRV, len(rrs))
for i, rr := range rrs {
rr := rr.(*dnsRR_SRV)
srvs[i] = &SRV{Target: rr.Target, Port: rr.Port, Priority: rr.Priority, Weight: rr.Weight}
var srvs []*SRV
var cname dnsmessage.Name
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeSRV {
if err := p.SkipAnswer(); err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
srv, err := p.SRVResource()
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
}
byPriorityWeight(srvs).sort()
return cname, srvs, nil
return cname.String(), srvs, nil
}
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
_, rrs, err := r.lookup(ctx, name, dnsTypeMX)
p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
if err != nil {
return nil, err
}
mxs := make([]*MX, len(rrs))
for i, rr := range rrs {
rr := rr.(*dnsRR_MX)
mxs[i] = &MX{Host: rr.Mx, Pref: rr.Pref}
var mxs []*MX
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeMX {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
mx, err := p.MXResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
}
byPref(mxs).sort()
return mxs, nil
}
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
_, rrs, err := r.lookup(ctx, name, dnsTypeNS)
p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
if err != nil {
return nil, err
}
nss := make([]*NS, len(rrs))
for i, rr := range rrs {
nss[i] = &NS{Host: rr.(*dnsRR_NS).Ns}
var nss []*NS
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeNS {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
ns, err := p.NSResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
nss = append(nss, &NS{Host: ns.NS.String()})
}
return nss, nil
}
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
_, rrs, err := r.lookup(ctx, name, dnsTypeTXT)
p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
if err != nil {
return nil, err
}
txts := make([]string, len(rrs))
for i, rr := range rrs {
txts[i] = rr.(*dnsRR_TXT).Txt
var txts []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeTXT {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
txt, err := p.TXTResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if len(txts) == 0 {
txts = txt.TXT
} else {
txts = append(txts, txt.TXT...)
}
}
return txts, nil
}
......
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dnsmessage_test
import (
"fmt"
"net"
"strings"
"golang_org/x/net/dns/dnsmessage"
)
func mustNewName(name string) dnsmessage.Name {
n, err := dnsmessage.NewName(name)
if err != nil {
panic(err)
}
return n
}
func ExampleParser() {
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true},
Questions: []dnsmessage.Question{
{
Name: mustNewName("foo.bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
{
Name: mustNewName("bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
Answers: []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("foo.bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
},
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}},
},
},
}
buf, err := msg.Pack()
if err != nil {
panic(err)
}
wantName := "bar.example.com."
var p dnsmessage.Parser
if _, err := p.Start(buf); err != nil {
panic(err)
}
for {
q, err := p.Question()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
panic(err)
}
if q.Name.String() != wantName {
continue
}
fmt.Println("Found question for name", wantName)
if err := p.SkipAllQuestions(); err != nil {
panic(err)
}
break
}
var gotIPs []net.IP
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
panic(err)
}
if (h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA) || h.Class != dnsmessage.ClassINET {
continue
}
if !strings.EqualFold(h.Name.String(), wantName) {
if err := p.SkipAnswer(); err != nil {
panic(err)
}
continue
}
switch h.Type {
case dnsmessage.TypeA:
r, err := p.AResource()
if err != nil {
panic(err)
}
gotIPs = append(gotIPs, r.A[:])
case dnsmessage.TypeAAAA:
r, err := p.AAAAResource()
if err != nil {
panic(err)
}
gotIPs = append(gotIPs, r.AAAA[:])
}
}
fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
// Output:
// Found question for name bar.example.com.
// Found A/AAAA records for name bar.example.com.: [127.0.0.2]
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment