Commit 5e03c84a authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: update http2 bundle to rev d62542

Updates to use new client pool abstraction.

Change-Id: I3552018038ee8394d313d3253af337b07be211f6
Reviewed-on: https://go-review.googlesource.com/16730Reviewed-by: default avatarBlake Mizerany <blake.mizerany@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent e884334b
...@@ -34,6 +34,102 @@ import ( ...@@ -34,6 +34,102 @@ import (
"time" "time"
) )
// ClientConnPool manages a pool of HTTP/2 client connections.
type http2ClientConnPool interface {
GetClientConn(req *Request, addr string) (*http2ClientConn, error)
MarkDead(*http2ClientConn)
}
type http2clientConnPool struct {
t *http2Transport
mu sync.Mutex // TODO: switch to RWMutex
// TODO: add support for sharing conns based on cert names
// (e.g. share conn for googleapis.com and appspot.com)
conns map[string][]*http2ClientConn // key is host:port
keys map[*http2ClientConn][]string
}
func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
return p.getClientConn(req, addr, true)
}
func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if cc.CanTakeNewRequest() {
p.mu.Unlock()
return cc, nil
}
}
p.mu.Unlock()
if !dialOnMiss {
return nil, http2ErrNoCachedConn
}
cc, err := p.t.dialClientConn(addr)
if err != nil {
return nil, err
}
p.addConn(addr, cc)
return cc, nil
}
func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) {
p.mu.Lock()
defer p.mu.Unlock()
if p.conns == nil {
p.conns = make(map[string][]*http2ClientConn)
}
if p.keys == nil {
p.keys = make(map[*http2ClientConn][]string)
}
p.conns[key] = append(p.conns[key], cc)
p.keys[cc] = append(p.keys[cc], key)
}
func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) {
p.mu.Lock()
defer p.mu.Unlock()
for _, key := range p.keys[cc] {
vv, ok := p.conns[key]
if !ok {
continue
}
newList := http2filterOutClientConn(vv, cc)
if len(newList) > 0 {
p.conns[key] = newList
} else {
delete(p.conns, key)
}
}
delete(p.keys, cc)
}
func (p *http2clientConnPool) closeIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()
for _, vv := range p.conns {
for _, cc := range vv {
cc.closeIfIdle()
}
}
}
func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn {
out := in[:0]
for _, v := range in {
if v != exclude {
out = append(out, v)
}
}
if len(in) != len(out) {
in[len(in)-1] = nil
}
return out
}
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type http2ErrCode uint32 type http2ErrCode uint32
...@@ -3503,20 +3599,33 @@ type http2Transport struct { ...@@ -3503,20 +3599,33 @@ type http2Transport struct {
// tls.Client. If nil, the default configuration is used. // tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
// TODO: switch to RWMutex // ConnPool optionally specifies an alternate connection pool to use.
// TODO: add support for sharing conns based on cert names // If nil, the default is used.
// (e.g. share conn for googleapis.com and appspot.com) ConnPool http2ClientConnPool
connMu sync.Mutex
conns map[string][]*http2clientConn // key is host:port connPoolOnce sync.Once
connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
}
func (t *http2Transport) connPool() http2ClientConnPool {
t.connPoolOnce.Do(t.initConnPool)
return t.connPoolOrDef
}
func (t *http2Transport) initConnPool() {
if t.ConnPool != nil {
t.connPoolOrDef = t.ConnPool
} else {
t.connPoolOrDef = &http2clientConnPool{t: t}
}
} }
// clientConn is the state of a single HTTP/2 client connection to an // ClientConn is the state of a single HTTP/2 client connection to an
// HTTP/2 server. // HTTP/2 server.
type http2clientConn struct { type http2ClientConn struct {
t *http2Transport t *http2Transport
tconn net.Conn tconn net.Conn // usually *tls.Conn, except specialized impls
tlsState *tls.ConnectionState tlsState *tls.ConnectionState // nil only for specialized impls
connKey []string // key(s) this connection is cached in, in t.conns
// readLoop goroutine fields: // readLoop goroutine fields:
readerDone chan struct{} // closed on error readerDone chan struct{} // closed on error
...@@ -3548,7 +3657,7 @@ type http2clientConn struct { ...@@ -3548,7 +3657,7 @@ type http2clientConn struct {
// clientStream is the state for a single HTTP/2 stream. One of these // clientStream is the state for a single HTTP/2 stream. One of these
// is created for each Transport.RoundTrip call. // is created for each Transport.RoundTrip call.
type http2clientStream struct { type http2clientStream struct {
cc *http2clientConn cc *http2ClientConn
ID uint32 ID uint32
resc chan http2resAndError resc chan http2resAndError
bufPipe http2pipe // buffered pipe with the flow-controlled response payload bufPipe http2pipe // buffered pipe with the flow-controlled response payload
...@@ -3600,24 +3709,28 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { ...@@ -3600,24 +3709,28 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) {
return t.RoundTripOpt(req, http2RoundTripOpt{}) return t.RoundTripOpt(req, http2RoundTripOpt{})
} }
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func http2authorityAddr(authority string) (addr string) {
if _, _, err := net.SplitHostPort(authority); err == nil {
return authority
}
return net.JoinHostPort(authority, "443")
}
// RoundTripOpt is like RoundTrip, but takes options. // RoundTripOpt is like RoundTrip, but takes options.
func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
if req.URL.Scheme != "https" { if req.URL.Scheme != "https" {
return nil, errors.New("http2: unsupported scheme") return nil, errors.New("http2: unsupported scheme")
} }
host, port, err := net.SplitHostPort(req.URL.Host) addr := http2authorityAddr(req.URL.Host)
if err != nil {
host = req.URL.Host
port = "443"
}
for { for {
cc, err := t.getClientConn(host, port, opt.OnlyCachedConn) cc, err := t.connPool().GetClientConn(req, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
res, err := cc.roundTrip(req) res, err := cc.RoundTrip(req)
if http2shouldRetryRequest(err) { if http2shouldRetryRequest(err) {
continue continue
} }
...@@ -3632,12 +3745,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res ...@@ -3632,12 +3745,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
// connected from previous requests but are now sitting idle. // connected from previous requests but are now sitting idle.
// It does not interrupt any connections currently in use. // It does not interrupt any connections currently in use.
func (t *http2Transport) CloseIdleConnections() { func (t *http2Transport) CloseIdleConnections() {
t.connMu.Lock() if cp, ok := t.connPool().(*http2clientConnPool); ok {
defer t.connMu.Unlock() cp.closeIdleConnections()
for _, vv := range t.conns {
for _, cc := range vv {
cc.closeIfIdle()
}
} }
} }
...@@ -3648,95 +3757,16 @@ func http2shouldRetryRequest(err error) bool { ...@@ -3648,95 +3757,16 @@ func http2shouldRetryRequest(err error) bool {
return err == http2errClientConnClosed return err == http2errClientConnClosed
} }
func (t *http2Transport) removeClientConn(cc *http2clientConn) { func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) {
t.connMu.Lock() host, _, err := net.SplitHostPort(addr)
defer t.connMu.Unlock()
for _, key := range cc.connKey {
vv, ok := t.conns[key]
if !ok {
continue
}
newList := http2filterOutClientConn(vv, cc)
if len(newList) > 0 {
t.conns[key] = newList
} else {
delete(t.conns, key)
}
}
}
func http2filterOutClientConn(in []*http2clientConn, exclude *http2clientConn) []*http2clientConn {
out := in[:0]
for _, v := range in {
if v != exclude {
out = append(out, v)
}
}
if len(in) != len(out) {
in[len(in)-1] = nil
}
return out
}
// AddIdleConn adds c as an idle conn for Transport.
// It assumes that c has not yet exchanged SETTINGS frames.
// The addr maybe be either "host" or "host:port".
func (t *http2Transport) AddIdleConn(addr string, c *tls.Conn) error {
var key string
_, _, err := net.SplitHostPort(addr)
if err == nil {
key = addr
} else {
key = addr + ":443"
}
cc, err := t.newClientConn(key, c)
if err != nil {
return err
}
t.addConn(key, cc)
return nil
}
func (t *http2Transport) addConn(key string, cc *http2clientConn) {
t.connMu.Lock()
defer t.connMu.Unlock()
if t.conns == nil {
t.conns = make(map[string][]*http2clientConn)
}
t.conns[key] = append(t.conns[key], cc)
}
func (t *http2Transport) getClientConn(host, port string, onlyCached bool) (*http2clientConn, error) {
key := net.JoinHostPort(host, port)
t.connMu.Lock()
for _, cc := range t.conns[key] {
if cc.canTakeNewRequest() {
t.connMu.Unlock()
return cc, nil
}
}
t.connMu.Unlock()
if onlyCached {
return nil, http2ErrNoCachedConn
}
cc, err := t.dialClientConn(host, port, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.addConn(key, cc) tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
return cc, nil
}
func (t *http2Transport) dialClientConn(host, port, key string) (*http2clientConn, error) {
tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return t.newClientConn(key, tconn) return t.NewClientConn(tconn)
} }
func (t *http2Transport) newTLSConfig(host string) *tls.Config { func (t *http2Transport) newTLSConfig(host string) *tls.Config {
...@@ -3779,15 +3809,14 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) ( ...@@ -3779,15 +3809,14 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (
return cn, nil return cn, nil
} }
func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2clientConn, error) { func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
if _, err := tconn.Write(http2clientPreface); err != nil { if _, err := c.Write(http2clientPreface); err != nil {
return nil, err return nil, err
} }
cc := &http2clientConn{ cc := &http2ClientConn{
t: t, t: t,
tconn: tconn, tconn: c,
connKey: []string{key},
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
nextStreamID: 1, nextStreamID: 1,
maxFrameSize: 16 << 10, maxFrameSize: 16 << 10,
...@@ -3798,15 +3827,15 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client ...@@ -3798,15 +3827,15 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client
cc.cond = sync.NewCond(&cc.mu) cc.cond = sync.NewCond(&cc.mu)
cc.flow.add(int32(http2initialWindowSize)) cc.flow.add(int32(http2initialWindowSize))
cc.bw = bufio.NewWriter(http2stickyErrWriter{tconn, &cc.werr}) cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr})
cc.br = bufio.NewReader(tconn) cc.br = bufio.NewReader(c)
cc.fr = http2NewFramer(cc.bw, cc.br) cc.fr = http2NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf) cc.henc = hpack.NewEncoder(&cc.hbuf)
type connectionStater interface { type connectionStater interface {
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
} }
if cs, ok := tconn.(connectionStater); ok { if cs, ok := c.(connectionStater); ok {
state := cs.ConnectionState() state := cs.ConnectionState()
cc.tlsState = &state cc.tlsState = &state
} }
...@@ -3852,13 +3881,13 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client ...@@ -3852,13 +3881,13 @@ func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2client
return cc, nil return cc, nil
} }
func (cc *http2clientConn) setGoAway(f *http2GoAwayFrame) { func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
cc.goAway = f cc.goAway = f
} }
func (cc *http2clientConn) canTakeNewRequest() bool { func (cc *http2ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
return cc.goAway == nil && return cc.goAway == nil &&
...@@ -3866,7 +3895,7 @@ func (cc *http2clientConn) canTakeNewRequest() bool { ...@@ -3866,7 +3895,7 @@ func (cc *http2clientConn) canTakeNewRequest() bool {
cc.nextStreamID < 2147483647 cc.nextStreamID < 2147483647
} }
func (cc *http2clientConn) closeIfIdle() { func (cc *http2ClientConn) closeIfIdle() {
cc.mu.Lock() cc.mu.Lock()
if len(cc.streams) > 0 { if len(cc.streams) > 0 {
cc.mu.Unlock() cc.mu.Unlock()
...@@ -3885,7 +3914,7 @@ const http2maxAllocFrameSize = 512 << 10 ...@@ -3885,7 +3914,7 @@ const http2maxAllocFrameSize = 512 << 10
// They're capped at the min of the peer's max frame size or 512KB // They're capped at the min of the peer's max frame size or 512KB
// (kinda arbitrarily), but definitely capped so we don't allocate 4GB // (kinda arbitrarily), but definitely capped so we don't allocate 4GB
// bufers. // bufers.
func (cc *http2clientConn) frameScratchBuffer() []byte { func (cc *http2ClientConn) frameScratchBuffer() []byte {
cc.mu.Lock() cc.mu.Lock()
size := cc.maxFrameSize size := cc.maxFrameSize
if size > http2maxAllocFrameSize { if size > http2maxAllocFrameSize {
...@@ -3902,7 +3931,7 @@ func (cc *http2clientConn) frameScratchBuffer() []byte { ...@@ -3902,7 +3931,7 @@ func (cc *http2clientConn) frameScratchBuffer() []byte {
return make([]byte, size) return make([]byte, size)
} }
func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) { func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
...@@ -3919,7 +3948,7 @@ func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) { ...@@ -3919,7 +3948,7 @@ func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) {
} }
func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) { func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
cc.mu.Lock() cc.mu.Lock()
if cc.closed { if cc.closed {
...@@ -4086,7 +4115,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int32) (taken int32, err ...@@ -4086,7 +4115,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int32) (taken int32, err
} }
// requires cc.mu be held. // requires cc.mu be held.
func (cc *http2clientConn) encodeHeaders(req *Request) []byte { func (cc *http2ClientConn) encodeHeaders(req *Request) []byte {
cc.hbuf.Reset() cc.hbuf.Reset()
host := req.Host host := req.Host
...@@ -4111,7 +4140,7 @@ func (cc *http2clientConn) encodeHeaders(req *Request) []byte { ...@@ -4111,7 +4140,7 @@ func (cc *http2clientConn) encodeHeaders(req *Request) []byte {
return cc.hbuf.Bytes() return cc.hbuf.Bytes()
} }
func (cc *http2clientConn) writeHeader(name, value string) { func (cc *http2ClientConn) writeHeader(name, value string) {
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
} }
...@@ -4121,7 +4150,7 @@ type http2resAndError struct { ...@@ -4121,7 +4150,7 @@ type http2resAndError struct {
} }
// requires cc.mu be held. // requires cc.mu be held.
func (cc *http2clientConn) newStream() *http2clientStream { func (cc *http2ClientConn) newStream() *http2clientStream {
cs := &http2clientStream{ cs := &http2clientStream{
cc: cc, cc: cc,
ID: cc.nextStreamID, ID: cc.nextStreamID,
...@@ -4137,7 +4166,7 @@ func (cc *http2clientConn) newStream() *http2clientStream { ...@@ -4137,7 +4166,7 @@ func (cc *http2clientConn) newStream() *http2clientStream {
return cs return cs
} }
func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStream { func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
cs := cc.streams[id] cs := cc.streams[id]
...@@ -4149,7 +4178,7 @@ func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStr ...@@ -4149,7 +4178,7 @@ func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStr
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type http2clientConnReadLoop struct { type http2clientConnReadLoop struct {
cc *http2clientConn cc *http2ClientConn
activeRes map[uint32]*http2clientStream // keyed by streamID activeRes map[uint32]*http2clientStream // keyed by streamID
// continueStreamID is the stream ID we're waiting for // continueStreamID is the stream ID we're waiting for
...@@ -4165,7 +4194,7 @@ type http2clientConnReadLoop struct { ...@@ -4165,7 +4194,7 @@ type http2clientConnReadLoop struct {
} }
// readLoop runs in its own goroutine and reads and dispatches frames. // readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *http2clientConn) readLoop() { func (cc *http2ClientConn) readLoop() {
rl := &http2clientConnReadLoop{ rl := &http2clientConnReadLoop{
cc: cc, cc: cc,
activeRes: make(map[uint32]*http2clientStream), activeRes: make(map[uint32]*http2clientStream),
...@@ -4185,7 +4214,7 @@ func (cc *http2clientConn) readLoop() { ...@@ -4185,7 +4214,7 @@ func (cc *http2clientConn) readLoop() {
func (rl *http2clientConnReadLoop) cleanup() { func (rl *http2clientConnReadLoop) cleanup() {
cc := rl.cc cc := rl.cc
defer cc.tconn.Close() defer cc.tconn.Close()
defer cc.t.removeClientConn(cc) defer cc.t.connPool().MarkDead(cc)
defer close(cc.readerDone) defer close(cc.readerDone)
err := cc.readerErr err := cc.readerErr
...@@ -4397,7 +4426,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { ...@@ -4397,7 +4426,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
cc := rl.cc cc := rl.cc
cc.t.removeClientConn(cc) cc.t.connPool().MarkDead(cc)
if f.ErrCode != 0 { if f.ErrCode != 0 {
cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
...@@ -4472,7 +4501,7 @@ func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) ...@@ -4472,7 +4501,7 @@ func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame)
return http2ConnectionError(http2ErrCodeProtocol) return http2ConnectionError(http2ErrCodeProtocol)
} }
func (cc *http2clientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
cc.wmu.Lock() cc.wmu.Lock()
cc.fr.WriteRSTStream(streamID, code) cc.fr.WriteRSTStream(streamID, code)
...@@ -4511,11 +4540,11 @@ func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) { ...@@ -4511,11 +4540,11 @@ func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
} }
} }
func (cc *http2clientConn) logf(format string, args ...interface{}) { func (cc *http2ClientConn) logf(format string, args ...interface{}) {
cc.t.logf(format, args...) cc.t.logf(format, args...)
} }
func (cc *http2clientConn) vlogf(format string, args ...interface{}) { func (cc *http2ClientConn) vlogf(format string, args ...interface{}) {
cc.t.vlogf(format, args...) cc.t.vlogf(format, args...)
} }
......
...@@ -25,9 +25,27 @@ import ( ...@@ -25,9 +25,27 @@ import (
"time" "time"
) )
// h2DefaultTransport is the HTTP/2 version of DefaultTransport. // HTTP/2 transport, integrated with the DefaultTransport.
// DefaultTransport and h2DefaultTransport are wired up together. var (
var h2DefaultTransport = &http2Transport{} // h2ConnPool is the connection pool for HTTP/2 connections.
h2ConnPool = &http2clientConnPool{}
// h2Transport is the HTTP/2 version of DefaultTransport.
h2Transport = &http2Transport{ConnPool: noDialClientConnPool{h2ConnPool}}
)
func init() {
h2ConnPool.t = h2Transport // avoid decalaration loop
}
// noDialClientConnPool is an implementation of http2.ClientConnPool
// which never dials. We let the HTTP/1.1 client dial and use its TLS
// connection instead.
type noDialClientConnPool struct{ *http2clientConnPool }
func (p noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
const doDial = false
return p.getClientConn(req, addr, doDial)
}
// DefaultTransport is the default implementation of Transport and is // DefaultTransport is the default implementation of Transport and is
// used by DefaultClient. It establishes network connections as needed // used by DefaultClient. It establishes network connections as needed
...@@ -50,24 +68,33 @@ func init() { ...@@ -50,24 +68,33 @@ func init() {
return return
} }
t := DefaultTransport.(*Transport) t := DefaultTransport.(*Transport)
t.RegisterProtocol("https", noDialH2Transport{h2DefaultTransport}) t.RegisterProtocol("https", noDialH2RoundTripper{})
t.TLSClientConfig = &tls.Config{ t.TLSClientConfig = &tls.Config{
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
} }
t.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{ t.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{
"h2": func(authority string, c *tls.Conn) RoundTripper { "h2": func(authority string, c *tls.Conn) RoundTripper {
h2DefaultTransport.AddIdleConn(authority, c) cc, err := h2Transport.NewClientConn(c)
return h2DefaultTransport if err != nil {
c.Close()
return erringRoundTripper{err}
}
h2ConnPool.addConn(http2authorityAddr(authority), cc)
return h2Transport
}, },
} }
} }
// noDialH2Transport is a RoundTripper which only tries to complete the request if type erringRoundTripper struct{ err error }
// the wrapped *http2Transport already has a cached connection to the host.
type noDialH2Transport struct{ rt *http2Transport } func (rt erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
// if there's already has a cached connection to the host.
type noDialH2RoundTripper struct{}
func (t noDialH2Transport) RoundTrip(req *Request) (*Response, error) { func (noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
res, err := t.rt.RoundTripOpt(req, http2RoundTripOpt{OnlyCachedConn: true}) res, err := h2Transport.RoundTrip(req)
if err == http2ErrNoCachedConn { if err == http2ErrNoCachedConn {
return nil, ErrSkipAltProtocol return nil, ErrSkipAltProtocol
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment