Commit eb55a0c8 authored by Gabriel Rosenhouse's avatar Gabriel Rosenhouse Committed by Brad Fitzpatrick

net/http: add DialTLSContext hook to Transport

Fixes #21526

Change-Id: I2f8215cd671641cddfa8499f8a8c0130db93dbc6
Reviewed-on: https://go-review.googlesource.com/c/go/+/61291Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
parent c9a4b01f
...@@ -142,15 +142,24 @@ type Transport struct { ...@@ -142,15 +142,24 @@ type Transport struct {
// If both are set, DialContext takes priority. // If both are set, DialContext takes priority.
Dial func(network, addr string) (net.Conn, error) Dial func(network, addr string) (net.Conn, error)
// DialTLS specifies an optional dial function for creating // DialTLSContext specifies an optional dial function for creating
// TLS connections for non-proxied HTTPS requests. // TLS connections for non-proxied HTTPS requests.
// //
// If DialTLS is nil, Dial and TLSClientConfig are used. // If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
// DialContext and TLSClientConfig are used.
// //
// If DialTLS is set, the Dial hook is not used for HTTPS // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
// requests and the TLSClientConfig and TLSHandshakeTimeout // requests and the TLSClientConfig and TLSHandshakeTimeout
// are ignored. The returned net.Conn is assumed to already be // are ignored. The returned net.Conn is assumed to already be
// past the TLS handshake. // past the TLS handshake.
DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
// TLS connections for non-proxied HTTPS requests.
//
// Deprecated: Use DialTLSContext instead, which allows the transport
// to cancel dials as soon as they are no longer needed.
// If both are set, DialTLSContext takes priority.
DialTLS func(network, addr string) (net.Conn, error) DialTLS func(network, addr string) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with // TLSClientConfig specifies the TLS configuration to use with
...@@ -286,6 +295,7 @@ func (t *Transport) Clone() *Transport { ...@@ -286,6 +295,7 @@ func (t *Transport) Clone() *Transport {
DialContext: t.DialContext, DialContext: t.DialContext,
Dial: t.Dial, Dial: t.Dial,
DialTLS: t.DialTLS, DialTLS: t.DialTLS,
DialTLSContext: t.DialTLSContext,
TLSHandshakeTimeout: t.TLSHandshakeTimeout, TLSHandshakeTimeout: t.TLSHandshakeTimeout,
DisableKeepAlives: t.DisableKeepAlives, DisableKeepAlives: t.DisableKeepAlives,
DisableCompression: t.DisableCompression, DisableCompression: t.DisableCompression,
...@@ -324,6 +334,10 @@ type h2Transport interface { ...@@ -324,6 +334,10 @@ type h2Transport interface {
CloseIdleConnections() CloseIdleConnections()
} }
func (t *Transport) hasCustomTLSDialer() bool {
return t.DialTLS != nil || t.DialTLSContext != nil
}
// onceSetNextProtoDefaults initializes TLSNextProto. // onceSetNextProtoDefaults initializes TLSNextProto.
// It must be called via t.nextProtoOnce.Do. // It must be called via t.nextProtoOnce.Do.
func (t *Transport) onceSetNextProtoDefaults() { func (t *Transport) onceSetNextProtoDefaults() {
...@@ -352,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() { ...@@ -352,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() {
// Transport. // Transport.
return return
} }
if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) { if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
// Be conservative and don't automatically enable // Be conservative and don't automatically enable
// http2 if they've specified a custom TLS config or // http2 if they've specified a custom TLS config or
// custom dialers. Let them opt-in themselves via // custom dialers. Let them opt-in themselves via
...@@ -1185,6 +1199,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { ...@@ -1185,6 +1199,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
} }
} }
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if t.DialTLSContext != nil {
conn, err = t.DialTLSContext(ctx, network, addr)
} else {
conn, err = t.DialTLS(network, addr)
}
if conn == nil && err == nil {
err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)")
}
return
}
// getConn dials and creates a new persistConn to the target as // getConn dials and creates a new persistConn to the target as
// specified in the connectMethod. This includes doing a proxy CONNECT // specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn // and/or setting up TLS. If this doesn't return an error, the persistConn
...@@ -1435,15 +1461,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers ...@@ -1435,15 +1461,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
} }
return err return err
} }
if cm.scheme() == "https" && t.DialTLS != nil { if cm.scheme() == "https" && t.hasCustomTLSDialer() {
var err error var err error
pconn.conn, err = t.DialTLS("tcp", cm.addr()) pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
if err != nil { if err != nil {
return nil, wrapErr(err) return nil, wrapErr(err)
} }
if pconn.conn == nil {
return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
}
if tc, ok := pconn.conn.(*tls.Conn); ok { if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below // Handshake here, in case DialTLS didn't. TLSNextProto below
// depends on it for knowing the connection state. // depends on it for knowing the connection state.
......
...@@ -3506,6 +3506,90 @@ func TestTransportDialTLS(t *testing.T) { ...@@ -3506,6 +3506,90 @@ func TestTransportDialTLS(t *testing.T) {
} }
} }
func TestTransportDialContext(t *testing.T) {
setParallel(t)
defer afterTest(t)
var mu sync.Mutex // guards following
var gotReq bool
var receivedContext context.Context
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
receivedContext = ctx
mu.Unlock()
return net.Dial(netw, addr)
}
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), "some-key", "some-value")
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if receivedContext != ctx {
t.Error("didn't receive correct context")
}
}
func TestTransportDialTLSContext(t *testing.T) {
setParallel(t)
defer afterTest(t)
var mu sync.Mutex // guards following
var gotReq bool
var receivedContext context.Context
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
receivedContext = ctx
mu.Unlock()
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
if err != nil {
return nil, err
}
return c, c.Handshake()
}
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), "some-key", "some-value")
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if receivedContext != ctx {
t.Error("didn't receive correct context")
}
}
// Test for issue 8755 // Test for issue 8755
// Ensure that if a proxy returns an error, it is exposed by RoundTrip // Ensure that if a proxy returns an error, it is exposed by RoundTrip
func TestRoundTripReturnsProxyError(t *testing.T) { func TestRoundTripReturnsProxyError(t *testing.T) {
...@@ -5577,6 +5661,7 @@ func TestTransportClone(t *testing.T) { ...@@ -5577,6 +5661,7 @@ func TestTransportClone(t *testing.T) {
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
Dial: func(network, addr string) (net.Conn, error) { panic("") }, Dial: func(network, addr string) (net.Conn, error) { panic("") },
DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
TLSClientConfig: new(tls.Config), TLSClientConfig: new(tls.Config),
TLSHandshakeTimeout: time.Second, TLSHandshakeTimeout: time.Second,
DisableKeepAlives: true, DisableKeepAlives: true,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment