Commit 7de71c85 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make Client use Request.Cancel for timeouts instead of CancelRequest

In the beginning, there was no way to cancel an HTTP request.

We later added Transport.CancelRequest to cancel an in-flight HTTP
request by breaking its underlying TCP connection, but it was hard to
use correctly and didn't work in all cases. And its error messages
were terrible. Some of those issues were fixed over time, but the most
unfixable problem was that it didn't compose well. All RoundTripper
implementations had to choose to whether to implement CancelRequest
and both decisions had negative consequences.

In Go 1.5 we added Request.Cancel, which composed well, worked in all
phases, had nice error messages, etc. But we forgot to use it in the
implementation of Client.Timeout (a timeout which spans multiple
requests and reading request bodies).

In Go 1.6 (upcoming), we added HTTP/2 support, but now Client.Timeout
didn't work because the http2.Transport didn't have a CancelRequest
method.

Rather than add a CancelRequest method to http2, officially deprecate
it and update the only caller (Client, for Client.Cancel) to use
Request.Cancel instead.

The http2 Client timeout tests are enabled now.

For compatibility, we still use CancelRequest in Client if we don't
recognize the RoundTripper type. But documentation has been updated to
tell people that CancelRequest is deprecated.

Fixes #13540

Change-Id: I15546b90825bb8b54905e17563eca55ea2642075
Reviewed-on: https://go-review.googlesource.com/18260Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 212bdd95
...@@ -20,7 +20,6 @@ import ( ...@@ -20,7 +20,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
...@@ -66,10 +65,15 @@ type Client struct { ...@@ -66,10 +65,15 @@ type Client struct {
// //
// A Timeout of zero means no timeout. // A Timeout of zero means no timeout.
// //
// The Client's Transport must support the CancelRequest // The Client cancels requests to the underlying Transport
// method or Client will return errors when attempting to make // using the Request.Cancel mechanism. Requests passed
// a request with Get, Head, Post, or Do. Client's default // to Client.Do may still set Request.Cancel; both will
// Transport (DefaultTransport) supports CancelRequest. // cancel the request.
//
// For compatibility, the Client will also use the deprecated
// CancelRequest method on Transport if found. New
// RoundTripper implementations should use Request.Cancel
// instead of implementing CancelRequest.
Timeout time.Duration Timeout time.Duration
} }
...@@ -142,13 +146,13 @@ type readClose struct { ...@@ -142,13 +146,13 @@ type readClose struct {
io.Closer io.Closer
} }
func (c *Client) send(req *Request) (*Response, error) { func (c *Client) send(req *Request, deadline time.Time) (*Response, error) {
if c.Jar != nil { if c.Jar != nil {
for _, cookie := range c.Jar.Cookies(req.URL) { for _, cookie := range c.Jar.Cookies(req.URL) {
req.AddCookie(cookie) req.AddCookie(cookie)
} }
} }
resp, err := send(req, c.transport()) resp, err := send(req, c.transport(), deadline)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -180,13 +184,20 @@ func (c *Client) send(req *Request) (*Response, error) { ...@@ -180,13 +184,20 @@ func (c *Client) send(req *Request) (*Response, error) {
// Generally Get, Post, or PostForm will be used instead of Do. // Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) { func (c *Client) Do(req *Request) (resp *Response, err error) {
method := valueOrDefault(req.Method, "GET") method := valueOrDefault(req.Method, "GET")
if method == "" || method == "GET" || method == "HEAD" { if method == "GET" || method == "HEAD" {
return c.doFollowingRedirects(req, shouldRedirectGet) return c.doFollowingRedirects(req, shouldRedirectGet)
} }
if method == "POST" || method == "PUT" { if method == "POST" || method == "PUT" {
return c.doFollowingRedirects(req, shouldRedirectPost) return c.doFollowingRedirects(req, shouldRedirectPost)
} }
return c.send(req) return c.send(req, c.deadline())
}
func (c *Client) deadline() time.Time {
if c.Timeout > 0 {
return time.Now().Add(c.Timeout)
}
return time.Time{}
} }
func (c *Client) transport() RoundTripper { func (c *Client) transport() RoundTripper {
...@@ -198,8 +209,10 @@ func (c *Client) transport() RoundTripper { ...@@ -198,8 +209,10 @@ func (c *Client) transport() RoundTripper {
// send issues an HTTP request. // send issues an HTTP request.
// Caller should close resp.Body when done reading from it. // Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) { func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) {
if t == nil { req := ireq // req is either the original request, or a modified fork
if rt == nil {
req.closeBody() req.closeBody()
return nil, errors.New("http: no Client.Transport or DefaultTransport") return nil, errors.New("http: no Client.Transport or DefaultTransport")
} }
...@@ -214,20 +227,39 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { ...@@ -214,20 +227,39 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
return nil, errors.New("http: Request.RequestURI can't be set in client requests.") return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
} }
// forkReq forks req into a shallow clone of ireq the first
// time it's called.
forkReq := func() {
if ireq == req {
req = new(Request)
*req = *ireq // shallow clone
}
}
// Most the callers of send (Get, Post, et al) don't need // Most the callers of send (Get, Post, et al) don't need
// Headers, leaving it uninitialized. We guarantee to the // Headers, leaving it uninitialized. We guarantee to the
// Transport that this has been initialized, though. // Transport that this has been initialized, though.
if req.Header == nil { if req.Header == nil {
forkReq()
req.Header = make(Header) req.Header = make(Header)
} }
if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" {
username := u.Username() username := u.Username()
password, _ := u.Password() password, _ := u.Password()
forkReq()
req.Header = cloneHeader(ireq.Header)
req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
} }
resp, err = t.RoundTrip(req)
if !deadline.IsZero() {
forkReq()
}
stopTimer, wasCanceled := setRequestCancel(req, rt, deadline)
resp, err := rt.RoundTrip(req)
if err != nil { if err != nil {
stopTimer()
if resp != nil { if resp != nil {
log.Printf("RoundTripper returned a response & error; ignoring response") log.Printf("RoundTripper returned a response & error; ignoring response")
} }
...@@ -241,9 +273,76 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { ...@@ -241,9 +273,76 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
} }
return nil, err return nil, err
} }
if !deadline.IsZero() {
resp.Body = &cancelTimerBody{
stop: stopTimer,
rc: resp.Body,
reqWasCanceled: wasCanceled,
}
}
return resp, nil return resp, nil
} }
// setRequestCancel sets the Cancel field of req, if deadline is
// non-zero. The RoundTripper's type is used to determine whether the legacy
// CancelRequest behavior should be used.
func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) {
if deadline.IsZero() {
return nop, alwaysFalse
}
initialReqCancel := req.Cancel // the user's original Request.Cancel, if any
cancel := make(chan struct{})
req.Cancel = cancel
wasCanceled = func() bool {
select {
case <-cancel:
return true
default:
return false
}
}
doCancel := func() {
// The new way:
close(cancel)
// The legacy compatibility way, used only
// for RoundTripper implementations written
// before Go 1.5 or Go 1.6.
type canceler interface {
CancelRequest(*Request)
}
switch v := rt.(type) {
case *Transport, *http2Transport:
// Do nothing. The net/http package's transports
// support the new Request.Cancel channel
case canceler:
v.CancelRequest(req)
}
}
stopTimerCh := make(chan struct{})
var once sync.Once
stopTimer = func() { once.Do(func() { close(stopTimerCh) }) }
timer := time.NewTimer(deadline.Sub(time.Now()))
go func() {
select {
case <-initialReqCancel:
doCancel()
case <-timer.C:
doCancel()
case <-stopTimerCh:
timer.Stop()
}
}()
return stopTimer, wasCanceled
}
// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt // See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
// "To receive authorization, the client sends the userid and password, // "To receive authorization, the client sends the userid and password,
// separated by a single colon (":") character, within a base64 // separated by a single colon (":") character, within a base64
...@@ -338,28 +437,8 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -338,28 +437,8 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
return nil, errors.New("http: nil Request.URL") return nil, errors.New("http: nil Request.URL")
} }
var reqmu sync.Mutex // guards req
req := ireq req := ireq
deadline := c.deadline()
var timer *time.Timer
var atomicWasCanceled int32 // atomic bool (1 or 0)
var wasCanceled = alwaysFalse
if c.Timeout > 0 {
wasCanceled = func() bool { return atomic.LoadInt32(&atomicWasCanceled) != 0 }
type canceler interface {
CancelRequest(*Request)
}
tr, ok := c.transport().(canceler)
if !ok {
return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
}
timer = time.AfterFunc(c.Timeout, func() {
atomic.StoreInt32(&atomicWasCanceled, 1)
reqmu.Lock()
defer reqmu.Unlock()
tr.CancelRequest(req)
})
}
urlStr := "" // next relative or absolute URL to fetch (after first request) urlStr := "" // next relative or absolute URL to fetch (after first request)
redirectFailed := false redirectFailed := false
...@@ -388,14 +467,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -388,14 +467,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
break break
} }
} }
reqmu.Lock()
req = nreq req = nreq
reqmu.Unlock()
} }
urlStr = req.URL.String() urlStr = req.URL.String()
if resp, err = c.send(req); err != nil { if resp, err = c.send(req, deadline); err != nil {
if wasCanceled() { if !deadline.IsZero() && !time.Now().Before(deadline) {
err = &httpError{ err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true, timeout: true,
...@@ -420,22 +497,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo ...@@ -420,22 +497,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
via = append(via, req) via = append(via, req)
continue continue
} }
if timer != nil {
resp.Body = &cancelTimerBody{
t: timer,
rc: resp.Body,
reqWasCanceled: wasCanceled,
}
}
return resp, nil return resp, nil
} }
method := ireq.Method method := valueOrDefault(ireq.Method, "GET")
if method == "" {
method = "GET"
}
urlErr := &url.Error{ urlErr := &url.Error{
Op: method[0:1] + strings.ToLower(method[1:]), Op: method[:1] + strings.ToLower(method[1:]),
URL: urlStr, URL: urlStr,
Err: err, Err: err,
} }
...@@ -548,30 +615,35 @@ func (c *Client) Head(url string) (resp *Response, err error) { ...@@ -548,30 +615,35 @@ func (c *Client) Head(url string) (resp *Response, err error) {
} }
// cancelTimerBody is an io.ReadCloser that wraps rc with two features: // cancelTimerBody is an io.ReadCloser that wraps rc with two features:
// 1) on Read EOF or Close, the timer t is Stopped, // 1) on Read error or close, the stop func is called.
// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and // 2) On Read failure, if reqWasCanceled is true, the error is wrapped and
// marked as net.Error that hit its timeout. // marked as net.Error that hit its timeout.
type cancelTimerBody struct { type cancelTimerBody struct {
t *time.Timer stop func() // stops the time.Timer waiting to cancel the request
rc io.ReadCloser rc io.ReadCloser
reqWasCanceled func() bool reqWasCanceled func() bool
} }
func (b *cancelTimerBody) Read(p []byte) (n int, err error) { func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p) n, err = b.rc.Read(p)
if err == nil {
return n, nil
}
b.stop()
if err == io.EOF { if err == io.EOF {
b.t.Stop() return n, err
} else if err != nil && b.reqWasCanceled() { }
return n, &httpError{ if b.reqWasCanceled() {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while reading body)", err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true, timeout: true,
} }
} }
return return n, err
} }
func (b *cancelTimerBody) Close() error { func (b *cancelTimerBody) Close() error {
err := b.rc.Close() err := b.rc.Close()
b.t.Stop() b.stop()
return err return err
} }
...@@ -922,10 +922,7 @@ func TestBasicAuthHeadersPreserved(t *testing.T) { ...@@ -922,10 +922,7 @@ func TestBasicAuthHeadersPreserved(t *testing.T) {
} }
func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
func TestClientTimeout_h2(t *testing.T) { func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
t.Skip("skipping in http2 mode; golang.org/issue/13540")
testClientTimeout(t, h2Mode)
}
func testClientTimeout(t *testing.T, h2 bool) { func testClientTimeout(t *testing.T, h2 bool) {
if testing.Short() { if testing.Short() {
...@@ -999,10 +996,7 @@ func testClientTimeout(t *testing.T, h2 bool) { ...@@ -999,10 +996,7 @@ func testClientTimeout(t *testing.T, h2 bool) {
} }
func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) }
func TestClientTimeout_Headers_h2(t *testing.T) { func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) }
t.Skip("skipping in http2 mode; golang.org/issue/13540")
testClientTimeout_Headers(t, h2Mode)
}
// Client.Timeout firing before getting to the body // Client.Timeout firing before getting to the body
func testClientTimeout_Headers(t *testing.T, h2 bool) { func testClientTimeout_Headers(t *testing.T, h2 bool) {
......
...@@ -211,3 +211,13 @@ func hasToken(v, token string) bool { ...@@ -211,3 +211,13 @@ func hasToken(v, token string) bool {
func isTokenBoundary(b byte) bool { func isTokenBoundary(b byte) bool {
return b == ' ' || b == ',' || b == '\t' return b == ' ' || b == ',' || b == '\t'
} }
func cloneHeader(h Header) Header {
h2 := make(Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
...@@ -372,6 +372,8 @@ func (t *Transport) CloseIdleConnections() { ...@@ -372,6 +372,8 @@ func (t *Transport) CloseIdleConnections() {
// CancelRequest cancels an in-flight request by closing its connection. // CancelRequest cancels an in-flight request by closing its connection.
// CancelRequest should only be called after RoundTrip has returned. // CancelRequest should only be called after RoundTrip has returned.
//
// Deprecated: Use Request.Cancel instead.
func (t *Transport) CancelRequest(req *Request) { func (t *Transport) CancelRequest(req *Request) {
t.reqMu.Lock() t.reqMu.Lock()
cancel := t.reqCanceler[req] cancel := t.reqCanceler[req]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment