Commit 5694ebf0 authored by Colby Ranger's avatar Colby Ranger Committed by Brad Fitzpatrick

net/http/httputil: Clean up ReverseProxy maxLatencyWriter goroutines.

When FlushInterval is specified on ReverseProxy, the ResponseWriter is
wrapped with a maxLatencyWriter that periodically flushes in a
goroutine. That goroutine was not being cleaned up at the end of the
request. This resulted in a panic when Flush() was being called on a
ResponseWriter that was closed.

The code was updated to always send the done message to the flushLoop()
goroutine after copying the body. Futhermore, the code was refactored to
allow the test to verify the maxLatencyWriter behavior.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/6033043
parent 6742d0a0
......@@ -17,6 +17,10 @@ import (
"time"
)
// beforeCopyResponse is a callback set by tests to intercept the state of the
// output io.Writer before the data is copied to it.
var beforeCopyResponse func(dst io.Writer)
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
......@@ -112,20 +116,32 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
return
}
defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
p.copyResponse(rw, res.Body)
}
if res.Body != nil {
var dst io.Writer = rw
if p.FlushInterval != 0 {
if wf, ok := rw.(writeFlusher); ok {
dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
io.Copy(dst, res.Body)
}
if beforeCopyResponse != nil {
beforeCopyResponse(dst)
}
io.Copy(dst, src)
}
type writeFlusher interface {
......@@ -137,22 +153,14 @@ type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
lk sync.Mutex // protects init of done, as well Write + Flush
lk sync.Mutex // protects Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
if m.done == nil {
m.done = make(chan bool)
go m.flushLoop()
}
n, err = m.dst.Write(p)
if err != nil {
m.done <- true
}
return
return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
......@@ -160,13 +168,15 @@ func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
case <-m.done:
return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
case <-m.done:
return
}
}
panic("unreached")
}
func (m *maxLatencyWriter) stop() { m.done <- true }
......@@ -7,11 +7,14 @@
package httputil
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"testing"
"time"
)
func TestReverseProxy(t *testing.T) {
......@@ -107,3 +110,58 @@ func TestReverseProxyQuery(t *testing.T) {
frontend.Close()
}
}
func TestReverseProxyFlushInterval(t *testing.T) {
if testing.Short() {
return
}
const expected = "hi"
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(expected))
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.FlushInterval = time.Microsecond
dstChan := make(chan io.Writer, 1)
beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
defer func() { beforeCopyResponse = nil }()
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
initGoroutines := runtime.NumGoroutine()
for i := 0; i < 100; i++ {
req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Get: %v", err)
}
if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
t.Errorf("got body %q; expected %q", bodyBytes, expected)
}
select {
case dst := <-dstChan:
if _, ok := dst.(*maxLatencyWriter); !ok {
t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
}
default:
t.Error("maxLatencyWriter Write() was never called")
}
res.Body.Close()
}
// Allow up to 50 additional goroutines over 100 requests.
if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
t.Errorf("grew %d goroutines; leak?", delta)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment