Commit 394842e2 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: add shutdown: TCPConn.CloseWrite and CloseRead

R=golang-dev, rsc, iant
CC=golang-dev
https://golang.org/cl/5136052
parent 260991ad
...@@ -358,6 +358,22 @@ func (fd *netFD) Close() os.Error { ...@@ -358,6 +358,22 @@ func (fd *netFD) Close() os.Error {
return nil return nil
} }
func (fd *netFD) CloseRead() os.Error {
if fd == nil || fd.sysfile == nil {
return os.EINVAL
}
syscall.Shutdown(fd.sysfd, syscall.SHUT_RD)
return nil
}
func (fd *netFD) CloseWrite() os.Error {
if fd == nil || fd.sysfile == nil {
return os.EINVAL
}
syscall.Shutdown(fd.sysfd, syscall.SHUT_WR)
return nil
}
func (fd *netFD) Read(p []byte) (n int, err os.Error) { func (fd *netFD) Read(p []byte) (n int, err os.Error) {
if fd == nil { if fd == nil {
return 0, os.EINVAL return 0, os.EINVAL
......
...@@ -312,6 +312,22 @@ func (fd *netFD) Close() os.Error { ...@@ -312,6 +312,22 @@ func (fd *netFD) Close() os.Error {
return nil return nil
} }
func (fd *netFD) CloseRead() os.Error {
if fd == nil || fd.sysfd == syscall.InvalidHandle {
return os.EINVAL
}
syscall.Shutdown(fd.sysfd, syscall.SHUT_RD)
return nil
}
func (fd *netFD) CloseWrite() os.Error {
if fd == nil || fd.sysfd == syscall.InvalidHandle {
return os.EINVAL
}
syscall.Shutdown(fd.sysfd, syscall.SHUT_WR)
return nil
}
// Read from network. // Read from network.
type readOp struct { type readOp struct {
......
...@@ -6,6 +6,7 @@ package net ...@@ -6,6 +6,7 @@ package net
import ( import (
"flag" "flag"
"os"
"regexp" "regexp"
"testing" "testing"
) )
...@@ -119,3 +120,46 @@ func TestReverseAddress(t *testing.T) { ...@@ -119,3 +120,46 @@ func TestReverseAddress(t *testing.T) {
} }
} }
} }
func TestShutdown(t *testing.T) {
l, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = Listen("tcp6", "[::1]:0"); err != nil {
t.Fatalf("ListenTCP on :0: %v", err)
}
}
go func() {
c, err := l.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
var buf [10]byte
n, err := c.Read(buf[:])
if n != 0 || err != os.EOF {
t.Fatalf("server Read = %d, %v; want 0, os.EOF", n, err)
}
c.Write([]byte("response"))
c.Close()
}()
c, err := Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
err = c.(*TCPConn).CloseWrite()
if err != nil {
t.Fatalf("CloseWrite: %v", err)
}
var buf [10]byte
n, err := c.Read(buf[:])
if err != nil {
t.Fatalf("client Read: %d, %v", n, err)
}
got := string(buf[:n])
if got != "response" {
t.Errorf("read = %q, want \"response\"", got)
}
}
...@@ -100,6 +100,24 @@ func (c *TCPConn) Close() os.Error { ...@@ -100,6 +100,24 @@ func (c *TCPConn) Close() os.Error {
return err return err
} }
// CloseRead shuts down the reading side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseRead() os.Error {
if !c.ok() {
return os.EINVAL
}
return c.fd.CloseRead()
}
// CloseWrite shuts down the writing side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseWrite() os.Error {
if !c.ok() {
return os.EINVAL
}
return c.fd.CloseWrite()
}
// LocalAddr returns the local network address, a *TCPAddr. // LocalAddr returns the local network address, a *TCPAddr.
func (c *TCPConn) LocalAddr() Addr { func (c *TCPConn) LocalAddr() Addr {
if !c.ok() { if !c.ok() {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment