Commit ae47e044 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Transport.DialTLS hook

Per discussions out of https://golang.org/cl/128930043/
and golang-nuts threads and with agl.

Fixes #8522

LGTM=agl, adg
R=agl, c, adg
CC=c, golang-codereviews
https://golang.org/cl/137940043
parent 13d0b82b
...@@ -43,8 +43,8 @@ var DefaultTransport RoundTripper = &Transport{ ...@@ -43,8 +43,8 @@ var DefaultTransport RoundTripper = &Transport{
// MaxIdleConnsPerHost. // MaxIdleConnsPerHost.
const DefaultMaxIdleConnsPerHost = 2 const DefaultMaxIdleConnsPerHost = 2
// Transport is an implementation of RoundTripper that supports http, // Transport is an implementation of RoundTripper that supports HTTP,
// https, and http proxies (for either http or https with CONNECT). // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
// Transport can also cache connections for future re-use. // Transport can also cache connections for future re-use.
type Transport struct { type Transport struct {
idleMu sync.Mutex idleMu sync.Mutex
...@@ -61,11 +61,22 @@ type Transport struct { ...@@ -61,11 +61,22 @@ type Transport struct {
// If Proxy is nil or returns a nil *URL, no proxy is used. // If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*Request) (*url.URL, error) Proxy func(*Request) (*url.URL, error)
// Dial specifies the dial function for creating TCP // Dial specifies the dial function for creating unencrypted
// connections. // TCP connections.
// If Dial is nil, net.Dial is used. // If Dial is nil, net.Dial is used.
Dial func(network, addr string) (net.Conn, error) Dial func(network, addr string) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
// TLS connections for non-proxied HTTPS requests.
//
// If DialTLS is nil, Dial and TLSClientConfig are used.
//
// If DialTLS is set, the Dial hook is not used for HTTPS
// requests and the TLSClientConfig and TLSHandshakeTimeout
// are ignored. The returned net.Conn is assumed to already be
// past the TLS handshake.
DialTLS func(network, addr string) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with // TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used. // tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
...@@ -504,44 +515,56 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error ...@@ -504,44 +515,56 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
} }
func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
conn, err := t.dial("tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
}
return nil, err
}
pa := cm.proxyAuth()
pconn := &persistConn{ pconn := &persistConn{
t: t, t: t,
cacheKey: cm.key(), cacheKey: cm.key(),
conn: conn,
reqch: make(chan requestAndChan, 1), reqch: make(chan requestAndChan, 1),
writech: make(chan writeRequest, 1), writech: make(chan writeRequest, 1),
closech: make(chan struct{}), closech: make(chan struct{}),
writeErrCh: make(chan error, 1), writeErrCh: make(chan error, 1),
} }
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
if tlsDial {
var err error
pconn.conn, err = t.DialTLS("tcp", cm.addr())
if err != nil {
return nil, err
}
if tc, ok := pconn.conn.(*tls.Conn); ok {
cs := tc.ConnectionState()
pconn.tlsState = &cs
}
} else {
conn, err := t.dial("tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
}
return nil, err
}
pconn.conn = conn
}
// Proxy setup.
switch { switch {
case cm.proxyURL == nil: case cm.proxyURL == nil:
// Do nothing. // Do nothing. Not using a proxy.
case cm.targetScheme == "http": case cm.targetScheme == "http":
pconn.isProxy = true pconn.isProxy = true
if pa != "" { if pa := cm.proxyAuth(); pa != "" {
pconn.mutateHeaderFunc = func(h Header) { pconn.mutateHeaderFunc = func(h Header) {
h.Set("Proxy-Authorization", pa) h.Set("Proxy-Authorization", pa)
} }
} }
case cm.targetScheme == "https": case cm.targetScheme == "https":
conn := pconn.conn
connectReq := &Request{ connectReq := &Request{
Method: "CONNECT", Method: "CONNECT",
URL: &url.URL{Opaque: cm.targetAddr}, URL: &url.URL{Opaque: cm.targetAddr},
Host: cm.targetAddr, Host: cm.targetAddr,
Header: make(Header), Header: make(Header),
} }
if pa != "" { if pa := cm.proxyAuth(); pa != "" {
connectReq.Header.Set("Proxy-Authorization", pa) connectReq.Header.Set("Proxy-Authorization", pa)
} }
connectReq.Write(conn) connectReq.Write(conn)
...@@ -562,7 +585,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { ...@@ -562,7 +585,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
} }
} }
if cm.targetScheme == "https" { if cm.targetScheme == "https" && !tlsDial {
// Initiate TLS and check remote host name against certificate. // Initiate TLS and check remote host name against certificate.
cfg := t.TLSClientConfig cfg := t.TLSClientConfig
if cfg == nil || cfg.ServerName == "" { if cfg == nil || cfg.ServerName == "" {
...@@ -575,7 +598,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { ...@@ -575,7 +598,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
cfg = &clone cfg = &clone
} }
} }
plainConn := conn plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg) tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2) errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake var timer *time.Timer // for canceling TLS handshake
......
...@@ -2096,6 +2096,46 @@ func TestTransportClosesBodyOnError(t *testing.T) { ...@@ -2096,6 +2096,46 @@ func TestTransportClosesBodyOnError(t *testing.T) {
} }
} }
func TestTransportDialTLS(t *testing.T) {
var mu sync.Mutex // guards following
var gotReq, didDial bool
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
tr := &Transport{
DialTLS: func(netw, addr string) (net.Conn, error) {
mu.Lock()
didDial = true
mu.Unlock()
c, err := tls.Dial(netw, addr, &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
return nil, err
}
return c, c.Handshake()
},
}
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if !didDial {
t.Error("didn't use dial hook")
}
}
func wantBody(res *http.Response, err error, want string) error { func wantBody(res *http.Response, err error, want string) error {
if err != nil { if err != nil {
return err return err
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment