Commit c69e6869 authored by Caio Marcelo de Oliveira Filho's avatar Caio Marcelo de Oliveira Filho Committed by Brad Fitzpatrick

net/http/httptest: record trailing headers in ResponseRecorder

Trailers() returns the headers that were set by the handler after the
headers were written "to the wire" (in this case HeaderMap) and that
were also specified in a proper header called "Trailer".

Neither HeaderMap or trailerMap (used for Trailers()) are manipulated by
the handler code, instead a third stagingMap is given to the
handler. This avoid a reference kept by handler to affect the recorded
results.

If a handler just modify the header but doesn't call any Write or Flush
method from ResponseWriter (or Flusher) interface, HeaderMap will not be
updated. In this case, calling Flush in the recorder is enough to get
the HeaderMap filled.

Fixes #14531.
Fixes #8857.

Change-Id: I42842341ec3e95c7b87d7e6f178c65cd03d63cc3
Reviewed-on: https://go-review.googlesource.com/20047Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 4c8589c3
...@@ -18,6 +18,9 @@ type ResponseRecorder struct { ...@@ -18,6 +18,9 @@ type ResponseRecorder struct {
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool Flushed bool
stagingMap http.Header // map that handlers manipulate to set headers
trailerMap http.Header // lazily filled when Trailers() is called
wroteHeader bool wroteHeader bool
} }
...@@ -36,10 +39,10 @@ const DefaultRemoteAddr = "1.2.3.4" ...@@ -36,10 +39,10 @@ const DefaultRemoteAddr = "1.2.3.4"
// Header returns the response headers. // Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header { func (rw *ResponseRecorder) Header() http.Header {
m := rw.HeaderMap m := rw.stagingMap
if m == nil { if m == nil {
m = make(http.Header) m = make(http.Header)
rw.HeaderMap = m rw.stagingMap = m
} }
return m return m
} }
...@@ -59,16 +62,15 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { ...@@ -59,16 +62,15 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
str = str[:512] str = str[:512]
} }
_, hasType := rw.HeaderMap["Content-Type"] m := rw.Header()
hasTE := rw.HeaderMap.Get("Transfer-Encoding") != ""
_, hasType := m["Content-Type"]
hasTE := m.Get("Transfer-Encoding") != ""
if !hasType && !hasTE { if !hasType && !hasTE {
if b == nil { if b == nil {
b = []byte(str) b = []byte(str)
} }
if rw.HeaderMap == nil { m.Set("Content-Type", http.DetectContentType(b))
rw.HeaderMap = make(http.Header)
}
rw.HeaderMap.Set("Content-Type", http.DetectContentType(b))
} }
rw.WriteHeader(200) rw.WriteHeader(200)
...@@ -92,11 +94,21 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) { ...@@ -92,11 +94,21 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) {
return len(str), nil return len(str), nil
} }
// WriteHeader sets rw.Code. // WriteHeader sets rw.Code. After it is called, changing rw.Header
// will not affect rw.HeaderMap.
func (rw *ResponseRecorder) WriteHeader(code int) { func (rw *ResponseRecorder) WriteHeader(code int) {
if !rw.wroteHeader { if rw.wroteHeader {
return
}
rw.Code = code rw.Code = code
rw.wroteHeader = true rw.wroteHeader = true
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
for k, vv := range rw.stagingMap {
vv2 := make([]string, len(vv))
copy(vv2, vv)
rw.HeaderMap[k] = vv2
} }
} }
...@@ -107,3 +119,33 @@ func (rw *ResponseRecorder) Flush() { ...@@ -107,3 +119,33 @@ func (rw *ResponseRecorder) Flush() {
} }
rw.Flushed = true rw.Flushed = true
} }
// Trailers returns any trailers set by the handler. It must be called
// after the handler finished running.
func (rw *ResponseRecorder) Trailers() http.Header {
if rw.trailerMap != nil {
return rw.trailerMap
}
trailers, ok := rw.HeaderMap["Trailer"]
if !ok {
rw.trailerMap = make(http.Header)
return rw.trailerMap
}
rw.trailerMap = make(http.Header, len(trailers))
for _, k := range trailers {
switch k {
case "Transfer-Encoding", "Content-Length", "Trailer":
// Ignore since forbidden by RFC 2616 14.40.
continue
}
k = http.CanonicalHeaderKey(k)
vv, ok := rw.stagingMap[k]
if !ok {
continue
}
vv2 := make([]string, len(vv))
copy(vv2, vv)
rw.trailerMap[k] = vv2
}
return rw.trailerMap
}
...@@ -47,6 +47,37 @@ func TestRecorder(t *testing.T) { ...@@ -47,6 +47,37 @@ func TestRecorder(t *testing.T) {
return nil return nil
} }
} }
hasNotHeaders := func(keys ...string) checkFunc {
return func(rec *ResponseRecorder) error {
for _, k := range keys {
_, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
if ok {
return fmt.Errorf("unexpected header %s", k)
}
}
return nil
}
}
hasTrailer := func(key, want string) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.Trailers().Get(key); got != want {
return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
}
return nil
}
}
hasNotTrailers := func(keys ...string) checkFunc {
return func(rec *ResponseRecorder) error {
trailers := rec.Trailers()
for _, k := range keys {
_, ok := trailers[http.CanonicalHeaderKey(k)]
if ok {
return fmt.Errorf("unexpected trailer %s", k)
}
}
return nil
}
}
tests := []struct { tests := []struct {
name string name string
...@@ -130,6 +161,39 @@ func TestRecorder(t *testing.T) { ...@@ -130,6 +161,39 @@ func TestRecorder(t *testing.T) {
}, },
check(hasHeader("Content-Type", "text/html; charset=utf-8")), check(hasHeader("Content-Type", "text/html; charset=utf-8")),
}, },
{
"Header is not changed after write",
func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("Key", "correct")
w.WriteHeader(200)
hdr.Set("Key", "incorrect")
},
check(hasHeader("Key", "correct")),
},
{
"Trailer headers are correctly recorded",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Non-Trailer", "correct")
w.Header().Set("Trailer", "Trailer-A")
w.Header().Add("Trailer", "Trailer-B")
w.Header().Add("Trailer", "Trailer-C")
io.WriteString(w, "<html>")
w.Header().Set("Non-Trailer", "incorrect")
w.Header().Set("Trailer-A", "valuea")
w.Header().Set("Trailer-C", "valuec")
w.Header().Set("Trailer-NotDeclared", "should be omitted")
},
check(
hasStatus(200),
hasHeader("Content-Type", "text/html; charset=utf-8"),
hasHeader("Non-Trailer", "correct"),
hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
hasTrailer("Trailer-A", "valuea"),
hasTrailer("Trailer-C", "valuec"),
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
),
},
} }
r, _ := http.NewRequest("GET", "http://foo.com/", nil) r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests { for _, tt := range tests {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment