Commit 0b5f2f0d authored by Ian Lance Taylor's avatar Ian Lance Taylor

net/http: if context is canceled, return its error

This permits the error message to distinguish between a context that was
canceled and a context that timed out.

Updates #16381.

Change-Id: I3994b98e32952abcd7ddb5fee08fa1535999be6d
Reviewed-on: https://go-review.googlesource.com/24978
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 643b9ec0
...@@ -313,8 +313,8 @@ func TestClientRedirectContext(t *testing.T) { ...@@ -313,8 +313,8 @@ func TestClientRedirectContext(t *testing.T) {
if !ok { if !ok {
t.Fatalf("got error %T; want *url.Error", err) t.Fatalf("got error %T; want *url.Error", err)
} }
if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn { if ue.Err != context.Canceled {
t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err) t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
} }
} }
......
...@@ -76,7 +76,7 @@ type Transport struct { ...@@ -76,7 +76,7 @@ type Transport struct {
idleLRU connLRU idleLRU connLRU
reqMu sync.Mutex reqMu sync.Mutex
reqCanceler map[*Request]func() reqCanceler map[*Request]func(error)
altMu sync.RWMutex altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
...@@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() { ...@@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() {
// cancelable context instead. CancelRequest cannot cancel HTTP/2 // cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests. // requests.
func (t *Transport) CancelRequest(req *Request) { func (t *Transport) CancelRequest(req *Request) {
t.cancelRequest(req, errRequestCanceled)
}
// Cancel an in-flight request, recording the error value.
func (t *Transport) cancelRequest(req *Request, err error) {
t.reqMu.Lock() t.reqMu.Lock()
cancel := t.reqCanceler[req] cancel := t.reqCanceler[req]
delete(t.reqCanceler, req) delete(t.reqCanceler, req)
t.reqMu.Unlock() t.reqMu.Unlock()
if cancel != nil { if cancel != nil {
cancel() cancel(err)
} }
} }
...@@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { ...@@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
} }
} }
func (t *Transport) setReqCanceler(r *Request, fn func()) { func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
t.reqMu.Lock() t.reqMu.Lock()
defer t.reqMu.Unlock() defer t.reqMu.Unlock()
if t.reqCanceler == nil { if t.reqCanceler == nil {
t.reqCanceler = make(map[*Request]func()) t.reqCanceler = make(map[*Request]func(error))
} }
if fn != nil { if fn != nil {
t.reqCanceler[r] = fn t.reqCanceler[r] = fn
...@@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) { ...@@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
// for the request, we don't set the function and return false. // for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if // Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call. // the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
t.reqMu.Lock() t.reqMu.Lock()
defer t.reqMu.Unlock() defer t.reqMu.Unlock()
_, ok := t.reqCanceler[r] _, ok := t.reqCanceler[r]
...@@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC ...@@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
// set request canceler to some non-nil function so we // set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when // can detect whether it was cleared between now and when
// we enter roundTrip // we enter roundTrip
t.setReqCanceler(req, func() {}) t.setReqCanceler(req, func(error) {})
return pc, nil return pc, nil
} }
...@@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC ...@@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
}() }()
} }
cancelc := make(chan struct{}) cancelc := make(chan error, 1)
t.setReqCanceler(req, func() { close(cancelc) }) t.setReqCanceler(req, func(err error) { cancelc <- err })
go func() { go func() {
pc, err := t.dialConn(ctx, cm) pc, err := t.dialConn(ctx, cm)
...@@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC ...@@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
select { select {
case <-req.Cancel: case <-req.Cancel:
case <-req.Context().Done(): case <-req.Context().Done():
case <-cancelc: return nil, req.Context().Err()
case err := <-cancelc:
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
default: default:
// It wasn't an error due to cancelation, so // It wasn't an error due to cancelation, so
// return the original error message: // return the original error message:
...@@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC ...@@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
return nil, errRequestCanceledConn return nil, errRequestCanceledConn
case <-req.Context().Done(): case <-req.Context().Done():
handlePendingDial() handlePendingDial()
return nil, errRequestCanceledConn return nil, req.Context().Err()
case <-cancelc: case err := <-cancelc:
handlePendingDial() handlePendingDial()
return nil, errRequestCanceledConn if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
} }
} }
...@@ -1231,8 +1244,8 @@ type persistConn struct { ...@@ -1231,8 +1244,8 @@ type persistConn struct {
mu sync.Mutex // guards following fields mu sync.Mutex // guards following fields
numExpectedResponses int numExpectedResponses int
closed error // set non-nil when conn is closed, before closech is closed closed error // set non-nil when conn is closed, before closech is closed
canceledErr error // set non-nil if conn is canceled
broken bool // an error has happened on this connection; marked broken so it's not reused. broken bool // an error has happened on this connection; marked broken so it's not reused.
canceled bool // whether this conn was broken due a CancelRequest
reused bool // whether conn has had successful request/response and is being reused. reused bool // whether conn has had successful request/response and is being reused.
// mutateHeaderFunc is an optional func to modify extra // mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the // headers on each outbound request before it's written. (the
...@@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool { ...@@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool {
return b return b
} }
// isCanceled reports whether this connection was closed due to CancelRequest. // canceled returns non-nil if the connection was closed due to
func (pc *persistConn) isCanceled() bool { // CancelRequest or due to context cancelation.
func (pc *persistConn) canceled() error {
pc.mu.Lock() pc.mu.Lock()
defer pc.mu.Unlock() defer pc.mu.Unlock()
return pc.canceled return pc.canceledErr
} }
// isReused reports whether this connection is in a known broken state. // isReused reports whether this connection is in a known broken state.
...@@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn ...@@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn
return return
} }
func (pc *persistConn) cancelRequest() { func (pc *persistConn) cancelRequest(err error) {
pc.mu.Lock() pc.mu.Lock()
defer pc.mu.Unlock() defer pc.mu.Unlock()
pc.canceled = true pc.canceledErr = err
pc.closeLocked(errRequestCanceled) pc.closeLocked(errRequestCanceled)
} }
...@@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er ...@@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
if err == nil { if err == nil {
return nil return nil
} }
if pc.isCanceled() { if err := pc.canceled(); err != nil {
return errRequestCanceled return err
} }
if err == errServerClosedIdle { if err == errServerClosedIdle {
return err return err
...@@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er ...@@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
// its pc.closech channel close, indicating the persistConn is dead. // its pc.closech channel close, indicating the persistConn is dead.
// (after closech is closed, pc.closed is valid). // (after closech is closed, pc.closed is valid).
func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error { func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error {
if pc.isCanceled() { if err := pc.canceled(); err != nil {
return errRequestCanceled return err
} }
err := pc.closed err := pc.closed
if err == errServerClosedIdle { if err == errServerClosedIdle {
...@@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() { ...@@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() {
waitForBodyRead <- isEOF waitForBodyRead <- isEOF
if isEOF { if isEOF {
<-eofc // see comment above eofc declaration <-eofc // see comment above eofc declaration
} else if err != nil && pc.isCanceled() { } else if err != nil {
return errRequestCanceled if cerr := pc.canceled(); cerr != nil {
return cerr
}
} }
return err return err
}, },
...@@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() { ...@@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() {
pc.t.CancelRequest(rc.req) pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done(): case <-rc.req.Context().Done():
alive = false alive = false
pc.t.CancelRequest(rc.req) pc.t.cancelRequest(rc.req, rc.req.Context().Err())
case <-pc.closech: case <-pc.closech:
alive = false alive = false
} }
...@@ -1836,8 +1852,8 @@ WaitResponse: ...@@ -1836,8 +1852,8 @@ WaitResponse:
select { select {
case err := <-writeErrCh: case err := <-writeErrCh:
if err != nil { if err != nil {
if pc.isCanceled() { if cerr := pc.canceled(); cerr != nil {
err = errRequestCanceled err = cerr
} }
re = responseAndError{err: err} re = responseAndError{err: err}
pc.close(fmt.Errorf("write error: %v", err)) pc.close(fmt.Errorf("write error: %v", err))
...@@ -1861,9 +1877,8 @@ WaitResponse: ...@@ -1861,9 +1877,8 @@ WaitResponse:
case <-cancelChan: case <-cancelChan:
pc.t.CancelRequest(req.Request) pc.t.CancelRequest(req.Request)
cancelChan = nil cancelChan = nil
ctxDoneChan = nil
case <-ctxDoneChan: case <-ctxDoneChan:
pc.t.CancelRequest(req.Request) pc.t.cancelRequest(req.Request, req.Context().Err())
cancelChan = nil cancelChan = nil
ctxDoneChan = nil ctxDoneChan = nil
} }
......
...@@ -1718,8 +1718,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { ...@@ -1718,8 +1718,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
} }
_, err := c.Do(req) _, err := c.Do(req)
if err == nil || !strings.Contains(err.Error(), "canceled") { if ue, ok := err.(*url.Error); ok {
t.Errorf("Do error = %v; want cancelation", err) err = ue.Err
}
if withCtx {
if err != context.Canceled {
t.Errorf("Do error = %v; want %v", err, context.Canceled)
}
} else {
if err == nil || !strings.Contains(err.Error(), "canceled") {
t.Errorf("Do error = %v; want cancelation", err)
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment