Commit c40a73d8 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make hidden http2 Transport respect remaining Transport fields

Updates x/net/http2 to git rev 72aa00c6 for https://golang.org/cl/18721
(but actually at https://golang.org/cl/18722 now)

Fixes #14008

Change-Id: If05d5ad51ec0ba5ba7e4fe16605c0a83f0484bc8
Reviewed-on: https://go-review.googlesource.com/18723
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 3208d92b
......@@ -47,7 +47,7 @@ const (
h2Mode = true
)
func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest {
func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
cst := &clientServerTest{
t: t,
h2: h2,
......@@ -55,6 +55,16 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest {
tr: &Transport{},
}
cst.c = &Client{Transport: cst.tr}
for _, opt := range opts {
switch opt := opt.(type) {
case func(*Transport):
opt(cst.tr)
default:
t.Fatalf("unhandled option type %T", opt)
}
}
if !h2 {
cst.ts = httptest.NewServer(h)
return cst
......@@ -139,6 +149,7 @@ type h12Compare struct {
Handler func(ResponseWriter, *Request) // required
ReqFunc reqFunc // optional
CheckResponse func(proto string, res *Response) // optional
Opts []interface{}
}
func (tt h12Compare) reqFunc() reqFunc {
......@@ -149,9 +160,9 @@ func (tt h12Compare) reqFunc() reqFunc {
}
func (tt h12Compare) run(t *testing.T) {
cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler))
cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
defer cst1.close()
cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler))
cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
defer cst2.close()
res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
......@@ -380,6 +391,20 @@ func TestH12_AutoGzip(t *testing.T) {
}.run(t)
}
func TestH12_AutoGzip_Disabled(t *testing.T) {
h12Compare{
Opts: []interface{}{
func(tr *Transport) { tr.DisableCompression = true },
},
Handler: func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
}
},
}.run(t)
}
// Test304Responses verifies that 304s don't declare that they're
// chunking in their response headers and aren't allowed to produce
// output.
......
......@@ -24,7 +24,6 @@ import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/net/http2/hpack"
"io"
"io/ioutil"
"log"
......@@ -38,6 +37,8 @@ import (
"strings"
"sync"
"time"
"golang.org/x/net/http2/hpack"
)
// ClientConnPool manages a pool of HTTP/2 client connections.
......@@ -248,7 +249,11 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [
func http2configureTransport(t1 *Transport) (*http2Transport, error) {
connPool := new(http2clientConnPool)
t2 := &http2Transport{ConnPool: http2noDialClientConnPool{connPool}}
t2 := &http2Transport{
ConnPool: http2noDialClientConnPool{connPool},
t1: t1,
}
connPool.t = t2
if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil {
return nil, err
}
......@@ -2184,6 +2189,19 @@ func http2bodyAllowedForStatus(status int) bool {
return true
}
type http2httpError struct {
msg string
timeout bool
}
func (e *http2httpError) Error() string { return e.msg }
func (e *http2httpError) Timeout() bool { return e.timeout }
func (e *http2httpError) Temporary() bool { return true }
var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true}
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
// io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered)
......@@ -4320,6 +4338,11 @@ type http2Transport struct {
// to mean no limit.
MaxHeaderListSize uint32
// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
t1 *Transport
connPoolOnce sync.Once
connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
}
......@@ -4335,11 +4358,7 @@ func (t *http2Transport) maxHeaderListSize() uint32 {
}
func (t *http2Transport) disableCompression() bool {
if t.DisableCompression {
return true
}
return false
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
......@@ -4395,7 +4414,7 @@ type http2ClientConn struct {
henc *hpack.Encoder
freeBuf [][]byte
wmu sync.Mutex // held while writing; acquire AFTER wmu if holding both
wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
werr error // first write error that has occurred
}
......@@ -4413,7 +4432,7 @@ type http2clientStream struct {
inflow http2flow // guarded by cc.mu
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read
stopReqBody bool // stop writing req body; guarded by cc.mu
stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
peerReset chan struct{} // closed on peer reset
resetErr error // populated before peerReset is closed
......@@ -4456,10 +4475,13 @@ func (cs *http2clientStream) checkReset() error {
}
}
func (cs *http2clientStream) abortRequestBodyWrite() {
func (cs *http2clientStream) abortRequestBodyWrite(err error) {
if err == nil {
panic("nil error")
}
cc := cs.cc
cc.mu.Lock()
cs.stopReqBody = true
cs.stopReqBody = err
cc.cond.Broadcast()
cc.mu.Unlock()
}
......@@ -4598,6 +4620,12 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (
return cn, nil
}
// disableKeepAlives reports whether connections should be closed as
// soon as possible after handling the first request.
func (t *http2Transport) disableKeepAlives() bool {
return t.t1 != nil && t.t1.DisableKeepAlives
}
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
if http2VerboseLogs {
t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
......@@ -4692,7 +4720,7 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool {
}
func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
return cc.goAway == nil &&
return cc.goAway == nil && !cc.closed &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
cc.nextStreamID < 2147483647
}
......@@ -4772,6 +4800,14 @@ func http2commaSeparatedTrailers(req *Request) (string, error) {
return "", nil
}
func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
if cc.t.t1 != nil {
return cc.t.t1.ResponseHeaderTimeout
}
return 0
}
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
trailers, err := http2commaSeparatedTrailers(req)
if err != nil {
......@@ -4832,24 +4868,32 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
return nil, werr
}
var respHeaderTimer <-chan time.Time
var bodyCopyErrc chan error // result of body copy
if hasBody {
bodyCopyErrc = make(chan error, 1)
go func() {
bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
}()
} else {
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
}
}
readLoopResCh := cs.resc
requestCanceledCh := http2requestCancel(req)
requestCanceled := false
bodyWritten := false
for {
select {
case re := <-readLoopResCh:
res := re.res
if re.err != nil || res.StatusCode > 299 {
cs.abortRequestBodyWrite()
cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
}
if re.err != nil {
cc.forgetStreamID(cs.ID)
......@@ -4858,32 +4902,35 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
res.Request = req
res.TLS = cc.tlsState
return res, nil
case <-respHeaderTimer:
cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
} else {
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
}
return nil, http2errTimeout
case <-requestCanceledCh:
cc.forgetStreamID(cs.ID)
cs.abortRequestBodyWrite()
if !hasBody {
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
return nil, http2errRequestCanceled
} else {
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
}
requestCanceled = true
requestCanceledCh = nil
readLoopResCh = nil
return nil, http2errRequestCanceled
case <-cs.peerReset:
if requestCanceled {
return nil, http2errRequestCanceled
}
return nil, cs.resetErr
case err := <-bodyCopyErrc:
if requestCanceled {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
return nil, http2errRequestCanceled
}
if err != nil {
return nil, err
}
bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
}
}
}
}
......@@ -4916,9 +4963,14 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []
return cc.werr
}
// errAbortReqBodyWrite is an internal error value.
// It doesn't escape to callers.
var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write")
// internal error values; they don't escape to callers
var (
// abort request body write; don't send cancel
http2errStopReqBodyWrite = errors.New("http2: aborting request body write")
// abort request body write, but send stream reset of cancel.
http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
)
func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
cc := cs.cc
......@@ -4951,7 +5003,13 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
for len(remain) > 0 && err == nil {
var allowed int32
allowed, err = cs.awaitFlowControl(len(remain))
if err != nil {
switch {
case err == http2errStopReqBodyWrite:
return err
case err == http2errStopReqBodyWriteAndCancel:
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
return err
case err != nil:
return err
}
cc.wmu.Lock()
......@@ -5005,8 +5063,8 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er
if cc.closed {
return 0, http2errClientConnClosed
}
if cs.stopReqBody {
return 0, http2errAbortReqBodyWrite
if cs.stopReqBody != nil {
return 0, cs.stopReqBody
}
if err := cs.checkReset(); err != nil {
return 0, err
......@@ -5074,7 +5132,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
cc.writeHeader(lowKey, v)
}
}
if contentLength >= 0 {
if http2shouldSendReqContentLength(req.Method, contentLength) {
cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
......@@ -5086,6 +5144,27 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
return cc.hbuf.Bytes()
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func http2shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
// requires cc.mu be held.
func (cc *http2ClientConn) encodeTrailers(req *Request) []byte {
cc.hbuf.Reset()
......@@ -5204,6 +5283,8 @@ func (rl *http2clientConnReadLoop) cleanup() {
func (rl *http2clientConnReadLoop) run() error {
cc := rl.cc
closeWhenIdle := cc.t.disableKeepAlives()
gotReply := false
for {
f, err := cc.fr.ReadFrame()
if err != nil {
......@@ -5218,18 +5299,25 @@ func (rl *http2clientConnReadLoop) run() error {
if http2VerboseLogs {
cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
}
maybeIdle := false
switch f := f.(type) {
case *http2HeadersFrame:
err = rl.processHeaders(f)
maybeIdle = true
gotReply = true
case *http2ContinuationFrame:
err = rl.processContinuation(f)
maybeIdle = true
case *http2DataFrame:
err = rl.processData(f)
maybeIdle = true
case *http2GoAwayFrame:
err = rl.processGoAway(f)
maybeIdle = true
case *http2RSTStreamFrame:
err = rl.processResetStream(f)
maybeIdle = true
case *http2SettingsFrame:
err = rl.processSettings(f)
case *http2PushPromiseFrame:
......@@ -5244,6 +5332,9 @@ func (rl *http2clientConnReadLoop) run() error {
if err != nil {
return err
}
if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
cc.closeIfIdle()
}
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment