Commit 43b9fcf6 authored by Michael Fraenkel's avatar Michael Fraenkel Committed by Brad Fitzpatrick

net/http: make Transport.MaxConnsPerHost work for HTTP/2

Treat HTTP/2 connections as an ongoing persistent connection. When we
are told there is no cached connections, cleanup the associated
connection and host connection count.

Fixes #27753

Change-Id: I6b7bd915fc7819617cb5d3b35e46e225c75eda29
Reviewed-on: https://go-review.googlesource.com/c/go/+/140357Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent c706d422
......@@ -506,8 +506,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
t.putOrCloseIdleConn(pconn)
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
......@@ -515,7 +515,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
if err == nil {
return resp, nil
}
if !pconn.shouldRetryRequest(req, err) {
if http2isNoCachedConnError(err) {
t.removeIdleConn(pconn)
t.decHostConnCount(cm.key()) // clean up the persistent connection
} else if !pconn.shouldRetryRequest(req, err) {
// Issue 16465: return underlying net.Conn.Read error from peek,
// as we've historically done.
if e, ok := err.(transportReadFromServerError); ok {
......@@ -778,9 +781,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if pconn.isBroken() {
return errConnBroken
}
if pconn.alt != nil {
return errNotCachingH2Conn
}
pconn.markReused()
key := pconn.cacheKey
......@@ -829,7 +829,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if pconn.idleTimer != nil {
pconn.idleTimer.Reset(t.IdleConnTimeout)
} else {
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
// idleTimer does not apply to HTTP/2
if pconn.alt == nil {
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
}
}
}
pconn.idleAt = time.Now()
......@@ -1377,7 +1380,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
return &persistConn{cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
}
}
......
......@@ -588,6 +588,107 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
<-reqComplete
}
func TestTransportMaxConnsPerHost(t *testing.T) {
defer afterTest(t)
if runtime.GOOS == "js" {
t.Skipf("skipping test on js/wasm")
}
h := HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("foo"))
if err != nil {
t.Fatalf("Write: %v", err)
}
})
testMaxConns := func(scheme string, ts *httptest.Server) {
defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
tr.MaxConnsPerHost = 1
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
}
connCh := make(chan net.Conn, 1)
var dialCnt, gotConnCnt, tlsHandshakeCnt int32
tr.Dial = func(network, addr string) (net.Conn, error) {
atomic.AddInt32(&dialCnt, 1)
c, err := net.Dial(network, addr)
connCh <- c
return c, err
}
doReq := func() {
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
if !connInfo.Reused {
atomic.AddInt32(&gotConnCnt, 1)
}
},
TLSHandshakeStart: func() {
atomic.AddInt32(&tlsHandshakeCnt, 1)
},
}
req, _ := NewRequest("GET", ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
resp, err := c.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body failed: %v", err)
}
}
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
doReq()
}()
}
wg.Wait()
expected := int32(tr.MaxConnsPerHost)
if dialCnt != expected {
t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
}
if gotConnCnt != expected {
t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
}
if ts.TLS != nil && tlsHandshakeCnt != expected {
t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
}
(<-connCh).Close()
doReq()
expected++
if dialCnt != expected {
t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
}
if gotConnCnt != expected {
t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
}
if ts.TLS != nil && tlsHandshakeCnt != expected {
t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
}
}
testMaxConns("http", httptest.NewServer(h))
testMaxConns("https", httptest.NewTLSServer(h))
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
ts.StartTLS()
testMaxConns("http2", ts)
}
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment