Commit 94bf9a8d authored by Bryan C. Mills's avatar Bryan C. Mills

net/http: fix wantConnQueue memory leaks in Transport

I'm trying to keep the code changes minimal for backporting to Go 1.13,
so it is still possible for a handful of entries to leak,
but the leaks are now O(1) instead of O(N) in the steady state.

Longer-term, I think it would be a good idea to coalesce idleMu with
connsPerHostMu and clear entries out of both queues as soon as their
goroutines are done waiting.

Fixes #33849
Fixes #33850

Change-Id: Ia66bc64671eb1014369f2d3a01debfc023b44281
Reviewed-on: https://go-review.googlesource.com/c/go/+/191964
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent c5f142fa
...@@ -953,6 +953,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { ...@@ -953,6 +953,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
t.idleConnWait = make(map[connectMethodKey]wantConnQueue) t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
} }
q := t.idleConnWait[w.key] q := t.idleConnWait[w.key]
q.cleanFront()
q.pushBack(w) q.pushBack(w)
t.idleConnWait[w.key] = q t.idleConnWait[w.key] = q
return false return false
...@@ -1137,7 +1138,7 @@ func (q *wantConnQueue) pushBack(w *wantConn) { ...@@ -1137,7 +1138,7 @@ func (q *wantConnQueue) pushBack(w *wantConn) {
q.tail = append(q.tail, w) q.tail = append(q.tail, w)
} }
// popFront removes and returns the w at the front of the queue. // popFront removes and returns the wantConn at the front of the queue.
func (q *wantConnQueue) popFront() *wantConn { func (q *wantConnQueue) popFront() *wantConn {
if q.headPos >= len(q.head) { if q.headPos >= len(q.head) {
if len(q.tail) == 0 { if len(q.tail) == 0 {
...@@ -1152,6 +1153,30 @@ func (q *wantConnQueue) popFront() *wantConn { ...@@ -1152,6 +1153,30 @@ func (q *wantConnQueue) popFront() *wantConn {
return w return w
} }
// peekFront returns the wantConn at the front of the queue without removing it.
func (q *wantConnQueue) peekFront() *wantConn {
if q.headPos < len(q.head) {
return q.head[q.headPos]
}
if len(q.tail) > 0 {
return q.tail[0]
}
return nil
}
// cleanFront pops any wantConns that are no longer waiting from the head of the
// queue, reporting whether any were popped.
func (q *wantConnQueue) cleanFront() (cleaned bool) {
for {
w := q.peekFront()
if w == nil || w.waiting() {
return cleaned
}
q.popFront()
cleaned = true
}
}
// getConn dials and creates a new persistConn to the target as // getConn dials and creates a new persistConn to the target as
// specified in the connectMethod. This includes doing a proxy CONNECT // specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn // and/or setting up TLS. If this doesn't return an error, the persistConn
...@@ -1261,6 +1286,7 @@ func (t *Transport) queueForDial(w *wantConn) { ...@@ -1261,6 +1286,7 @@ func (t *Transport) queueForDial(w *wantConn) {
t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
} }
q := t.connsPerHostWait[w.key] q := t.connsPerHostWait[w.key]
q.cleanFront()
q.pushBack(w) q.pushBack(w)
t.connsPerHostWait[w.key] = q t.connsPerHostWait[w.key] = q
} }
......
...@@ -1658,6 +1658,176 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { ...@@ -1658,6 +1658,176 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) {
} }
} }
// A countedConn is a net.Conn that decrements an atomic counter when finalized.
type countedConn struct {
net.Conn
}
// A countingDialer dials connections and counts the number that remain reachable.
type countingDialer struct {
dialer net.Dialer
mu sync.Mutex
total, live int64
}
func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
counted := new(countedConn)
counted.Conn = conn
d.mu.Lock()
defer d.mu.Unlock()
d.total++
d.live++
runtime.SetFinalizer(counted, d.decrement)
return counted, nil
}
func (d *countingDialer) decrement(*countedConn) {
d.mu.Lock()
defer d.mu.Unlock()
d.live--
}
func (d *countingDialer) Read() (total, live int64) {
d.mu.Lock()
defer d.mu.Unlock()
return d.total, d.live
}
func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Close every connection so that it cannot be kept alive.
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Errorf("Hijack failed unexpectedly: %v", err)
return
}
conn.Close()
}))
defer ts.Close()
var d countingDialer
c := ts.Client()
c.Transport.(*Transport).DialContext = d.DialContext
body := []byte("Hello")
for i := 0; ; i++ {
total, live := d.Read()
if live < total {
break
}
if i >= 1<<12 {
t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
}
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
_, err = c.Do(req)
if err == nil {
t.Fatal("expected broken connection")
}
runtime.GC()
}
}
type countedContext struct {
context.Context
}
type contextCounter struct {
mu sync.Mutex
live int64
}
func (cc *contextCounter) Track(ctx context.Context) context.Context {
counted := new(countedContext)
counted.Context = ctx
cc.mu.Lock()
defer cc.mu.Unlock()
cc.live++
runtime.SetFinalizer(counted, cc.decrement)
return counted
}
func (cc *contextCounter) decrement(*countedContext) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.live--
}
func (cc *contextCounter) Read() (live int64) {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.live
}
func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
runtime.Gosched()
w.WriteHeader(StatusOK)
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).MaxConnsPerHost = 1
ctx := context.Background()
body := []byte("Hello")
doPosts := func(cc *contextCounter) {
var wg sync.WaitGroup
for n := 64; n > 0; n-- {
wg.Add(1)
go func() {
defer wg.Done()
ctx := cc.Track(ctx)
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
if err != nil {
t.Error(err)
}
_, err = c.Do(req.WithContext(ctx))
if err != nil {
t.Errorf("Do failed with error: %v", err)
}
}()
}
wg.Wait()
}
var initialCC contextCounter
doPosts(&initialCC)
// flushCC exists only to put pressure on the GC to finalize the initialCC
// contexts: the flushCC allocations should eventually displace the initialCC
// allocations.
var flushCC contextCounter
for i := 0; ; i++ {
live := initialCC.Read()
if live == 0 {
break
}
if i >= 100 {
t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
}
doPosts(&flushCC)
runtime.GC()
}
}
// This used to crash; https://golang.org/issue/3266 // This used to crash; https://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) { func TestTransportIdleConnCrash(t *testing.T) {
defer afterTest(t) defer afterTest(t)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment