Commit ae080c1a authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: handle "close" amongst multiple Connection tokens

Fixes #8840

Change-Id: I194d0248734c15336f91a6bcf57ffcc9c0a3a435
Reviewed-on: https://go-review.googlesource.com/9434Reviewed-by: default avatarDavid Crawshaw <crawshaw@golang.org>
parent 0774f6db
...@@ -4,6 +4,11 @@ ...@@ -4,6 +4,11 @@
package http package http
import (
"strings"
"unicode/utf8"
)
// This file deals with lexical matters of HTTP // This file deals with lexical matters of HTTP
var isTokenTable = [127]bool{ var isTokenTable = [127]bool{
...@@ -94,3 +99,71 @@ func isToken(r rune) bool { ...@@ -94,3 +99,71 @@ func isToken(r rune) bool {
func isNotToken(r rune) bool { func isNotToken(r rune) bool {
return !isToken(r) return !isToken(r)
} }
// headerValuesContainsToken reports whether any string in values
// contains the provided token, ASCII case-insensitively.
func headerValuesContainsToken(values []string, token string) bool {
for _, v := range values {
if headerValueContainsToken(v, token) {
return true
}
}
return false
}
// isOWS reports whether b is an optional whitespace byte, as defined
// by RFC 7230 section 3.2.3.
func isOWS(b byte) bool { return b == ' ' || b == '\t' }
// trimOWS returns x with all optional whitespace removes from the
// beginning and end.
func trimOWS(x string) string {
// TODO: consider using strings.Trim(x, " \t") instead,
// if and when it's fast enough. See issue 10292.
// But this ASCII-only code will probably always beat UTF-8
// aware code.
for len(x) > 0 && isOWS(x[0]) {
x = x[1:]
}
for len(x) > 0 && isOWS(x[len(x)-1]) {
x = x[:len(x)-1]
}
return x
}
// headerValueContainsToken reports whether v (assumed to be a
// 0#element, in the ABNF extension described in RFC 7230 section 7)
// contains token amongst its comma-separated tokens, ASCII
// case-insensitively.
func headerValueContainsToken(v string, token string) bool {
v = trimOWS(v)
if comma := strings.IndexByte(v, ','); comma != -1 {
return tokenEqual(trimOWS(v[:comma]), token) || headerValueContainsToken(v[comma+1:], token)
}
return tokenEqual(v, token)
}
// lowerASCII returns the ASCII lowercase version of b.
func lowerASCII(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively.
func tokenEqual(t1, t2 string) bool {
if len(t1) != len(t2) {
return false
}
for i, b := range t1 {
if b >= utf8.RuneSelf {
// No UTF-8 or non-ASCII allowed in tokens.
return false
}
if lowerASCII(byte(b)) != lowerASCII(t2[i]) {
return false
}
}
return true
}
...@@ -29,3 +29,73 @@ func TestIsToken(t *testing.T) { ...@@ -29,3 +29,73 @@ func TestIsToken(t *testing.T) {
} }
} }
} }
func TestHeaderValuesContainsToken(t *testing.T) {
tests := []struct {
vals []string
token string
want bool
}{
{
vals: []string{"foo"},
token: "foo",
want: true,
},
{
vals: []string{"bar", "foo"},
token: "foo",
want: true,
},
{
vals: []string{"foo"},
token: "FOO",
want: true,
},
{
vals: []string{"foo"},
token: "bar",
want: false,
},
{
vals: []string{" foo "},
token: "FOO",
want: true,
},
{
vals: []string{"foo,bar"},
token: "FOO",
want: true,
},
{
vals: []string{"bar,foo,bar"},
token: "FOO",
want: true,
},
{
vals: []string{"bar , foo"},
token: "FOO",
want: true,
},
{
vals: []string{"foo ,bar "},
token: "FOO",
want: true,
},
{
vals: []string{"bar, foo ,bar"},
token: "FOO",
want: true,
},
{
vals: []string{"bar , foo"},
token: "FOO",
want: true,
},
}
for _, tt := range tests {
got := headerValuesContainsToken(tt.vals, tt.token)
if got != tt.want {
t.Errorf("headerValuesContainsToken(%q, %q) = %v; want %v", tt.vals, tt.token, got, tt.want)
}
}
}
...@@ -405,6 +405,57 @@ some body`, ...@@ -405,6 +405,57 @@ some body`,
"foobar", "foobar",
}, },
// Both keep-alive and close, on the same Connection line. (Issue 8840)
{
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 256\r\n" +
"Connection: keep-alive, close\r\n" +
"\r\n",
Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq("HEAD"),
Header: Header{
"Content-Length": {"256"},
},
TransferEncoding: nil,
Close: true,
ContentLength: 256,
},
"",
},
// Both keep-alive and close, on different Connection lines. (Issue 8840)
{
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 256\r\n" +
"Connection: keep-alive\r\n" +
"Connection: close\r\n" +
"\r\n",
Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq("HEAD"),
Header: Header{
"Content-Length": {"256"},
},
TransferEncoding: nil,
Close: true,
ContentLength: 256,
},
"",
},
} }
func TestReadResponse(t *testing.T) { func TestReadResponse(t *testing.T) {
......
...@@ -508,14 +508,13 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { ...@@ -508,14 +508,13 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
if major < 1 { if major < 1 {
return true return true
} else if major == 1 && minor == 0 { } else if major == 1 && minor == 0 {
if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") { vv := header["Connection"]
if headerValuesContainsToken(vv, "close") || !headerValuesContainsToken(vv, "keep-alive") {
return true return true
} }
return false return false
} else { } else {
// TODO: Should split on commas, toss surrounding white space, if headerValuesContainsToken(header["Connection"], "close") {
// and check each field.
if strings.ToLower(header.get("Connection")) == "close" {
if removeCloseHeader { if removeCloseHeader {
header.Del("Connection") header.Del("Connection")
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment