Commit 022504b3 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix when server deadlines get extended

Deadlines should be extended at the beginning of
a request, not at the beginning of a connection.

Fixes #4676

R=golang-dev, fullung, patrick.allen.higgins, adg
CC=golang-dev
https://golang.org/cl/7220076
parent 49fe632c
......@@ -256,28 +256,20 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
}
func TestServerTimeouts(t *testing.T) {
// TODO(bradfitz): convert this to use httptest.Server
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error: %v", err)
}
addr, _ := l.Addr().(*net.TCPAddr)
reqNum := 0
handler := HandlerFunc(func(res ResponseWriter, req *Request) {
ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
reqNum++
fmt.Fprintf(res, "req=%d", reqNum)
})
server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond}
go server.Serve(l)
url := fmt.Sprintf("http://%s/", addr)
}))
ts.Config.ReadTimeout = 250 * time.Millisecond
ts.Config.WriteTimeout = 250 * time.Millisecond
ts.Start()
defer ts.Close()
// Hit the HTTP server successfully.
tr := &Transport{DisableKeepAlives: true} // they interfere with this test
c := &Client{Transport: tr}
r, err := c.Get(url)
r, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("http Get #1: %v", err)
}
......@@ -290,13 +282,13 @@ func TestServerTimeouts(t *testing.T) {
// Slow client that should timeout.
t1 := time.Now()
conn, err := net.Dial("tcp", addr.String())
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
buf := make([]byte, 1)
n, err := conn.Read(buf)
latency := time.Now().Sub(t1)
latency := time.Since(t1)
if n != 0 || err != io.EOF {
t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
}
......@@ -307,7 +299,7 @@ func TestServerTimeouts(t *testing.T) {
// Hit the HTTP server successfully again, verifying that the
// previous slow connection didn't run our handler. (that we
// get "req=2", not "req=3")
r, err = Get(url)
r, err = Get(ts.URL)
if err != nil {
t.Fatalf("http Get #2: %v", err)
}
......@@ -317,7 +309,21 @@ func TestServerTimeouts(t *testing.T) {
t.Errorf("Get #2 got %q, want %q", string(got), expected)
}
l.Close()
if !testing.Short() {
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
go io.Copy(ioutil.Discard, conn)
for i := 0; i < 5; i++ {
_, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
if err != nil {
t.Fatalf("on write %d: %v", i, err)
}
time.Sleep(ts.Config.ReadTimeout / 2)
}
}
}
// TestIdentityResponse verifies that a handler can unset
......
......@@ -416,6 +416,16 @@ func (c *conn) readRequest() (w *response, err error) {
if c.hijacked() {
return nil, ErrHijacked
}
if d := c.server.ReadTimeout; d != 0 {
c.rwc.SetReadDeadline(time.Now().Add(d))
}
if d := c.server.WriteTimeout; d != 0 {
defer func() {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}()
}
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
var req *Request
if req, err = ReadRequest(c.buf.Reader); err != nil {
......@@ -779,6 +789,12 @@ func (c *conn) serve() {
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
if d := c.server.ReadTimeout; d != 0 {
c.rwc.SetReadDeadline(time.Now().Add(d))
}
if d := c.server.WriteTimeout; d != 0 {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
return
}
......@@ -1274,12 +1290,6 @@ func (srv *Server) Serve(l net.Listener) error {
return e
}
tempDelay = 0
if srv.ReadTimeout != 0 {
rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
}
if srv.WriteTimeout != 0 {
rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
}
c, err := srv.newConn(rw)
if err != nil {
continue
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment