Commit 59699aa1 authored by Jack's avatar Jack Committed by Brad Fitzpatrick

net/http/httptest: guarantee ResponseRecorder.Result returns a non-nil body

The doc for ResponseRecorder.Result guarantees that the body of the returned
http.Response will be non-nil, but this only holds true if the caller's body is
non-nil. With this change, if the caller's body is nil then the returned
response's body will be an empty io.ReadCloser.

Fixes #26442

Change-Id: I3b2fe4a2541caf9997dbb8978bbaf1f58cd1f471
GitHub-Last-Rev: d802967d89e89c50077fb2d0d455163dcea0eb43
GitHub-Pull-Request: golang/go#26453
Reviewed-on: https://go-review.googlesource.com/124875Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
parent f4c787b6
...@@ -184,6 +184,8 @@ func (rw *ResponseRecorder) Result() *http.Response { ...@@ -184,6 +184,8 @@ func (rw *ResponseRecorder) Result() *http.Response {
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
if rw.Body != nil { if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
} else {
res.Body = http.NoBody
} }
res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
......
...@@ -7,6 +7,7 @@ package httptest ...@@ -7,6 +7,7 @@ package httptest
import ( import (
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"testing" "testing"
) )
...@@ -39,6 +40,19 @@ func TestRecorder(t *testing.T) { ...@@ -39,6 +40,19 @@ func TestRecorder(t *testing.T) {
return nil return nil
} }
} }
hasResultContents := func(want string) checkFunc {
return func(rec *ResponseRecorder) error {
contentBytes, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
return err
}
contents := string(contentBytes)
if contents != want {
return fmt.Errorf("Result().Body = %s; want %s", contents, want)
}
return nil
}
}
hasContents := func(want string) checkFunc { hasContents := func(want string) checkFunc {
return func(rec *ResponseRecorder) error { return func(rec *ResponseRecorder) error {
if rec.Body.String() != want { if rec.Body.String() != want {
...@@ -273,6 +287,15 @@ func TestRecorder(t *testing.T) { ...@@ -273,6 +287,15 @@ func TestRecorder(t *testing.T) {
}, },
check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
}, },
{
"nil ResponseRecorder.Body", // Issue 26642
func(w http.ResponseWriter, r *http.Request) {
w.(*ResponseRecorder).Body = nil
io.WriteString(w, "hi")
},
check(hasResultContents("")), // check we don't crash reading the body
},
} { } {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest("GET", "http://foo.com/", nil) r, _ := http.NewRequest("GET", "http://foo.com/", nil)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment