Commit 6bddb13b authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: populate Request.Close in ReadRequest

Fixes #8261

LGTM=r
R=r
CC=golang-codereviews
https://golang.org/cl/126620043
parent a6cd7334
......@@ -11,6 +11,7 @@ import (
"io"
"net/url"
"reflect"
"strings"
"testing"
)
......@@ -295,14 +296,39 @@ var reqTests = []reqTest{
noTrailer,
noError,
},
// Connection: close. golang.org/issue/8261
{
"GET / HTTP/1.1\r\nHost: issue8261.com\r\nConnection: close\r\n\r\n",
&Request{
Method: "GET",
URL: &url.URL{
Path: "/",
},
Header: Header{
// This wasn't removed from Go 1.0 to
// Go 1.3, so locking it in that we
// keep this:
"Connection": []string{"close"},
},
Host: "issue8261.com",
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Close: true,
RequestURI: "/",
},
noBody,
noTrailer,
noError,
},
}
func TestReadRequest(t *testing.T) {
for i := range reqTests {
tt := &reqTests[i]
var braw bytes.Buffer
braw.WriteString(tt.Raw)
req, err := ReadRequest(bufio.NewReader(&braw))
req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.Raw)))
if err != nil {
if err.Error() != tt.Error {
t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error)
......@@ -311,21 +337,22 @@ func TestReadRequest(t *testing.T) {
}
rbody := req.Body
req.Body = nil
diff(t, fmt.Sprintf("#%d Request", i), req, tt.Req)
testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw)
diff(t, testName, req, tt.Req)
var bout bytes.Buffer
if rbody != nil {
_, err := io.Copy(&bout, rbody)
if err != nil {
t.Fatalf("#%d. copying body: %v", i, err)
t.Fatalf("%s: copying body: %v", testName, err)
}
rbody.Close()
}
body := bout.String()
if body != tt.Body {
t.Errorf("#%d: Body = %q want %q", i, body, tt.Body)
t.Errorf("%s: Body = %q want %q", testName, body, tt.Body)
}
if !reflect.DeepEqual(tt.Trailer, req.Trailer) {
t.Errorf("#%d. Trailers differ.\n got: %v\nwant: %v", i, req.Trailer, tt.Trailer)
t.Errorf("%s: Trailers differ.\n got: %v\nwant: %v", testName, req.Trailer, tt.Trailer)
}
}
}
......@@ -635,6 +635,7 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
return nil, err
}
req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false)
return req, nil
}
......
......@@ -303,7 +303,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
t.StatusCode = rr.StatusCode
t.ProtoMajor = rr.ProtoMajor
t.ProtoMinor = rr.ProtoMinor
t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header)
t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true)
isResponse = true
if rr.Request != nil {
t.RequestMethod = rr.Request.Method
......@@ -502,7 +502,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
// Determine whether to hang up after sending a request and body, or
// receiving a response and body
// 'header' is the request headers
func shouldClose(major, minor int, header Header) bool {
func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
if major < 1 {
return true
} else if major == 1 && minor == 0 {
......@@ -514,7 +514,9 @@ func shouldClose(major, minor int, header Header) bool {
// TODO: Should split on commas, toss surrounding white space,
// and check each field.
if strings.ToLower(header.get("Connection")) == "close" {
header.Del("Connection")
if removeCloseHeader {
header.Del("Connection")
}
return true
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment