Commit 839d47ad authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Transport.ResponseHeaderTimeout

Update #3362

R=golang-dev, adg, rsc
CC=golang-dev
https://golang.org/cl/7369055
parent 37cb6f80
...@@ -73,6 +73,12 @@ type Transport struct { ...@@ -73,6 +73,12 @@ type Transport struct {
// (keep-alive) to keep per-host. If zero, // (keep-alive) to keep per-host. If zero,
// DefaultMaxIdleConnsPerHost is used. // DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int MaxIdleConnsPerHost int
// ResponseHeaderTimeout, if non-zero, specifies the amount of
// time to wait for a server's response headers after fully
// writing the request (including its body, if any). This
// time does not include the time to read the response body.
ResponseHeaderTimeout time.Duration
} }
// ProxyFromEnvironment returns the URL of the proxy to use for a // ProxyFromEnvironment returns the URL of the proxy to use for a
...@@ -743,6 +749,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err ...@@ -743,6 +749,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
var re responseAndError var re responseAndError
var pconnDeadCh = pc.closech var pconnDeadCh = pc.closech
var failTicker <-chan time.Time var failTicker <-chan time.Time
var respHeaderTimer <-chan time.Time
WaitResponse: WaitResponse:
for { for {
select { select {
...@@ -752,6 +759,9 @@ WaitResponse: ...@@ -752,6 +759,9 @@ WaitResponse:
pc.close() pc.close()
break WaitResponse break WaitResponse
} }
if d := pc.t.ResponseHeaderTimeout; d > 0 {
respHeaderTimer = time.After(d)
}
case <-pconnDeadCh: case <-pconnDeadCh:
// The persist connection is dead. This shouldn't // The persist connection is dead. This shouldn't
// usually happen (only with Connection: close responses // usually happen (only with Connection: close responses
...@@ -768,7 +778,11 @@ WaitResponse: ...@@ -768,7 +778,11 @@ WaitResponse:
pconnDeadCh = nil // avoid spinning pconnDeadCh = nil // avoid spinning
failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
case <-failTicker: case <-failTicker:
re = responseAndError{nil, errors.New("net/http: transport closed before response was received")} re = responseAndError{err: errors.New("net/http: transport closed before response was received")}
break WaitResponse
case <-respHeaderTimer:
pc.close()
re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")}
break WaitResponse break WaitResponse
case re = <-resc: case re = <-resc:
break WaitResponse break WaitResponse
......
...@@ -1113,6 +1113,54 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ...@@ -1113,6 +1113,54 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
ts.Close() ts.Close()
} }
func TestTransportResponseHeaderTimeout(t *testing.T) {
defer checkLeakedTransports(t)
if testing.Short() {
t.Skip("skipping timeout test in -short mode")
}
const debug = false
mux := NewServeMux()
mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
time.Sleep(2 * time.Second)
})
ts := httptest.NewServer(mux)
defer ts.Close()
tr := &Transport{
ResponseHeaderTimeout: 500 * time.Millisecond,
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
tests := []struct {
path string
want int
wantErr string
}{
{path: "/fast", want: 200},
{path: "/slow", wantErr: "timeout awaiting response headers"},
{path: "/fast", want: 200},
}
for i, tt := range tests {
res, err := c.Get(ts.URL + tt.path)
if err != nil {
if strings.Contains(err.Error(), tt.wantErr) {
continue
}
t.Errorf("%d. unexpected error: %v", i, err)
continue
}
if tt.wantErr != "" {
t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
continue
}
if res.StatusCode != tt.want {
t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
}
}
}
type fooProto struct{} type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) { func (fooProto) RoundTrip(req *Request) (*Response, error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment