Commit 9d29be46 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: clean up Transport.RoundTrip error handling

If I put a 10 millisecond sleep at testHookWaitResLoop, before the big
select in (*persistConn).roundTrip, two flakes immediately started
happening, TestTransportBodyReadError (#19231) and
TestTransportPersistConnReadLoopEOF.

The problem was that there are many ways for a RoundTrip call to fail
(errors reading from Request.Body while writing the response, errors
writing the response, errors reading the response due to server
closes, errors due to servers sending malformed responses,
cancelations, timeouts, etc.), and many of those failures then tear
down the TCP connection, causing more failures, since there are always
at least three goroutines involved (reading, writing, RoundTripping).

Because the errors were communicated over buffered channels to a giant
select, the error returned to the caller was a function of which
random select case was called, which was why a 10ms delay before the
select brought out so many bugs. (several fixed in my previous CLs the past
few days).

Instead, track the error explicitly in the transportRequest, guarded
by a mutex.

In addition, this CL now:

* differentiates between the two ways writing a request can fail: the
  io.Copy reading from the Request.Body or the io.Copy writing to the
  network. A new io.Reader type notes read errors from the
  Request.Body. The read-from-body vs write-to-network errors are now
  prioritized differently.

* unifies the two mapRoundTripErrorFromXXX methods into one
  mapRoundTripError method since their logic is now the same.

* adds a (*Request).WithT(*testing.T) method in export_test.go, usable
  by tests, to call t.Logf at points during RoundTrip. This is disabled
  behind a constant except when debugging.

* documents and deflakes TestClientRedirectContext

I've tested this CL with high -count values, with/without -race,
with/without delays before the select, etc. So far it seems robust.

Fixes #19231 (TestTransportBodyReadError flake)
Updates #14203 (source of errors unclear; they're now tracked more)
Updates #15935 (document Transport errors more; at least understood more now)

Change-Id: I3cccc3607f369724b5344763e35ad2b7ea415738
Reviewed-on: https://go-review.googlesource.com/37495
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarJosh Bleecher Snyder <josharian@gmail.com>
parent 87649d32
......@@ -304,6 +304,7 @@ func TestClientRedirects(t *testing.T) {
}
}
// Tests that Client redirects' contexts are derived from the original request's context.
func TestClientRedirectContext(t *testing.T) {
setParallel(t)
defer afterTest(t)
......@@ -320,10 +321,12 @@ func TestClientRedirectContext(t *testing.T) {
Transport: tr,
CheckRedirect: func(req *Request, via []*Request) error {
cancel()
if len(via) > 2 {
return errors.New("too many redirects")
}
select {
case <-req.Context().Done():
return nil
case <-time.After(5 * time.Second):
return errors.New("redirected request's context never expired after root request canceled")
}
},
}
req, _ := NewRequest("GET", ts.URL, nil)
......@@ -1818,6 +1821,7 @@ func TestTransportBodyReadError(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req = req.WithT(t)
_, err = tr.RoundTrip(req)
if err != someErr {
t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr)
......
......@@ -8,9 +8,11 @@
package http
import (
"context"
"net"
"sort"
"sync"
"testing"
"time"
)
......@@ -199,3 +201,7 @@ func (s *Server) ExportAllConnsIdle() bool {
}
return true
}
func (r *Request) WithT(t *testing.T) *Request {
return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
}
......@@ -621,6 +621,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
// Write body and trailer
err = tw.WriteBody(w)
if err != nil {
if tw.bodyReadError == err {
err = requestBodyReadError{err}
}
return err
}
......@@ -630,6 +633,11 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
return nil
}
// requestBodyReadError wraps an error from (*Request).write to indicate
// that the error came from a Read call on the Request.Body.
// This error type should not escape the net/http package to users.
type requestBodyReadError struct{ error }
func idnaASCII(v string) (string, error) {
if isASCII(v) {
return v, nil
......
......@@ -51,6 +51,19 @@ func (br *byteReader) Read(p []byte) (n int, err error) {
return 1, io.EOF
}
// transferBodyReader is an io.Reader that reads from tw.Body
// and records any non-EOF error in tw.bodyReadError.
// It is exactly 1 pointer wide to avoid allocations into interfaces.
type transferBodyReader struct{ tw *transferWriter }
func (br transferBodyReader) Read(p []byte) (n int, err error) {
n, err = br.tw.Body.Read(p)
if err != nil && err != io.EOF {
br.tw.bodyReadError = err
}
return
}
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
......@@ -64,6 +77,7 @@ type transferWriter struct {
TransferEncoding []string
Trailer Header
IsResponse bool
bodyReadError error // any non-EOF error from reading Body
FlushHeaders bool // flush headers to network before body
ByteReadCh chan readResult // non-nil if probeRequestBody called
......@@ -304,24 +318,25 @@ func (t *transferWriter) WriteBody(w io.Writer) error {
// Write body
if t.Body != nil {
var body = transferBodyReader{t}
if chunked(t.TransferEncoding) {
if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
w = &internal.FlushAfterChunkWriter{Writer: bw}
}
cw := internal.NewChunkedWriter(w)
_, err = io.Copy(cw, t.Body)
_, err = io.Copy(cw, body)
if err == nil {
err = cw.Close()
}
} else if t.ContentLength == -1 {
ncopy, err = io.Copy(w, t.Body)
ncopy, err = io.Copy(w, body)
} else {
ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength))
if err != nil {
return err
}
var nextra int64
nextra, err = io.Copy(ioutil.Discard, t.Body)
nextra, err = io.Copy(ioutil.Discard, body)
ncopy += nextra
}
if err != nil {
......
......@@ -303,11 +303,15 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
}
// transportRequest is a wrapper around a *Request that adds
// optional extra headers to write.
// optional extra headers to write and stores any error to return
// from roundTrip.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider
}
func (tr *transportRequest) extraHeaders() Header {
......@@ -317,6 +321,14 @@ func (tr *transportRequest) extraHeaders() Header {
return tr.extra
}
func (tr *transportRequest) setError(err error) {
tr.mu.Lock()
if tr.err == nil {
tr.err = err
}
tr.mu.Unlock()
}
// RoundTrip implements the RoundTripper interface.
//
// For higher-level HTTP client support (such as handling of cookies
......@@ -1420,61 +1432,51 @@ func (pc *persistConn) closeConnIfStillIdle() {
pc.close(errIdleConnTimeout)
}
// mapRoundTripErrorFromReadLoop maps the provided readLoop error into
// the error value that should be returned from persistConn.roundTrip.
// mapRoundTripError returns the appropriate error value for
// persistConn.roundTrip.
//
// The provided err is the first error that (*persistConn).roundTrip
// happened to receive from its select statement.
//
// The startBytesWritten value should be the value of pc.nwrite before the roundTrip
// started writing the request.
func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) {
func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error {
if err == nil {
return nil
}
if err := pc.canceled(); err != nil {
return err
}
if err == errServerClosedIdle {
return err
}
if _, ok := err.(transportReadFromServerError); ok {
return err
}
if pc.isBroken() {
<-pc.writeLoopDone
if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
return nothingWrittenError{err}
}
// If the request was canceled, that's better than network
// failures that were likely the result of tearing down the
// connection.
if cerr := pc.canceled(); cerr != nil {
return cerr
}
return err
}
// mapRoundTripErrorAfterClosed returns the error value to be propagated
// up to Transport.RoundTrip method when persistConn.roundTrip sees
// its pc.closech channel close, indicating the persistConn is dead.
// (after closech is closed, pc.closed is valid).
func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error {
if err := pc.canceled(); err != nil {
return err
// See if an error was set explicitly.
req.mu.Lock()
reqErr := req.err
req.mu.Unlock()
if reqErr != nil {
return reqErr
}
err := pc.closed
if err == errServerClosedIdle {
// Don't decorate
return err
}
if _, ok := err.(transportReadFromServerError); ok {
// Don't decorate
return err
}
// Wait for the writeLoop goroutine to terminated, and then
// see if we actually managed to write anything. If not, we
// can retry the request.
if pc.isBroken() {
<-pc.writeLoopDone
if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
return nothingWrittenError{err}
}
return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err)
}
return err
}
func (pc *persistConn) readLoop() {
......@@ -1746,6 +1748,17 @@ func (pc *persistConn) writeLoop() {
case wr := <-pc.writech:
startBytesWritten := pc.nwrite
err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
if bre, ok := err.(requestBodyReadError); ok {
err = bre.error
// Errors reading from the user's
// Request.Body are high priority.
// Set it here before sending on the
// channels below or calling
// pc.close() which tears town
// connections and causes other
// errors.
wr.req.setError(err)
}
if err == nil {
err = pc.bw.Flush()
}
......@@ -1913,6 +1926,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
gone := make(chan struct{})
defer close(gone)
defer func() {
if err != nil {
pc.t.setReqCanceler(req.Request, nil)
}
}()
const debugRoundTrip = false
// Write the request concurrently with waiting for a response,
// in case the server decides to reply before reading our full
// request body.
......@@ -1929,38 +1950,50 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
callerGone: gone,
}
var re responseAndError
var respHeaderTimer <-chan time.Time
cancelChan := req.Request.Cancel
ctxDoneChan := req.Context().Done()
WaitResponse:
for {
testHookWaitResLoop()
select {
case err := <-writeErrCh:
if err != nil {
if cerr := pc.canceled(); cerr != nil {
err = cerr
if debugRoundTrip {
req.logf("writeErrCh resv: %T/%#v", err, err)
}
re = responseAndError{err: err}
if err != nil {
pc.close(fmt.Errorf("write error: %v", err))
break WaitResponse
return nil, pc.mapRoundTripError(req, startBytesWritten, err)
}
if d := pc.t.ResponseHeaderTimeout; d > 0 {
if debugRoundTrip {
req.logf("starting timer for %v", d)
}
timer := time.NewTimer(d)
defer timer.Stop() // prevent leaks
respHeaderTimer = timer.C
}
case <-pc.closech:
re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)}
break WaitResponse
if debugRoundTrip {
req.logf("closech recv: %T %#v", pc.closed, pc.closed)
}
return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
case <-respHeaderTimer:
if debugRoundTrip {
req.logf("timeout waiting for response headers.")
}
pc.close(errTimeout)
re = responseAndError{err: errTimeout}
break WaitResponse
case re = <-resc:
re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err)
break WaitResponse
return nil, errTimeout
case re := <-resc:
if (re.res == nil) == (re.err == nil) {
panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
}
if debugRoundTrip {
req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
}
if re.err != nil {
return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
}
return re.res, nil
case <-cancelChan:
pc.t.CancelRequest(req.Request)
cancelChan = nil
......@@ -1970,14 +2003,16 @@ WaitResponse:
ctxDoneChan = nil
}
}
}
if re.err != nil {
pc.t.setReqCanceler(req.Request, nil)
}
if (re.res == nil) == (re.err == nil) {
panic("internal error: exactly one of res or err should be set")
// tLogKey is a context WithValue key for test debugging contexts containing
// a t.Logf func. See export_test.go's Request.WithT method.
type tLogKey struct{}
func (r *transportRequest) logf(format string, args ...interface{}) {
if logf, ok := r.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok {
logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...)
}
return re.res, re.err
}
// markReused marks this connection as having been successfully used for a
......
......@@ -30,6 +30,7 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
tr := new(Transport)
req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
req = req.WithT(t)
treq := &transportRequest{Request: req}
cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
pc, err := tr.getConn(treq, cm)
......@@ -47,13 +48,13 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
_, err = pc.roundTrip(treq)
if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
}
<-pc.closech
err = pc.closed
if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
}
}
......
......@@ -1625,7 +1625,9 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
{path: "/fast", want: 200},
}
for i, tt := range tests {
res, err := c.Get(ts.URL + tt.path)
req, _ := NewRequest("GET", ts.URL+tt.path, nil)
req = req.WithT(t)
res, err := c.Do(req)
select {
case <-inHandler:
case <-time.After(5 * time.Second):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment