Commit 70ee5252 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix Transport crash when abandoning dial which upgrades protos

When the Transport was creating an bound HTTP connection (protocol
unknown initially) and then ends up deciding it doesn't need it, a
goroutine sits around to clean up whatever the result was. That
goroutine made the false assumption that the result was always an
HTTP/1 connection or an error. It may also be an alternate protocol
in which case the *persistConn.conn net.Conn field is nil, and the
alt field is non-nil.

Fixes #13839

Change-Id: Ia4972e5eb1ad53fa00410b3466d4129c753e0871
Reviewed-on: https://go-review.googlesource.com/18573Reviewed-by: default avatarRuss Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 4525571f
...@@ -21,6 +21,7 @@ var ( ...@@ -21,6 +21,7 @@ var (
ExportServerNewConn = (*Server).newConn ExportServerNewConn = (*Server).newConn
ExportCloseWriteAndWait = (*conn).closeWriteAndWait ExportCloseWriteAndWait = (*conn).closeWriteAndWait
ExportErrRequestCanceled = errRequestCanceled ExportErrRequestCanceled = errRequestCanceled
ExportErrRequestCanceledConn = errRequestCanceledConn
ExportServeFile = serveFile ExportServeFile = serveFile
ExportHttp2ConfigureTransport = http2ConfigureTransport ExportHttp2ConfigureTransport = http2ConfigureTransport
ExportHttp2ConfigureServer = http2ConfigureServer ExportHttp2ConfigureServer = http2ConfigureServer
......
...@@ -618,9 +618,13 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { ...@@ -618,9 +618,13 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
return true return true
} }
func (t *Transport) dial(network, addr string) (c net.Conn, err error) { func (t *Transport) dial(network, addr string) (net.Conn, error) {
if t.Dial != nil { if t.Dial != nil {
return t.Dial(network, addr) c, err := t.Dial(network, addr)
if c == nil && err == nil {
err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
}
return c, err
} }
return net.Dial(network, addr) return net.Dial(network, addr)
} }
...@@ -682,10 +686,10 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error ...@@ -682,10 +686,10 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
return pc, nil return pc, nil
case <-req.Cancel: case <-req.Cancel:
handlePendingDial() handlePendingDial()
return nil, errors.New("net/http: request canceled while waiting for connection") return nil, errRequestCanceledConn
case <-cancelc: case <-cancelc:
handlePendingDial() handlePendingDial()
return nil, errors.New("net/http: request canceled while waiting for connection") return nil, errRequestCanceledConn
} }
} }
...@@ -705,6 +709,9 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { ...@@ -705,6 +709,9 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if pconn.conn == nil {
return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
}
if tc, ok := pconn.conn.(*tls.Conn); ok { if tc, ok := pconn.conn.(*tls.Conn); ok {
cs := tc.ConnectionState() cs := tc.ConnectionState()
pconn.tlsState = &cs pconn.tlsState = &cs
...@@ -1326,6 +1333,7 @@ func (e *httpError) Temporary() bool { return true } ...@@ -1326,6 +1333,7 @@ func (e *httpError) Temporary() bool { return true }
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
var errClosed error = &httpError{err: "net/http: server closed connection before response was received"} var errClosed error = &httpError{err: "net/http: server closed connection before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled") var errRequestCanceled = errors.New("net/http: request canceled")
var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
func nop() {} func nop() {}
...@@ -1502,9 +1510,19 @@ func (pc *persistConn) closeLocked(err error) { ...@@ -1502,9 +1510,19 @@ func (pc *persistConn) closeLocked(err error) {
} }
pc.broken = true pc.broken = true
if pc.closed == nil { if pc.closed == nil {
pc.conn.Close()
pc.closed = err pc.closed = err
close(pc.closech) if pc.alt != nil {
// Do nothing; can only get here via getConn's
// handlePendingDial's putOrCloseIdleConn when
// it turns out the abandoned connection in
// flight ended up negotiating an alternate
// protocol. We don't use the connection
// freelist for http2. That's done by the
// alternate protocol's RoundTripper.
} else {
pc.conn.Close()
close(pc.closech)
}
} }
pc.mutateHeaderFunc = nil pc.mutateHeaderFunc = nil
} }
......
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
. "net/http" . "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
"net/http/internal"
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
...@@ -2939,6 +2940,98 @@ func TestTransportReuseConnEmptyResponseBody(t *testing.T) { ...@@ -2939,6 +2940,98 @@ func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
} }
} }
// Issue 13839
func TestNoCrashReturningTransportAltConn(t *testing.T) {
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
t.Fatal(err)
}
ln := newLocalListener(t)
defer ln.Close()
handledPendingDial := make(chan bool, 1)
SetPendingDialHooks(nil, func() { handledPendingDial <- true })
defer SetPendingDialHooks(nil, nil)
testDone := make(chan struct{})
defer close(testDone)
go func() {
tln := tls.NewListener(ln, &tls.Config{
NextProtos: []string{"foo"},
Certificates: []tls.Certificate{cert},
})
sc, err := tln.Accept()
if err != nil {
t.Error(err)
return
}
if err := sc.(*tls.Conn).Handshake(); err != nil {
t.Error(err)
return
}
<-testDone
sc.Close()
}()
addr := ln.Addr().String()
req, _ := NewRequest("GET", "https://fake.tld/", nil)
cancel := make(chan struct{})
req.Cancel = cancel
doReturned := make(chan bool, 1)
madeRoundTripper := make(chan bool, 1)
tr := &Transport{
DisableKeepAlives: true,
TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
"foo": func(authority string, c *tls.Conn) RoundTripper {
madeRoundTripper <- true
return funcRoundTripper(func() {
t.Error("foo RoundTripper should not be called")
})
},
},
Dial: func(_, _ string) (net.Conn, error) {
panic("shouldn't be called")
},
DialTLS: func(_, _ string) (net.Conn, error) {
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"foo"},
})
if err != nil {
return nil, err
}
if err := tc.Handshake(); err != nil {
return nil, err
}
close(cancel)
<-doReturned
return tc, nil
},
}
c := &Client{Transport: tr}
_, err = c.Do(req)
if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
}
doReturned <- true
<-madeRoundTripper
<-handledPendingDial
}
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()
func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
fn()
return nil, errFakeRoundTrip
}
func wantBody(res *Response, err error, want string) error { func wantBody(res *Response, err error, want string) error {
if err != nil { if err != nil {
return err return err
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment