Commit 5fa3aeb1 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: check read and write deadlines before doing syscalls

Otherwise a fast sender or receiver can make sockets always
readable or writable, preventing deadline checks from ever
occuring.

Update #4191 (fixes it with other CL, coming separately)
Fixes #4403

R=golang-dev, alex.brainman, dave, mikioh.mikioh
CC=golang-dev
https://golang.org/cl/6851096
parent 314fd624
...@@ -423,6 +423,12 @@ func (fd *netFD) Read(p []byte) (n int, err error) { ...@@ -423,6 +423,12 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
} }
defer fd.decref() defer fd.decref()
for { for {
if fd.rdeadline > 0 {
if time.Now().UnixNano() >= fd.rdeadline {
err = errTimeout
break
}
}
n, err = syscall.Read(int(fd.sysfd), p) n, err = syscall.Read(int(fd.sysfd), p)
if err == syscall.EAGAIN { if err == syscall.EAGAIN {
err = errTimeout err = errTimeout
...@@ -453,6 +459,12 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { ...@@ -453,6 +459,12 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
} }
defer fd.decref() defer fd.decref()
for { for {
if fd.rdeadline > 0 {
if time.Now().UnixNano() >= fd.rdeadline {
err = errTimeout
break
}
}
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
if err == syscall.EAGAIN { if err == syscall.EAGAIN {
err = errTimeout err = errTimeout
...@@ -481,6 +493,12 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S ...@@ -481,6 +493,12 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
} }
defer fd.decref() defer fd.decref()
for { for {
if fd.rdeadline > 0 {
if time.Now().UnixNano() >= fd.rdeadline {
err = errTimeout
break
}
}
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
if err == syscall.EAGAIN { if err == syscall.EAGAIN {
err = errTimeout err = errTimeout
...@@ -512,6 +530,12 @@ func (fd *netFD) Write(p []byte) (int, error) { ...@@ -512,6 +530,12 @@ func (fd *netFD) Write(p []byte) (int, error) {
var err error var err error
nn := 0 nn := 0
for { for {
if fd.wdeadline > 0 {
if time.Now().UnixNano() >= fd.wdeadline {
err = errTimeout
break
}
}
var n int var n int
n, err = syscall.Write(int(fd.sysfd), p[nn:]) n, err = syscall.Write(int(fd.sysfd), p[nn:])
if n > 0 { if n > 0 {
...@@ -551,6 +575,12 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { ...@@ -551,6 +575,12 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
} }
defer fd.decref() defer fd.decref()
for { for {
if fd.wdeadline > 0 {
if time.Now().UnixNano() >= fd.wdeadline {
err = errTimeout
break
}
}
err = syscall.Sendto(fd.sysfd, p, 0, sa) err = syscall.Sendto(fd.sysfd, p, 0, sa)
if err == syscall.EAGAIN { if err == syscall.EAGAIN {
err = errTimeout err = errTimeout
...@@ -578,6 +608,12 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob ...@@ -578,6 +608,12 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
} }
defer fd.decref() defer fd.decref()
for { for {
if fd.wdeadline > 0 {
if time.Now().UnixNano() >= fd.wdeadline {
err = errTimeout
break
}
}
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
if err == syscall.EAGAIN { if err == syscall.EAGAIN {
err = errTimeout err = errTimeout
......
...@@ -6,11 +6,24 @@ package net ...@@ -6,11 +6,24 @@ package net
import ( import (
"fmt" "fmt"
"io"
"io/ioutil"
"runtime" "runtime"
"testing" "testing"
"time" "time"
) )
func isTimeout(err error) bool {
e, ok := err.(Error)
return ok && e.Timeout()
}
type copyRes struct {
n int64
err error
d time.Duration
}
func testTimeout(t *testing.T, net, addr string, readFrom bool) { func testTimeout(t *testing.T, net, addr string, readFrom bool) {
c, err := Dial(net, addr) c, err := Dial(net, addr)
if err != nil { if err != nil {
...@@ -230,3 +243,191 @@ func TestReadWriteDeadline(t *testing.T) { ...@@ -230,3 +243,191 @@ func TestReadWriteDeadline(t *testing.T) {
<-quit <-quit
<-lnquit <-lnquit
} }
type neverEnding byte
func (b neverEnding) Read(p []byte) (n int, err error) {
for i := range p {
p[i] = byte(b)
}
return len(p), nil
}
func TestVariousDeadlines1Proc(t *testing.T) {
testVariousDeadlines(t, 1)
}
func TestVariousDeadlines4Proc(t *testing.T) {
testVariousDeadlines(t, 4)
}
func testVariousDeadlines(t *testing.T, maxProcs int) {
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
ln := newLocalListener(t)
defer ln.Close()
donec := make(chan struct{})
defer close(donec)
testsDone := func() bool {
select {
case <-donec:
return true
}
return false
}
// The server, with no timeouts of its own, sending bytes to clients
// as fast as it can.
servec := make(chan copyRes)
go func() {
for {
c, err := ln.Accept()
if err != nil {
if !testsDone() {
t.Fatalf("Accept: %v", err)
}
return
}
go func() {
t0 := time.Now()
n, err := io.Copy(c, neverEnding('a'))
d := time.Since(t0)
c.Close()
servec <- copyRes{n, err, d}
}()
}
}()
for _, timeout := range []time.Duration{
1 * time.Nanosecond,
2 * time.Nanosecond,
5 * time.Nanosecond,
50 * time.Nanosecond,
100 * time.Nanosecond,
200 * time.Nanosecond,
500 * time.Nanosecond,
750 * time.Nanosecond,
1 * time.Microsecond,
5 * time.Microsecond,
25 * time.Microsecond,
250 * time.Microsecond,
500 * time.Microsecond,
1 * time.Millisecond,
5 * time.Millisecond,
100 * time.Millisecond,
250 * time.Millisecond,
500 * time.Millisecond,
1 * time.Second,
} {
numRuns := 3
if testing.Short() {
numRuns = 1
if timeout > 500*time.Microsecond {
continue
}
}
for run := 0; run < numRuns; run++ {
name := fmt.Sprintf("%v run %d/%d", timeout, run+1, numRuns)
t.Log(name)
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
clientc := make(chan copyRes)
go func() {
t0 := time.Now()
c.SetDeadline(t0.Add(timeout))
n, err := io.Copy(ioutil.Discard, c)
d := time.Since(t0)
c.Close()
clientc <- copyRes{n, err, d}
}()
const tooLong = 2000 * time.Millisecond
select {
case res := <-clientc:
if isTimeout(res.err) {
t.Logf("for %v, good client timeout after %v, reading %d bytes", name, res.d, res.n)
} else {
t.Fatalf("for %v: client Copy = %d, %v (want timeout)", name, res.n, res.err)
}
case <-time.After(tooLong):
t.Fatalf("for %v: timeout (%v) waiting for client to timeout (%v) reading", name, tooLong, timeout)
}
select {
case res := <-servec:
t.Logf("for %v: server in %v wrote %d, %v", name, res.d, res.n, res.err)
case <-time.After(tooLong):
t.Fatalf("for %v, timeout waiting for server to finish writing", name)
}
}
}
}
// TestReadDeadlineDataAvailable tests that read deadlines work, even
// if there's data ready to be read.
func TestReadDeadlineDataAvailable(t *testing.T) {
ln := newLocalListener(t)
defer ln.Close()
servec := make(chan copyRes)
const msg = "data client shouldn't read, even though it it'll be waiting"
go func() {
c, err := ln.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
defer c.Close()
n, err := c.Write([]byte(msg))
servec <- copyRes{n: int64(n), err: err}
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
if res := <-servec; res.err != nil || res.n != int64(len(msg)) {
t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg))
}
c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat.
buf := make([]byte, len(msg)/2)
n, err := c.Read(buf)
if n > 0 || !isTimeout(err) {
t.Fatalf("client read = %d (%q) err=%v; want 0, timeout", n, buf[:n], err)
}
}
// TestWriteDeadlineBufferAvailable tests that write deadlines work, even
// if there's buffer space available to write.
func TestWriteDeadlineBufferAvailable(t *testing.T) {
ln := newLocalListener(t)
defer ln.Close()
servec := make(chan copyRes)
go func() {
c, err := ln.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
defer c.Close()
c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past
n, err := c.Write([]byte{'x'})
servec <- copyRes{n: int64(n), err: err}
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
res := <-servec
if res.n != 0 {
t.Errorf("Write = %d; want 0", res.n)
}
if !isTimeout(res.err) {
t.Errorf("Write error = %v; want timeout", res.err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment