Commit 194bbe84 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: update bundled http2

Adds tests for #122590 and updates x/net/http2 to git rev 6a8eb5e2b1 for:

     http2: call httptrace.ClientTrace.GetConn in Transport when needed
     https://golang.org/cl/122590

     http2: fire httptrace.ClientTrace.WroteHeaderField if set
     https://golang.org/cl/122816

     http2: compare Connection header value case-insensitively
     https://golang.org/cl/122588

This also includes the code for graceful shutdown, but it has no
public API surface via net/http, and should not affect any existing
code paths, as it's purely new stuff:

     http2: implement client initiated graceful shutdown
     https://golang.org/cl/30076

Fixes #19761
Fixes #23041

Change-Id: I5558a84591014554cad15ee3852a349ed717530f
Reviewed-on: https://go-review.googlesource.com/122591
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
parent fb59bcce
...@@ -725,9 +725,31 @@ const ( ...@@ -725,9 +725,31 @@ const (
http2noDialOnMiss = false http2noDialOnMiss = false
) )
// shouldTraceGetConn reports whether getClientConn should call any
// ClientTrace.GetConn hook associated with the http.Request.
//
// This complexity is needed to avoid double calls of the GetConn hook
// during the back-and-forth between net/http and x/net/http2 (when the
// net/http.Transport is upgraded to also speak http2), as well as support
// the case where x/net/http2 is being used directly.
func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool {
// If our Transport wasn't made via ConfigureTransport, always
// trace the GetConn hook if provided, because that means the
// http2 package is being used directly and it's the one
// dialing, as opposed to net/http.
if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok {
return true
}
// Otherwise, only use the GetConn hook if this connection has
// been used previously for other requests. For fresh
// connections, the net/http package does the dialing.
return !st.freshConn
}
func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
if http2isConnectionCloseRequest(req) && dialOnMiss { if http2isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection. // It gets its own connection.
http2traceGetConn(req, addr)
const singleUse = true const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse) cc, err := p.t.dialClientConn(addr, singleUse)
if err != nil { if err != nil {
...@@ -737,7 +759,10 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis ...@@ -737,7 +759,10 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis
} }
p.mu.Lock() p.mu.Lock()
for _, cc := range p.conns[addr] { for _, cc := range p.conns[addr] {
if cc.CanTakeNewRequest() { if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
http2traceGetConn(req, addr)
}
p.mu.Unlock() p.mu.Unlock()
return cc, nil return cc, nil
} }
...@@ -746,6 +771,7 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis ...@@ -746,6 +771,7 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis
p.mu.Unlock() p.mu.Unlock()
return nil, http2ErrNoCachedConn return nil, http2ErrNoCachedConn
} }
http2traceGetConn(req, addr)
call := p.getStartDialLocked(addr) call := p.getStartDialLocked(addr)
p.mu.Unlock() p.mu.Unlock()
<-call.done <-call.done
...@@ -2861,6 +2887,16 @@ func http2summarizeFrame(f http2Frame) string { ...@@ -2861,6 +2887,16 @@ func http2summarizeFrame(f http2Frame) string {
return buf.String() return buf.String()
} }
func http2traceHasWroteHeaderField(trace *http2clientTrace) bool {
return trace != nil && trace.WroteHeaderField != nil
}
func http2traceWroteHeaderField(trace *http2clientTrace, k, v string) {
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(k, []string{v})
}
}
func http2transportExpectContinueTimeout(t1 *Transport) time.Duration { func http2transportExpectContinueTimeout(t1 *Transport) time.Duration {
return t1.ExpectContinueTimeout return t1.ExpectContinueTimeout
} }
...@@ -2869,6 +2905,8 @@ type http2contextContext interface { ...@@ -2869,6 +2905,8 @@ type http2contextContext interface {
context.Context context.Context
} }
var http2errCanceled = context.Canceled
func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) { func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) {
ctx, cancel = context.WithCancel(context.Background()) ctx, cancel = context.WithCancel(context.Background())
ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
...@@ -2899,6 +2937,14 @@ func (t *http2Transport) idleConnTimeout() time.Duration { ...@@ -2899,6 +2937,14 @@ func (t *http2Transport) idleConnTimeout() time.Duration {
func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } func http2setResponseUncompressed(res *Response) { res.Uncompressed = true }
func http2traceGetConn(req *Request, hostPort string) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GetConn == nil {
return
}
trace.GetConn(hostPort)
}
func http2traceGotConn(req *Request, cc *http2ClientConn) { func http2traceGotConn(req *Request, cc *http2ClientConn) {
trace := httptrace.ContextClientTrace(req.Context()) trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GotConn == nil { if trace == nil || trace.GotConn == nil {
...@@ -2956,6 +3002,11 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { ...@@ -2956,6 +3002,11 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error {
return cc.ping(ctx) return cc.ping(ctx)
} }
// Shutdown gracefully closes the client connection, waiting for running streams to complete.
func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
return cc.shutdown(ctx)
}
func http2cloneTLSConfig(c *tls.Config) *tls.Config { func http2cloneTLSConfig(c *tls.Config) *tls.Config {
c2 := c.Clone() c2 := c.Clone()
c2.GetClientCertificate = c.GetClientCertificate // golang.org/issue/19264 c2.GetClientCertificate = c.GetClientCertificate // golang.org/issue/19264
...@@ -6698,6 +6749,7 @@ type http2ClientConn struct { ...@@ -6698,6 +6749,7 @@ type http2ClientConn struct {
cond *sync.Cond // hold mu; broadcast on flow/closed changes cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow http2flow // our conn-level flow control quota (cs.flow is per stream) flow http2flow // our conn-level flow control quota (cs.flow is per stream)
inflow http2flow // peer's conn-level flow control inflow http2flow // peer's conn-level flow control
closing bool
closed bool closed bool
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
...@@ -7170,12 +7222,32 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool { ...@@ -7170,12 +7222,32 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool {
return cc.canTakeNewRequestLocked() return cc.canTakeNewRequestLocked()
} }
func (cc *http2ClientConn) canTakeNewRequestLocked() bool { // clientConnIdleState describes the suitability of a client
// connection to initiate a new RoundTrip request.
type http2clientConnIdleState struct {
canTakeNewRequest bool
freshConn bool // whether it's unused by any previous request
}
func (cc *http2ClientConn) idleState() http2clientConnIdleState {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.idleStateLocked()
}
func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) {
if cc.singleUse && cc.nextStreamID > 1 { if cc.singleUse && cc.nextStreamID > 1 {
return false return
} }
return cc.goAway == nil && !cc.closed && st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing &&
int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32 int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
return
}
func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
st := cc.idleStateLocked()
return st.canTakeNewRequest
} }
// onIdleTimeout is called from a time.AfterFunc goroutine. It will // onIdleTimeout is called from a time.AfterFunc goroutine. It will
...@@ -7205,6 +7277,88 @@ func (cc *http2ClientConn) closeIfIdle() { ...@@ -7205,6 +7277,88 @@ func (cc *http2ClientConn) closeIfIdle() {
cc.tconn.Close() cc.tconn.Close()
} }
var http2shutdownEnterWaitStateHook = func() {}
// Shutdown gracefully close the client connection, waiting for running streams to complete.
// Public implementation is in go17.go and not_go17.go
func (cc *http2ClientConn) shutdown(ctx http2contextContext) error {
if err := cc.sendGoAway(); err != nil {
return err
}
// Wait for all in-flight streams to complete or connection to close
done := make(chan error, 1)
cancelled := false // guarded by cc.mu
go func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if len(cc.streams) == 0 || cc.closed {
cc.closed = true
done <- cc.tconn.Close()
break
}
if cancelled {
break
}
cc.cond.Wait()
}
}()
http2shutdownEnterWaitStateHook()
select {
case err := <-done:
return err
case <-ctx.Done():
cc.mu.Lock()
// Free the goroutine above
cancelled = true
cc.cond.Broadcast()
cc.mu.Unlock()
return ctx.Err()
}
}
func (cc *http2ClientConn) sendGoAway() error {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.wmu.Lock()
defer cc.wmu.Unlock()
if cc.closing {
// GOAWAY sent already
return nil
}
// Send a graceful shutdown frame to server
maxStreamID := cc.nextStreamID
if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil {
return err
}
if err := cc.bw.Flush(); err != nil {
return err
}
// Prevent new requests
cc.closing = true
return nil
}
// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *http2ClientConn) Close() error {
cc.mu.Lock()
defer cc.cond.Broadcast()
defer cc.mu.Unlock()
err := errors.New("http2: client connection force closed via ClientConn.Close")
for id, cs := range cc.streams {
select {
case cs.resc <- http2resAndError{err: err}:
default:
}
cs.bufPipe.CloseWithError(err)
delete(cc.streams, id)
}
cc.closed = true
return cc.tconn.Close()
}
const http2maxAllocFrameSize = 512 << 10 const http2maxAllocFrameSize = 512 << 10
// frameBuffer returns a scratch buffer suitable for writing DATA frames. // frameBuffer returns a scratch buffer suitable for writing DATA frames.
...@@ -7287,7 +7441,7 @@ func http2checkConnHeaders(req *Request) error { ...@@ -7287,7 +7441,7 @@ func http2checkConnHeaders(req *Request) error {
if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
} }
if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") { if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !strings.EqualFold(vv[0], "close") && !strings.EqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("http2: invalid Connection request header: %q", vv) return fmt.Errorf("http2: invalid Connection request header: %q", vv)
} }
return nil return nil
...@@ -7831,9 +7985,16 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail ...@@ -7831,9 +7985,16 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
return nil, http2errRequestHeaderListSize return nil, http2errRequestHeaderListSize
} }
trace := http2requestTrace(req)
traceHeaders := http2traceHasWroteHeaderField(trace)
// Header list size is ok. Write the headers. // Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) { enumerateHeaders(func(name, value string) {
cc.writeHeader(strings.ToLower(name), value) name = strings.ToLower(name)
cc.writeHeader(name, value)
if traceHeaders {
http2traceWroteHeaderField(trace, name, value)
}
}) })
return cc.hbuf.Bytes(), nil return cc.hbuf.Bytes(), nil
......
...@@ -3782,8 +3782,12 @@ func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace( ...@@ -3782,8 +3782,12 @@ func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(
func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
defer afterTest(t) defer afterTest(t)
const resBody = "some body" const resBody = "some body"
gotWroteReqEvent := make(chan struct{}) gotWroteReqEvent := make(chan struct{}, 500)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method == "GET" {
// Do nothing for the second request.
return
}
if _, err := ioutil.ReadAll(r.Body); err != nil { if _, err := ioutil.ReadAll(r.Body); err != nil {
t.Error(err) t.Error(err)
} }
...@@ -3851,7 +3855,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3851,7 +3855,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
Got100Continue: func() { logf("Got100Continue") }, Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) { WroteRequest: func(e httptrace.WroteRequestInfo) {
logf("WroteRequest: %+v", e) logf("WroteRequest: %+v", e)
close(gotWroteReqEvent) gotWroteReqEvent <- struct{}{}
}, },
} }
if h2 { if h2 {
...@@ -3934,6 +3938,28 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3934,6 +3938,28 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
if t.Failed() { if t.Failed() {
t.Errorf("Output:\n%s", got) t.Errorf("Output:\n%s", got)
} }
// And do a second request:
req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
res, err = cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 200 {
t.Fatal(res.Status)
}
res.Body.Close()
mu.Lock()
got = buf.String()
mu.Unlock()
sub := "Getting conn for dns-is-faked.golang:"
if gotn, want := strings.Count(got, sub), 2; gotn != want {
t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
}
} }
func TestTransportEventTraceTLSVerify(t *testing.T) { func TestTransportEventTraceTLSVerify(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment