Commit 0ab78df9 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix a few crashes with a ClientTrace with nil funcs

And add a test.

Updates #12580

Change-Id: Ia7eaba09b8e7fd0eddbcaefb948d01ab10af876e
Reviewed-on: https://go-review.googlesource.com/22659Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent b3a130e8
...@@ -787,11 +787,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC ...@@ -787,11 +787,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
req := treq.Request req := treq.Request
trace := treq.trace trace := treq.trace
ctx := req.Context() ctx := req.Context()
if trace != nil { if trace != nil && trace.GetConn != nil {
trace.GetConn(cm.addr()) trace.GetConn(cm.addr())
} }
if pc, idleSince := t.getIdleConn(cm); pc != nil { if pc, idleSince := t.getIdleConn(cm); pc != nil {
if trace != nil { if trace != nil && trace.GotConn != nil {
trace.GotConn(pc.gotIdleConnTrace(idleSince)) trace.GotConn(pc.gotIdleConnTrace(idleSince))
} }
// set request canceler to some non-nil function so we // set request canceler to some non-nil function so we
...@@ -834,7 +834,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC ...@@ -834,7 +834,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
select { select {
case v := <-dialc: case v := <-dialc:
// Our dial finished. // Our dial finished.
if trace != nil && v.pc != nil { if trace != nil && trace.GotConn != nil && v.pc != nil {
trace.GotConn(httptrace.GotConnInfo{Conn: v.pc.conn}) trace.GotConn(httptrace.GotConnInfo{Conn: v.pc.conn})
} }
return v.pc, v.err return v.pc, v.err
......
...@@ -3193,7 +3193,12 @@ func TestTransportResponseHeaderLength(t *testing.T) { ...@@ -3193,7 +3193,12 @@ func TestTransportResponseHeaderLength(t *testing.T) {
} }
} }
func TestTransportEventTrace(t *testing.T) { func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, false) }
// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, true) }
func testTransportEventTrace(t *testing.T, noHooks bool) {
defer afterTest(t) defer afterTest(t)
const resBody = "some body" const resBody = "some body"
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
...@@ -3233,7 +3238,7 @@ func TestTransportEventTrace(t *testing.T) { ...@@ -3233,7 +3238,7 @@ func TestTransportEventTrace(t *testing.T) {
}) })
req, _ := NewRequest("POST", "http://dns-is-faked.golang:"+port, strings.NewReader("some body")) req, _ := NewRequest("POST", "http://dns-is-faked.golang:"+port, strings.NewReader("some body"))
req = req.WithContext(httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ trace := &httptrace.ClientTrace{
GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
GotFirstResponseByte: func() { logf("first response byte") }, GotFirstResponseByte: func() { logf("first response byte") },
...@@ -3250,7 +3255,12 @@ func TestTransportEventTrace(t *testing.T) { ...@@ -3250,7 +3255,12 @@ func TestTransportEventTrace(t *testing.T) {
Wait100Continue: func() { logf("Wait100Continue") }, Wait100Continue: func() { logf("Wait100Continue") },
Got100Continue: func() { logf("Got100Continue") }, Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) }, WroteRequest: func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) },
})) }
if noHooks {
// zero out all func pointers, trying to get some path to crash
*trace = httptrace.ClientTrace{}
}
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
req.Header.Set("Expect", "100-continue") req.Header.Set("Expect", "100-continue")
res, err := c.Do(req) res, err := c.Do(req)
...@@ -3266,6 +3276,13 @@ func TestTransportEventTrace(t *testing.T) { ...@@ -3266,6 +3276,13 @@ func TestTransportEventTrace(t *testing.T) {
} }
res.Body.Close() res.Body.Close()
if noHooks {
// Done at this point. Just testing a full HTTP
// requests can happen with a trace pointing to a zero
// ClientTrace, full of nil func pointers.
return
}
got := buf.String() got := buf.String()
wantSub := func(sub string) { wantSub := func(sub string) {
if !strings.Contains(got, sub) { if !strings.Contains(got, sub) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment