Commit e38fa916 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix TimeoutHandler data races; hold lock longer

The existing lock needed to be held longer. If a timeout occured
while writing (but after the guarded timeout check), the writes
would clobber a future connection's buffer.

Also remove a harmless warning by making Write also set the
flag that headers were sent (implicitly), so we don't try to
write headers later (a no-op + warning) on timeout after we've
started writing.

Fixes #8414
Fixes #8209

LGTM=ruiu, adg
R=adg, ruiu
CC=golang-codereviews
https://golang.org/cl/123610043
parent 339a24da
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"math/rand"
"net" "net"
. "net/http" . "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -1188,6 +1189,82 @@ func TestTimeoutHandler(t *testing.T) { ...@@ -1188,6 +1189,82 @@ func TestTimeoutHandler(t *testing.T) {
} }
} }
// See issues 8209 and 8414.
func TestTimeoutHandlerRace(t *testing.T) {
defer afterTest(t)
delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
ms, _ := strconv.Atoi(r.URL.Path[1:])
if ms == 0 {
ms = 1
}
for i := 0; i < ms; i++ {
w.Write([]byte("hi"))
time.Sleep(time.Millisecond)
}
})
ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
defer ts.Close()
var wg sync.WaitGroup
gate := make(chan bool, 10)
n := 50
if testing.Short() {
n = 10
gate = make(chan bool, 3)
}
for i := 0; i < n; i++ {
gate <- true
wg.Add(1)
go func() {
defer wg.Done()
defer func() { <-gate }()
res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
if err == nil {
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
}
}()
}
wg.Wait()
}
// See issues 8209 and 8414.
func TestTimeoutHandlerRaceHeader(t *testing.T) {
defer afterTest(t)
delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
w.WriteHeader(204)
})
ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, ""))
defer ts.Close()
var wg sync.WaitGroup
gate := make(chan bool, 50)
n := 500
if testing.Short() {
n = 10
}
for i := 0; i < n; i++ {
gate <- true
wg.Add(1)
go func() {
defer wg.Done()
defer func() { <-gate }()
res, err := Get(ts.URL)
if err != nil {
t.Error(err)
return
}
defer res.Body.Close()
io.Copy(ioutil.Discard, res.Body)
}()
}
wg.Wait()
}
// Verifies we don't path.Clean() on the wrong parts in redirects. // Verifies we don't path.Clean() on the wrong parts in redirects.
func TestRedirectMunging(t *testing.T) { func TestRedirectMunging(t *testing.T) {
req, _ := NewRequest("GET", "http://example.com/", nil) req, _ := NewRequest("GET", "http://example.com/", nil)
......
...@@ -1916,9 +1916,9 @@ func (tw *timeoutWriter) Header() Header { ...@@ -1916,9 +1916,9 @@ func (tw *timeoutWriter) Header() Header {
func (tw *timeoutWriter) Write(p []byte) (int, error) { func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock() tw.mu.Lock()
timedOut := tw.timedOut defer tw.mu.Unlock()
tw.mu.Unlock() tw.wroteHeader = true // implicitly at least
if timedOut { if tw.timedOut {
return 0, ErrHandlerTimeout return 0, ErrHandlerTimeout
} }
return tw.w.Write(p) return tw.w.Write(p)
...@@ -1926,12 +1926,11 @@ func (tw *timeoutWriter) Write(p []byte) (int, error) { ...@@ -1926,12 +1926,11 @@ func (tw *timeoutWriter) Write(p []byte) (int, error) {
func (tw *timeoutWriter) WriteHeader(code int) { func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut || tw.wroteHeader { if tw.timedOut || tw.wroteHeader {
tw.mu.Unlock()
return return
} }
tw.wroteHeader = true tw.wroteHeader = true
tw.mu.Unlock()
tw.w.WriteHeader(code) tw.w.WriteHeader(code)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment