Commit 4d8031cf authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make the MaxBytesReader.Read error sticky

Fixes #14981

Change-Id: I39b906d119ca96815801a0fbef2dbe524a3246ff
Reviewed-on: https://go-review.googlesource.com/23009Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 20e362da
...@@ -885,68 +885,56 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { ...@@ -885,68 +885,56 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
} }
type maxBytesReader struct { type maxBytesReader struct {
w ResponseWriter w ResponseWriter
r io.ReadCloser // underlying reader r io.ReadCloser // underlying reader
n int64 // max bytes remaining n int64 // max bytes remaining
stopped bool err error // sticky error
sawEOF bool
} }
func (l *maxBytesReader) tooLarge() (n int, err error) { func (l *maxBytesReader) tooLarge() (n int, err error) {
if !l.stopped { l.err = errors.New("http: request body too large")
l.stopped = true return 0, l.err
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
}
return 0, errors.New("http: request body too large")
} }
func (l *maxBytesReader) Read(p []byte) (n int, err error) { func (l *maxBytesReader) Read(p []byte) (n int, err error) {
toRead := l.n if l.err != nil {
if l.n == 0 { return 0, l.err
if l.sawEOF { }
return l.tooLarge() if len(p) == 0 {
} return 0, nil
// The underlying io.Reader may not return (0, io.EOF)
// at EOF if the requested size is 0, so read 1 byte
// instead. The io.Reader docs are a bit ambiguous
// about the return value of Read when 0 bytes are
// requested, and {bytes,strings}.Reader gets it wrong
// too (it returns (0, nil) even at EOF).
toRead = 1
} }
if int64(len(p)) > toRead { // If they asked for a 32KB byte read but only 5 bytes are
p = p[:toRead] // remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
} }
n, err = l.r.Read(p) n, err = l.r.Read(p)
if err == io.EOF {
l.sawEOF = true if int64(n) <= l.n {
} l.n -= int64(n)
if l.n == 0 { l.err = err
// If we had zero bytes to read remaining (but hadn't seen EOF) return n, err
// and we get a byte here, that means we went over our limit.
if n > 0 {
return l.tooLarge()
}
return 0, err
} }
l.n -= int64(n)
if l.n < 0 { n = int(l.n)
l.n = 0 l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
} }
return if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = errors.New("http: request body too large")
return n, l.err
} }
func (l *maxBytesReader) Close() error { func (l *maxBytesReader) Close() error {
......
...@@ -679,6 +679,46 @@ func TestIssue10884_MaxBytesEOF(t *testing.T) { ...@@ -679,6 +679,46 @@ func TestIssue10884_MaxBytesEOF(t *testing.T) {
} }
} }
// Issue 14981: MaxBytesReader's return error wasn't sticky. It
// doesn't technically need to be, but people expected it to be.
func TestMaxBytesReaderStickyError(t *testing.T) {
isSticky := func(r io.Reader) error {
var log bytes.Buffer
buf := make([]byte, 1000)
var firstErr error
for {
n, err := r.Read(buf)
fmt.Fprintf(&log, "Read(%d) = %d, %v\n", len(buf), n, err)
if err == nil {
continue
}
if firstErr == nil {
firstErr = err
continue
}
if !reflect.DeepEqual(err, firstErr) {
return fmt.Errorf("non-sticky error. got log:\n%s", log.Bytes())
}
t.Logf("Got log: %s", log.Bytes())
return nil
}
}
tests := [...]struct {
readable int
limit int64
}{
0: {99, 100},
1: {100, 100},
2: {101, 100},
}
for i, tt := range tests {
rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit)
if err := isSticky(rc); err != nil {
t.Errorf("%d. error: %v", i, err)
}
}
}
func testMissingFile(t *testing.T, req *Request) { func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing") f, fh, err := req.FormFile("missing")
if f != nil { if f != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment