Commit abdd73cc authored by Edward Muller's avatar Edward Muller Committed by Brad Fitzpatrick

net/http/httptrace: add ClientTrace.TLSHandshakeStart & TLSHandshakeDone

Fixes #16965

Change-Id: I3638fe280a5b1063ff589e6e1ff8a97c74b77c66
Reviewed-on: https://go-review.googlesource.com/30359Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 12c9844c
...@@ -396,7 +396,7 @@ var pkgDeps = map[string][]string{ ...@@ -396,7 +396,7 @@ var pkgDeps = map[string][]string{
"runtime/debug", "runtime/debug",
}, },
"net/http/internal": {"L4"}, "net/http/internal": {"L4"},
"net/http/httptrace": {"context", "internal/nettrace", "net", "reflect", "time"}, "net/http/httptrace": {"context", "crypto/tls", "internal/nettrace", "net", "reflect", "time"},
// HTTP-using packages. // HTTP-using packages.
"expvar": {"L4", "OS", "encoding/json", "net/http"}, "expvar": {"L4", "OS", "encoding/json", "net/http"},
......
...@@ -8,6 +8,7 @@ package httptrace ...@@ -8,6 +8,7 @@ package httptrace
import ( import (
"context" "context"
"crypto/tls"
"internal/nettrace" "internal/nettrace"
"net" "net"
"reflect" "reflect"
...@@ -119,6 +120,16 @@ type ClientTrace struct { ...@@ -119,6 +120,16 @@ type ClientTrace struct {
// enabled, this may be called multiple times. // enabled, this may be called multiple times.
ConnectDone func(network, addr string, err error) ConnectDone func(network, addr string, err error)
// TLSHandshakeStart is called when the TLS handshake is started. When
// connecting to a HTTPS site via a HTTP proxy, the handshake happens after
// the CONNECT request is processed by the proxy.
TLSHandshakeStart func()
// TLSHandshakeDone is called after the TLS handshake with either the
// successful handshake's connection state, or a non-nil error on handshake
// failure.
TLSHandshakeDone func(tls.ConnectionState, error)
// WroteHeaders is called after the Transport has written // WroteHeaders is called after the Transport has written
// the request headers. // the request headers.
WroteHeaders func() WroteHeaders func()
......
...@@ -955,6 +955,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon ...@@ -955,6 +955,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
writeErrCh: make(chan error, 1), writeErrCh: make(chan error, 1),
writeLoopDone: make(chan struct{}), writeLoopDone: make(chan struct{}),
} }
trace := httptrace.ContextClientTrace(ctx)
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
if tlsDial { if tlsDial {
var err error var err error
...@@ -968,11 +969,20 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon ...@@ -968,11 +969,20 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
if tc, ok := pconn.conn.(*tls.Conn); ok { if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below // Handshake here, in case DialTLS didn't. TLSNextProto below
// depends on it for knowing the connection state. // depends on it for knowing the connection state.
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
if err := tc.Handshake(); err != nil { if err := tc.Handshake(); err != nil {
go pconn.conn.Close() go pconn.conn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
return nil, err return nil, err
} }
cs := tc.ConnectionState() cs := tc.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
pconn.tlsState = &cs pconn.tlsState = &cs
} }
} else { } else {
...@@ -1042,6 +1052,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon ...@@ -1042,6 +1052,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}) })
} }
go func() { go func() {
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := tlsConn.Handshake() err := tlsConn.Handshake()
if timer != nil { if timer != nil {
timer.Stop() timer.Stop()
...@@ -1050,6 +1063,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon ...@@ -1050,6 +1063,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}() }()
if err := <-errc; err != nil { if err := <-errc; err != nil {
plainConn.Close() plainConn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
return nil, err return nil, err
} }
if !cfg.InsecureSkipVerify { if !cfg.InsecureSkipVerify {
...@@ -1059,6 +1075,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon ...@@ -1059,6 +1075,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
} }
} }
cs := tlsConn.ConnectionState() cs := tlsConn.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
pconn.tlsState = &cs pconn.tlsState = &cs
pconn.conn = tlsConn pconn.conn = tlsConn
} }
......
...@@ -3288,6 +3288,12 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3288,6 +3288,12 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
close(gotWroteReqEvent) close(gotWroteReqEvent)
}, },
} }
if h2 {
trace.TLSHandshakeStart = func() { logf("tls handshake start") }
trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
}
}
if noHooks { if noHooks {
// zero out all func pointers, trying to get some path to crash // zero out all func pointers, trying to get some path to crash
*trace = httptrace.ClientTrace{} *trace = httptrace.ClientTrace{}
...@@ -3339,7 +3345,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3339,7 +3345,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("Reused:false WasIdle:false IdleTime:0s")
wantOnce("first response byte") wantOnce("first response byte")
if !h2 { if h2 {
wantOnce("tls handshake start")
wantOnce("tls handshake done")
} else {
wantOnce("PutIdleConn = <nil>") wantOnce("PutIdleConn = <nil>")
} }
wantOnce("Wait100Continue") wantOnce("Wait100Continue")
...@@ -3411,6 +3420,55 @@ func TestTransportEventTraceRealDNS(t *testing.T) { ...@@ -3411,6 +3420,55 @@ func TestTransportEventTraceRealDNS(t *testing.T) {
} }
} }
// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1
// connections. The http2 test is done in TestTransportEventTrace_h2
func TestTLSHandshakeTrace(t *testing.T) {
defer afterTest(t)
s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer s.Close()
var mu sync.Mutex
var start, done bool
trace := &httptrace.ClientTrace{
TLSHandshakeStart: func() {
mu.Lock()
defer mu.Unlock()
start = true
},
TLSHandshakeDone: func(s tls.ConnectionState, err error) {
mu.Lock()
defer mu.Unlock()
done = true
if err != nil {
t.Fatal("Expected error to be nil but was:", err)
}
},
}
tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
req, err := NewRequest("GET", s.URL, nil)
if err != nil {
t.Fatal("Unable to construct test request:", err)
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
r, err := c.Do(req)
if err != nil {
t.Fatal("Unexpected error making request:", err)
}
r.Body.Close()
mu.Lock()
defer mu.Unlock()
if !start {
t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
}
if !done {
t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
}
}
func TestTransportMaxIdleConns(t *testing.T) { func TestTransportMaxIdleConns(t *testing.T) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment