Commit 672729eb authored by Ian Gudger's avatar Ian Gudger Committed by Brad Fitzpatrick

net: use golang.org/x/net/dns/dnsmessage for DNS resolution

Vendors golang.org/x/net/dns/dnsmessage from x/net git rev
892bf7b0c6e2f93b51166bf3882e50277fa5afc6

Updates #16218
Updates #21160

Change-Id: Ic4e8f3c3d83c2936354ec14c5be93b0d2b42dd91
Reviewed-on: https://go-review.googlesource.com/37879
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent c830e05a
...@@ -313,7 +313,7 @@ var pkgDeps = map[string][]string{ ...@@ -313,7 +313,7 @@ var pkgDeps = map[string][]string{
"context", "math/rand", "os", "reflect", "sort", "syscall", "time", "context", "math/rand", "os", "reflect", "sort", "syscall", "time",
"internal/nettrace", "internal/poll", "internal/nettrace", "internal/poll",
"internal/syscall/windows", "internal/singleflight", "internal/race", "internal/syscall/windows", "internal/singleflight", "internal/race",
"golang_org/x/net/lif", "golang_org/x/net/route", "golang_org/x/net/dns/dnsmessage", "golang_org/x/net/lif", "golang_org/x/net/route",
}, },
// NET enables use of basic network-related packages. // NET enables use of basic network-related packages.
......
...@@ -7,6 +7,8 @@ package net ...@@ -7,6 +7,8 @@ package net
import ( import (
"math/rand" "math/rand"
"sort" "sort"
"golang_org/x/net/dns/dnsmessage"
) )
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP // reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
...@@ -35,71 +37,13 @@ func reverseaddr(addr string) (arpa string, err error) { ...@@ -35,71 +37,13 @@ func reverseaddr(addr string) (arpa string, err error) {
return string(buf), nil return string(buf), nil
} }
// Find answer for name in dns message. func equalASCIIName(x, y dnsmessage.Name) bool {
// On return, if err == nil, addrs != nil. if x.Length != y.Length {
func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) {
addrs = make([]dnsRR, 0, len(dns.answer))
if dns.rcode == dnsRcodeNameError {
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
if dns.rcode != dnsRcodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
if dns.rcode == dnsRcodeServerFailure {
err.IsTemporary = true
}
return "", nil, err
}
// Look for the name.
// Presotto says it's okay to assume that servers listed in
// /etc/resolv.conf are recursive resolvers.
// We asked for recursion, so it should have included
// all the answers we need in this one packet.
Cname:
for cnameloop := 0; cnameloop < 10; cnameloop++ {
addrs = addrs[0:0]
for _, rr := range dns.answer {
if _, justHeader := rr.(*dnsRR_Header); justHeader {
// Corrupt record: we only have a
// header. That header might say it's
// of type qtype, but we don't
// actually have it. Skip.
continue
}
h := rr.Header()
if h.Class == dnsClassINET && equalASCIILabel(h.Name, name) {
switch h.Rrtype {
case qtype:
addrs = append(addrs, rr)
case dnsTypeCNAME:
// redirect to cname
name = rr.(*dnsRR_CNAME).Cname
continue Cname
}
}
}
if len(addrs) == 0 {
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
return name, addrs, nil
}
return "", nil, &DNSError{Err: "too many redirects", Name: name, Server: server}
}
func equalASCIILabel(x, y string) bool {
if len(x) != len(y) {
return false return false
} }
for i := 0; i < len(x); i++ { for i := 0; i < int(x.Length); i++ {
a := x[i] a := x.Data[i]
b := y[i] b := y.Data[i]
if 'A' <= a && a <= 'Z' { if 'A' <= a && a <= 'Z' {
a += 0x20 a += 0x20
} }
......
...@@ -67,51 +67,3 @@ func testWeighting(t *testing.T, margin float64) { ...@@ -67,51 +67,3 @@ func testWeighting(t *testing.T, margin float64) {
func TestWeighting(t *testing.T) { func TestWeighting(t *testing.T) {
testWeighting(t, 0.05) testWeighting(t, 0.05)
} }
// Issue 8434: verify that Temporary returns true on an error when rcode
// is SERVFAIL
func TestIssue8434(t *testing.T) {
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
rcode: dnsRcodeServerFailure,
},
}
_, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
if err == nil {
t.Fatal("expected an error")
}
if ne, ok := err.(Error); !ok {
t.Fatalf("err = %#v; wanted something supporting net.Error", err)
} else if !ne.Temporary() {
t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
}
if de, ok := err.(*DNSError); !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
} else if !de.IsTemporary {
t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
}
}
// Issue 12778: verify that NXDOMAIN without RA bit errors as
// "no such host" and not "server misbehaving"
func TestIssue12778(t *testing.T) {
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
rcode: dnsRcodeNameError,
recursion_available: false,
},
}
_, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
if err == nil {
t.Fatal("expected an error")
}
de, ok := err.(*DNSError)
if !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
}
if de.Err != errNoSuchHost.Error() {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
}
}
...@@ -23,142 +23,231 @@ import ( ...@@ -23,142 +23,231 @@ import (
"os" "os"
"sync" "sync"
"time" "time"
)
// A dnsConn represents a DNS transport endpoint.
type dnsConn interface {
io.Closer
SetDeadline(time.Time) error "golang_org/x/net/dns/dnsmessage"
)
// dnsRoundTrip executes a single DNS transaction, returning a func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
// DNS response message for the provided DNS query message. id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
return 0, nil, nil, err
}
if err := b.Question(q); err != nil {
return 0, nil, nil, err
}
tcpReq, err = b.Finish()
udpReq = tcpReq[2:]
l := len(tcpReq) - 2
tcpReq[0] = byte(l >> 8)
tcpReq[1] = byte(l)
return id, udpReq, tcpReq, err
} }
// dnsPacketConn implements the dnsConn interface for RFC 1035's func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
// "UDP usage" transport mechanism. Conn is a packet-oriented connection, if !respHdr.Response {
// such as a *UDPConn. return false
type dnsPacketConn struct { }
Conn if reqID != respHdr.ID {
return false
}
if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
return false
}
return true
} }
func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
if _, err := c.Write(b); err != nil { if _, err := c.Write(b); err != nil {
return nil, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
b = make([]byte, 512) // see RFC 1035 b = make([]byte, 512) // see RFC 1035
for { for {
n, err := c.Read(b) n, err := c.Read(b)
if err != nil { if err != nil {
return nil, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
resp := &dnsMsg{} var p dnsmessage.Parser
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) { // Ignore invalid responses as they may be malicious
// Ignore invalid responses as they may be malicious // forgery attempts. Instead continue waiting until
// forgery attempts. Instead continue waiting until // timeout. See golang.org/issue/13281.
// timeout. See golang.org/issue/13281. h, err := p.Start(b[:n])
if err != nil {
continue continue
} }
return resp, nil q, err := p.Question()
if err != nil || !checkResponse(id, query, h, q) {
continue
}
return p, h, nil
} }
} }
// dnsStreamConn implements the dnsConn interface for RFC 1035's func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
// such as a *TCPConn.
type dnsStreamConn struct {
Conn
}
func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
l := len(b)
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil { if _, err := c.Write(b); err != nil {
return nil, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil { if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
l = int(b[0])<<8 | int(b[1]) l := int(b[0])<<8 | int(b[1])
if l > len(b) { if l > len(b) {
b = make([]byte, l) b = make([]byte, l)
} }
n, err := io.ReadFull(c, b[:l]) n, err := io.ReadFull(c, b[:l])
if err != nil { if err != nil {
return nil, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
resp := &dnsMsg{} var p dnsmessage.Parser
if !resp.Unpack(b[:n]) { h, err := p.Start(b[:n])
return nil, errors.New("cannot unmarshal DNS message") if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
} }
if !resp.IsResponseTo(query) { q, err := p.Question()
return nil, errors.New("invalid DNS response") if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
} }
return resp, nil if !checkResponse(id, query, h, q) {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
}
return p, h, nil
} }
// exchange sends a query on the connection and hopes for a response. // exchange sends a query on the connection and hopes for a response.
func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
out := dnsMsg{ q.Class = dnsmessage.ClassINET
dnsMsgHdr: dnsMsgHdr{ id, udpReq, tcpReq, err := newRequest(q)
recursion_desired: true, if err != nil {
}, return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot marshal DNS message")
question: []dnsQuestion{
{name, qtype, dnsClassINET},
},
} }
for _, network := range []string{"udp", "tcp"} { for _, network := range []string{"udp", "tcp"} {
// TODO(mdempsky): Refactor so defers from UDP-based
// exchanges happen before TCP-based exchange.
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel() defer cancel()
c, err := r.dial(ctx, network, server) c, err := r.dial(ctx, network, server)
if err != nil { if err != nil {
return nil, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
defer c.Close()
if d, ok := ctx.Deadline(); ok && !d.IsZero() { if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d) c.SetDeadline(d)
} }
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) var p dnsmessage.Parser
in, err := c.dnsRoundTrip(&out) var h dnsmessage.Header
if network == "tcp" {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
} else {
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
}
c.Close()
if err != nil { if err != nil {
return nil, mapErr(err) return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
}
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
} }
if in.truncated { // see RFC 5966 if h.Truncated { // see RFC 5966
continue continue
} }
return in, nil return p, h, nil
}
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("no answer from DNS server")
}
func checkHeaders(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error {
_, err := p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
return &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
return &DNSError{Err: "lame referral", Name: name, Server: server}
}
// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
// it means the response in msg was not useful and trying another
// server probably won't help. Return now in those cases.
// TODO: indicate this in a more obvious way, such as a field on DNSError?
if h.RCode == dnsmessage.RCodeNameError {
return &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
if h.RCode != dnsmessage.RCodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
if h.RCode == dnsmessage.RCodeServerFailure {
err.IsTemporary = true
}
return err
}
return nil
}
func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error {
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
return &DNSError{
Err: errNoSuchHost.Error(),
Name: name,
Server: server,
}
}
if err != nil {
return &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type == qtype {
return nil
}
if err := p.SkipAnswer(); err != nil {
return &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
} }
return nil, errors.New("no answer from DNS server")
} }
// Do a lookup for a single name, which must be rooted // Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers). // (otherwise answer will not find the answers).
func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
var lastErr error var lastErr error
serverOffset := cfg.serverOffset() serverOffset := cfg.serverOffset()
sLen := uint32(len(cfg.servers)) sLen := uint32(len(cfg.servers))
n, err := dnsmessage.NewName(name)
if err != nil {
return dnsmessage.Parser{}, "", errors.New("cannot marshal DNS message")
}
q := dnsmessage.Question{
Name: n,
Type: qtype,
Class: dnsmessage.ClassINET,
}
for i := 0; i < cfg.attempts; i++ { for i := 0; i < cfg.attempts; i++ {
for j := uint32(0); j < sLen; j++ { for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen] server := cfg.servers[(serverOffset+j)%sLen]
msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout) p, h, err := r.exchange(ctx, server, q, cfg.timeout)
if err != nil { if err != nil {
lastErr = &DNSError{ lastErr = &DNSError{
Err: err.Error(), Err: err.Error(),
...@@ -175,41 +264,19 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, ...@@ -175,41 +264,19 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
} }
continue continue
} }
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434. lastErr = checkHeaders(&p, h, name, server)
if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 { if lastErr != nil {
lastErr = &DNSError{Err: "lame referral", Name: name, Server: server}
continue continue
} }
cname, rrs, err := answer(name, server, msg, qtype)
// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError, lastErr = skipToAnswer(&p, qtype, name, server)
// it means the response in msg was not useful and trying another if lastErr == nil {
// server probably won't help. Return now in those cases. return p, server, nil
// TODO: indicate this in a more obvious way, such as a field on DNSError?
if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
return cname, rrs, err
} }
lastErr = err
} }
} }
return "", nil, lastErr return dnsmessage.Parser{}, "", lastErr
}
// addrRecordList converts and returns a list of IP addresses from DNS
// address records (both A and AAAA). Other record types are ignored.
func addrRecordList(rrs []dnsRR) []IPAddr {
addrs := make([]IPAddr, 0, 4)
for _, rr := range rrs {
switch rr := rr.(type) {
case *dnsRR_A:
addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
case *dnsRR_AAAA:
ip := make(IP, IPv6len)
copy(ip, rr.AAAA[:])
addrs = append(addrs, IPAddr{IP: ip})
}
}
return addrs
} }
// A resolverConfig represents a DNS stub resolver configuration. // A resolverConfig represents a DNS stub resolver configuration.
...@@ -287,21 +354,26 @@ func (conf *resolverConfig) releaseSema() { ...@@ -287,21 +354,26 @@ func (conf *resolverConfig) releaseSema() {
<-conf.ch <-conf.ch
} }
func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
if !isDomainName(name) { if !isDomainName(name) {
// We used to use "invalid domain name" as the error, // We used to use "invalid domain name" as the error,
// but that is a detail of the specific lookup mechanism. // but that is a detail of the specific lookup mechanism.
// Other lookups might allow broader name syntax // Other lookups might allow broader name syntax
// (for example Multicast DNS allows UTF-8; see RFC 6762). // (for example Multicast DNS allows UTF-8; see RFC 6762).
// For consistency with libc resolvers, report no such host. // For consistency with libc resolvers, report no such host.
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name} return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
} }
resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock() resolvConf.mu.RLock()
conf := resolvConf.dnsConfig conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock() resolvConf.mu.RUnlock()
var (
p dnsmessage.Parser
server string
err error
)
for _, fqdn := range conf.nameList(name) { for _, fqdn := range conf.nameList(name) {
cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype) p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
if err == nil { if err == nil {
break break
} }
...@@ -311,13 +383,16 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname ...@@ -311,13 +383,16 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname
break break
} }
} }
if err == nil {
return p, server, nil
}
if err, ok := err.(*DNSError); ok { if err, ok := err.(*DNSError); ok {
// Show original name passed to lookup, not suffixed one. // Show original name passed to lookup, not suffixed one.
// In general we might have tried many suffixes; showing // In general we might have tried many suffixes; showing
// just one is misleading. See also golang.org/issue/6324. // just one is misleading. See also golang.org/issue/6324.
err.Name = name err.Name = name
} }
return return dnsmessage.Parser{}, "", err
} }
// avoidDNS reports whether this is a hostname for which we should not // avoidDNS reports whether this is a hostname for which we should not
...@@ -454,36 +529,36 @@ func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr, ...@@ -454,36 +529,36 @@ func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr,
return return
} }
func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) { func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles { if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name) addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles { if len(addrs) > 0 || order == hostLookupFiles {
return addrs, name, nil return addrs, dnsmessage.Name{}, nil
} }
} }
if !isDomainName(name) { if !isDomainName(name) {
// See comment in func lookup above about use of errNoSuchHost. // See comment in func lookup above about use of errNoSuchHost.
return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name} return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name}
} }
resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock() resolvConf.mu.RLock()
conf := resolvConf.dnsConfig conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock() resolvConf.mu.RUnlock()
type racer struct { type racer struct {
cname string p dnsmessage.Parser
rrs []dnsRR server string
error error
} }
lane := make(chan racer, 1) lane := make(chan racer, 1)
qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA} qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
var lastErr error var lastErr error
for _, fqdn := range conf.nameList(name) { for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes { for _, qtype := range qtypes {
dnsWaitGroup.Add(1) dnsWaitGroup.Add(1)
go func(qtype uint16) { go func(qtype dnsmessage.Type) {
defer dnsWaitGroup.Done() p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype) lane <- racer{p, server, err}
lane <- racer{cname, rrs, err} dnsWaitGroup.Done()
}(qtype) }(qtype)
} }
hitStrictError := false hitStrictError := false
...@@ -500,9 +575,74 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order ...@@ -500,9 +575,74 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
} }
continue continue
} }
addrs = append(addrs, addrRecordList(racer.rrs)...)
if cname == "" { // Presotto says it's okay to assume that servers listed in
cname = racer.cname // /etc/resolv.conf are recursive resolvers.
//
// We asked for recursion, so it should have included all the
// answers we need in this one packet.
//
// Further, RFC 1035 section 4.3.1 says that "the recursive
// response to a query will be... The answer to the query,
// possibly preface by one or more CNAME RRs that specify
// aliases encountered on the way to an answer."
//
// Therefore, we should be able to assume that we can ignore
// CNAMEs and that the A and AAAA records we requested are
// for the canonical name.
loop:
for {
h, err := racer.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
}
if err != nil {
break
}
switch h.Type {
case dnsmessage.TypeA:
a, err := racer.p.AResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
case dnsmessage.TypeAAAA:
aaaa, err := racer.p.AAAAResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
default:
if err := racer.p.SkipAnswer(); err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
break loop
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
} }
} }
if hitStrictError { if hitStrictError {
...@@ -528,17 +668,17 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order ...@@ -528,17 +668,17 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
addrs = goLookupIPFiles(name) addrs = goLookupIPFiles(name)
} }
if len(addrs) == 0 && lastErr != nil { if len(addrs) == 0 && lastErr != nil {
return nil, "", lastErr return nil, dnsmessage.Name{}, lastErr
} }
} }
return addrs, cname, nil return addrs, cname, nil
} }
// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME. // goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (cname string, err error) { func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) {
order := systemConf().hostLookupOrder(host) order := systemConf().hostLookupOrder(host)
_, cname, err = r.goLookupIPCNAMEOrder(ctx, host, order) _, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order)
return return cname.String(), err
} }
// goLookupPTR is the native Go implementation of LookupAddr. // goLookupPTR is the native Go implementation of LookupAddr.
...@@ -555,13 +695,36 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, erro ...@@ -555,13 +695,36 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, rrs, err := r.lookup(ctx, arpa, dnsTypePTR) p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ptrs := make([]string, len(rrs)) var ptrs []string
for i, rr := range rrs { for {
ptrs[i] = rr.(*dnsRR_PTR).Ptr h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
if h.Type != dnsmessage.TypePTR {
continue
}
ptr, err := p.PTRResource()
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
ptrs = append(ptrs, ptr.PTR.String())
} }
return ptrs, nil return ptrs, nil
} }
...@@ -19,42 +19,59 @@ import ( ...@@ -19,42 +19,59 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"golang_org/x/net/dns/dnsmessage"
) )
var goResolver = Resolver{PreferGo: true} var goResolver = Resolver{PreferGo: true}
// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation. // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
const TestAddr uint32 = 0xc0000201 var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation. // Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
func mustNewName(name string) dnsmessage.Name {
nn, err := dnsmessage.NewName(name)
if err != nil {
panic(fmt.Sprint("creating name: ", err))
}
return nn
}
func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
return dnsmessage.Question{
Name: mustNewName(name),
Type: qtype,
Class: class,
}
}
var dnsTransportFallbackTests = []struct { var dnsTransportFallbackTests = []struct {
server string server string
name string question dnsmessage.Question
qtype uint16 timeout int
timeout int rcode dnsmessage.RCode
rcode int
}{ }{
// Querying "com." with qtype=255 usually makes an answer // Querying "com." with qtype=255 usually makes an answer
// which requires more than 512 bytes. // which requires more than 512 bytes.
{"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess}, {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
{"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess}, {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
} }
func TestDNSTransportFallback(t *testing.T) { func TestDNSTransportFallback(t *testing.T) {
fake := fakeDNSServer{ fake := fakeDNSServer{
rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.Header.ID,
response: true, Response: true,
rcode: dnsRcodeSuccess, RCode: dnsmessage.RCodeSuccess,
}, },
question: q.question, Questions: q.Questions,
} }
if n == "udp" { if n == "udp" {
r.truncated = true r.Header.Truncated = true
} }
return r, nil return r, nil
}, },
...@@ -63,15 +80,13 @@ func TestDNSTransportFallback(t *testing.T) { ...@@ -63,15 +80,13 @@ func TestDNSTransportFallback(t *testing.T) {
for _, tt := range dnsTransportFallbackTests { for _, tt := range dnsTransportFallbackTests {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second) _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
continue continue
} }
switch msg.rcode { if h.RCode != tt.rcode {
case tt.rcode: t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
default:
t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
continue continue
} }
} }
...@@ -80,39 +95,38 @@ func TestDNSTransportFallback(t *testing.T) { ...@@ -80,39 +95,38 @@ func TestDNSTransportFallback(t *testing.T) {
// See RFC 6761 for further information about the reserved, pseudo // See RFC 6761 for further information about the reserved, pseudo
// domain names. // domain names.
var specialDomainNameTests = []struct { var specialDomainNameTests = []struct {
name string question dnsmessage.Question
qtype uint16 rcode dnsmessage.RCode
rcode int
}{ }{
// Name resolution APIs and libraries should not recognize the // Name resolution APIs and libraries should not recognize the
// followings as special. // followings as special.
{"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError}, {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
{"test.", dnsTypeALL, dnsRcodeNameError}, {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
{"example.com.", dnsTypeALL, dnsRcodeSuccess}, {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
// Name resolution APIs and libraries should recognize the // Name resolution APIs and libraries should recognize the
// followings as special and should not send any queries. // followings as special and should not send any queries.
// Though, we test those names here for verifying negative // Though, we test those names here for verifying negative
// answers at DNS query-response interaction level. // answers at DNS query-response interaction level.
{"localhost.", dnsTypeALL, dnsRcodeNameError}, {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
{"invalid.", dnsTypeALL, dnsRcodeNameError}, {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
} }
func TestSpecialDomainName(t *testing.T) { func TestSpecialDomainName(t *testing.T) {
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
switch q.question[0].Name { switch q.Questions[0].Name.String() {
case "example.com.": case "example.com.":
r.rcode = dnsRcodeSuccess r.Header.RCode = dnsmessage.RCodeSuccess
default: default:
r.rcode = dnsRcodeNameError r.Header.RCode = dnsmessage.RCodeNameError
} }
return r, nil return r, nil
...@@ -122,15 +136,13 @@ func TestSpecialDomainName(t *testing.T) { ...@@ -122,15 +136,13 @@ func TestSpecialDomainName(t *testing.T) {
for _, tt := range specialDomainNameTests { for _, tt := range specialDomainNameTests {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second) _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
continue continue
} }
switch msg.rcode { if h.RCode != tt.rcode {
case tt.rcode, dnsRcodeServerFailure: t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
default:
t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
continue continue
} }
} }
...@@ -177,24 +189,26 @@ func TestAvoidDNSName(t *testing.T) { ...@@ -177,24 +189,26 @@ func TestAvoidDNSName(t *testing.T) {
} }
} }
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA { if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
r.answer = []dnsRR{ r.Answers = []dnsmessage.Resource{
&dnsRR_A{ {
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeA, Type: dnsmessage.TypeA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 4, Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
}, },
A: TestAddr,
}, },
} }
} }
...@@ -459,54 +473,57 @@ var goLookupIPWithResolverConfigTests = []struct { ...@@ -459,54 +473,57 @@ var goLookupIPWithResolverConfigTests = []struct {
func TestGoLookupIPWithResolverConfig(t *testing.T) { func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait() defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
switch s { switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53": case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break break
default: default:
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
return nil, poll.ErrTimeout return dnsmessage.Message{}, poll.ErrTimeout
} }
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
for _, question := range q.question { for _, question := range q.Questions {
switch question.Qtype { switch question.Type {
case dnsTypeA: case dnsmessage.TypeA:
switch question.Name { switch question.Name.String() {
case "hostname.as112.net.": case "hostname.as112.net.":
break break
case "ipv4.google.com.": case "ipv4.google.com.":
r.answer = append(r.answer, &dnsRR_A{ r.Answers = append(r.Answers, dnsmessage.Resource{
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeA, Type: dnsmessage.TypeA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 4, Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
}, },
A: TestAddr,
}) })
default: default:
} }
case dnsTypeAAAA: case dnsmessage.TypeAAAA:
switch question.Name { switch question.Name.String() {
case "hostname.as112.net.": case "hostname.as112.net.":
break break
case "ipv6.google.com.": case "ipv6.google.com.":
r.answer = append(r.answer, &dnsRR_AAAA{ r.Answers = append(r.Answers, dnsmessage.Resource{
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeAAAA, Type: dnsmessage.TypeAAAA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 16, Length: 16,
},
Body: &dnsmessage.AAAAResource{
AAAA: TestAddr6,
}, },
AAAA: TestAddr6,
}) })
} }
} }
...@@ -554,13 +571,13 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { ...@@ -554,13 +571,13 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
func TestGoLookupIPOrderFallbackToFile(t *testing.T) { func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait() defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
return r, nil return r, nil
}} }}
...@@ -624,20 +641,20 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -624,20 +641,20 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
switch q.question[0].Name { switch q.Questions[0].Name.String() {
case fqdn + ".servfail.": case fqdn + ".servfail.":
r.rcode = dnsRcodeServerFailure r.Header.RCode = dnsmessage.RCodeServerFailure
default: default:
r.rcode = dnsRcodeNameError r.Header.RCode = dnsmessage.RCodeNameError
} }
return r, nil return r, nil
...@@ -679,28 +696,30 @@ func TestIgnoreLameReferrals(t *testing.T) { ...@@ -679,28 +696,30 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q) t.Log(s, q)
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
if s == "192.0.2.2:53" { if s == "192.0.2.2:53" {
r.recursion_available = true r.Header.RecursionAvailable = true
if q.question[0].Qtype == dnsTypeA { if q.Questions[0].Type == dnsmessage.TypeA {
r.answer = []dnsRR{ r.Answers = []dnsmessage.Resource{
&dnsRR_A{ {
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeA, Type: dnsmessage.TypeA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 4, Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
}, },
A: TestAddr,
}, },
} }
} }
...@@ -766,20 +785,23 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { ...@@ -766,20 +785,23 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
} }
type fakeDNSServer struct { type fakeDNSServer struct {
rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
} }
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
} }
type fakeDNSConn struct { type fakeDNSConn struct {
Conn Conn
tcp bool
server *fakeDNSServer server *fakeDNSServer
n string n string
s string s string
q *dnsMsg q dnsmessage.Message
t time.Time t time.Time
buf []byte
} }
func (f *fakeDNSConn) Close() error { func (f *fakeDNSConn) Close() error {
...@@ -787,15 +809,32 @@ func (f *fakeDNSConn) Close() error { ...@@ -787,15 +809,32 @@ func (f *fakeDNSConn) Close() error {
} }
func (f *fakeDNSConn) Read(b []byte) (int, error) { func (f *fakeDNSConn) Read(b []byte) (int, error) {
if len(f.buf) > 0 {
n := copy(b, f.buf)
f.buf = f.buf[n:]
return n, nil
}
resp, err := f.server.rh(f.n, f.s, f.q, f.t) resp, err := f.server.rh(f.n, f.s, f.q, f.t)
if err != nil { if err != nil {
return 0, err return 0, err
} }
bb, ok := resp.Pack() bb := make([]byte, 2, 514)
if !ok { bb, err = resp.AppendPack(bb)
return 0, errors.New("cannot marshal DNS message") if err != nil {
return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
} }
if f.tcp {
l := len(bb) - 2
bb[0] = byte(l >> 8)
bb[1] = byte(l)
f.buf = bb
return f.Read(b)
}
bb = bb[2:]
if len(b) < len(bb) { if len(b) < len(bb) {
return 0, errors.New("read would fragment DNS message") return 0, errors.New("read would fragment DNS message")
} }
...@@ -809,9 +848,11 @@ func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) { ...@@ -809,9 +848,11 @@ func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
} }
func (f *fakeDNSConn) Write(b []byte) (int, error) { func (f *fakeDNSConn) Write(b []byte) (int, error) {
f.q = new(dnsMsg) if f.tcp && len(b) >= 2 {
if !f.q.Unpack(b) { b = b[2:]
return 0, errors.New("cannot unmarshal DNS message") }
if f.q.Unpack(b) != nil {
return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
} }
return len(b), nil return len(b), nil
} }
...@@ -836,64 +877,75 @@ func TestIgnoreDNSForgeries(t *testing.T) { ...@@ -836,64 +877,75 @@ func TestIgnoreDNSForgeries(t *testing.T) {
return return
} }
msg := &dnsMsg{} var msg dnsmessage.Message
if !msg.Unpack(b[:n]) { if msg.Unpack(b[:n]) != nil {
t.Error("invalid DNS query") t.Error("invalid DNS query:", err)
return return
} }
s.Write([]byte("garbage DNS response packet")) s.Write([]byte("garbage DNS response packet"))
msg.response = true msg.Header.Response = true
msg.id++ // make invalid ID msg.Header.ID++ // make invalid ID
b, ok := msg.Pack()
if !ok { if b, err = msg.Pack(); err != nil {
t.Error("failed to pack DNS response") t.Error("failed to pack DNS response:", err)
return return
} }
s.Write(b) s.Write(b)
msg.id-- // restore original ID msg.Header.ID-- // restore original ID
msg.answer = []dnsRR{ msg.Answers = []dnsmessage.Resource{
&dnsRR_A{ {
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: "www.example.com.", Name: mustNewName("www.example.com."),
Rrtype: dnsTypeA, Type: dnsmessage.TypeA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 4, Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
}, },
A: TestAddr,
}, },
} }
b, ok = msg.Pack() b, err = msg.Pack()
if !ok { if err != nil {
t.Error("failed to pack DNS response") t.Error("failed to pack DNS response:", err)
return return
} }
s.Write(b) s.Write(b)
}() }()
msg := &dnsMsg{ msg := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: 42, ID: 42,
}, },
question: []dnsQuestion{ Questions: []dnsmessage.Question{
{ {
Name: "www.example.com.", Name: mustNewName("www.example.com."),
Qtype: dnsTypeA, Type: dnsmessage.TypeA,
Qclass: dnsClassINET, Class: dnsmessage.ClassINET,
}, },
}, },
} }
dc := &dnsPacketConn{c} b, err := msg.Pack()
resp, err := dc.dnsRoundTrip(msg)
if err != nil { if err != nil {
t.Fatalf("dnsRoundTripUDP failed: %v", err) t.Fatal("Pack failed:", err)
} }
if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr { p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
if err != nil {
t.Fatalf("dnsPacketRoundTrip failed: %v", err)
}
p.SkipAllQuestions()
as, err := p.AllAnswers()
if err != nil {
t.Fatal("AllAnswers failed:", err)
}
if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
t.Errorf("got address %v, want %v", got, TestAddr) t.Errorf("got address %v, want %v", got, TestAddr)
} }
} }
...@@ -918,7 +970,7 @@ func TestRetryTimeout(t *testing.T) { ...@@ -918,7 +970,7 @@ func TestRetryTimeout(t *testing.T) {
var deadline0 time.Time var deadline0 time.Time
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline) t.Log(s, q, deadline)
if deadline.IsZero() { if deadline.IsZero() {
...@@ -928,7 +980,7 @@ func TestRetryTimeout(t *testing.T) { ...@@ -928,7 +980,7 @@ func TestRetryTimeout(t *testing.T) {
if s == "192.0.2.1:53" { if s == "192.0.2.1:53" {
deadline0 = deadline deadline0 = deadline
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
return nil, poll.ErrTimeout return dnsmessage.Message{}, poll.ErrTimeout
} }
if deadline.Equal(deadline0) { if deadline.Equal(deadline0) {
...@@ -979,7 +1031,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { ...@@ -979,7 +1031,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
} }
var usedServers []string var usedServers []string
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s) usedServers = append(usedServers, s)
return mockTXTResponse(q), nil return mockTXTResponse(q), nil
}} }}
...@@ -997,22 +1049,24 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { ...@@ -997,22 +1049,24 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
} }
} }
func mockTXTResponse(q *dnsMsg) *dnsMsg { func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
recursion_available: true, RecursionAvailable: true,
}, },
question: q.question, Questions: q.Questions,
answer: []dnsRR{ Answers: []dnsmessage.Resource{
&dnsRR_TXT{ {
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeTXT, Type: dnsmessage.TypeTXT,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.TXTResource{
TXT: []string{"ok"},
}, },
Txt: "ok",
}, },
}, },
} }
...@@ -1080,22 +1134,22 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1080,22 +1134,22 @@ func TestStrictErrorsLookupIP(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string
resolveWhich func(quest *dnsQuestion) resolveWhichEnum resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
wantStrictErr error wantStrictErr error
wantLaxErr error wantLaxErr error
wantIPs []string wantIPs []string
}{ }{
{ {
desc: "No errors", desc: "No errors",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
return resolveOK return resolveOK
}, },
wantIPs: []string{ip4, ip6}, wantIPs: []string{ip4, ip6},
}, },
{ {
desc: "searchX error fails in strict mode", desc: "searchX error fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name == searchX { if quest.Name.String() == searchX {
return resolveTimeout return resolveTimeout
} }
return resolveOK return resolveOK
...@@ -1105,8 +1159,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1105,8 +1159,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}, },
{ {
desc: "searchX IPv4-only timeout fails in strict mode", desc: "searchX IPv4-only timeout fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name == searchX && quest.Qtype == dnsTypeA { if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
return resolveTimeout return resolveTimeout
} }
return resolveOK return resolveOK
...@@ -1116,8 +1170,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1116,8 +1170,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}, },
{ {
desc: "searchX IPv6-only servfail fails in strict mode", desc: "searchX IPv6-only servfail fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name == searchX && quest.Qtype == dnsTypeAAAA { if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
return resolveServfail return resolveServfail
} }
return resolveOK return resolveOK
...@@ -1127,8 +1181,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1127,8 +1181,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}, },
{ {
desc: "searchY error always fails", desc: "searchY error always fails",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name == searchY { if quest.Name.String() == searchY {
return resolveTimeout return resolveTimeout
} }
return resolveOK return resolveOK
...@@ -1138,8 +1192,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1138,8 +1192,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}, },
{ {
desc: "searchY IPv4-only socket error fails in strict mode", desc: "searchY IPv4-only socket error fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name == searchY && quest.Qtype == dnsTypeA { if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
return resolveOpError return resolveOpError
} }
return resolveOK return resolveOK
...@@ -1149,8 +1203,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1149,8 +1203,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}, },
{ {
desc: "searchY IPv6-only timeout fails in strict mode", desc: "searchY IPv6-only timeout fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name == searchY && quest.Qtype == dnsTypeAAAA { if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
return resolveTimeout return resolveTimeout
} }
return resolveOK return resolveOK
...@@ -1161,80 +1215,84 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1161,80 +1215,84 @@ func TestStrictErrorsLookupIP(t *testing.T) {
} }
for i, tt := range cases { for i, tt := range cases {
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q) t.Log(s, q)
switch tt.resolveWhich(&q.question[0]) { switch tt.resolveWhich(q.Questions[0]) {
case resolveOK: case resolveOK:
// Handle below. // Handle below.
case resolveOpError: case resolveOpError:
return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
case resolveServfail: case resolveServfail:
return &dnsMsg{ return dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
rcode: dnsRcodeServerFailure, RCode: dnsmessage.RCodeServerFailure,
}, },
question: q.question, Questions: q.Questions,
}, nil }, nil
case resolveTimeout: case resolveTimeout:
return nil, poll.ErrTimeout return dnsmessage.Message{}, poll.ErrTimeout
default: default:
t.Fatal("Impossible resolveWhich") t.Fatal("Impossible resolveWhich")
} }
switch q.question[0].Name { switch q.Questions[0].Name.String() {
case searchX, name + ".": case searchX, name + ".":
// Return NXDOMAIN to utilize the search list. // Return NXDOMAIN to utilize the search list.
return &dnsMsg{ return dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
rcode: dnsRcodeNameError, RCode: dnsmessage.RCodeNameError,
}, },
question: q.question, Questions: q.Questions,
}, nil }, nil
case searchY: case searchY:
// Return records below. // Return records below.
default: default:
return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
} }
r := &dnsMsg{ r := dnsmessage.Message{
dnsMsgHdr: dnsMsgHdr{ Header: dnsmessage.Header{
id: q.id, ID: q.ID,
response: true, Response: true,
}, },
question: q.question, Questions: q.Questions,
} }
switch q.question[0].Qtype { switch q.Questions[0].Type {
case dnsTypeA: case dnsmessage.TypeA:
r.answer = []dnsRR{ r.Answers = []dnsmessage.Resource{
&dnsRR_A{ {
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeA, Type: dnsmessage.TypeA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 4, Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
}, },
A: TestAddr,
}, },
} }
case dnsTypeAAAA: case dnsmessage.TypeAAAA:
r.answer = []dnsRR{ r.Answers = []dnsmessage.Resource{
&dnsRR_AAAA{ {
Hdr: dnsRR_Header{ Header: dnsmessage.ResourceHeader{
Name: q.question[0].Name, Name: q.Questions[0].Name,
Rrtype: dnsTypeAAAA, Type: dnsmessage.TypeAAAA,
Class: dnsClassINET, Class: dnsmessage.ClassINET,
Rdlength: 16, Length: 16,
},
Body: &dnsmessage.AAAAResource{
AAAA: TestAddr6,
}, },
AAAA: TestAddr6,
}, },
} }
default: default:
return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype) return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
} }
return r, nil return r, nil
}} }}
...@@ -1295,22 +1353,22 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1295,22 +1353,22 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org." const searchY = "test.y.golang.org."
const txt = "Hello World" const txt = "Hello World"
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q) t.Log(s, q)
switch q.question[0].Name { switch q.Questions[0].Name.String() {
case searchX: case searchX:
return nil, poll.ErrTimeout return dnsmessage.Message{}, poll.ErrTimeout
case searchY: case searchY:
return mockTXTResponse(q), nil return mockTXTResponse(q), nil
default: default:
return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
} }
}} }}
for _, strict := range []bool{true, false} { for _, strict := range []bool{true, false} {
r := Resolver{StrictErrors: strict, Dial: fake.DialContext} r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
_, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT) p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
var wantErr error var wantErr error
var wantRRs int var wantRRs int
if strict { if strict {
...@@ -1326,8 +1384,12 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1326,8 +1384,12 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
if !reflect.DeepEqual(err, wantErr) { if !reflect.DeepEqual(err, wantErr) {
t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr) t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
} }
if len(rrs) != wantRRs { a, err := p.AllAnswers()
t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs) if err != nil {
a = nil
}
if len(a) != wantRRs {
t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
} }
} }
} }
...@@ -1337,9 +1399,9 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1337,9 +1399,9 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
func TestDNSGoroutineRace(t *testing.T) { func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait() defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) { fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond) time.Sleep(10 * time.Microsecond)
return nil, poll.ErrTimeout return dnsmessage.Message{}, poll.ErrTimeout
}} }}
r := Resolver{PreferGo: true, Dial: fake.DialContext} r := Resolver{PreferGo: true, Dial: fake.DialContext}
...@@ -1353,3 +1415,76 @@ func TestDNSGoroutineRace(t *testing.T) { ...@@ -1353,3 +1415,76 @@ func TestDNSGoroutineRace(t *testing.T) {
t.Fatal("fake DNS lookup unexpectedly succeeded") t.Fatal("fake DNS lookup unexpectedly succeeded")
} }
} }
// Issue 8434: verify that Temporary returns true on an error when rcode
// is SERVFAIL
func TestIssue8434(t *testing.T) {
msg := dnsmessage.Message{
Header: dnsmessage.Header{
RCode: dnsmessage.RCodeServerFailure,
},
}
b, err := msg.Pack()
if err != nil {
t.Fatal("Pack failed:", err)
}
var p dnsmessage.Parser
h, err := p.Start(b)
if err != nil {
t.Fatal("Start failed:", err)
}
if err := p.SkipAllQuestions(); err != nil {
t.Fatal("SkipAllQuestions failed:", err)
}
err = checkHeaders(&p, h, "golang.org", "foo:53")
if err == nil {
t.Fatal("expected an error")
}
if ne, ok := err.(Error); !ok {
t.Fatalf("err = %#v; wanted something supporting net.Error", err)
} else if !ne.Temporary() {
t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
}
if de, ok := err.(*DNSError); !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
} else if !de.IsTemporary {
t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
}
}
// Issue 12778: verify that NXDOMAIN without RA bit errors as
// "no such host" and not "server misbehaving"
func TestIssue12778(t *testing.T) {
msg := dnsmessage.Message{
Header: dnsmessage.Header{
RCode: dnsmessage.RCodeNameError,
RecursionAvailable: false,
},
}
b, err := msg.Pack()
if err != nil {
t.Fatal("Pack failed:", err)
}
var p dnsmessage.Parser
h, err := p.Start(b)
if err != nil {
t.Fatal("Start failed:", err)
}
if err := p.SkipAllQuestions(); err != nil {
t.Fatal("SkipAllQuestions failed:", err)
}
err = checkHeaders(&p, h, "golang.org", "foo:53")
if err == nil {
t.Fatal("expected an error")
}
de, ok := err.(*DNSError)
if !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
}
if de.Err != errNoSuchHost.Error() {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DNS packet assembly. See RFC 1035.
//
// This is intended to support name resolution during Dial.
// It doesn't have to be blazing fast.
//
// Each message structure has a Walk method that is used by
// a generic pack/unpack routine. Thus, if in the future we need
// to define new message structs, no new pack/unpack/printing code
// needs to be written.
//
// The first half of this file defines the DNS message formats.
// The second half implements the conversion to and from wire format.
// A few of the structure elements have string tags to aid the
// generic pack/unpack routines.
//
// TODO(rsc): There are enough names defined in this file that they're all
// prefixed with dns. Perhaps put this in its own package later.
package net
// Packet formats
// Wire constants.
const (
// valid dnsRR_Header.Rrtype and dnsQuestion.qtype
dnsTypeA = 1
dnsTypeNS = 2
dnsTypeMD = 3
dnsTypeMF = 4
dnsTypeCNAME = 5
dnsTypeSOA = 6
dnsTypeMB = 7
dnsTypeMG = 8
dnsTypeMR = 9
dnsTypeNULL = 10
dnsTypeWKS = 11
dnsTypePTR = 12
dnsTypeHINFO = 13
dnsTypeMINFO = 14
dnsTypeMX = 15
dnsTypeTXT = 16
dnsTypeAAAA = 28
dnsTypeSRV = 33
// valid dnsQuestion.qtype only
dnsTypeAXFR = 252
dnsTypeMAILB = 253
dnsTypeMAILA = 254
dnsTypeALL = 255
// valid dnsQuestion.qclass
dnsClassINET = 1
dnsClassCSNET = 2
dnsClassCHAOS = 3
dnsClassHESIOD = 4
dnsClassANY = 255
// dnsMsg.rcode
dnsRcodeSuccess = 0
dnsRcodeFormatError = 1
dnsRcodeServerFailure = 2
dnsRcodeNameError = 3
dnsRcodeNotImplemented = 4
dnsRcodeRefused = 5
)
// A dnsStruct describes how to iterate over its fields to emulate
// reflective marshaling.
type dnsStruct interface {
// Walk iterates over fields of a structure and calls f
// with a reference to that field, the name of the field
// and a tag ("", "domain", "ipv4", "ipv6") specifying
// particular encodings. Possible concrete types
// for v are *uint16, *uint32, *string, or []byte, and
// *int, *bool in the case of dnsMsgHdr.
// Whenever f returns false, Walk must stop and return
// false, and otherwise return true.
Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool)
}
// The wire format for the DNS packet header.
type dnsHeader struct {
Id uint16
Bits uint16
Qdcount, Ancount, Nscount, Arcount uint16
}
func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&h.Id, "Id", "") &&
f(&h.Bits, "Bits", "") &&
f(&h.Qdcount, "Qdcount", "") &&
f(&h.Ancount, "Ancount", "") &&
f(&h.Nscount, "Nscount", "") &&
f(&h.Arcount, "Arcount", "")
}
const (
// dnsHeader.Bits
_QR = 1 << 15 // query/response (response=1)
_AA = 1 << 10 // authoritative
_TC = 1 << 9 // truncated
_RD = 1 << 8 // recursion desired
_RA = 1 << 7 // recursion available
)
// DNS queries.
type dnsQuestion struct {
Name string
Qtype uint16
Qclass uint16
}
func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&q.Name, "Name", "domain") &&
f(&q.Qtype, "Qtype", "") &&
f(&q.Qclass, "Qclass", "")
}
// DNS responses (resource records).
// There are many types of messages,
// but they all share the same header.
type dnsRR_Header struct {
Name string
Rrtype uint16
Class uint16
Ttl uint32
Rdlength uint16 // length of data after header
}
func (h *dnsRR_Header) Header() *dnsRR_Header {
return h
}
func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&h.Name, "Name", "domain") &&
f(&h.Rrtype, "Rrtype", "") &&
f(&h.Class, "Class", "") &&
f(&h.Ttl, "Ttl", "") &&
f(&h.Rdlength, "Rdlength", "")
}
type dnsRR interface {
dnsStruct
Header() *dnsRR_Header
}
// Specific DNS RR formats for each query type.
type dnsRR_CNAME struct {
Hdr dnsRR_Header
Cname string
}
func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain")
}
type dnsRR_MX struct {
Hdr dnsRR_Header
Pref uint16
Mx string
}
func (rr *dnsRR_MX) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain")
}
type dnsRR_NS struct {
Hdr dnsRR_Header
Ns string
}
func (rr *dnsRR_NS) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain")
}
type dnsRR_PTR struct {
Hdr dnsRR_Header
Ptr string
}
func (rr *dnsRR_PTR) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain")
}
type dnsRR_SOA struct {
Hdr dnsRR_Header
Ns string
Mbox string
Serial uint32
Refresh uint32
Retry uint32
Expire uint32
Minttl uint32
}
func (rr *dnsRR_SOA) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) &&
f(&rr.Ns, "Ns", "domain") &&
f(&rr.Mbox, "Mbox", "domain") &&
f(&rr.Serial, "Serial", "") &&
f(&rr.Refresh, "Refresh", "") &&
f(&rr.Retry, "Retry", "") &&
f(&rr.Expire, "Expire", "") &&
f(&rr.Minttl, "Minttl", "")
}
type dnsRR_TXT struct {
Hdr dnsRR_Header
Txt string // not domain name
}
func (rr *dnsRR_TXT) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool {
if !rr.Hdr.Walk(f) {
return false
}
var n uint16 = 0
for n < rr.Hdr.Rdlength {
var txt string
if !f(&txt, "Txt", "") {
return false
}
// more bytes than rr.Hdr.Rdlength said there would be
if rr.Hdr.Rdlength-n < uint16(len(txt))+1 {
return false
}
n += uint16(len(txt)) + 1
rr.Txt += txt
}
return true
}
type dnsRR_SRV struct {
Hdr dnsRR_Header
Priority uint16
Weight uint16
Port uint16
Target string
}
func (rr *dnsRR_SRV) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) &&
f(&rr.Priority, "Priority", "") &&
f(&rr.Weight, "Weight", "") &&
f(&rr.Port, "Port", "") &&
f(&rr.Target, "Target", "domain")
}
type dnsRR_A struct {
Hdr dnsRR_Header
A uint32
}
func (rr *dnsRR_A) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4")
}
type dnsRR_AAAA struct {
Hdr dnsRR_Header
AAAA [16]byte
}
func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
return &rr.Hdr
}
func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6")
}
// Packing and unpacking.
//
// All the packers and unpackers take a (msg []byte, off int)
// and return (off1 int, ok bool). If they return ok==false, they
// also return off1==len(msg), so that the next unpacker will
// also fail. This lets us avoid checks of ok until the end of a
// packing sequence.
// Map of constructors for each RR wire type.
var rr_mk = map[int]func() dnsRR{
dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
dnsTypeA: func() dnsRR { return new(dnsRR_A) },
dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
}
// Pack a domain name s into msg[off:].
// Domain names are a sequence of counted strings
// split at the dots. They end with a zero-length string.
func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
// Add trailing dot to canonicalize name.
if n := len(s); n == 0 || s[n-1] != '.' {
s += "."
}
// Allow root domain.
if s == "." {
msg[off] = 0
off++
return off, true
}
// Each dot ends a segment of the name.
// We trade each dot byte for a length byte.
// There is also a trailing zero.
// Check that we have all the space we need.
tot := len(s) + 1
if off+tot > len(msg) {
return len(msg), false
}
// Emit sequence of counted strings, chopping at dots.
begin := 0
for i := 0; i < len(s); i++ {
if s[i] == '.' {
if i-begin >= 1<<6 { // top two bits of length must be clear
return len(msg), false
}
if i-begin == 0 {
return len(msg), false
}
msg[off] = byte(i - begin)
off++
for j := begin; j < i; j++ {
msg[off] = s[j]
off++
}
begin = i + 1
}
}
msg[off] = 0
off++
return off, true
}
// Unpack a domain name.
// In addition to the simple sequences of counted strings above,
// domain names are allowed to refer to strings elsewhere in the
// packet, to avoid repeating common suffixes when returning
// many entries in a single domain. The pointers are marked
// by a length byte with the top two bits set. Ignoring those
// two bits, that byte and the next give a 14 bit offset from msg[0]
// where we should pick up the trail.
// Note that if we jump elsewhere in the packet,
// we return off1 == the offset after the first pointer we found,
// which is where the next record will start.
// In theory, the pointers are only allowed to jump backward.
// We let them jump anywhere and stop jumping after a while.
func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
s = ""
ptr := 0 // number of pointers followed
Loop:
for {
if off >= len(msg) {
return "", len(msg), false
}
c := int(msg[off])
off++
switch c & 0xC0 {
case 0x00:
if c == 0x00 {
// end of name
break Loop
}
// literal string
if off+c > len(msg) {
return "", len(msg), false
}
s += string(msg[off:off+c]) + "."
off += c
case 0xC0:
// pointer to somewhere else in msg.
// remember location after first ptr,
// since that's how many bytes we consumed.
// also, don't follow too many pointers --
// maybe there's a loop.
if off >= len(msg) {
return "", len(msg), false
}
c1 := msg[off]
off++
if ptr == 0 {
off1 = off
}
if ptr++; ptr > 10 {
return "", len(msg), false
}
off = (c^0xC0)<<8 | int(c1)
default:
// 0x80 and 0x40 are reserved
return "", len(msg), false
}
}
if len(s) == 0 {
s = "."
}
if ptr == 0 {
off1 = off
}
return s, off1, true
}
// packStruct packs a structure into msg at specified offset off, and
// returns off1 such that msg[off:off1] is the encoded data.
func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
ok = any.Walk(func(field interface{}, name, tag string) bool {
switch fv := field.(type) {
default:
println("net: dns: unknown packing type")
return false
case *uint16:
i := *fv
if off+2 > len(msg) {
return false
}
msg[off] = byte(i >> 8)
msg[off+1] = byte(i)
off += 2
case *uint32:
i := *fv
msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8)
msg[off+3] = byte(i)
off += 4
case []byte:
n := len(fv)
if off+n > len(msg) {
return false
}
copy(msg[off:off+n], fv)
off += n
case *string:
s := *fv
switch tag {
default:
println("net: dns: unknown string tag", tag)
return false
case "domain":
off, ok = packDomainName(s, msg, off)
if !ok {
return false
}
case "":
// Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > len(msg) {
return false
}
msg[off] = byte(len(s))
off++
off += copy(msg[off:], s)
}
}
return true
})
if !ok {
return len(msg), false
}
return off, true
}
// unpackStruct decodes msg[off:] into the given structure, and
// returns off1 such that msg[off:off1] is the encoded data.
func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
ok = any.Walk(func(field interface{}, name, tag string) bool {
switch fv := field.(type) {
default:
println("net: dns: unknown packing type")
return false
case *uint16:
if off+2 > len(msg) {
return false
}
*fv = uint16(msg[off])<<8 | uint16(msg[off+1])
off += 2
case *uint32:
if off+4 > len(msg) {
return false
}
*fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 |
uint32(msg[off+2])<<8 | uint32(msg[off+3])
off += 4
case []byte:
n := len(fv)
if off+n > len(msg) {
return false
}
copy(fv, msg[off:off+n])
off += n
case *string:
var s string
switch tag {
default:
println("net: dns: unknown string tag", tag)
return false
case "domain":
s, off, ok = unpackDomainName(msg, off)
if !ok {
return false
}
case "":
if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
return false
}
n := int(msg[off])
off++
b := make([]byte, n)
for i := 0; i < n; i++ {
b[i] = msg[off+i]
}
off += n
s = string(b)
}
*fv = s
}
return true
})
if !ok {
return len(msg), false
}
return off, true
}
// Generic struct printer. Prints fields with tag "ipv4" or "ipv6"
// as IP addresses.
func printStruct(any dnsStruct) string {
s := "{"
i := 0
any.Walk(func(val interface{}, name, tag string) bool {
i++
if i > 1 {
s += ", "
}
s += name + "="
switch tag {
case "ipv4":
i := *val.(*uint32)
s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
case "ipv6":
i := val.([]byte)
s += IP(i).String()
default:
var i int64
switch v := val.(type) {
default:
// can't really happen.
s += "<unknown type>"
return true
case *string:
s += *v
return true
case []byte:
s += string(v)
return true
case *bool:
if *v {
s += "true"
} else {
s += "false"
}
return true
case *int:
i = int64(*v)
case *uint:
i = int64(*v)
case *uint8:
i = int64(*v)
case *uint16:
i = int64(*v)
case *uint32:
i = int64(*v)
case *uint64:
i = int64(*v)
case *uintptr:
i = int64(*v)
}
s += itoa(int(i))
}
return true
})
s += "}"
return s
}
// Resource record packer.
func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
var off1 int
// pack twice, once to find end of header
// and again to find end of packet.
// a bit inefficient but this doesn't need to be fast.
// off1 is end of header
// off2 is end of rr
off1, ok = packStruct(rr.Header(), msg, off)
if !ok {
return len(msg), false
}
off2, ok = packStruct(rr, msg, off)
if !ok {
return len(msg), false
}
// pack a third time; redo header with correct data length
rr.Header().Rdlength = uint16(off2 - off1)
packStruct(rr.Header(), msg, off)
return off2, true
}
// Resource record unpacker.
func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
// unpack just the header, to find the rr type and length
var h dnsRR_Header
off0 := off
if off, ok = unpackStruct(&h, msg, off); !ok {
return nil, len(msg), false
}
end := off + int(h.Rdlength)
// make an rr of that type and re-unpack.
// again inefficient but doesn't need to be fast.
mk, known := rr_mk[int(h.Rrtype)]
if !known {
return &h, end, true
}
rr = mk()
off, ok = unpackStruct(rr, msg, off0)
if off != end {
return &h, end, true
}
return rr, off, ok
}
// Usable representation of a DNS packet.
// A manually-unpacked version of (id, bits).
// This is in its own struct for easy printing.
type dnsMsgHdr struct {
id uint16
response bool
opcode int
authoritative bool
truncated bool
recursion_desired bool
recursion_available bool
rcode int
}
func (h *dnsMsgHdr) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&h.id, "id", "") &&
f(&h.response, "response", "") &&
f(&h.opcode, "opcode", "") &&
f(&h.authoritative, "authoritative", "") &&
f(&h.truncated, "truncated", "") &&
f(&h.recursion_desired, "recursion_desired", "") &&
f(&h.recursion_available, "recursion_available", "") &&
f(&h.rcode, "rcode", "")
}
type dnsMsg struct {
dnsMsgHdr
question []dnsQuestion
answer []dnsRR
ns []dnsRR
extra []dnsRR
}
func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
var dh dnsHeader
// Convert convenient dnsMsg into wire-like dnsHeader.
dh.Id = dns.id
dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
if dns.recursion_available {
dh.Bits |= _RA
}
if dns.recursion_desired {
dh.Bits |= _RD
}
if dns.truncated {
dh.Bits |= _TC
}
if dns.authoritative {
dh.Bits |= _AA
}
if dns.response {
dh.Bits |= _QR
}
// Prepare variable sized arrays.
question := dns.question
answer := dns.answer
ns := dns.ns
extra := dns.extra
dh.Qdcount = uint16(len(question))
dh.Ancount = uint16(len(answer))
dh.Nscount = uint16(len(ns))
dh.Arcount = uint16(len(extra))
// Could work harder to calculate message size,
// but this is far more than we need and not
// big enough to hurt the allocator.
msg = make([]byte, 2000)
// Pack it in: header and then the pieces.
off := 0
off, ok = packStruct(&dh, msg, off)
if !ok {
return nil, false
}
for i := 0; i < len(question); i++ {
off, ok = packStruct(&question[i], msg, off)
if !ok {
return nil, false
}
}
for i := 0; i < len(answer); i++ {
off, ok = packRR(answer[i], msg, off)
if !ok {
return nil, false
}
}
for i := 0; i < len(ns); i++ {
off, ok = packRR(ns[i], msg, off)
if !ok {
return nil, false
}
}
for i := 0; i < len(extra); i++ {
off, ok = packRR(extra[i], msg, off)
if !ok {
return nil, false
}
}
return msg[0:off], true
}
func (dns *dnsMsg) Unpack(msg []byte) bool {
// Header.
var dh dnsHeader
off := 0
var ok bool
if off, ok = unpackStruct(&dh, msg, off); !ok {
return false
}
dns.id = dh.Id
dns.response = (dh.Bits & _QR) != 0
dns.opcode = int(dh.Bits>>11) & 0xF
dns.authoritative = (dh.Bits & _AA) != 0
dns.truncated = (dh.Bits & _TC) != 0
dns.recursion_desired = (dh.Bits & _RD) != 0
dns.recursion_available = (dh.Bits & _RA) != 0
dns.rcode = int(dh.Bits & 0xF)
// Arrays.
dns.question = make([]dnsQuestion, dh.Qdcount)
dns.answer = make([]dnsRR, 0, dh.Ancount)
dns.ns = make([]dnsRR, 0, dh.Nscount)
dns.extra = make([]dnsRR, 0, dh.Arcount)
var rec dnsRR
for i := 0; i < len(dns.question); i++ {
off, ok = unpackStruct(&dns.question[i], msg, off)
if !ok {
return false
}
}
for i := 0; i < int(dh.Ancount); i++ {
rec, off, ok = unpackRR(msg, off)
if !ok {
return false
}
dns.answer = append(dns.answer, rec)
}
for i := 0; i < int(dh.Nscount); i++ {
rec, off, ok = unpackRR(msg, off)
if !ok {
return false
}
dns.ns = append(dns.ns, rec)
}
for i := 0; i < int(dh.Arcount); i++ {
rec, off, ok = unpackRR(msg, off)
if !ok {
return false
}
dns.extra = append(dns.extra, rec)
}
// if off != len(msg) {
// println("extra bytes in dns packet", off, "<", len(msg));
// }
return true
}
func (dns *dnsMsg) String() string {
s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
if len(dns.question) > 0 {
s += "-- Questions\n"
for i := 0; i < len(dns.question); i++ {
s += printStruct(&dns.question[i]) + "\n"
}
}
if len(dns.answer) > 0 {
s += "-- Answers\n"
for i := 0; i < len(dns.answer); i++ {
s += printStruct(dns.answer[i]) + "\n"
}
}
if len(dns.ns) > 0 {
s += "-- Name servers\n"
for i := 0; i < len(dns.ns); i++ {
s += printStruct(dns.ns[i]) + "\n"
}
}
if len(dns.extra) > 0 {
s += "-- Extra\n"
for i := 0; i < len(dns.extra); i++ {
s += printStruct(dns.extra[i]) + "\n"
}
}
return s
}
// IsResponseTo reports whether m is an acceptable response to query.
func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
if !m.response {
return false
}
if m.id != query.id {
return false
}
if len(m.question) != len(query.question) {
return false
}
for i, q := range m.question {
q2 := query.question[i]
if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
return false
}
}
return true
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"encoding/hex"
"reflect"
"testing"
)
func TestStructPackUnpack(t *testing.T) {
want := dnsQuestion{
Name: ".",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
}
buf := make([]byte, 50)
n, ok := packStruct(&want, buf, 0)
if !ok {
t.Fatal("packing failed")
}
buf = buf[:n]
got := dnsQuestion{}
n, ok = unpackStruct(&got, buf, 0)
if !ok {
t.Fatal("unpacking failed")
}
if n != len(buf) {
t.Errorf("unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %+v, want = %+v", got, want)
}
}
func TestDomainNamePackUnpack(t *testing.T) {
tests := []struct {
in string
want string
ok bool
}{
{"", ".", true},
{".", ".", true},
{"google..com", "", false},
{"google.com", "google.com.", true},
{"google..com.", "", false},
{"google.com.", "google.com.", true},
{".google.com.", "", false},
{"www..google.com.", "", false},
{"www.google.com.", "www.google.com.", true},
}
for _, test := range tests {
buf := make([]byte, 30)
n, ok := packDomainName(test.in, buf, 0)
if ok != test.ok {
t.Errorf("packing of %s: got ok = %t, want = %t", test.in, ok, test.ok)
continue
}
if !test.ok {
continue
}
buf = buf[:n]
got, n, ok := unpackDomainName(buf, 0)
if !ok {
t.Errorf("unpacking for %s failed", test.in)
continue
}
if n != len(buf) {
t.Errorf(
"unpacked different amount than packed for %s: got n = %d, want = %d",
test.in,
n,
len(buf),
)
}
if got != test.want {
t.Errorf("unpacking packing of %s: got = %s, want = %s", test.in, got, test.want)
}
}
}
func TestDNSPackUnpack(t *testing.T) {
want := dnsMsg{
question: []dnsQuestion{{
Name: ".",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
}},
answer: []dnsRR{},
ns: []dnsRR{},
extra: []dnsRR{},
}
b, ok := want.Pack()
if !ok {
t.Fatal("packing failed")
}
var got dnsMsg
ok = got.Unpack(b)
if !ok {
t.Fatal("unpacking failed")
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %+v, want = %+v", got, want)
}
}
func TestDNSParseSRVReply(t *testing.T) {
data, err := hex.DecodeString(dnsSRVReply)
if err != nil {
t.Fatal(err)
}
msg := new(dnsMsg)
ok := msg.Unpack(data)
if !ok {
t.Fatal("unpacking packet failed")
}
_ = msg.String() // exercise this code path
if g, e := len(msg.answer), 5; g != e {
t.Errorf("len(msg.answer) = %d; want %d", g, e)
}
for idx, rr := range msg.answer {
if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
}
if _, ok := rr.(*dnsRR_SRV); !ok {
t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
}
}
for _, name := range [...]string{
"_xmpp-server._tcp.google.com.",
"_XMPP-Server._TCP.Google.COM.",
"_XMPP-SERVER._TCP.GOOGLE.COM.",
} {
_, addrs, err := answer(name, "foo:53", msg, uint16(dnsTypeSRV))
if err != nil {
t.Error(err)
}
if g, e := len(addrs), 5; g != e {
t.Errorf("len(addrs) = %d; want %d", g, e)
t.Logf("addrs = %#v", addrs)
}
}
// repack and unpack.
data2, ok := msg.Pack()
msg2 := new(dnsMsg)
msg2.Unpack(data2)
switch {
case !ok:
t.Error("failed to repack message")
case !reflect.DeepEqual(msg, msg2):
t.Error("repacked message differs from original")
}
}
func TestDNSParseCorruptSRVReply(t *testing.T) {
data, err := hex.DecodeString(dnsSRVCorruptReply)
if err != nil {
t.Fatal(err)
}
msg := new(dnsMsg)
ok := msg.Unpack(data)
if !ok {
t.Fatal("unpacking packet failed")
}
_ = msg.String() // exercise this code path
if g, e := len(msg.answer), 5; g != e {
t.Errorf("len(msg.answer) = %d; want %d", g, e)
}
for idx, rr := range msg.answer {
if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
}
if idx == 4 {
if _, ok := rr.(*dnsRR_Header); !ok {
t.Errorf("answer[%d] = %T; want *dnsRR_Header", idx, rr)
}
} else {
if _, ok := rr.(*dnsRR_SRV); !ok {
t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
}
}
}
_, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV))
if err != nil {
t.Fatalf("answer: %v", err)
}
if g, e := len(addrs), 4; g != e {
t.Errorf("len(addrs) = %d; want %d", g, e)
t.Logf("addrs = %#v", addrs)
}
}
func TestDNSParseTXTReply(t *testing.T) {
expectedTxt1 := "v=spf1 redirect=_spf.google.com"
expectedTxt2 := "v=spf1 ip4:69.63.179.25 ip4:69.63.178.128/25 ip4:69.63.184.0/25 " +
"ip4:66.220.144.128/25 ip4:66.220.155.0/24 " +
"ip4:69.171.232.0/25 ip4:66.220.157.0/25 " +
"ip4:69.171.244.0/24 mx -all"
replies := []string{dnsTXTReply1, dnsTXTReply2}
expectedTxts := []string{expectedTxt1, expectedTxt2}
for i := range replies {
data, err := hex.DecodeString(replies[i])
if err != nil {
t.Fatal(err)
}
msg := new(dnsMsg)
ok := msg.Unpack(data)
if !ok {
t.Errorf("test %d: unpacking packet failed", i)
continue
}
if len(msg.answer) != 1 {
t.Errorf("test %d: len(rr.answer) = %d; want 1", i, len(msg.answer))
continue
}
rr := msg.answer[0]
rrTXT, ok := rr.(*dnsRR_TXT)
if !ok {
t.Errorf("test %d: answer[0] = %T; want *dnsRR_TXT", i, rr)
continue
}
if rrTXT.Txt != expectedTxts[i] {
t.Errorf("test %d: Txt = %s; want %s", i, rrTXT.Txt, expectedTxts[i])
}
}
}
func TestDNSParseTXTCorruptDataLengthReply(t *testing.T) {
replies := []string{dnsTXTCorruptDataLengthReply1, dnsTXTCorruptDataLengthReply2}
for i := range replies {
data, err := hex.DecodeString(replies[i])
if err != nil {
t.Fatal(err)
}
msg := new(dnsMsg)
ok := msg.Unpack(data)
if ok {
t.Errorf("test %d: expected to fail on unpacking corrupt packet", i)
}
}
}
func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
replies := []string{dnsTXTCorruptTXTLengthReply1, dnsTXTCorruptTXTLengthReply2}
for i := range replies {
data, err := hex.DecodeString(replies[i])
if err != nil {
t.Fatal(err)
}
msg := new(dnsMsg)
ok := msg.Unpack(data)
// Unpacking should succeed, but we should just get the header.
if !ok {
t.Errorf("test %d: unpacking packet failed", i)
continue
}
if len(msg.answer) != 1 {
t.Errorf("test %d: len(rr.answer) = %d; want 1", i, len(msg.answer))
continue
}
rr := msg.answer[0]
if _, justHeader := rr.(*dnsRR_Header); !justHeader {
t.Errorf("test %d: rr = %T; expected *dnsRR_Header", i, rr)
}
}
}
func TestIsResponseTo(t *testing.T) {
// Sample DNS query.
query := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: 42,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
}
resp := query
resp.response = true
if !resp.IsResponseTo(&query) {
t.Error("got false, want true")
}
badResponses := []dnsMsg{
// Different ID.
{
dnsMsgHdr: dnsMsgHdr{
id: 43,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
},
// Different query name.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.google.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
},
// Different query type.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
},
},
},
// Different query class.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassCSNET,
},
},
},
// No questions.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
},
// Extra questions.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
{
Name: "www.golang.org.",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
},
},
},
}
for i := range badResponses {
if badResponses[i].IsResponseTo(&query) {
t.Errorf("%v: got true, want false", i)
}
}
}
// Valid DNS SRV reply
const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
"73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
"000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
"00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
"6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
"72016c06676f6f676c6503636f6d00c00c002100010000012c00210014000014950c78" +
"6d70702d73657276657231016c06676f6f676c6503636f6d00"
// Corrupt DNS SRV reply, with its final RR having a bogus length
// (perhaps it was truncated, or it's malicious) The mutation is the
// capital "FF" below, instead of the proper "21".
const dnsSRVCorruptReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
"73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
"000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
"00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
"6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
"72016c06676f6f676c6503636f6d00c00c002100010000012c00FF0014000014950c78" +
"6d70702d73657276657231016c06676f6f676c6503636f6d00"
// TXT reply with one <character-string>
const dnsTXTReply1 = "b3458180000100010004000505676d61696c03636f6d0000100001c00c001000010000012c00" +
"201f763d737066312072656469726563743d5f7370662e676f6f676c652e636f6dc00" +
"c0002000100025d4c000d036e733406676f6f676c65c012c00c0002000100025d4c00" +
"06036e7331c057c00c0002000100025d4c0006036e7333c057c00c0002000100025d4" +
"c0006036e7332c057c06c00010001000248b50004d8ef200ac09000010001000248b5" +
"0004d8ef220ac07e00010001000248b50004d8ef240ac05300010001000248b50004d" +
"8ef260a0000291000000000000000"
// TXT reply with more than one <character-string>.
// See https://tools.ietf.org/html/rfc1035#section-3.3.14
const dnsTXTReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
"100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
"36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
"62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
"343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
"06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
"070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
"f0cc0fd0001000100025d15000445abff0c"
// DataLength field should be sum of all TXT fields. In this case it's less.
const dnsTXTCorruptDataLengthReply1 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
"100000e1000967f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
"36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
"62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
"343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
"06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
"070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
"f0cc0fd0001000100025d15000445abff0c"
// Same as above but DataLength is more than sum of TXT fields.
const dnsTXTCorruptDataLengthReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
"100000e1001227f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
"36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
"62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
"343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
"06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
"070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
"f0cc0fd0001000100025d15000445abff0c"
// TXT Length field is less than actual length.
const dnsTXTCorruptTXTLengthReply1 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
"100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
"36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
"62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
"343a36392e3137312e3233322e302f323520691470343a36362e3232302e3135372e302f32352" +
"06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
"070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
"f0cc0fd0001000100025d15000445abff0c"
// TXT Length field is more than actual length.
const dnsTXTCorruptTXTLengthReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
"100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
"36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
"62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
"343a36392e3137312e3233322e302f323520693370343a36362e3232302e3135372e302f32352" +
"06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
"070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
"f0cc0fd0001000100025d15000445abff0c"
...@@ -9,6 +9,8 @@ package net ...@@ -9,6 +9,8 @@ package net
import ( import (
"context" "context"
"sync" "sync"
"golang_org/x/net/dns/dnsmessage"
) )
var onceReadProtocols sync.Once var onceReadProtocols sync.Once
...@@ -51,7 +53,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) { ...@@ -51,7 +53,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
return lookupProtocolMap(name) return lookupProtocolMap(name)
} }
func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) { func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to // Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will // dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has // call back here to translate it. The DNS config parser has
...@@ -68,10 +70,7 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e ...@@ -68,10 +70,7 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
if err != nil { if err != nil {
return nil, mapErr(err) return nil, mapErr(err)
} }
if _, ok := c.(PacketConn); ok { return c, nil
return &dnsPacketConn{c}, nil
}
return &dnsStreamConn{c}, nil
} }
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
...@@ -98,8 +97,8 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, e ...@@ -98,8 +97,8 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, e
// cgo not available (or netgo); fall back to Go's DNS resolver // cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS order = hostLookupFilesDNS
} }
addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order) ips, _, err := r.goLookupIPCNAMEOrder(ctx, host, order)
return return ips, err
} }
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) { func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
...@@ -134,53 +133,176 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( ...@@ -134,53 +133,176 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
} else { } else {
target = "_" + service + "._" + proto + "." + name target = "_" + service + "._" + proto + "." + name
} }
cname, rrs, err := r.lookup(ctx, target, dnsTypeSRV) p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
srvs := make([]*SRV, len(rrs)) var srvs []*SRV
for i, rr := range rrs { var cname dnsmessage.Name
rr := rr.(*dnsRR_SRV) for {
srvs[i] = &SRV{Target: rr.Target, Port: rr.Port, Priority: rr.Priority, Weight: rr.Weight} h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeSRV {
if err := p.SkipAnswer(); err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
srv, err := p.SRVResource()
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
} }
byPriorityWeight(srvs).sort() byPriorityWeight(srvs).sort()
return cname, srvs, nil return cname.String(), srvs, nil
} }
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
_, rrs, err := r.lookup(ctx, name, dnsTypeMX) p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mxs := make([]*MX, len(rrs)) var mxs []*MX
for i, rr := range rrs { for {
rr := rr.(*dnsRR_MX) h, err := p.AnswerHeader()
mxs[i] = &MX{Host: rr.Mx, Pref: rr.Pref} if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeMX {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
mx, err := p.MXResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
} }
byPref(mxs).sort() byPref(mxs).sort()
return mxs, nil return mxs, nil
} }
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
_, rrs, err := r.lookup(ctx, name, dnsTypeNS) p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
if err != nil { if err != nil {
return nil, err return nil, err
} }
nss := make([]*NS, len(rrs)) var nss []*NS
for i, rr := range rrs { for {
nss[i] = &NS{Host: rr.(*dnsRR_NS).Ns} h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeNS {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
ns, err := p.NSResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
nss = append(nss, &NS{Host: ns.NS.String()})
} }
return nss, nil return nss, nil
} }
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
_, rrs, err := r.lookup(ctx, name, dnsTypeTXT) p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
if err != nil { if err != nil {
return nil, err return nil, err
} }
txts := make([]string, len(rrs)) var txts []string
for i, rr := range rrs { for {
txts[i] = rr.(*dnsRR_TXT).Txt h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeTXT {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
txt, err := p.TXTResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if len(txts) == 0 {
txts = txt.TXT
} else {
txts = append(txts, txt.TXT...)
}
} }
return txts, nil return txts, nil
} }
......
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dnsmessage_test
import (
"fmt"
"net"
"strings"
"golang_org/x/net/dns/dnsmessage"
)
func mustNewName(name string) dnsmessage.Name {
n, err := dnsmessage.NewName(name)
if err != nil {
panic(err)
}
return n
}
func ExampleParser() {
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true},
Questions: []dnsmessage.Question{
{
Name: mustNewName("foo.bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
{
Name: mustNewName("bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
Answers: []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("foo.bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
},
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}},
},
},
}
buf, err := msg.Pack()
if err != nil {
panic(err)
}
wantName := "bar.example.com."
var p dnsmessage.Parser
if _, err := p.Start(buf); err != nil {
panic(err)
}
for {
q, err := p.Question()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
panic(err)
}
if q.Name.String() != wantName {
continue
}
fmt.Println("Found question for name", wantName)
if err := p.SkipAllQuestions(); err != nil {
panic(err)
}
break
}
var gotIPs []net.IP
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
panic(err)
}
if (h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA) || h.Class != dnsmessage.ClassINET {
continue
}
if !strings.EqualFold(h.Name.String(), wantName) {
if err := p.SkipAnswer(); err != nil {
panic(err)
}
continue
}
switch h.Type {
case dnsmessage.TypeA:
r, err := p.AResource()
if err != nil {
panic(err)
}
gotIPs = append(gotIPs, r.A[:])
case dnsmessage.TypeAAAA:
r, err := p.AAAAResource()
if err != nil {
panic(err)
}
gotIPs = append(gotIPs, r.AAAA[:])
}
}
fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
// Output:
// Found question for name bar.example.com.
// Found A/AAAA records for name bar.example.com.: [127.0.0.2]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package dnsmessage provides a mostly RFC 1035 compliant implementation of
// DNS message packing and unpacking.
//
// This implementation is designed to minimize heap allocations and avoid
// unnecessary packing and unpacking as much as possible.
package dnsmessage
import (
"errors"
)
// Message formats
// A Type is a type of DNS request and response.
type Type uint16
// A Class is a type of network.
type Class uint16
// An OpCode is a DNS operation code.
type OpCode uint16
// An RCode is a DNS response status code.
type RCode uint16
// Wire constants.
const (
// ResourceHeader.Type and Question.Type
TypeA Type = 1
TypeNS Type = 2
TypeCNAME Type = 5
TypeSOA Type = 6
TypePTR Type = 12
TypeMX Type = 15
TypeTXT Type = 16
TypeAAAA Type = 28
TypeSRV Type = 33
// Question.Type
TypeWKS Type = 11
TypeHINFO Type = 13
TypeMINFO Type = 14
TypeAXFR Type = 252
TypeALL Type = 255
// ResourceHeader.Class and Question.Class
ClassINET Class = 1
ClassCSNET Class = 2
ClassCHAOS Class = 3
ClassHESIOD Class = 4
// Question.Class
ClassANY Class = 255
// Message.Rcode
RCodeSuccess RCode = 0
RCodeFormatError RCode = 1
RCodeServerFailure RCode = 2
RCodeNameError RCode = 3
RCodeNotImplemented RCode = 4
RCodeRefused RCode = 5
)
var (
// ErrNotStarted indicates that the prerequisite information isn't
// available yet because the previous records haven't been appropriately
// parsed, skipped or finished.
ErrNotStarted = errors.New("parsing/packing of this type isn't available yet")
// ErrSectionDone indicated that all records in the section have been
// parsed or finished.
ErrSectionDone = errors.New("parsing/packing of this section has completed")
errBaseLen = errors.New("insufficient data for base length type")
errCalcLen = errors.New("insufficient data for calculated length type")
errReserved = errors.New("segment prefix is reserved")
errTooManyPtr = errors.New("too many pointers (>10)")
errInvalidPtr = errors.New("invalid pointer")
errNilResouceBody = errors.New("nil resource body")
errResourceLen = errors.New("insufficient data for resource body length")
errSegTooLong = errors.New("segment length too long")
errZeroSegLen = errors.New("zero length segment")
errResTooLong = errors.New("resource length too long")
errTooManyQuestions = errors.New("too many Questions to pack (>65535)")
errTooManyAnswers = errors.New("too many Answers to pack (>65535)")
errTooManyAuthorities = errors.New("too many Authorities to pack (>65535)")
errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)")
errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)")
errStringTooLong = errors.New("character string exceeds maximum length (255)")
)
// Internal constants.
const (
// packStartingCap is the default initial buffer size allocated during
// packing.
//
// The starting capacity doesn't matter too much, but most DNS responses
// Will be <= 512 bytes as it is the limit for DNS over UDP.
packStartingCap = 512
// uint16Len is the length (in bytes) of a uint16.
uint16Len = 2
// uint32Len is the length (in bytes) of a uint32.
uint32Len = 4
// headerLen is the length (in bytes) of a DNS header.
//
// A header is comprised of 6 uint16s and no padding.
headerLen = 6 * uint16Len
)
type nestedError struct {
// s is the current level's error message.
s string
// err is the nested error.
err error
}
// nestedError implements error.Error.
func (e *nestedError) Error() string {
return e.s + ": " + e.err.Error()
}
// Header is a representation of a DNS message header.
type Header struct {
ID uint16
Response bool
OpCode OpCode
Authoritative bool
Truncated bool
RecursionDesired bool
RecursionAvailable bool
RCode RCode
}
func (m *Header) pack() (id uint16, bits uint16) {
id = m.ID
bits = uint16(m.OpCode)<<11 | uint16(m.RCode)
if m.RecursionAvailable {
bits |= headerBitRA
}
if m.RecursionDesired {
bits |= headerBitRD
}
if m.Truncated {
bits |= headerBitTC
}
if m.Authoritative {
bits |= headerBitAA
}
if m.Response {
bits |= headerBitQR
}
return
}
// Message is a representation of a DNS message.
type Message struct {
Header
Questions []Question
Answers []Resource
Authorities []Resource
Additionals []Resource
}
type section uint8
const (
sectionNotStarted section = iota
sectionHeader
sectionQuestions
sectionAnswers
sectionAuthorities
sectionAdditionals
sectionDone
headerBitQR = 1 << 15 // query/response (response=1)
headerBitAA = 1 << 10 // authoritative
headerBitTC = 1 << 9 // truncated
headerBitRD = 1 << 8 // recursion desired
headerBitRA = 1 << 7 // recursion available
)
var sectionNames = map[section]string{
sectionHeader: "header",
sectionQuestions: "Question",
sectionAnswers: "Answer",
sectionAuthorities: "Authority",
sectionAdditionals: "Additional",
}
// header is the wire format for a DNS message header.
type header struct {
id uint16
bits uint16
questions uint16
answers uint16
authorities uint16
additionals uint16
}
func (h *header) count(sec section) uint16 {
switch sec {
case sectionQuestions:
return h.questions
case sectionAnswers:
return h.answers
case sectionAuthorities:
return h.authorities
case sectionAdditionals:
return h.additionals
}
return 0
}
// pack appends the wire format of the header to msg.
func (h *header) pack(msg []byte) []byte {
msg = packUint16(msg, h.id)
msg = packUint16(msg, h.bits)
msg = packUint16(msg, h.questions)
msg = packUint16(msg, h.answers)
msg = packUint16(msg, h.authorities)
return packUint16(msg, h.additionals)
}
func (h *header) unpack(msg []byte, off int) (int, error) {
newOff := off
var err error
if h.id, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"id", err}
}
if h.bits, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"bits", err}
}
if h.questions, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"questions", err}
}
if h.answers, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"answers", err}
}
if h.authorities, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"authorities", err}
}
if h.additionals, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"additionals", err}
}
return newOff, nil
}
func (h *header) header() Header {
return Header{
ID: h.id,
Response: (h.bits & headerBitQR) != 0,
OpCode: OpCode(h.bits>>11) & 0xF,
Authoritative: (h.bits & headerBitAA) != 0,
Truncated: (h.bits & headerBitTC) != 0,
RecursionDesired: (h.bits & headerBitRD) != 0,
RecursionAvailable: (h.bits & headerBitRA) != 0,
RCode: RCode(h.bits & 0xF),
}
}
// A Resource is a DNS resource record.
type Resource struct {
Header ResourceHeader
Body ResourceBody
}
// A ResourceBody is a DNS resource record minus the header.
type ResourceBody interface {
// pack packs a Resource except for its header.
pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error)
// realType returns the actual type of the Resource. This is used to
// fill in the header Type field.
realType() Type
}
// pack appends the wire format of the Resource to msg.
func (r *Resource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
if r.Body == nil {
return msg, errNilResouceBody
}
oldMsg := msg
r.Header.Type = r.Body.realType()
msg, length, err := r.Header.pack(msg, compression, compressionOff)
if err != nil {
return msg, &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
msg, err = r.Body.pack(msg, compression, compressionOff)
if err != nil {
return msg, &nestedError{"content", err}
}
if err := r.Header.fixLen(msg, length, preLen); err != nil {
return oldMsg, err
}
return msg, nil
}
// A Parser allows incrementally parsing a DNS message.
//
// When parsing is started, the Header is parsed. Next, each Question can be
// either parsed or skipped. Alternatively, all Questions can be skipped at
// once. When all Questions have been parsed, attempting to parse Questions
// will return (nil, nil) and attempting to skip Questions will return
// (true, nil). After all Questions have been either parsed or skipped, all
// Answers, Authorities and Additionals can be either parsed or skipped in the
// same way, and each type of Resource must be fully parsed or skipped before
// proceeding to the next type of Resource.
//
// Note that there is no requirement to fully skip or parse the message.
type Parser struct {
msg []byte
header header
section section
off int
index int
resHeaderValid bool
resHeader ResourceHeader
}
// Start parses the header and enables the parsing of Questions.
func (p *Parser) Start(msg []byte) (Header, error) {
if p.msg != nil {
*p = Parser{}
}
p.msg = msg
var err error
if p.off, err = p.header.unpack(msg, 0); err != nil {
return Header{}, &nestedError{"unpacking header", err}
}
p.section = sectionQuestions
return p.header.header(), nil
}
func (p *Parser) checkAdvance(sec section) error {
if p.section < sec {
return ErrNotStarted
}
if p.section > sec {
return ErrSectionDone
}
p.resHeaderValid = false
if p.index == int(p.header.count(sec)) {
p.index = 0
p.section++
return ErrSectionDone
}
return nil
}
func (p *Parser) resource(sec section) (Resource, error) {
var r Resource
var err error
r.Header, err = p.resourceHeader(sec)
if err != nil {
return r, err
}
p.resHeaderValid = false
r.Body, p.off, err = unpackResourceBody(p.msg, p.off, r.Header)
if err != nil {
return Resource{}, &nestedError{"unpacking " + sectionNames[sec], err}
}
p.index++
return r, nil
}
func (p *Parser) resourceHeader(sec section) (ResourceHeader, error) {
if p.resHeaderValid {
return p.resHeader, nil
}
if err := p.checkAdvance(sec); err != nil {
return ResourceHeader{}, err
}
var hdr ResourceHeader
off, err := hdr.unpack(p.msg, p.off)
if err != nil {
return ResourceHeader{}, err
}
p.resHeaderValid = true
p.resHeader = hdr
p.off = off
return hdr, nil
}
func (p *Parser) skipResource(sec section) error {
if p.resHeaderValid {
newOff := p.off + int(p.resHeader.Length)
if newOff > len(p.msg) {
return errResourceLen
}
p.off = newOff
p.resHeaderValid = false
p.index++
return nil
}
if err := p.checkAdvance(sec); err != nil {
return err
}
var err error
p.off, err = skipResource(p.msg, p.off)
if err != nil {
return &nestedError{"skipping: " + sectionNames[sec], err}
}
p.index++
return nil
}
// Question parses a single Question.
func (p *Parser) Question() (Question, error) {
if err := p.checkAdvance(sectionQuestions); err != nil {
return Question{}, err
}
var name Name
off, err := name.unpack(p.msg, p.off)
if err != nil {
return Question{}, &nestedError{"unpacking Question.Name", err}
}
typ, off, err := unpackType(p.msg, off)
if err != nil {
return Question{}, &nestedError{"unpacking Question.Type", err}
}
class, off, err := unpackClass(p.msg, off)
if err != nil {
return Question{}, &nestedError{"unpacking Question.Class", err}
}
p.off = off
p.index++
return Question{name, typ, class}, nil
}
// AllQuestions parses all Questions.
func (p *Parser) AllQuestions() ([]Question, error) {
// Multiple questions are valid according to the spec,
// but servers don't actually support them. There will
// be at most one question here.
//
// Do not pre-allocate based on info in p.header, since
// the data is untrusted.
qs := []Question{}
for {
q, err := p.Question()
if err == ErrSectionDone {
return qs, nil
}
if err != nil {
return nil, err
}
qs = append(qs, q)
}
}
// SkipQuestion skips a single Question.
func (p *Parser) SkipQuestion() error {
if err := p.checkAdvance(sectionQuestions); err != nil {
return err
}
off, err := skipName(p.msg, p.off)
if err != nil {
return &nestedError{"skipping Question Name", err}
}
if off, err = skipType(p.msg, off); err != nil {
return &nestedError{"skipping Question Type", err}
}
if off, err = skipClass(p.msg, off); err != nil {
return &nestedError{"skipping Question Class", err}
}
p.off = off
p.index++
return nil
}
// SkipAllQuestions skips all Questions.
func (p *Parser) SkipAllQuestions() error {
for {
if err := p.SkipQuestion(); err == ErrSectionDone {
return nil
} else if err != nil {
return err
}
}
}
// AnswerHeader parses a single Answer ResourceHeader.
func (p *Parser) AnswerHeader() (ResourceHeader, error) {
return p.resourceHeader(sectionAnswers)
}
// Answer parses a single Answer Resource.
func (p *Parser) Answer() (Resource, error) {
return p.resource(sectionAnswers)
}
// AllAnswers parses all Answer Resources.
func (p *Parser) AllAnswers() ([]Resource, error) {
// The most common query is for A/AAAA, which usually returns
// a handful of IPs.
//
// Pre-allocate up to a certain limit, since p.header is
// untrusted data.
n := int(p.header.answers)
if n > 20 {
n = 20
}
as := make([]Resource, 0, n)
for {
a, err := p.Answer()
if err == ErrSectionDone {
return as, nil
}
if err != nil {
return nil, err
}
as = append(as, a)
}
}
// SkipAnswer skips a single Answer Resource.
func (p *Parser) SkipAnswer() error {
return p.skipResource(sectionAnswers)
}
// SkipAllAnswers skips all Answer Resources.
func (p *Parser) SkipAllAnswers() error {
for {
if err := p.SkipAnswer(); err == ErrSectionDone {
return nil
} else if err != nil {
return err
}
}
}
// AuthorityHeader parses a single Authority ResourceHeader.
func (p *Parser) AuthorityHeader() (ResourceHeader, error) {
return p.resourceHeader(sectionAuthorities)
}
// Authority parses a single Authority Resource.
func (p *Parser) Authority() (Resource, error) {
return p.resource(sectionAuthorities)
}
// AllAuthorities parses all Authority Resources.
func (p *Parser) AllAuthorities() ([]Resource, error) {
// Authorities contains SOA in case of NXDOMAIN and friends,
// otherwise it is empty.
//
// Pre-allocate up to a certain limit, since p.header is
// untrusted data.
n := int(p.header.authorities)
if n > 10 {
n = 10
}
as := make([]Resource, 0, n)
for {
a, err := p.Authority()
if err == ErrSectionDone {
return as, nil
}
if err != nil {
return nil, err
}
as = append(as, a)
}
}
// SkipAuthority skips a single Authority Resource.
func (p *Parser) SkipAuthority() error {
return p.skipResource(sectionAuthorities)
}
// SkipAllAuthorities skips all Authority Resources.
func (p *Parser) SkipAllAuthorities() error {
for {
if err := p.SkipAuthority(); err == ErrSectionDone {
return nil
} else if err != nil {
return err
}
}
}
// AdditionalHeader parses a single Additional ResourceHeader.
func (p *Parser) AdditionalHeader() (ResourceHeader, error) {
return p.resourceHeader(sectionAdditionals)
}
// Additional parses a single Additional Resource.
func (p *Parser) Additional() (Resource, error) {
return p.resource(sectionAdditionals)
}
// AllAdditionals parses all Additional Resources.
func (p *Parser) AllAdditionals() ([]Resource, error) {
// Additionals usually contain OPT, and sometimes A/AAAA
// glue records.
//
// Pre-allocate up to a certain limit, since p.header is
// untrusted data.
n := int(p.header.additionals)
if n > 10 {
n = 10
}
as := make([]Resource, 0, n)
for {
a, err := p.Additional()
if err == ErrSectionDone {
return as, nil
}
if err != nil {
return nil, err
}
as = append(as, a)
}
}
// SkipAdditional skips a single Additional Resource.
func (p *Parser) SkipAdditional() error {
return p.skipResource(sectionAdditionals)
}
// SkipAllAdditionals skips all Additional Resources.
func (p *Parser) SkipAllAdditionals() error {
for {
if err := p.SkipAdditional(); err == ErrSectionDone {
return nil
} else if err != nil {
return err
}
}
}
// CNAMEResource parses a single CNAMEResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) CNAMEResource() (CNAMEResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeCNAME {
return CNAMEResource{}, ErrNotStarted
}
r, err := unpackCNAMEResource(p.msg, p.off)
if err != nil {
return CNAMEResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// MXResource parses a single MXResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) MXResource() (MXResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeMX {
return MXResource{}, ErrNotStarted
}
r, err := unpackMXResource(p.msg, p.off)
if err != nil {
return MXResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// NSResource parses a single NSResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) NSResource() (NSResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeNS {
return NSResource{}, ErrNotStarted
}
r, err := unpackNSResource(p.msg, p.off)
if err != nil {
return NSResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// PTRResource parses a single PTRResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) PTRResource() (PTRResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypePTR {
return PTRResource{}, ErrNotStarted
}
r, err := unpackPTRResource(p.msg, p.off)
if err != nil {
return PTRResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// SOAResource parses a single SOAResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) SOAResource() (SOAResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeSOA {
return SOAResource{}, ErrNotStarted
}
r, err := unpackSOAResource(p.msg, p.off)
if err != nil {
return SOAResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// TXTResource parses a single TXTResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) TXTResource() (TXTResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeTXT {
return TXTResource{}, ErrNotStarted
}
r, err := unpackTXTResource(p.msg, p.off, p.resHeader.Length)
if err != nil {
return TXTResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// SRVResource parses a single SRVResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) SRVResource() (SRVResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeSRV {
return SRVResource{}, ErrNotStarted
}
r, err := unpackSRVResource(p.msg, p.off)
if err != nil {
return SRVResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// AResource parses a single AResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) AResource() (AResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeA {
return AResource{}, ErrNotStarted
}
r, err := unpackAResource(p.msg, p.off)
if err != nil {
return AResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// AAAAResource parses a single AAAAResource.
//
// One of the XXXHeader methods must have been called before calling this
// method.
func (p *Parser) AAAAResource() (AAAAResource, error) {
if !p.resHeaderValid || p.resHeader.Type != TypeAAAA {
return AAAAResource{}, ErrNotStarted
}
r, err := unpackAAAAResource(p.msg, p.off)
if err != nil {
return AAAAResource{}, err
}
p.off += int(p.resHeader.Length)
p.resHeaderValid = false
p.index++
return r, nil
}
// Unpack parses a full Message.
func (m *Message) Unpack(msg []byte) error {
var p Parser
var err error
if m.Header, err = p.Start(msg); err != nil {
return err
}
if m.Questions, err = p.AllQuestions(); err != nil {
return err
}
if m.Answers, err = p.AllAnswers(); err != nil {
return err
}
if m.Authorities, err = p.AllAuthorities(); err != nil {
return err
}
if m.Additionals, err = p.AllAdditionals(); err != nil {
return err
}
return nil
}
// Pack packs a full Message.
func (m *Message) Pack() ([]byte, error) {
return m.AppendPack(make([]byte, 0, packStartingCap))
}
// AppendPack is like Pack but appends the full Message to b and returns the
// extended buffer.
func (m *Message) AppendPack(b []byte) ([]byte, error) {
// Validate the lengths. It is very unlikely that anyone will try to
// pack more than 65535 of any particular type, but it is possible and
// we should fail gracefully.
if len(m.Questions) > int(^uint16(0)) {
return nil, errTooManyQuestions
}
if len(m.Answers) > int(^uint16(0)) {
return nil, errTooManyAnswers
}
if len(m.Authorities) > int(^uint16(0)) {
return nil, errTooManyAuthorities
}
if len(m.Additionals) > int(^uint16(0)) {
return nil, errTooManyAdditionals
}
var h header
h.id, h.bits = m.Header.pack()
h.questions = uint16(len(m.Questions))
h.answers = uint16(len(m.Answers))
h.authorities = uint16(len(m.Authorities))
h.additionals = uint16(len(m.Additionals))
compressionOff := len(b)
msg := h.pack(b)
// RFC 1035 allows (but does not require) compression for packing. RFC
// 1035 requires unpacking implementations to support compression, so
// unconditionally enabling it is fine.
//
// DNS lookups are typically done over UDP, and RFC 1035 states that UDP
// DNS messages can be a maximum of 512 bytes long. Without compression,
// many DNS response messages are over this limit, so enabling
// compression will help ensure compliance.
compression := map[string]int{}
for i := range m.Questions {
var err error
if msg, err = m.Questions[i].pack(msg, compression, compressionOff); err != nil {
return nil, &nestedError{"packing Question", err}
}
}
for i := range m.Answers {
var err error
if msg, err = m.Answers[i].pack(msg, compression, compressionOff); err != nil {
return nil, &nestedError{"packing Answer", err}
}
}
for i := range m.Authorities {
var err error
if msg, err = m.Authorities[i].pack(msg, compression, compressionOff); err != nil {
return nil, &nestedError{"packing Authority", err}
}
}
for i := range m.Additionals {
var err error
if msg, err = m.Additionals[i].pack(msg, compression, compressionOff); err != nil {
return nil, &nestedError{"packing Additional", err}
}
}
return msg, nil
}
// A Builder allows incrementally packing a DNS message.
//
// Example usage:
// buf := make([]byte, 2, 514)
// b := NewBuilder(buf, Header{...})
// b.EnableCompression()
// // Optionally start a section and add things to that section.
// // Repeat adding sections as necessary.
// buf, err := b.Finish()
// // If err is nil, buf[2:] will contain the built bytes.
type Builder struct {
// msg is the storage for the message being built.
msg []byte
// section keeps track of the current section being built.
section section
// header keeps track of what should go in the header when Finish is
// called.
header header
// start is the starting index of the bytes allocated in msg for header.
start int
// compression is a mapping from name suffixes to their starting index
// in msg.
compression map[string]int
}
// NewBuilder creates a new builder with compression disabled.
//
// Note: Most users will want to immediately enable compression with the
// EnableCompression method. See that method's comment for why you may or may
// not want to enable compression.
//
// The DNS message is appended to the provided initial buffer buf (which may be
// nil) as it is built. The final message is returned by the (*Builder).Finish
// method, which may return the same underlying array if there was sufficient
// capacity in the slice.
func NewBuilder(buf []byte, h Header) Builder {
if buf == nil {
buf = make([]byte, 0, packStartingCap)
}
b := Builder{msg: buf, start: len(buf)}
b.header.id, b.header.bits = h.pack()
var hb [headerLen]byte
b.msg = append(b.msg, hb[:]...)
b.section = sectionHeader
return b
}
// EnableCompression enables compression in the Builder.
//
// Leaving compression disabled avoids compression related allocations, but can
// result in larger message sizes. Be careful with this mode as it can cause
// messages to exceed the UDP size limit.
//
// According to RFC 1035, section 4.1.4, the use of compression is optional, but
// all implementations must accept both compressed and uncompressed DNS
// messages.
//
// Compression should be enabled before any sections are added for best results.
func (b *Builder) EnableCompression() {
b.compression = map[string]int{}
}
func (b *Builder) startCheck(s section) error {
if b.section <= sectionNotStarted {
return ErrNotStarted
}
if b.section > s {
return ErrSectionDone
}
return nil
}
// StartQuestions prepares the builder for packing Questions.
func (b *Builder) StartQuestions() error {
if err := b.startCheck(sectionQuestions); err != nil {
return err
}
b.section = sectionQuestions
return nil
}
// StartAnswers prepares the builder for packing Answers.
func (b *Builder) StartAnswers() error {
if err := b.startCheck(sectionAnswers); err != nil {
return err
}
b.section = sectionAnswers
return nil
}
// StartAuthorities prepares the builder for packing Authorities.
func (b *Builder) StartAuthorities() error {
if err := b.startCheck(sectionAuthorities); err != nil {
return err
}
b.section = sectionAuthorities
return nil
}
// StartAdditionals prepares the builder for packing Additionals.
func (b *Builder) StartAdditionals() error {
if err := b.startCheck(sectionAdditionals); err != nil {
return err
}
b.section = sectionAdditionals
return nil
}
func (b *Builder) incrementSectionCount() error {
var count *uint16
var err error
switch b.section {
case sectionQuestions:
count = &b.header.questions
err = errTooManyQuestions
case sectionAnswers:
count = &b.header.answers
err = errTooManyAnswers
case sectionAuthorities:
count = &b.header.authorities
err = errTooManyAuthorities
case sectionAdditionals:
count = &b.header.additionals
err = errTooManyAdditionals
}
if *count == ^uint16(0) {
return err
}
*count++
return nil
}
// Question adds a single Question.
func (b *Builder) Question(q Question) error {
if b.section < sectionQuestions {
return ErrNotStarted
}
if b.section > sectionQuestions {
return ErrSectionDone
}
msg, err := q.pack(b.msg, b.compression, b.start)
if err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
func (b *Builder) checkResourceSection() error {
if b.section < sectionAnswers {
return ErrNotStarted
}
if b.section > sectionAdditionals {
return ErrSectionDone
}
return nil
}
// CNAMEResource adds a single CNAMEResource.
func (b *Builder) CNAMEResource(h ResourceHeader, r CNAMEResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"CNAMEResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// MXResource adds a single MXResource.
func (b *Builder) MXResource(h ResourceHeader, r MXResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"MXResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// NSResource adds a single NSResource.
func (b *Builder) NSResource(h ResourceHeader, r NSResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"NSResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// PTRResource adds a single PTRResource.
func (b *Builder) PTRResource(h ResourceHeader, r PTRResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"PTRResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// SOAResource adds a single SOAResource.
func (b *Builder) SOAResource(h ResourceHeader, r SOAResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"SOAResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// TXTResource adds a single TXTResource.
func (b *Builder) TXTResource(h ResourceHeader, r TXTResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"TXTResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// SRVResource adds a single SRVResource.
func (b *Builder) SRVResource(h ResourceHeader, r SRVResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"SRVResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// AResource adds a single AResource.
func (b *Builder) AResource(h ResourceHeader, r AResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"AResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// AAAAResource adds a single AAAAResource.
func (b *Builder) AAAAResource(h ResourceHeader, r AAAAResource) error {
if err := b.checkResourceSection(); err != nil {
return err
}
h.Type = r.realType()
msg, length, err := h.pack(b.msg, b.compression, b.start)
if err != nil {
return &nestedError{"ResourceHeader", err}
}
preLen := len(msg)
if msg, err = r.pack(msg, b.compression, b.start); err != nil {
return &nestedError{"AAAAResource body", err}
}
if err := h.fixLen(msg, length, preLen); err != nil {
return err
}
if err := b.incrementSectionCount(); err != nil {
return err
}
b.msg = msg
return nil
}
// Finish ends message building and generates a binary message.
func (b *Builder) Finish() ([]byte, error) {
if b.section < sectionHeader {
return nil, ErrNotStarted
}
b.section = sectionDone
// Space for the header was allocated in NewBuilder.
b.header.pack(b.msg[b.start:b.start])
return b.msg, nil
}
// A ResourceHeader is the header of a DNS resource record. There are
// many types of DNS resource records, but they all share the same header.
type ResourceHeader struct {
// Name is the domain name for which this resource record pertains.
Name Name
// Type is the type of DNS resource record.
//
// This field will be set automatically during packing.
Type Type
// Class is the class of network to which this DNS resource record
// pertains.
Class Class
// TTL is the length of time (measured in seconds) which this resource
// record is valid for (time to live). All Resources in a set should
// have the same TTL (RFC 2181 Section 5.2).
TTL uint32
// Length is the length of data in the resource record after the header.
//
// This field will be set automatically during packing.
Length uint16
}
// pack appends the wire format of the ResourceHeader to oldMsg.
//
// The bytes where length was packed are returned as a slice so they can be
// updated after the rest of the Resource has been packed.
func (h *ResourceHeader) pack(oldMsg []byte, compression map[string]int, compressionOff int) (msg []byte, length []byte, err error) {
msg = oldMsg
if msg, err = h.Name.pack(msg, compression, compressionOff); err != nil {
return oldMsg, nil, &nestedError{"Name", err}
}
msg = packType(msg, h.Type)
msg = packClass(msg, h.Class)
msg = packUint32(msg, h.TTL)
lenBegin := len(msg)
msg = packUint16(msg, h.Length)
return msg, msg[lenBegin : lenBegin+uint16Len], nil
}
func (h *ResourceHeader) unpack(msg []byte, off int) (int, error) {
newOff := off
var err error
if newOff, err = h.Name.unpack(msg, newOff); err != nil {
return off, &nestedError{"Name", err}
}
if h.Type, newOff, err = unpackType(msg, newOff); err != nil {
return off, &nestedError{"Type", err}
}
if h.Class, newOff, err = unpackClass(msg, newOff); err != nil {
return off, &nestedError{"Class", err}
}
if h.TTL, newOff, err = unpackUint32(msg, newOff); err != nil {
return off, &nestedError{"TTL", err}
}
if h.Length, newOff, err = unpackUint16(msg, newOff); err != nil {
return off, &nestedError{"Length", err}
}
return newOff, nil
}
func (h *ResourceHeader) fixLen(msg []byte, length []byte, preLen int) error {
conLen := len(msg) - preLen
if conLen > int(^uint16(0)) {
return errResTooLong
}
// Fill in the length now that we know how long the content is.
packUint16(length[:0], uint16(conLen))
h.Length = uint16(conLen)
return nil
}
func skipResource(msg []byte, off int) (int, error) {
newOff, err := skipName(msg, off)
if err != nil {
return off, &nestedError{"Name", err}
}
if newOff, err = skipType(msg, newOff); err != nil {
return off, &nestedError{"Type", err}
}
if newOff, err = skipClass(msg, newOff); err != nil {
return off, &nestedError{"Class", err}
}
if newOff, err = skipUint32(msg, newOff); err != nil {
return off, &nestedError{"TTL", err}
}
length, newOff, err := unpackUint16(msg, newOff)
if err != nil {
return off, &nestedError{"Length", err}
}
if newOff += int(length); newOff > len(msg) {
return off, errResourceLen
}
return newOff, nil
}
// packUint16 appends the wire format of field to msg.
func packUint16(msg []byte, field uint16) []byte {
return append(msg, byte(field>>8), byte(field))
}
func unpackUint16(msg []byte, off int) (uint16, int, error) {
if off+uint16Len > len(msg) {
return 0, off, errBaseLen
}
return uint16(msg[off])<<8 | uint16(msg[off+1]), off + uint16Len, nil
}
func skipUint16(msg []byte, off int) (int, error) {
if off+uint16Len > len(msg) {
return off, errBaseLen
}
return off + uint16Len, nil
}
// packType appends the wire format of field to msg.
func packType(msg []byte, field Type) []byte {
return packUint16(msg, uint16(field))
}
func unpackType(msg []byte, off int) (Type, int, error) {
t, o, err := unpackUint16(msg, off)
return Type(t), o, err
}
func skipType(msg []byte, off int) (int, error) {
return skipUint16(msg, off)
}
// packClass appends the wire format of field to msg.
func packClass(msg []byte, field Class) []byte {
return packUint16(msg, uint16(field))
}
func unpackClass(msg []byte, off int) (Class, int, error) {
c, o, err := unpackUint16(msg, off)
return Class(c), o, err
}
func skipClass(msg []byte, off int) (int, error) {
return skipUint16(msg, off)
}
// packUint32 appends the wire format of field to msg.
func packUint32(msg []byte, field uint32) []byte {
return append(
msg,
byte(field>>24),
byte(field>>16),
byte(field>>8),
byte(field),
)
}
func unpackUint32(msg []byte, off int) (uint32, int, error) {
if off+uint32Len > len(msg) {
return 0, off, errBaseLen
}
v := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
return v, off + uint32Len, nil
}
func skipUint32(msg []byte, off int) (int, error) {
if off+uint32Len > len(msg) {
return off, errBaseLen
}
return off + uint32Len, nil
}
// packText appends the wire format of field to msg.
func packText(msg []byte, field string) ([]byte, error) {
l := len(field)
if l > 255 {
return nil, errStringTooLong
}
msg = append(msg, byte(l))
msg = append(msg, field...)
return msg, nil
}
func unpackText(msg []byte, off int) (string, int, error) {
if off >= len(msg) {
return "", off, errBaseLen
}
beginOff := off + 1
endOff := beginOff + int(msg[off])
if endOff > len(msg) {
return "", off, errCalcLen
}
return string(msg[beginOff:endOff]), endOff, nil
}
func skipText(msg []byte, off int) (int, error) {
if off >= len(msg) {
return off, errBaseLen
}
endOff := off + 1 + int(msg[off])
if endOff > len(msg) {
return off, errCalcLen
}
return endOff, nil
}
// packBytes appends the wire format of field to msg.
func packBytes(msg []byte, field []byte) []byte {
return append(msg, field...)
}
func unpackBytes(msg []byte, off int, field []byte) (int, error) {
newOff := off + len(field)
if newOff > len(msg) {
return off, errBaseLen
}
copy(field, msg[off:newOff])
return newOff, nil
}
func skipBytes(msg []byte, off int, field []byte) (int, error) {
newOff := off + len(field)
if newOff > len(msg) {
return off, errBaseLen
}
return newOff, nil
}
const nameLen = 255
// A Name is a non-encoded domain name. It is used instead of strings to avoid
// allocations.
type Name struct {
Data [nameLen]byte
Length uint8
}
// NewName creates a new Name from a string.
func NewName(name string) (Name, error) {
if len([]byte(name)) > nameLen {
return Name{}, errCalcLen
}
n := Name{Length: uint8(len(name))}
copy(n.Data[:], []byte(name))
return n, nil
}
func (n Name) String() string {
return string(n.Data[:n.Length])
}
// pack appends the wire format of the Name to msg.
//
// Domain names are a sequence of counted strings split at the dots. They end
// with a zero-length string. Compression can be used to reuse domain suffixes.
//
// The compression map will be updated with new domain suffixes. If compression
// is nil, compression will not be used.
func (n *Name) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
oldMsg := msg
// Add a trailing dot to canonicalize name.
if n.Length == 0 || n.Data[n.Length-1] != '.' {
return oldMsg, errNonCanonicalName
}
// Allow root domain.
if n.Data[0] == '.' && n.Length == 1 {
return append(msg, 0), nil
}
// Emit sequence of counted strings, chopping at dots.
for i, begin := 0, 0; i < int(n.Length); i++ {
// Check for the end of the segment.
if n.Data[i] == '.' {
// The two most significant bits have special meaning.
// It isn't allowed for segments to be long enough to
// need them.
if i-begin >= 1<<6 {
return oldMsg, errSegTooLong
}
// Segments must have a non-zero length.
if i-begin == 0 {
return oldMsg, errZeroSegLen
}
msg = append(msg, byte(i-begin))
for j := begin; j < i; j++ {
msg = append(msg, n.Data[j])
}
begin = i + 1
continue
}
// We can only compress domain suffixes starting with a new
// segment. A pointer is two bytes with the two most significant
// bits set to 1 to indicate that it is a pointer.
if (i == 0 || n.Data[i-1] == '.') && compression != nil {
if ptr, ok := compression[string(n.Data[i:])]; ok {
// Hit. Emit a pointer instead of the rest of
// the domain.
return append(msg, byte(ptr>>8|0xC0), byte(ptr)), nil
}
// Miss. Add the suffix to the compression table if the
// offset can be stored in the available 14 bytes.
if len(msg) <= int(^uint16(0)>>2) {
compression[string(n.Data[i:])] = len(msg) - compressionOff
}
}
}
return append(msg, 0), nil
}
// unpack unpacks a domain name.
func (n *Name) unpack(msg []byte, off int) (int, error) {
// currOff is the current working offset.
currOff := off
// newOff is the offset where the next record will start. Pointers lead
// to data that belongs to other names and thus doesn't count towards to
// the usage of this name.
newOff := off
// ptr is the number of pointers followed.
var ptr int
// Name is a slice representation of the name data.
name := n.Data[:0]
Loop:
for {
if currOff >= len(msg) {
return off, errBaseLen
}
c := int(msg[currOff])
currOff++
switch c & 0xC0 {
case 0x00: // String segment
if c == 0x00 {
// A zero length signals the end of the name.
break Loop
}
endOff := currOff + c
if endOff > len(msg) {
return off, errCalcLen
}
name = append(name, msg[currOff:endOff]...)
name = append(name, '.')
currOff = endOff
case 0xC0: // Pointer
if currOff >= len(msg) {
return off, errInvalidPtr
}
c1 := msg[currOff]
currOff++
if ptr == 0 {
newOff = currOff
}
// Don't follow too many pointers, maybe there's a loop.
if ptr++; ptr > 10 {
return off, errTooManyPtr
}
currOff = (c^0xC0)<<8 | int(c1)
default:
// Prefixes 0x80 and 0x40 are reserved.
return off, errReserved
}
}
if len(name) == 0 {
name = append(name, '.')
}
if len(name) > len(n.Data) {
return off, errCalcLen
}
n.Length = uint8(len(name))
if ptr == 0 {
newOff = currOff
}
return newOff, nil
}
func skipName(msg []byte, off int) (int, error) {
// newOff is the offset where the next record will start. Pointers lead
// to data that belongs to other names and thus doesn't count towards to
// the usage of this name.
newOff := off
Loop:
for {
if newOff >= len(msg) {
return off, errBaseLen
}
c := int(msg[newOff])
newOff++
switch c & 0xC0 {
case 0x00:
if c == 0x00 {
// A zero length signals the end of the name.
break Loop
}
// literal string
newOff += c
if newOff > len(msg) {
return off, errCalcLen
}
case 0xC0:
// Pointer to somewhere else in msg.
// Pointers are two bytes.
newOff++
// Don't follow the pointer as the data here has ended.
break Loop
default:
// Prefixes 0x80 and 0x40 are reserved.
return off, errReserved
}
}
return newOff, nil
}
// A Question is a DNS query.
type Question struct {
Name Name
Type Type
Class Class
}
// pack appends the wire format of the Question to msg.
func (q *Question) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
msg, err := q.Name.pack(msg, compression, compressionOff)
if err != nil {
return msg, &nestedError{"Name", err}
}
msg = packType(msg, q.Type)
return packClass(msg, q.Class), nil
}
func unpackResourceBody(msg []byte, off int, hdr ResourceHeader) (ResourceBody, int, error) {
var (
r ResourceBody
err error
name string
)
switch hdr.Type {
case TypeA:
var rb AResource
rb, err = unpackAResource(msg, off)
r = &rb
name = "A"
case TypeNS:
var rb NSResource
rb, err = unpackNSResource(msg, off)
r = &rb
name = "NS"
case TypeCNAME:
var rb CNAMEResource
rb, err = unpackCNAMEResource(msg, off)
r = &rb
name = "CNAME"
case TypeSOA:
var rb SOAResource
rb, err = unpackSOAResource(msg, off)
r = &rb
name = "SOA"
case TypePTR:
var rb PTRResource
rb, err = unpackPTRResource(msg, off)
r = &rb
name = "PTR"
case TypeMX:
var rb MXResource
rb, err = unpackMXResource(msg, off)
r = &rb
name = "MX"
case TypeTXT:
var rb TXTResource
rb, err = unpackTXTResource(msg, off, hdr.Length)
r = &rb
name = "TXT"
case TypeAAAA:
var rb AAAAResource
rb, err = unpackAAAAResource(msg, off)
r = &rb
name = "AAAA"
case TypeSRV:
var rb SRVResource
rb, err = unpackSRVResource(msg, off)
r = &rb
name = "SRV"
}
if err != nil {
return nil, off, &nestedError{name + " record", err}
}
if r == nil {
return nil, off, errors.New("invalid resource type: " + string(hdr.Type+'0'))
}
return r, off + int(hdr.Length), nil
}
// A CNAMEResource is a CNAME Resource record.
type CNAMEResource struct {
CNAME Name
}
func (r *CNAMEResource) realType() Type {
return TypeCNAME
}
// pack appends the wire format of the CNAMEResource to msg.
func (r *CNAMEResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
return r.CNAME.pack(msg, compression, compressionOff)
}
func unpackCNAMEResource(msg []byte, off int) (CNAMEResource, error) {
var cname Name
if _, err := cname.unpack(msg, off); err != nil {
return CNAMEResource{}, err
}
return CNAMEResource{cname}, nil
}
// An MXResource is an MX Resource record.
type MXResource struct {
Pref uint16
MX Name
}
func (r *MXResource) realType() Type {
return TypeMX
}
// pack appends the wire format of the MXResource to msg.
func (r *MXResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
oldMsg := msg
msg = packUint16(msg, r.Pref)
msg, err := r.MX.pack(msg, compression, compressionOff)
if err != nil {
return oldMsg, &nestedError{"MXResource.MX", err}
}
return msg, nil
}
func unpackMXResource(msg []byte, off int) (MXResource, error) {
pref, off, err := unpackUint16(msg, off)
if err != nil {
return MXResource{}, &nestedError{"Pref", err}
}
var mx Name
if _, err := mx.unpack(msg, off); err != nil {
return MXResource{}, &nestedError{"MX", err}
}
return MXResource{pref, mx}, nil
}
// An NSResource is an NS Resource record.
type NSResource struct {
NS Name
}
func (r *NSResource) realType() Type {
return TypeNS
}
// pack appends the wire format of the NSResource to msg.
func (r *NSResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
return r.NS.pack(msg, compression, compressionOff)
}
func unpackNSResource(msg []byte, off int) (NSResource, error) {
var ns Name
if _, err := ns.unpack(msg, off); err != nil {
return NSResource{}, err
}
return NSResource{ns}, nil
}
// A PTRResource is a PTR Resource record.
type PTRResource struct {
PTR Name
}
func (r *PTRResource) realType() Type {
return TypePTR
}
// pack appends the wire format of the PTRResource to msg.
func (r *PTRResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
return r.PTR.pack(msg, compression, compressionOff)
}
func unpackPTRResource(msg []byte, off int) (PTRResource, error) {
var ptr Name
if _, err := ptr.unpack(msg, off); err != nil {
return PTRResource{}, err
}
return PTRResource{ptr}, nil
}
// An SOAResource is an SOA Resource record.
type SOAResource struct {
NS Name
MBox Name
Serial uint32
Refresh uint32
Retry uint32
Expire uint32
// MinTTL the is the default TTL of Resources records which did not
// contain a TTL value and the TTL of negative responses. (RFC 2308
// Section 4)
MinTTL uint32
}
func (r *SOAResource) realType() Type {
return TypeSOA
}
// pack appends the wire format of the SOAResource to msg.
func (r *SOAResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
oldMsg := msg
msg, err := r.NS.pack(msg, compression, compressionOff)
if err != nil {
return oldMsg, &nestedError{"SOAResource.NS", err}
}
msg, err = r.MBox.pack(msg, compression, compressionOff)
if err != nil {
return oldMsg, &nestedError{"SOAResource.MBox", err}
}
msg = packUint32(msg, r.Serial)
msg = packUint32(msg, r.Refresh)
msg = packUint32(msg, r.Retry)
msg = packUint32(msg, r.Expire)
return packUint32(msg, r.MinTTL), nil
}
func unpackSOAResource(msg []byte, off int) (SOAResource, error) {
var ns Name
off, err := ns.unpack(msg, off)
if err != nil {
return SOAResource{}, &nestedError{"NS", err}
}
var mbox Name
if off, err = mbox.unpack(msg, off); err != nil {
return SOAResource{}, &nestedError{"MBox", err}
}
serial, off, err := unpackUint32(msg, off)
if err != nil {
return SOAResource{}, &nestedError{"Serial", err}
}
refresh, off, err := unpackUint32(msg, off)
if err != nil {
return SOAResource{}, &nestedError{"Refresh", err}
}
retry, off, err := unpackUint32(msg, off)
if err != nil {
return SOAResource{}, &nestedError{"Retry", err}
}
expire, off, err := unpackUint32(msg, off)
if err != nil {
return SOAResource{}, &nestedError{"Expire", err}
}
minTTL, _, err := unpackUint32(msg, off)
if err != nil {
return SOAResource{}, &nestedError{"MinTTL", err}
}
return SOAResource{ns, mbox, serial, refresh, retry, expire, minTTL}, nil
}
// A TXTResource is a TXT Resource record.
type TXTResource struct {
TXT []string
}
func (r *TXTResource) realType() Type {
return TypeTXT
}
// pack appends the wire format of the TXTResource to msg.
func (r *TXTResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
oldMsg := msg
for _, s := range r.TXT {
var err error
msg, err = packText(msg, s)
if err != nil {
return oldMsg, err
}
}
return msg, nil
}
func unpackTXTResource(msg []byte, off int, length uint16) (TXTResource, error) {
txts := make([]string, 0, 1)
for n := uint16(0); n < length; {
var t string
var err error
if t, off, err = unpackText(msg, off); err != nil {
return TXTResource{}, &nestedError{"text", err}
}
// Check if we got too many bytes.
if length-n < uint16(len(t))+1 {
return TXTResource{}, errCalcLen
}
n += uint16(len(t)) + 1
txts = append(txts, t)
}
return TXTResource{txts}, nil
}
// An SRVResource is an SRV Resource record.
type SRVResource struct {
Priority uint16
Weight uint16
Port uint16
Target Name // Not compressed as per RFC 2782.
}
func (r *SRVResource) realType() Type {
return TypeSRV
}
// pack appends the wire format of the SRVResource to msg.
func (r *SRVResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
oldMsg := msg
msg = packUint16(msg, r.Priority)
msg = packUint16(msg, r.Weight)
msg = packUint16(msg, r.Port)
msg, err := r.Target.pack(msg, nil, compressionOff)
if err != nil {
return oldMsg, &nestedError{"SRVResource.Target", err}
}
return msg, nil
}
func unpackSRVResource(msg []byte, off int) (SRVResource, error) {
priority, off, err := unpackUint16(msg, off)
if err != nil {
return SRVResource{}, &nestedError{"Priority", err}
}
weight, off, err := unpackUint16(msg, off)
if err != nil {
return SRVResource{}, &nestedError{"Weight", err}
}
port, off, err := unpackUint16(msg, off)
if err != nil {
return SRVResource{}, &nestedError{"Port", err}
}
var target Name
if _, err := target.unpack(msg, off); err != nil {
return SRVResource{}, &nestedError{"Target", err}
}
return SRVResource{priority, weight, port, target}, nil
}
// An AResource is an A Resource record.
type AResource struct {
A [4]byte
}
func (r *AResource) realType() Type {
return TypeA
}
// pack appends the wire format of the AResource to msg.
func (r *AResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
return packBytes(msg, r.A[:]), nil
}
func unpackAResource(msg []byte, off int) (AResource, error) {
var a [4]byte
if _, err := unpackBytes(msg, off, a[:]); err != nil {
return AResource{}, err
}
return AResource{a}, nil
}
// An AAAAResource is an AAAA Resource record.
type AAAAResource struct {
AAAA [16]byte
}
func (r *AAAAResource) realType() Type {
return TypeAAAA
}
// pack appends the wire format of the AAAAResource to msg.
func (r *AAAAResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
return packBytes(msg, r.AAAA[:]), nil
}
func unpackAAAAResource(msg []byte, off int) (AAAAResource, error) {
var aaaa [16]byte
if _, err := unpackBytes(msg, off, aaaa[:]); err != nil {
return AAAAResource{}, err
}
return AAAAResource{aaaa}, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dnsmessage
import (
"bytes"
"fmt"
"reflect"
"strings"
"testing"
)
func mustNewName(name string) Name {
n, err := NewName(name)
if err != nil {
panic(err)
}
return n
}
func (m *Message) String() string {
s := fmt.Sprintf("Message: %#v\n", &m.Header)
if len(m.Questions) > 0 {
s += "-- Questions\n"
for _, q := range m.Questions {
s += fmt.Sprintf("%#v\n", q)
}
}
if len(m.Answers) > 0 {
s += "-- Answers\n"
for _, a := range m.Answers {
s += fmt.Sprintf("%#v\n", a)
}
}
if len(m.Authorities) > 0 {
s += "-- Authorities\n"
for _, ns := range m.Authorities {
s += fmt.Sprintf("%#v\n", ns)
}
}
if len(m.Additionals) > 0 {
s += "-- Additionals\n"
for _, e := range m.Additionals {
s += fmt.Sprintf("%#v\n", e)
}
}
return s
}
func TestNameString(t *testing.T) {
want := "foo"
name := mustNewName(want)
if got := fmt.Sprint(name); got != want {
t.Errorf("got fmt.Sprint(%#v) = %s, want = %s", name, got, want)
}
}
func TestQuestionPackUnpack(t *testing.T) {
want := Question{
Name: mustNewName("."),
Type: TypeA,
Class: ClassINET,
}
buf, err := want.pack(make([]byte, 1, 50), map[string]int{}, 1)
if err != nil {
t.Fatal("Packing failed:", err)
}
var p Parser
p.msg = buf
p.header.questions = 1
p.section = sectionQuestions
p.off = 1
got, err := p.Question()
if err != nil {
t.Fatalf("Unpacking failed: %v\n%s", err, string(buf[1:]))
}
if p.off != len(buf) {
t.Errorf("Unpacked different amount than packed: got n = %d, want = %d", p.off, len(buf))
}
if !reflect.DeepEqual(got, want) {
t.Errorf("Got = %+v, want = %+v", got, want)
}
}
func TestName(t *testing.T) {
tests := []string{
"",
".",
"google..com",
"google.com",
"google..com.",
"google.com.",
".google.com.",
"www..google.com.",
"www.google.com.",
}
for _, test := range tests {
n, err := NewName(test)
if err != nil {
t.Errorf("Creating name for %q: %v", test, err)
continue
}
if ns := n.String(); ns != test {
t.Errorf("Got %#v.String() = %q, want = %q", n, ns, test)
continue
}
}
}
func TestNamePackUnpack(t *testing.T) {
tests := []struct {
in string
want string
err error
}{
{"", "", errNonCanonicalName},
{".", ".", nil},
{"google..com", "", errNonCanonicalName},
{"google.com", "", errNonCanonicalName},
{"google..com.", "", errZeroSegLen},
{"google.com.", "google.com.", nil},
{".google.com.", "", errZeroSegLen},
{"www..google.com.", "", errZeroSegLen},
{"www.google.com.", "www.google.com.", nil},
}
for _, test := range tests {
in := mustNewName(test.in)
want := mustNewName(test.want)
buf, err := in.pack(make([]byte, 0, 30), map[string]int{}, 0)
if err != test.err {
t.Errorf("Packing of %q: got err = %v, want err = %v", test.in, err, test.err)
continue
}
if test.err != nil {
continue
}
var got Name
n, err := got.unpack(buf, 0)
if err != nil {
t.Errorf("Unpacking for %q failed: %v", test.in, err)
continue
}
if n != len(buf) {
t.Errorf(
"Unpacked different amount than packed for %q: got n = %d, want = %d",
test.in,
n,
len(buf),
)
}
if got != want {
t.Errorf("Unpacking packing of %q: got = %#v, want = %#v", test.in, got, want)
}
}
}
func checkErrorPrefix(err error, prefix string) bool {
e, ok := err.(*nestedError)
return ok && e.s == prefix
}
func TestHeaderUnpackError(t *testing.T) {
wants := []string{
"id",
"bits",
"questions",
"answers",
"authorities",
"additionals",
}
var buf []byte
var h header
for _, want := range wants {
n, err := h.unpack(buf, 0)
if n != 0 || !checkErrorPrefix(err, want) {
t.Errorf("got h.unpack([%d]byte, 0) = %d, %v, want = 0, %s", len(buf), n, err, want)
}
buf = append(buf, 0, 0)
}
}
func TestParserStart(t *testing.T) {
const want = "unpacking header"
var p Parser
for i := 0; i <= 1; i++ {
_, err := p.Start([]byte{})
if !checkErrorPrefix(err, want) {
t.Errorf("got p.Start(nil) = _, %v, want = _, %s", err, want)
}
}
}
func TestResourceNotStarted(t *testing.T) {
tests := []struct {
name string
fn func(*Parser) error
}{
{"CNAMEResource", func(p *Parser) error { _, err := p.CNAMEResource(); return err }},
{"MXResource", func(p *Parser) error { _, err := p.MXResource(); return err }},
{"NSResource", func(p *Parser) error { _, err := p.NSResource(); return err }},
{"PTRResource", func(p *Parser) error { _, err := p.PTRResource(); return err }},
{"SOAResource", func(p *Parser) error { _, err := p.SOAResource(); return err }},
{"TXTResource", func(p *Parser) error { _, err := p.TXTResource(); return err }},
{"SRVResource", func(p *Parser) error { _, err := p.SRVResource(); return err }},
{"AResource", func(p *Parser) error { _, err := p.AResource(); return err }},
{"AAAAResource", func(p *Parser) error { _, err := p.AAAAResource(); return err }},
}
for _, test := range tests {
if err := test.fn(&Parser{}); err != ErrNotStarted {
t.Errorf("got _, %v = p.%s(), want = _, %v", err, test.name, ErrNotStarted)
}
}
}
func TestDNSPackUnpack(t *testing.T) {
wants := []Message{
{
Questions: []Question{
{
Name: mustNewName("."),
Type: TypeAAAA,
Class: ClassINET,
},
},
Answers: []Resource{},
Authorities: []Resource{},
Additionals: []Resource{},
},
largeTestMsg(),
}
for i, want := range wants {
b, err := want.Pack()
if err != nil {
t.Fatalf("%d: packing failed: %v", i, err)
}
var got Message
err = got.Unpack(b)
if err != nil {
t.Fatalf("%d: unpacking failed: %v", i, err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("%d: got = %+v, want = %+v", i, &got, &want)
}
}
}
func TestDNSAppendPackUnpack(t *testing.T) {
wants := []Message{
{
Questions: []Question{
{
Name: mustNewName("."),
Type: TypeAAAA,
Class: ClassINET,
},
},
Answers: []Resource{},
Authorities: []Resource{},
Additionals: []Resource{},
},
largeTestMsg(),
}
for i, want := range wants {
b := make([]byte, 2, 514)
b, err := want.AppendPack(b)
if err != nil {
t.Fatalf("%d: packing failed: %v", i, err)
}
b = b[2:]
var got Message
err = got.Unpack(b)
if err != nil {
t.Fatalf("%d: unpacking failed: %v", i, err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("%d: got = %+v, want = %+v", i, &got, &want)
}
}
}
func TestSkipAll(t *testing.T) {
msg := largeTestMsg()
buf, err := msg.Pack()
if err != nil {
t.Fatal("Packing large test message:", err)
}
var p Parser
if _, err := p.Start(buf); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
f func() error
}{
{"SkipAllQuestions", p.SkipAllQuestions},
{"SkipAllAnswers", p.SkipAllAnswers},
{"SkipAllAuthorities", p.SkipAllAuthorities},
{"SkipAllAdditionals", p.SkipAllAdditionals},
}
for _, test := range tests {
for i := 1; i <= 3; i++ {
if err := test.f(); err != nil {
t.Errorf("Call #%d to %s(): %v", i, test.name, err)
}
}
}
}
func TestSkipEach(t *testing.T) {
msg := smallTestMsg()
buf, err := msg.Pack()
if err != nil {
t.Fatal("Packing test message:", err)
}
var p Parser
if _, err := p.Start(buf); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
f func() error
}{
{"SkipQuestion", p.SkipQuestion},
{"SkipAnswer", p.SkipAnswer},
{"SkipAuthority", p.SkipAuthority},
{"SkipAdditional", p.SkipAdditional},
}
for _, test := range tests {
if err := test.f(); err != nil {
t.Errorf("First call: got %s() = %v, want = %v", test.name, err, nil)
}
if err := test.f(); err != ErrSectionDone {
t.Errorf("Second call: got %s() = %v, want = %v", test.name, err, ErrSectionDone)
}
}
}
func TestSkipAfterRead(t *testing.T) {
msg := smallTestMsg()
buf, err := msg.Pack()
if err != nil {
t.Fatal("Packing test message:", err)
}
var p Parser
if _, err := p.Start(buf); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
skip func() error
read func() error
}{
{"Question", p.SkipQuestion, func() error { _, err := p.Question(); return err }},
{"Answer", p.SkipAnswer, func() error { _, err := p.Answer(); return err }},
{"Authority", p.SkipAuthority, func() error { _, err := p.Authority(); return err }},
{"Additional", p.SkipAdditional, func() error { _, err := p.Additional(); return err }},
}
for _, test := range tests {
if err := test.read(); err != nil {
t.Errorf("Got %s() = _, %v, want = _, %v", test.name, err, nil)
}
if err := test.skip(); err != ErrSectionDone {
t.Errorf("Got Skip%s() = %v, want = %v", test.name, err, ErrSectionDone)
}
}
}
func TestSkipNotStarted(t *testing.T) {
var p Parser
tests := []struct {
name string
f func() error
}{
{"SkipAllQuestions", p.SkipAllQuestions},
{"SkipAllAnswers", p.SkipAllAnswers},
{"SkipAllAuthorities", p.SkipAllAuthorities},
{"SkipAllAdditionals", p.SkipAllAdditionals},
}
for _, test := range tests {
if err := test.f(); err != ErrNotStarted {
t.Errorf("Got %s() = %v, want = %v", test.name, err, ErrNotStarted)
}
}
}
func TestTooManyRecords(t *testing.T) {
const recs = int(^uint16(0)) + 1
tests := []struct {
name string
msg Message
want error
}{
{
"Questions",
Message{
Questions: make([]Question, recs),
},
errTooManyQuestions,
},
{
"Answers",
Message{
Answers: make([]Resource, recs),
},
errTooManyAnswers,
},
{
"Authorities",
Message{
Authorities: make([]Resource, recs),
},
errTooManyAuthorities,
},
{
"Additionals",
Message{
Additionals: make([]Resource, recs),
},
errTooManyAdditionals,
},
}
for _, test := range tests {
if _, got := test.msg.Pack(); got != test.want {
t.Errorf("Packing %d %s: got = %v, want = %v", recs, test.name, got, test.want)
}
}
}
func TestVeryLongTxt(t *testing.T) {
want := Resource{
ResourceHeader{
Name: mustNewName("foo.bar.example.com."),
Type: TypeTXT,
Class: ClassINET,
},
&TXTResource{[]string{
"",
"",
"foo bar",
"",
"www.example.com",
"www.example.com.",
strings.Repeat(".", 255),
}},
}
buf, err := want.pack(make([]byte, 0, 8000), map[string]int{}, 0)
if err != nil {
t.Fatal("Packing failed:", err)
}
var got Resource
off, err := got.Header.unpack(buf, 0)
if err != nil {
t.Fatal("Unpacking ResourceHeader failed:", err)
}
body, n, err := unpackResourceBody(buf, off, got.Header)
if err != nil {
t.Fatal("Unpacking failed:", err)
}
got.Body = body
if n != len(buf) {
t.Errorf("Unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
}
if !reflect.DeepEqual(got, want) {
t.Errorf("Got = %#v, want = %#v", got, want)
}
}
func TestTooLongTxt(t *testing.T) {
rb := TXTResource{[]string{strings.Repeat(".", 256)}}
if _, err := rb.pack(make([]byte, 0, 8000), map[string]int{}, 0); err != errStringTooLong {
t.Errorf("Packing TXTRecord with 256 character string: got err = %v, want = %v", err, errStringTooLong)
}
}
func TestStartAppends(t *testing.T) {
buf := make([]byte, 2, 514)
wantBuf := []byte{4, 44}
copy(buf, wantBuf)
b := NewBuilder(buf, Header{})
b.EnableCompression()
buf, err := b.Finish()
if err != nil {
t.Fatal("Building failed:", err)
}
if got, want := len(buf), headerLen+2; got != want {
t.Errorf("Got len(buf} = %d, want = %d", got, want)
}
if string(buf[:2]) != string(wantBuf) {
t.Errorf("Original data not preserved, got = %v, want = %v", buf[:2], wantBuf)
}
}
func TestStartError(t *testing.T) {
tests := []struct {
name string
fn func(*Builder) error
}{
{"Questions", func(b *Builder) error { return b.StartQuestions() }},
{"Answers", func(b *Builder) error { return b.StartAnswers() }},
{"Authorities", func(b *Builder) error { return b.StartAuthorities() }},
{"Additionals", func(b *Builder) error { return b.StartAdditionals() }},
}
envs := []struct {
name string
fn func() *Builder
want error
}{
{"sectionNotStarted", func() *Builder { return &Builder{section: sectionNotStarted} }, ErrNotStarted},
{"sectionDone", func() *Builder { return &Builder{section: sectionDone} }, ErrSectionDone},
}
for _, env := range envs {
for _, test := range tests {
if got := test.fn(env.fn()); got != env.want {
t.Errorf("got Builder{%s}.Start%s = %v, want = %v", env.name, test.name, got, env.want)
}
}
}
}
func TestBuilderResourceError(t *testing.T) {
tests := []struct {
name string
fn func(*Builder) error
}{
{"CNAMEResource", func(b *Builder) error { return b.CNAMEResource(ResourceHeader{}, CNAMEResource{}) }},
{"MXResource", func(b *Builder) error { return b.MXResource(ResourceHeader{}, MXResource{}) }},
{"NSResource", func(b *Builder) error { return b.NSResource(ResourceHeader{}, NSResource{}) }},
{"PTRResource", func(b *Builder) error { return b.PTRResource(ResourceHeader{}, PTRResource{}) }},
{"SOAResource", func(b *Builder) error { return b.SOAResource(ResourceHeader{}, SOAResource{}) }},
{"TXTResource", func(b *Builder) error { return b.TXTResource(ResourceHeader{}, TXTResource{}) }},
{"SRVResource", func(b *Builder) error { return b.SRVResource(ResourceHeader{}, SRVResource{}) }},
{"AResource", func(b *Builder) error { return b.AResource(ResourceHeader{}, AResource{}) }},
{"AAAAResource", func(b *Builder) error { return b.AAAAResource(ResourceHeader{}, AAAAResource{}) }},
}
envs := []struct {
name string
fn func() *Builder
want error
}{
{"sectionNotStarted", func() *Builder { return &Builder{section: sectionNotStarted} }, ErrNotStarted},
{"sectionHeader", func() *Builder { return &Builder{section: sectionHeader} }, ErrNotStarted},
{"sectionQuestions", func() *Builder { return &Builder{section: sectionQuestions} }, ErrNotStarted},
{"sectionDone", func() *Builder { return &Builder{section: sectionDone} }, ErrSectionDone},
}
for _, env := range envs {
for _, test := range tests {
if got := test.fn(env.fn()); got != env.want {
t.Errorf("got Builder{%s}.%s = %v, want = %v", env.name, test.name, got, env.want)
}
}
}
}
func TestFinishError(t *testing.T) {
var b Builder
want := ErrNotStarted
if _, got := b.Finish(); got != want {
t.Errorf("got Builder{}.Finish() = %v, want = %v", got, want)
}
}
func TestBuilder(t *testing.T) {
msg := largeTestMsg()
want, err := msg.Pack()
if err != nil {
t.Fatal("Packing without builder:", err)
}
b := NewBuilder(nil, msg.Header)
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
t.Fatal("b.StartQuestions():", err)
}
for _, q := range msg.Questions {
if err := b.Question(q); err != nil {
t.Fatalf("b.Question(%#v): %v", q, err)
}
}
if err := b.StartAnswers(); err != nil {
t.Fatal("b.StartAnswers():", err)
}
for _, a := range msg.Answers {
switch a.Header.Type {
case TypeA:
if err := b.AResource(a.Header, *a.Body.(*AResource)); err != nil {
t.Fatalf("b.AResource(%#v): %v", a, err)
}
case TypeNS:
if err := b.NSResource(a.Header, *a.Body.(*NSResource)); err != nil {
t.Fatalf("b.NSResource(%#v): %v", a, err)
}
case TypeCNAME:
if err := b.CNAMEResource(a.Header, *a.Body.(*CNAMEResource)); err != nil {
t.Fatalf("b.CNAMEResource(%#v): %v", a, err)
}
case TypeSOA:
if err := b.SOAResource(a.Header, *a.Body.(*SOAResource)); err != nil {
t.Fatalf("b.SOAResource(%#v): %v", a, err)
}
case TypePTR:
if err := b.PTRResource(a.Header, *a.Body.(*PTRResource)); err != nil {
t.Fatalf("b.PTRResource(%#v): %v", a, err)
}
case TypeMX:
if err := b.MXResource(a.Header, *a.Body.(*MXResource)); err != nil {
t.Fatalf("b.MXResource(%#v): %v", a, err)
}
case TypeTXT:
if err := b.TXTResource(a.Header, *a.Body.(*TXTResource)); err != nil {
t.Fatalf("b.TXTResource(%#v): %v", a, err)
}
case TypeAAAA:
if err := b.AAAAResource(a.Header, *a.Body.(*AAAAResource)); err != nil {
t.Fatalf("b.AAAAResource(%#v): %v", a, err)
}
case TypeSRV:
if err := b.SRVResource(a.Header, *a.Body.(*SRVResource)); err != nil {
t.Fatalf("b.SRVResource(%#v): %v", a, err)
}
}
}
if err := b.StartAuthorities(); err != nil {
t.Fatal("b.StartAuthorities():", err)
}
for _, a := range msg.Authorities {
if err := b.NSResource(a.Header, *a.Body.(*NSResource)); err != nil {
t.Fatalf("b.NSResource(%#v): %v", a, err)
}
}
if err := b.StartAdditionals(); err != nil {
t.Fatal("b.StartAdditionals():", err)
}
for _, a := range msg.Additionals {
if err := b.TXTResource(a.Header, *a.Body.(*TXTResource)); err != nil {
t.Fatalf("b.TXTResource(%#v): %v", a, err)
}
}
got, err := b.Finish()
if err != nil {
t.Fatal("b.Finish():", err)
}
if !bytes.Equal(got, want) {
t.Fatalf("Got from Builder: %#v\nwant = %#v", got, want)
}
}
func TestResourcePack(t *testing.T) {
for _, tt := range []struct {
m Message
err error
}{
{
Message{
Questions: []Question{
{
Name: mustNewName("."),
Type: TypeAAAA,
Class: ClassINET,
},
},
Answers: []Resource{{ResourceHeader{}, nil}},
},
&nestedError{"packing Answer", errNilResouceBody},
},
{
Message{
Questions: []Question{
{
Name: mustNewName("."),
Type: TypeAAAA,
Class: ClassINET,
},
},
Authorities: []Resource{{ResourceHeader{}, (*NSResource)(nil)}},
},
&nestedError{"packing Authority",
&nestedError{"ResourceHeader",
&nestedError{"Name", errNonCanonicalName},
},
},
},
{
Message{
Questions: []Question{
{
Name: mustNewName("."),
Type: TypeA,
Class: ClassINET,
},
},
Additionals: []Resource{{ResourceHeader{}, nil}},
},
&nestedError{"packing Additional", errNilResouceBody},
},
} {
_, err := tt.m.Pack()
if !reflect.DeepEqual(err, tt.err) {
t.Errorf("got %v for %v; want %v", err, tt.m, tt.err)
}
}
}
func benchmarkParsingSetup() ([]byte, error) {
name := mustNewName("foo.bar.example.com.")
msg := Message{
Header: Header{Response: true, Authoritative: true},
Questions: []Question{
{
Name: name,
Type: TypeA,
Class: ClassINET,
},
},
Answers: []Resource{
{
ResourceHeader{
Name: name,
Class: ClassINET,
},
&AResource{[4]byte{}},
},
{
ResourceHeader{
Name: name,
Class: ClassINET,
},
&AAAAResource{[16]byte{}},
},
{
ResourceHeader{
Name: name,
Class: ClassINET,
},
&CNAMEResource{name},
},
{
ResourceHeader{
Name: name,
Class: ClassINET,
},
&NSResource{name},
},
},
}
buf, err := msg.Pack()
if err != nil {
return nil, fmt.Errorf("msg.Pack(): %v", err)
}
return buf, nil
}
func benchmarkParsing(tb testing.TB, buf []byte) {
var p Parser
if _, err := p.Start(buf); err != nil {
tb.Fatal("p.Start(buf):", err)
}
for {
_, err := p.Question()
if err == ErrSectionDone {
break
}
if err != nil {
tb.Fatal("p.Question():", err)
}
}
for {
h, err := p.AnswerHeader()
if err == ErrSectionDone {
break
}
if err != nil {
panic(err)
}
switch h.Type {
case TypeA:
if _, err := p.AResource(); err != nil {
tb.Fatal("p.AResource():", err)
}
case TypeAAAA:
if _, err := p.AAAAResource(); err != nil {
tb.Fatal("p.AAAAResource():", err)
}
case TypeCNAME:
if _, err := p.CNAMEResource(); err != nil {
tb.Fatal("p.CNAMEResource():", err)
}
case TypeNS:
if _, err := p.NSResource(); err != nil {
tb.Fatal("p.NSResource():", err)
}
default:
tb.Fatalf("unknown type: %T", h)
}
}
}
func BenchmarkParsing(b *testing.B) {
buf, err := benchmarkParsingSetup()
if err != nil {
b.Fatal(err)
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkParsing(b, buf)
}
}
func TestParsingAllocs(t *testing.T) {
buf, err := benchmarkParsingSetup()
if err != nil {
t.Fatal(err)
}
if allocs := testing.AllocsPerRun(100, func() { benchmarkParsing(t, buf) }); allocs > 0.5 {
t.Errorf("Allocations during parsing: got = %f, want ~0", allocs)
}
}
func benchmarkBuildingSetup() (Name, []byte) {
name := mustNewName("foo.bar.example.com.")
buf := make([]byte, 0, packStartingCap)
return name, buf
}
func benchmarkBuilding(tb testing.TB, name Name, buf []byte) {
bld := NewBuilder(buf, Header{Response: true, Authoritative: true})
if err := bld.StartQuestions(); err != nil {
tb.Fatal("bld.StartQuestions():", err)
}
q := Question{
Name: name,
Type: TypeA,
Class: ClassINET,
}
if err := bld.Question(q); err != nil {
tb.Fatalf("bld.Question(%+v): %v", q, err)
}
hdr := ResourceHeader{
Name: name,
Class: ClassINET,
}
if err := bld.StartAnswers(); err != nil {
tb.Fatal("bld.StartQuestions():", err)
}
ar := AResource{[4]byte{}}
if err := bld.AResource(hdr, ar); err != nil {
tb.Fatalf("bld.AResource(%+v, %+v): %v", hdr, ar, err)
}
aaar := AAAAResource{[16]byte{}}
if err := bld.AAAAResource(hdr, aaar); err != nil {
tb.Fatalf("bld.AAAAResource(%+v, %+v): %v", hdr, aaar, err)
}
cnr := CNAMEResource{name}
if err := bld.CNAMEResource(hdr, cnr); err != nil {
tb.Fatalf("bld.CNAMEResource(%+v, %+v): %v", hdr, cnr, err)
}
nsr := NSResource{name}
if err := bld.NSResource(hdr, nsr); err != nil {
tb.Fatalf("bld.NSResource(%+v, %+v): %v", hdr, nsr, err)
}
if _, err := bld.Finish(); err != nil {
tb.Fatal("bld.Finish():", err)
}
}
func BenchmarkBuilding(b *testing.B) {
name, buf := benchmarkBuildingSetup()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkBuilding(b, name, buf)
}
}
func TestBuildingAllocs(t *testing.T) {
name, buf := benchmarkBuildingSetup()
if allocs := testing.AllocsPerRun(100, func() { benchmarkBuilding(t, name, buf) }); allocs > 0.5 {
t.Errorf("Allocations during building: got = %f, want ~0", allocs)
}
}
func smallTestMsg() Message {
name := mustNewName("example.com.")
return Message{
Header: Header{Response: true, Authoritative: true},
Questions: []Question{
{
Name: name,
Type: TypeA,
Class: ClassINET,
},
},
Answers: []Resource{
{
ResourceHeader{
Name: name,
Type: TypeA,
Class: ClassINET,
},
&AResource{[4]byte{127, 0, 0, 1}},
},
},
Authorities: []Resource{
{
ResourceHeader{
Name: name,
Type: TypeA,
Class: ClassINET,
},
&AResource{[4]byte{127, 0, 0, 1}},
},
},
Additionals: []Resource{
{
ResourceHeader{
Name: name,
Type: TypeA,
Class: ClassINET,
},
&AResource{[4]byte{127, 0, 0, 1}},
},
},
}
}
func BenchmarkPack(b *testing.B) {
msg := largeTestMsg()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := msg.Pack(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAppendPack(b *testing.B) {
msg := largeTestMsg()
buf := make([]byte, 0, packStartingCap)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := msg.AppendPack(buf[:0]); err != nil {
b.Fatal(err)
}
}
}
func largeTestMsg() Message {
name := mustNewName("foo.bar.example.com.")
return Message{
Header: Header{Response: true, Authoritative: true},
Questions: []Question{
{
Name: name,
Type: TypeA,
Class: ClassINET,
},
},
Answers: []Resource{
{
ResourceHeader{
Name: name,
Type: TypeA,
Class: ClassINET,
},
&AResource{[4]byte{127, 0, 0, 1}},
},
{
ResourceHeader{
Name: name,
Type: TypeA,
Class: ClassINET,
},
&AResource{[4]byte{127, 0, 0, 2}},
},
{
ResourceHeader{
Name: name,
Type: TypeAAAA,
Class: ClassINET,
},
&AAAAResource{[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}},
},
{
ResourceHeader{
Name: name,
Type: TypeCNAME,
Class: ClassINET,
},
&CNAMEResource{mustNewName("alias.example.com.")},
},
{
ResourceHeader{
Name: name,
Type: TypeSOA,
Class: ClassINET,
},
&SOAResource{
NS: mustNewName("ns1.example.com."),
MBox: mustNewName("mb.example.com."),
Serial: 1,
Refresh: 2,
Retry: 3,
Expire: 4,
MinTTL: 5,
},
},
{
ResourceHeader{
Name: name,
Type: TypePTR,
Class: ClassINET,
},
&PTRResource{mustNewName("ptr.example.com.")},
},
{
ResourceHeader{
Name: name,
Type: TypeMX,
Class: ClassINET,
},
&MXResource{
7,
mustNewName("mx.example.com."),
},
},
{
ResourceHeader{
Name: name,
Type: TypeSRV,
Class: ClassINET,
},
&SRVResource{
8,
9,
11,
mustNewName("srv.example.com."),
},
},
},
Authorities: []Resource{
{
ResourceHeader{
Name: name,
Type: TypeNS,
Class: ClassINET,
},
&NSResource{mustNewName("ns1.example.com.")},
},
{
ResourceHeader{
Name: name,
Type: TypeNS,
Class: ClassINET,
},
&NSResource{mustNewName("ns2.example.com.")},
},
},
Additionals: []Resource{
{
ResourceHeader{
Name: name,
Type: TypeTXT,
Class: ClassINET,
},
&TXTResource{[]string{"So Long, and Thanks for All the Fish"}},
},
{
ResourceHeader{
Name: name,
Type: TypeTXT,
Class: ClassINET,
},
&TXTResource{[]string{"Hamster Huey and the Gooey Kablooie"}},
},
},
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment