Commit 5e404b36 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Transport.Clone

Fixes #26013

Change-Id: I2c82bd90ea7ce6f7a8e5b6c460d3982dca681a93
Reviewed-on: https://go-review.googlesource.com/c/go/+/174597Reviewed-by: default avatarAndrew Bonventre <andybons@golang.org>
parent 9f513254
......@@ -261,13 +261,14 @@ type Transport struct {
// ReadBufferSize specifies the size of the read buffer used
// when reading from the transport.
//If zero, a default (currently 4KB) is used.
// If zero, a default (currently 4KB) is used.
ReadBufferSize int
// nextProtoOnce guards initialization of TLSNextProto and
// h2transport (via onceSetNextProtoDefaults)
nextProtoOnce sync.Once
h2transport h2Transport // non-nil if http2 wired up
nextProtoOnce sync.Once
h2transport h2Transport // non-nil if http2 wired up
tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired
// ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero
// TLSClientConfig or Dial, DialTLS or DialContext func is provided. By default, use of any those fields conservatively
......@@ -290,6 +291,40 @@ func (t *Transport) readBufferSize() int {
return 4 << 10
}
// Clone returns a deep copy of t's exported fields.
func (t *Transport) Clone() *Transport {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
t2 := &Transport{
Proxy: t.Proxy,
DialContext: t.DialContext,
Dial: t.Dial,
DialTLS: t.DialTLS,
TLSClientConfig: t.TLSClientConfig.Clone(),
TLSHandshakeTimeout: t.TLSHandshakeTimeout,
DisableKeepAlives: t.DisableKeepAlives,
DisableCompression: t.DisableCompression,
MaxIdleConns: t.MaxIdleConns,
MaxIdleConnsPerHost: t.MaxIdleConnsPerHost,
MaxConnsPerHost: t.MaxConnsPerHost,
IdleConnTimeout: t.IdleConnTimeout,
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
ExpectContinueTimeout: t.ExpectContinueTimeout,
ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
ForceAttemptHTTP2: t.ForceAttemptHTTP2,
WriteBufferSize: t.WriteBufferSize,
ReadBufferSize: t.ReadBufferSize,
}
if !t.tlsNextProtoWasNil {
npm := map[string]func(authority string, c *tls.Conn) RoundTripper{}
for k, v := range t.TLSNextProto {
npm[k] = v
}
t2.TLSNextProto = npm
}
return t2
}
// h2Transport is the interface we expect to be able to call from
// net/http against an *http2.Transport that's either bundled into
// h2_bundle.go or supplied by the user via x/net/http2.
......@@ -303,6 +338,7 @@ type h2Transport interface {
// onceSetNextProtoDefaults initializes TLSNextProto.
// It must be called via t.nextProtoOnce.Do.
func (t *Transport) onceSetNextProtoDefaults() {
t.tlsNextProtoWasNil = (t.TLSNextProto == nil)
if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") {
return
}
......
......@@ -20,6 +20,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"go/token"
"internal/nettrace"
"internal/testenv"
"io"
......@@ -5320,3 +5321,53 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) {
})
}
}
func TestTransportClone(t *testing.T) {
tr := &Transport{
Proxy: func(*Request) (*url.URL, error) { panic("") },
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
Dial: func(network, addr string) (net.Conn, error) { panic("") },
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
TLSClientConfig: new(tls.Config),
TLSHandshakeTimeout: time.Second,
DisableKeepAlives: true,
DisableCompression: true,
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
IdleConnTimeout: time.Second,
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
ProxyConnectHeader: Header{},
MaxResponseHeaderBytes: 1,
ForceAttemptHTTP2: true,
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
},
ReadBufferSize: 1,
WriteBufferSize: 1,
}
tr2 := tr.Clone()
rv := reflect.ValueOf(tr2).Elem()
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
sf := rt.Field(i)
if !token.IsExported(sf.Name) {
continue
}
if rv.Field(i).IsZero() {
t.Errorf("cloned field t2.%s is zero", sf.Name)
}
}
if _, ok := tr2.TLSNextProto["foo"]; !ok {
t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
}
// But test that a nil TLSNextProto is kept nil:
tr = new(Transport)
tr2 = tr.Clone()
if tr2.TLSNextProto != nil {
t.Errorf("Transport.TLSNextProto unexpected non-nil")
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment