Commit 51057bda authored by Albert Strasheim's avatar Albert Strasheim Committed by Russ Cox

net: fix "unexpected socket family" error from WriteToUDP.

R=rsc, iant, mikioh.mikioh
CC=golang-dev
https://golang.org/cl/5128048
parent 8219cc9a
...@@ -22,6 +22,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) { ...@@ -22,6 +22,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
return nil, os.NewSyscallError("getsockopt", errno) return nil, os.NewSyscallError("getsockopt", errno)
} }
family := syscall.AF_UNSPEC
toAddr := sockaddrToTCP toAddr := sockaddrToTCP
sa, _ := syscall.Getsockname(fd) sa, _ := syscall.Getsockname(fd)
switch sa.(type) { switch sa.(type) {
...@@ -29,18 +30,21 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) { ...@@ -29,18 +30,21 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
closesocket(fd) closesocket(fd)
return nil, os.EINVAL return nil, os.EINVAL
case *syscall.SockaddrInet4: case *syscall.SockaddrInet4:
family = syscall.AF_INET
if proto == syscall.SOCK_DGRAM { if proto == syscall.SOCK_DGRAM {
toAddr = sockaddrToUDP toAddr = sockaddrToUDP
} else if proto == syscall.SOCK_RAW { } else if proto == syscall.SOCK_RAW {
toAddr = sockaddrToIP toAddr = sockaddrToIP
} }
case *syscall.SockaddrInet6: case *syscall.SockaddrInet6:
family = syscall.AF_INET6
if proto == syscall.SOCK_DGRAM { if proto == syscall.SOCK_DGRAM {
toAddr = sockaddrToUDP toAddr = sockaddrToUDP
} else if proto == syscall.SOCK_RAW { } else if proto == syscall.SOCK_RAW {
toAddr = sockaddrToIP toAddr = sockaddrToIP
} }
case *syscall.SockaddrUnix: case *syscall.SockaddrUnix:
family = syscall.AF_UNIX
toAddr = sockaddrToUnix toAddr = sockaddrToUnix
if proto == syscall.SOCK_DGRAM { if proto == syscall.SOCK_DGRAM {
toAddr = sockaddrToUnixgram toAddr = sockaddrToUnixgram
...@@ -52,7 +56,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) { ...@@ -52,7 +56,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) {
sa, _ = syscall.Getpeername(fd) sa, _ = syscall.Getpeername(fd)
raddr := toAddr(sa) raddr := toAddr(sa)
if nfd, err = newFD(fd, 0, proto, laddr.Network()); err != nil { if nfd, err = newFD(fd, family, proto, laddr.Network()); err != nil {
return nil, err return nil, err
} }
nfd.setAddr(laddr, raddr) nfd.setAddr(laddr, raddr)
......
...@@ -73,7 +73,7 @@ func TestFileListener(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestFileListener(t *testing.T) {
} }
} }
func testFilePacketConn(t *testing.T, pcf packetConnFile) { func testFilePacketConn(t *testing.T, pcf packetConnFile, listen bool) {
f, err := pcf.File() f, err := pcf.File()
if err != nil { if err != nil {
t.Fatalf("File failed: %v", err) t.Fatalf("File failed: %v", err)
...@@ -85,6 +85,11 @@ func testFilePacketConn(t *testing.T, pcf packetConnFile) { ...@@ -85,6 +85,11 @@ func testFilePacketConn(t *testing.T, pcf packetConnFile) {
if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) { if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) {
t.Fatalf("LocalAddrs not equal: %#v != %#v", pcf.LocalAddr(), c.LocalAddr()) t.Fatalf("LocalAddrs not equal: %#v != %#v", pcf.LocalAddr(), c.LocalAddr())
} }
if listen {
if _, err := c.WriteTo([]byte{}, c.LocalAddr()); err != nil {
t.Fatalf("WriteTo failed: %v", err)
}
}
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err) t.Fatalf("Close failed: %v", err)
} }
...@@ -98,7 +103,7 @@ func testFilePacketConnListen(t *testing.T, net, laddr string) { ...@@ -98,7 +103,7 @@ func testFilePacketConnListen(t *testing.T, net, laddr string) {
if err != nil { if err != nil {
t.Fatalf("Listen failed: %v", err) t.Fatalf("Listen failed: %v", err)
} }
testFilePacketConn(t, l.(packetConnFile)) testFilePacketConn(t, l.(packetConnFile), true)
if err := l.Close(); err != nil { if err := l.Close(); err != nil {
t.Fatalf("Close failed: %v", err) t.Fatalf("Close failed: %v", err)
} }
...@@ -109,7 +114,7 @@ func testFilePacketConnDial(t *testing.T, net, raddr string) { ...@@ -109,7 +114,7 @@ func testFilePacketConnDial(t *testing.T, net, raddr string) {
if err != nil { if err != nil {
t.Fatalf("Dial failed: %v", err) t.Fatalf("Dial failed: %v", err)
} }
testFilePacketConn(t, c.(packetConnFile)) testFilePacketConn(t, c.(packetConnFile), false)
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
t.Fatalf("Close failed: %v", err) t.Fatalf("Close failed: %v", err)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment