Commit da7a51d1 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: don't error when marshalling nil IP addresses

See https://code.google.com/p/go/issues/detail?id=6339#c3

Fixes #6339

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/13553044
parent 3b6b53f4
...@@ -315,6 +315,9 @@ func (ip IP) String() string { ...@@ -315,6 +315,9 @@ func (ip IP) String() string {
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by String. // The encoding is the same as returned by String.
func (ip IP) MarshalText() ([]byte, error) { func (ip IP) MarshalText() ([]byte, error) {
if len(ip) == 0 {
return []byte(""), nil
}
if len(ip) != IPv4len && len(ip) != IPv6len { if len(ip) != IPv4len && len(ip) != IPv6len {
return nil, errors.New("invalid IP address") return nil, errors.New("invalid IP address")
} }
...@@ -324,6 +327,10 @@ func (ip IP) MarshalText() ([]byte, error) { ...@@ -324,6 +327,10 @@ func (ip IP) MarshalText() ([]byte, error) {
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
// The IP address is expected in a form accepted by ParseIP. // The IP address is expected in a form accepted by ParseIP.
func (ip *IP) UnmarshalText(text []byte) error { func (ip *IP) UnmarshalText(text []byte) error {
if len(text) == 0 {
*ip = nil
return nil
}
s := string(text) s := string(text)
x := ParseIP(s) x := ParseIP(s)
if x == nil { if x == nil {
......
...@@ -32,14 +32,35 @@ func TestParseIP(t *testing.T) { ...@@ -32,14 +32,35 @@ func TestParseIP(t *testing.T) {
if out := ParseIP(tt.in); !reflect.DeepEqual(out, tt.out) { if out := ParseIP(tt.in); !reflect.DeepEqual(out, tt.out) {
t.Errorf("ParseIP(%q) = %v, want %v", tt.in, out, tt.out) t.Errorf("ParseIP(%q) = %v, want %v", tt.in, out, tt.out)
} }
if tt.in == "" {
// Tested in TestMarshalEmptyIP below.
continue
}
var out IP var out IP
if err := out.UnmarshalText([]byte(tt.in)); !reflect.DeepEqual(out, tt.out) || (tt.out == nil) != (err != nil) { if err := out.UnmarshalText([]byte(tt.in)); !reflect.DeepEqual(out, tt.out) || (tt.out == nil) != (err != nil) {
t.Errorf("IP.UnmarshalText(%q) = %v, %v, want %v", tt.in, out, err, tt.out) t.Errorf("IP.UnmarshalText(%q) = %v, %v, want %v", tt.in, out, err, tt.out)
} }
} }
} }
// Issue 6339
func TestMarshalEmptyIP(t *testing.T) {
for _, in := range [][]byte{nil, []byte("")} {
var out = IP{1, 2, 3, 4}
if err := out.UnmarshalText(in); err != nil || out != nil {
t.Errorf("UnmarshalText(%v) = %v, %v; want nil, nil", in, out, err)
}
}
var ip IP
got, err := ip.MarshalText()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, []byte("")) {
t.Errorf(`got %#v, want []byte("")`, got)
}
}
var ipStringTests = []struct { var ipStringTests = []struct {
in IP in IP
out string // see RFC 5952 out string // see RFC 5952
...@@ -53,23 +74,19 @@ var ipStringTests = []struct { ...@@ -53,23 +74,19 @@ var ipStringTests = []struct {
{IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1}, "2001:db8::1:0:0:1"}, {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1}, "2001:db8::1:0:0:1"},
{IP{0x20, 0x1, 0xD, 0xB8, 0, 0, 0, 0, 0, 0xA, 0, 0xB, 0, 0xC, 0, 0xD}, "2001:db8::a:b:c:d"}, {IP{0x20, 0x1, 0xD, 0xB8, 0, 0, 0, 0, 0, 0xA, 0, 0xB, 0, 0xC, 0, 0xD}, "2001:db8::a:b:c:d"},
{IPv4(192, 168, 0, 1), "192.168.0.1"}, {IPv4(192, 168, 0, 1), "192.168.0.1"},
{nil, "<nil>"}, {nil, ""},
} }
func TestIPString(t *testing.T) { func TestIPString(t *testing.T) {
for _, tt := range ipStringTests { for _, tt := range ipStringTests {
if out := tt.in.String(); out != tt.out {
t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.out)
}
if tt.in != nil { if tt.in != nil {
if out, err := tt.in.MarshalText(); string(out) != tt.out || err != nil { if out := tt.in.String(); out != tt.out {
t.Errorf("IP.MarshalText(%v) = %q, %v, want %q, nil", out, err, tt.out) t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.out)
}
} else {
if _, err := tt.in.MarshalText(); err == nil {
t.Errorf("IP.MarshalText(nil) succeeded, want failure")
} }
} }
if out, err := tt.in.MarshalText(); string(out) != tt.out || err != nil {
t.Errorf("IP.MarshalText(%v) = %q, %v, want %q, nil", out, err, tt.out)
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment