Commit 1a058cd0 authored by Michael Fraenkel's avatar Michael Fraenkel Committed by Brad Fitzpatrick

net/http: only decrement connection count if we removed a connection

The connection count must only be decremented if the persistent
connection was also removed.

Fixes #34941

Change-Id: I5070717d5d9effec78016005fa4910593500c8cf
Reviewed-on: https://go-review.googlesource.com/c/go/+/202087Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent e7ce8627
...@@ -545,8 +545,9 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { ...@@ -545,8 +545,9 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
_, isH2DialError := pconn.alt.(http2erringRoundTripper) _, isH2DialError := pconn.alt.(http2erringRoundTripper)
if http2isNoCachedConnError(err) || isH2DialError { if http2isNoCachedConnError(err) || isH2DialError {
t.removeIdleConn(pconn) if t.removeIdleConn(pconn) {
t.decConnsPerHost(pconn.cacheKey) t.decConnsPerHost(pconn.cacheKey)
}
} }
if !pconn.shouldRetryRequest(req, err) { if !pconn.shouldRetryRequest(req, err) {
// Issue 16465: return underlying net.Conn.Read error from peek, // Issue 16465: return underlying net.Conn.Read error from peek,
...@@ -958,26 +959,28 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { ...@@ -958,26 +959,28 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
} }
// removeIdleConn marks pconn as dead. // removeIdleConn marks pconn as dead.
func (t *Transport) removeIdleConn(pconn *persistConn) { func (t *Transport) removeIdleConn(pconn *persistConn) bool {
t.idleMu.Lock() t.idleMu.Lock()
defer t.idleMu.Unlock() defer t.idleMu.Unlock()
t.removeIdleConnLocked(pconn) return t.removeIdleConnLocked(pconn)
} }
// t.idleMu must be held. // t.idleMu must be held.
func (t *Transport) removeIdleConnLocked(pconn *persistConn) { func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
if pconn.idleTimer != nil { if pconn.idleTimer != nil {
pconn.idleTimer.Stop() pconn.idleTimer.Stop()
} }
t.idleLRU.remove(pconn) t.idleLRU.remove(pconn)
key := pconn.cacheKey key := pconn.cacheKey
pconns := t.idleConn[key] pconns := t.idleConn[key]
var removed bool
switch len(pconns) { switch len(pconns) {
case 0: case 0:
// Nothing // Nothing
case 1: case 1:
if pconns[0] == pconn { if pconns[0] == pconn {
delete(t.idleConn, key) delete(t.idleConn, key)
removed = true
} }
default: default:
for i, v := range pconns { for i, v := range pconns {
...@@ -988,9 +991,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { ...@@ -988,9 +991,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
// conns at the end. // conns at the end.
copy(pconns[i:], pconns[i+1:]) copy(pconns[i:], pconns[i+1:])
t.idleConn[key] = pconns[:len(pconns)-1] t.idleConn[key] = pconns[:len(pconns)-1]
removed = true
break break
} }
} }
return removed
} }
func (t *Transport) setReqCanceler(r *Request, fn func(error)) { func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
......
...@@ -5893,3 +5893,59 @@ func TestDontCacheBrokenHTTP2Conn(t *testing.T) { ...@@ -5893,3 +5893,59 @@ func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
t.Errorf("GotConn calls = %v; want %v", got, want) t.Errorf("GotConn calls = %v; want %v", got, want)
} }
} }
// Issue 34941
// When the client has too many concurrent requests on a single connection,
// http.http2noCachedConnError is reported on multiple requests. There should
// only be one decrement regardless of the number of failures.
func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
defer afterTest(t)
h := HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("foo"))
if err != nil {
t.Fatalf("Write: %v", err)
}
})
ts := httptest.NewUnstartedServer(h)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
tr.MaxConnsPerHost = 1
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
}
errCh := make(chan error, 300)
doReq := func() {
resp, err := c.Get(ts.URL)
if err != nil {
errCh <- fmt.Errorf("request failed: %v", err)
return
}
defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
errCh <- fmt.Errorf("read body failed: %v", err)
}
}
var wg sync.WaitGroup
for i := 0; i < 300; i++ {
wg.Add(1)
go func() {
defer wg.Done()
doReq()
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Errorf("error occurred: %v", err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment