Commit 0c2d3f73 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick Committed by Russ Cox

net/http: on Transport body write error, wait briefly for a response

From https://github.com/golang/go/issues/11745#issuecomment-123555313 :

The http.RoundTripper interface says you get either a *Response, or an
error.

But in the case of a client writing a large request and the server
replying prematurely (e.g. 403 Forbidden) and closing the connection
without reading the request body, what does the client want? The 403
response, or the error that the body couldn't be copied?

This CL implements the aforementioned comment's option c), making the
Transport give an N millisecond advantage to responses over body write
errors.

Updates #11745

Change-Id: I4485a782505d54de6189f6856a7a1f33ce4d5e5e
Reviewed-on: https://go-review.googlesource.com/12590Reviewed-by: default avatarRuss Cox <rsc@golang.org>
parent a60c5366
......@@ -1164,6 +1164,19 @@ WaitResponse:
for {
select {
case err := <-writeErrCh:
if isSyscallWriteError(err) {
// Issue 11745. If we failed to write the request
// body, it's possible the server just heard enough
// and already wrote to us. Prioritize the server's
// response over returning a body write error.
select {
case re = <-resc:
pc.close()
break WaitResponse
case <-time.After(50 * time.Millisecond):
// Fall through.
}
}
if err != nil {
re = responseAndError{nil, err}
pc.close()
......@@ -1366,3 +1379,16 @@ type fakeLocker struct{}
func (fakeLocker) Lock() {}
func (fakeLocker) Unlock() {}
func isSyscallWriteError(err error) bool {
switch e := err.(type) {
case *url.Error:
return isSyscallWriteError(e.Err)
case *net.OpError:
return e.Op == "write" && isSyscallWriteError(e.Err)
case *os.SyscallError:
return e.Syscall == "write"
default:
return false
}
}
......@@ -18,7 +18,6 @@ import (
"io/ioutil"
"log"
"net"
"net/http"
. "net/http"
"net/http/httptest"
"net/url"
......@@ -320,7 +319,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
addrSeen[r.RemoteAddr]++
if r.URL.Path == "/chunked/" {
w.WriteHeader(200)
w.(http.Flusher).Flush()
w.(Flusher).Flush()
} else {
w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
w.WriteHeader(200)
......@@ -335,7 +334,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
wantLen := []int{len(msg), -1}[pi]
addrSeen = make(map[string]int)
for i := 0; i < 3; i++ {
res, err := http.Get(ts.URL + path)
res, err := Get(ts.URL + path)
if err != nil {
t.Errorf("Get %s: %v", path, err)
continue
......@@ -1976,7 +1975,7 @@ func TestIdleConnChannelLeak(t *testing.T) {
// then closes it.
func TestTransportClosesRequestBody(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(ioutil.Discard, r.Body)
}))
defer ts.Close()
......@@ -2346,13 +2345,13 @@ func TestTransportDialTLS(t *testing.T) {
// Test for issue 8755
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
func TestRoundTripReturnsProxyError(t *testing.T) {
badProxy := func(*http.Request) (*url.URL, error) {
badProxy := func(*Request) (*url.URL, error) {
return nil, errors.New("errorMessage")
}
tr := &Transport{Proxy: badProxy}
req, _ := http.NewRequest("GET", "http://example.com", nil)
req, _ := NewRequest("GET", "http://example.com", nil)
_, err := tr.RoundTrip(req)
......@@ -2644,7 +2643,57 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
}
}
func wantBody(res *http.Response, err error, want string) error {
// Issue 11745.
func TestTransportPrefersResponseOverWriteError(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
const contentLengthLimit = 1024 * 1024 // 1MB
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.ContentLength >= contentLengthLimit {
w.WriteHeader(StatusBadRequest)
r.Body.Close()
return
}
w.WriteHeader(StatusOK)
}))
defer ts.Close()
fail := 0
count := 100
bigBody := strings.Repeat("a", contentLengthLimit*2)
for i := 0; i < count; i++ {
req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
if err != nil {
t.Fatal(err)
}
tr := new(Transport)
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
resp, err := client.Do(req)
if err != nil {
fail++
t.Logf("%d = %#v", i, err)
if ue, ok := err.(*url.Error); ok {
t.Logf("urlErr = %#v", ue.Err)
if ne, ok := ue.Err.(*net.OpError); ok {
t.Logf("netOpError = %#v", ne.Err)
}
}
} else {
resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("Expected status code 400, got %v", resp.Status)
}
}
}
if fail > 0 {
t.Errorf("Failed %v out of %v\n", fail, count)
}
}
func wantBody(res *Response, err error, want string) error {
if err != nil {
return err
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment