Commit fd4b4b56 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Transport.TLSHandshakeTimeout; set it by default

Update #3362

LGTM=agl
R=agl
CC=golang-codereviews
https://golang.org/cl/68150045
parent abe53f87
......@@ -36,6 +36,7 @@ var DefaultTransport RoundTripper = &Transport{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
......@@ -69,6 +70,10 @@ type Transport struct {
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// TLSHandshakeTimeout specifies the maximum amount of time waiting to
// wait for a TLS handshake. Zero means no timeout.
TLSHandshakeTimeout time.Duration
// DisableKeepAlives, if true, prevents re-use of TCP connections
// between different HTTP requests.
DisableKeepAlives bool
......@@ -542,16 +547,33 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
cfg = &clone
}
}
conn = tls.Client(conn, cfg)
if err = conn.(*tls.Conn).Handshake(); err != nil {
plainConn := conn
tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d := t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
return nil, err
}
if !cfg.InsecureSkipVerify {
if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
plainConn.Close()
return nil, err
}
}
pconn.conn = conn
pconn.conn = tlsConn
}
pconn.br = bufio.NewReader(pconn.conn)
......@@ -1084,3 +1106,9 @@ type readerAndCloser struct {
io.Reader
io.Closer
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
......@@ -1722,6 +1722,73 @@ func TestTransportClosesRequestBody(t *testing.T) {
}
}
func TestTransportTLSHandshakeTimeout(t *testing.T) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping in short mode")
}
ln := newLocalListener(t)
defer ln.Close()
testdonec := make(chan struct{})
defer close(testdonec)
go func() {
c, err := ln.Accept()
if err != nil {
t.Error(err)
return
}
<-testdonec
c.Close()
}()
getdonec := make(chan struct{})
go func() {
defer close(getdonec)
tr := &Transport{
Dial: func(_, _ string) (net.Conn, error) {
return net.Dial("tcp", ln.Addr().String())
},
TLSHandshakeTimeout: 250 * time.Millisecond,
}
cl := &Client{Transport: tr}
_, err := cl.Get("https://dummy.tld/")
if err == nil {
t.Fatal("expected error")
}
ue, ok := err.(*url.Error)
if !ok {
t.Fatalf("expected url.Error; got %#v", err)
}
ne, ok := ue.Err.(net.Error)
if !ok {
t.Fatalf("expected net.Error; got %#v", err)
}
if !ne.Timeout() {
t.Error("expected timeout error; got %v", err)
}
if !strings.Contains(err.Error(), "handshake timeout") {
t.Error("expected 'handshake timeout' in error; got %v", err)
}
}()
select {
case <-getdonec:
case <-time.After(5 * time.Second):
t.Error("test timeout; TLS handshake hung?")
}
}
func newLocalListener(t *testing.T) net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
ln, err = net.Listen("tcp6", "[::1]:0")
}
if err != nil {
t.Fatal(err)
}
return ln
}
type countCloseReader struct {
n *int
io.Reader
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment