Commit 952c2fd6 authored by Ian Gudger's avatar Ian Gudger Committed by Matthew Dempsky

net: fix packDomainName encoding of root and invalid names

Fixes #14372

Change-Id: I40d594582639e87ef2574d37ac868e37ffaa17dc
Reviewed-on: https://go-review.googlesource.com/19623Reviewed-by: default avatarMatthew Dempsky <mdempsky@google.com>
parent 3ddfaa56
...@@ -406,6 +406,13 @@ func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { ...@@ -406,6 +406,13 @@ func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
s += "." s += "."
} }
// Allow root domain.
if s == "." {
msg[off] = 0
off++
return off, true
}
// Each dot ends a segment of the name. // Each dot ends a segment of the name.
// We trade each dot byte for a length byte. // We trade each dot byte for a length byte.
// There is also a trailing zero. // There is also a trailing zero.
...@@ -422,8 +429,13 @@ func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { ...@@ -422,8 +429,13 @@ func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
if i-begin >= 1<<6 { // top two bits of length must be clear if i-begin >= 1<<6 { // top two bits of length must be clear
return len(msg), false return len(msg), false
} }
if i-begin == 0 {
return len(msg), false
}
msg[off] = byte(i - begin) msg[off] = byte(i - begin)
off++ off++
for j := begin; j < i; j++ { for j := begin; j < i; j++ {
msg[off] = s[j] msg[off] = s[j]
off++ off++
...@@ -494,6 +506,9 @@ Loop: ...@@ -494,6 +506,9 @@ Loop:
return "", len(msg), false return "", len(msg), false
} }
} }
if len(s) == 0 {
s = "."
}
if ptr == 0 { if ptr == 0 {
off1 = off off1 = off
} }
...@@ -803,20 +818,32 @@ func (dns *dnsMsg) Pack() (msg []byte, ok bool) { ...@@ -803,20 +818,32 @@ func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
// Pack it in: header and then the pieces. // Pack it in: header and then the pieces.
off := 0 off := 0
off, ok = packStruct(&dh, msg, off) off, ok = packStruct(&dh, msg, off)
if !ok {
return nil, false
}
for i := 0; i < len(question); i++ { for i := 0; i < len(question); i++ {
off, ok = packStruct(&question[i], msg, off) off, ok = packStruct(&question[i], msg, off)
if !ok {
return nil, false
}
} }
for i := 0; i < len(answer); i++ { for i := 0; i < len(answer); i++ {
off, ok = packRR(answer[i], msg, off) off, ok = packRR(answer[i], msg, off)
if !ok {
return nil, false
}
} }
for i := 0; i < len(ns); i++ { for i := 0; i < len(ns); i++ {
off, ok = packRR(ns[i], msg, off) off, ok = packRR(ns[i], msg, off)
if !ok {
return nil, false
}
} }
for i := 0; i < len(extra); i++ { for i := 0; i < len(extra); i++ {
off, ok = packRR(extra[i], msg, off) off, ok = packRR(extra[i], msg, off)
} if !ok {
if !ok { return nil, false
return nil, false }
} }
return msg[0:off], true return msg[0:off], true
} }
...@@ -848,6 +875,9 @@ func (dns *dnsMsg) Unpack(msg []byte) bool { ...@@ -848,6 +875,9 @@ func (dns *dnsMsg) Unpack(msg []byte) bool {
for i := 0; i < len(dns.question); i++ { for i := 0; i < len(dns.question); i++ {
off, ok = unpackStruct(&dns.question[i], msg, off) off, ok = unpackStruct(&dns.question[i], msg, off)
if !ok {
return false
}
} }
for i := 0; i < int(dh.Ancount); i++ { for i := 0; i < int(dh.Ancount); i++ {
rec, off, ok = unpackRR(msg, off) rec, off, ok = unpackRR(msg, off)
......
...@@ -10,6 +10,103 @@ import ( ...@@ -10,6 +10,103 @@ import (
"testing" "testing"
) )
func TestStructPackUnpack(t *testing.T) {
want := dnsQuestion{
Name: ".",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
}
buf := make([]byte, 50)
n, ok := packStruct(&want, buf, 0)
if !ok {
t.Fatal("packing failed")
}
buf = buf[:n]
got := dnsQuestion{}
n, ok = unpackStruct(&got, buf, 0)
if !ok {
t.Fatal("unpacking failed")
}
if n != len(buf) {
t.Error("unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %+v, want = %+v", got, want)
}
}
func TestDomainNamePackUnpack(t *testing.T) {
tests := []struct {
in string
want string
ok bool
}{
{"", ".", true},
{".", ".", true},
{"google..com", "", false},
{"google.com", "google.com.", true},
{"google..com.", "", false},
{"google.com.", "google.com.", true},
{".google.com.", "", false},
{"www..google.com.", "", false},
{"www.google.com.", "www.google.com.", true},
}
for _, test := range tests {
buf := make([]byte, 30)
n, ok := packDomainName(test.in, buf, 0)
if ok != test.ok {
t.Errorf("packing of %s: got ok = %t, want = %t", test.in, ok, test.ok)
continue
}
if !test.ok {
continue
}
buf = buf[:n]
got, n, ok := unpackDomainName(buf, 0)
if !ok {
t.Errorf("unpacking for %s failed", test.in)
continue
}
if n != len(buf) {
t.Error(
"unpacked different amount than packed for %s: got n = %d, want = %d",
test.in,
n,
len(buf),
)
}
if got != test.want {
t.Errorf("unpacking packing of %s: got = %s, want = %s", test.in, got, test.want)
}
}
}
func TestDNSPackUnpack(t *testing.T) {
want := dnsMsg{
question: []dnsQuestion{{
Name: ".",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
}},
answer: []dnsRR{},
ns: []dnsRR{},
extra: []dnsRR{},
}
b, ok := want.Pack()
if !ok {
t.Fatal("packing failed")
}
var got dnsMsg
ok = got.Unpack(b)
if !ok {
t.Fatal("unpacking failed")
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got = %+v, want = %+v", got, want)
}
}
func TestDNSParseSRVReply(t *testing.T) { func TestDNSParseSRVReply(t *testing.T) {
data, err := hex.DecodeString(dnsSRVReply) data, err := hex.DecodeString(dnsSRVReply)
if err != nil { if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment