Commit 36feb1a0 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: limit Transport's reading of response header bytes from servers

The default is 10MB, like http2, but can be configured with a new
field http.Transport.MaxResponseHeaderBytes.

Fixes #9115

Change-Id: I01808ac631ce4794ef2b0dfc391ed51cf951ceb1
Reviewed-on: https://go-review.googlesource.com/21329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarEmmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
parent 7a4211bc
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
// maxInt64 is the effective "infinite" value for the Server and
// Transport's byte-limiting readers.
const maxInt64 = 1<<63 - 1
// TODO(bradfitz): move common stuff here. The other files have accumulated
// generic http stuff in random places.
...@@ -497,7 +497,7 @@ type connReader struct { ...@@ -497,7 +497,7 @@ type connReader struct {
} }
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
func (cr *connReader) Read(p []byte) (n int, err error) { func (cr *connReader) Read(p []byte) (n int, err error) {
......
...@@ -146,6 +146,13 @@ type Transport struct { ...@@ -146,6 +146,13 @@ type Transport struct {
// If TLSNextProto is nil, HTTP/2 support is enabled automatically. // If TLSNextProto is nil, HTTP/2 support is enabled automatically.
TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
//
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
// nextProtoOnce guards initialization of TLSNextProto and // nextProtoOnce guards initialization of TLSNextProto and
// h2transport (via onceSetNextProtoDefaults) // h2transport (via onceSetNextProtoDefaults)
nextProtoOnce sync.Once nextProtoOnce sync.Once
...@@ -188,8 +195,23 @@ func (t *Transport) onceSetNextProtoDefaults() { ...@@ -188,8 +195,23 @@ func (t *Transport) onceSetNextProtoDefaults() {
t2, err := http2configureTransport(t) t2, err := http2configureTransport(t)
if err != nil { if err != nil {
log.Printf("Error enabling Transport HTTP/2 support: %v", err) log.Printf("Error enabling Transport HTTP/2 support: %v", err)
} else { return
t.h2transport = t2 }
t.h2transport = t2
// Auto-configure the http2.Transport's MaxHeaderListSize from
// the http.Transport's MaxResponseHeaderBytes. They don't
// exactly mean the same thing, but they're close.
//
// TODO: also add this to x/net/http2.Configure Transport, behind
// a +build go1.7 build tag:
if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 {
const h2max = 1<<32 - 1
if limit1 >= h2max {
t2.MaxHeaderListSize = h2max
} else {
t2.MaxHeaderListSize = uint32(limit1)
}
} }
} }
...@@ -351,7 +373,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { ...@@ -351,7 +373,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
// resent on a new connection. The non-nil input error is the error from // resent on a new connection. The non-nil input error is the error from
// roundTrip, which might be wrapped in a beforeRespHeaderError error. // roundTrip, which might be wrapped in a beforeRespHeaderError error.
// //
// The return value is err or the unwrapped error inside a // The return value is either nil to retry the request, the provided
// err unmodified, or the unwrapped error inside a
// beforeRespHeaderError. // beforeRespHeaderError.
func checkTransportResend(err error, req *Request, pconn *persistConn) error { func checkTransportResend(err error, req *Request, pconn *persistConn) error {
brhErr, ok := err.(beforeRespHeaderError) brhErr, ok := err.(beforeRespHeaderError)
...@@ -864,7 +887,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { ...@@ -864,7 +887,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
} }
} }
pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF}) pconn.br = bufio.NewReader(pconn)
pconn.bw = bufio.NewWriter(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop() go pconn.readLoop()
go pconn.writeLoop() go pconn.writeLoop()
...@@ -998,17 +1021,18 @@ type persistConn struct { ...@@ -998,17 +1021,18 @@ type persistConn struct {
// If it's non-nil, the rest of the fields are unused. // If it's non-nil, the rest of the fields are unused.
alt RoundTripper alt RoundTripper
t *Transport t *Transport
cacheKey connectMethodKey cacheKey connectMethodKey
conn net.Conn conn net.Conn
tlsState *tls.ConnectionState tlsState *tls.ConnectionState
br *bufio.Reader // from conn br *bufio.Reader // from conn
sawEOF bool // whether we've seen EOF from conn; owned by readLoop bw *bufio.Writer // to conn
bw *bufio.Writer // to conn reqch chan requestAndChan // written by roundTrip; read by readLoop
reqch chan requestAndChan // written by roundTrip; read by readLoop writech chan writeRequest // written by roundTrip; read by writeLoop
writech chan writeRequest // written by roundTrip; read by writeLoop closech chan struct{} // closed when conn closed
closech chan struct{} // closed when conn closed isProxy bool
isProxy bool sawEOF bool // whether we've seen EOF from conn; owned by readLoop
readLimit int64 // bytes allowed to be read; owned by readLoop
// writeErrCh passes the request write error (usually nil) // writeErrCh passes the request write error (usually nil)
// from the writeLoop goroutine to the readLoop which passes // from the writeLoop goroutine to the readLoop which passes
// it off to the res.Body reader, which then uses it to decide // it off to the res.Body reader, which then uses it to decide
...@@ -1027,6 +1051,28 @@ type persistConn struct { ...@@ -1027,6 +1051,28 @@ type persistConn struct {
mutateHeaderFunc func(Header) mutateHeaderFunc func(Header)
} }
func (pc *persistConn) maxHeaderResponseSize() int64 {
if v := pc.t.MaxResponseHeaderBytes; v != 0 {
return v
}
return 10 << 20 // conservative default; same as http2
}
func (pc *persistConn) Read(p []byte) (n int, err error) {
if pc.readLimit <= 0 {
return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize())
}
if int64(len(p)) > pc.readLimit {
p = p[:pc.readLimit]
}
n, err = pc.conn.Read(p)
if err == io.EOF {
pc.sawEOF = true
}
pc.readLimit -= int64(n)
return
}
// isBroken reports whether this connection is in a known broken state. // isBroken reports whether this connection is in a known broken state.
func (pc *persistConn) isBroken() bool { func (pc *persistConn) isBroken() bool {
pc.mu.Lock() pc.mu.Lock()
...@@ -1082,6 +1128,7 @@ func (pc *persistConn) readLoop() { ...@@ -1082,6 +1128,7 @@ func (pc *persistConn) readLoop() {
alive := true alive := true
for alive { for alive {
pc.readLimit = pc.maxHeaderResponseSize()
_, err := pc.br.Peek(1) _, err := pc.br.Peek(1)
if err != nil { if err != nil {
err = beforeRespHeaderError{err} err = beforeRespHeaderError{err}
...@@ -1103,6 +1150,9 @@ func (pc *persistConn) readLoop() { ...@@ -1103,6 +1150,9 @@ func (pc *persistConn) readLoop() {
} }
if err != nil { if err != nil {
if pc.readLimit <= 0 {
err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize())
}
// If we won't be able to retry this request later (from the // If we won't be able to retry this request later (from the
// roundTrip goroutine), mark it as done now. // roundTrip goroutine), mark it as done now.
// BEFORE the send on rc.ch, as the client might re-use the // BEFORE the send on rc.ch, as the client might re-use the
...@@ -1120,6 +1170,7 @@ func (pc *persistConn) readLoop() { ...@@ -1120,6 +1170,7 @@ func (pc *persistConn) readLoop() {
} }
return return
} }
pc.readLimit = maxInt64 // effictively no limit for response bodies
pc.mu.Lock() pc.mu.Lock()
pc.numExpectedResponses-- pc.numExpectedResponses--
...@@ -1251,6 +1302,7 @@ func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err erro ...@@ -1251,6 +1302,7 @@ func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err erro
} }
} }
if resp.StatusCode == 100 { if resp.StatusCode == 100 {
pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
resp, err = ReadResponse(pc.br, rc.req) resp, err = ReadResponse(pc.br, rc.req)
if err != nil { if err != nil {
return return
...@@ -1706,19 +1758,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true } ...@@ -1706,19 +1758,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true } func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
type noteEOFReader struct {
r io.Reader
sawEOF *bool
}
func (nr noteEOFReader) Read(p []byte) (n int, err error) {
n, err = nr.r.Read(p)
if err == io.EOF {
*nr.sawEOF = true
}
return
}
// fakeLocker is a sync.Locker which does nothing. It's used to guard // fakeLocker is a sync.Locker which does nothing. It's used to guard
// test-only fields when not under test, to avoid runtime atomic // test-only fields when not under test, to avoid runtime atomic
// overhead. // overhead.
......
...@@ -3090,6 +3090,42 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { ...@@ -3090,6 +3090,42 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
} }
} }
func TestTransportResponseHeaderLength(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/long" {
w.Header().Set("Long", strings.Repeat("a", 1<<20))
}
}))
defer ts.Close()
tr := &Transport{
MaxResponseHeaderBytes: 512 << 10,
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
if res, err := c.Get(ts.URL); err != nil {
t.Fatal(err)
} else {
res.Body.Close()
}
res, err := c.Get(ts.URL + "/long")
if err == nil {
defer res.Body.Close()
var n int64
for k, vv := range res.Header {
for _, v := range vv {
n += int64(len(k)) + int64(len(v))
}
}
t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
}
if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
t.Errorf("got error: %v; want %q", err, want)
}
}
var errFakeRoundTrip = errors.New("fake roundtrip") var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func() type funcRoundTripper func()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment