Commit 6e11f45e authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make Server validate Host headers

Fixes #11206 (that we accept invalid bytes)
Fixes #13624 (that we don't require a Host header in HTTP/1.1 per spec)

Change-Id: I4138281d513998789163237e83bb893aeda43336
Reviewed-on: https://go-review.googlesource.com/17892Reviewed-by: default avatarRuss Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent c2fb457e
...@@ -689,8 +689,9 @@ func putTextprotoReader(r *textproto.Reader) { ...@@ -689,8 +689,9 @@ func putTextprotoReader(r *textproto.Reader) {
} }
// ReadRequest reads and parses an incoming request from b. // ReadRequest reads and parses an incoming request from b.
func ReadRequest(b *bufio.Reader) (req *Request, err error) { func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) }
func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
tp := newTextprotoReader(b) tp := newTextprotoReader(b)
req = new(Request) req = new(Request)
...@@ -757,7 +758,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { ...@@ -757,7 +758,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
if req.Host == "" { if req.Host == "" {
req.Host = req.Header.get("Host") req.Host = req.Header.get("Host")
} }
delete(req.Header, "Host") if deleteHostHeader {
delete(req.Header, "Host")
}
fixPragmaCacheControl(req.Header) fixPragmaCacheControl(req.Header)
...@@ -1060,3 +1063,59 @@ func (r *Request) isReplayable() bool { ...@@ -1060,3 +1063,59 @@ func (r *Request) isReplayable() bool {
r.Method == "OPTIONS" || r.Method == "OPTIONS" ||
r.Method == "TRACE") r.Method == "TRACE")
} }
func validHostHeader(h string) bool {
// The latests spec is actually this:
//
// http://tools.ietf.org/html/rfc7230#section-5.4
// Host = uri-host [ ":" port ]
//
// Where uri-host is:
// http://tools.ietf.org/html/rfc3986#section-3.2.2
//
// But we're going to be much more lenient for now and just
// search for any byte that's not a valid byte in any of those
// expressions.
for i := 0; i < len(h); i++ {
if !validHostByte[h[i]] {
return false
}
}
return true
}
// See the validHostHeader comment.
var validHostByte = [256]bool{
'0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
'8': true, '9': true,
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
'y': true, 'z': true,
'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
'Y': true, 'Z': true,
'!': true, // sub-delims
'$': true, // sub-delims
'%': true, // pct-encoded (and used in IPv6 zones)
'&': true, // sub-delims
'(': true, // sub-delims
')': true, // sub-delims
'*': true, // sub-delims
'+': true, // sub-delims
',': true, // sub-delims
'-': true, // unreserved
'.': true, // unreserved
':': true, // IPv6address + Host expression's optional port
';': true, // sub-delims
'=': true, // sub-delims
'[': true,
'\'': true, // sub-delims
']': true,
'_': true, // unreserved
'~': true, // unreserved
}
...@@ -2201,7 +2201,7 @@ func TestClientWriteShutdown(t *testing.T) { ...@@ -2201,7 +2201,7 @@ func TestClientWriteShutdown(t *testing.T) {
// buffered before chunk headers are added, not after chunk headers. // buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) { func TestServerBufferedChunking(t *testing.T) {
conn := new(testConn) conn := new(testConn)
conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
conn.closec = make(chan bool, 1) conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn} ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
...@@ -2934,9 +2934,9 @@ func TestCodesPreventingContentTypeAndBody(t *testing.T) { ...@@ -2934,9 +2934,9 @@ func TestCodesPreventingContentTypeAndBody(t *testing.T) {
"GET / HTTP/1.0", "GET / HTTP/1.0",
"GET /header HTTP/1.0", "GET /header HTTP/1.0",
"GET /more HTTP/1.0", "GET /more HTTP/1.0",
"GET / HTTP/1.1", "GET / HTTP/1.1\nHost: foo",
"GET /header HTTP/1.1", "GET /header HTTP/1.1\nHost: foo",
"GET /more HTTP/1.1", "GET /more HTTP/1.1\nHost: foo",
} { } {
got := ht.rawResponse(req) got := ht.rawResponse(req)
wantStatus := fmt.Sprintf("%d %s", code, StatusText(code)) wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
...@@ -2957,7 +2957,7 @@ func TestContentTypeOkayOn204(t *testing.T) { ...@@ -2957,7 +2957,7 @@ func TestContentTypeOkayOn204(t *testing.T) {
w.Header().Set("Content-Type", "foo/bar") w.Header().Set("Content-Type", "foo/bar")
w.WriteHeader(204) w.WriteHeader(204)
})) }))
got := ht.rawResponse("GET / HTTP/1.1") got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
if !strings.Contains(got, "Content-Type: foo/bar") { if !strings.Contains(got, "Content-Type: foo/bar") {
t.Errorf("Response = %q; want Content-Type: foo/bar", got) t.Errorf("Response = %q; want Content-Type: foo/bar", got)
} }
...@@ -3628,6 +3628,54 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) { ...@@ -3628,6 +3628,54 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) {
} }
} }
// Test that we validate the Host header.
func TestServerValidatesHostHeader(t *testing.T) {
tests := []struct {
proto string
host string
want int
}{
{"HTTP/1.1", "", 400},
{"HTTP/1.1", "Host: \r\n", 200},
{"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
{"HTTP/1.1", "Host: foo.com\r\n", 200},
{"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
{"HTTP/1.1", "Host: foo.com:80\r\n", 200},
{"HTTP/1.1", "Host: ::1\r\n", 200},
{"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it
{"HTTP/1.1", "Host: [::1]:80\r\n", 200},
{"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
{"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
{"HTTP/1.1", "Host: \x06\r\n", 400},
{"HTTP/1.1", "Host: \xff\r\n", 400},
{"HTTP/1.1", "Host: {\r\n", 400},
{"HTTP/1.1", "Host: }\r\n", 400},
{"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
// HTTP/1.0 can lack a host header, but if present
// must play by the rules too:
{"HTTP/1.0", "", 200},
{"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
{"HTTP/1.0", "Host: \xff\r\n", 400},
}
for _, tt := range tests {
conn := &testConn{closec: make(chan bool)}
io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n")
ln := &oneConnListener{conn}
go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
<-conn.closec
res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
if err != nil {
t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
continue
}
if res.StatusCode != tt.want {
t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
}
}
}
func BenchmarkClientServer(b *testing.B) { func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.StopTimer() b.StopTimer()
......
...@@ -686,7 +686,7 @@ func (c *conn) readRequest() (w *response, err error) { ...@@ -686,7 +686,7 @@ func (c *conn) readRequest() (w *response, err error) {
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek)) c.bufr.Discard(numLeadingCRorLF(peek))
} }
req, err := ReadRequest(c.bufr) req, err := readRequest(c.bufr, false)
c.mu.Unlock() c.mu.Unlock()
if err != nil { if err != nil {
if c.r.hitReadLimit() { if c.r.hitReadLimit() {
...@@ -697,6 +697,18 @@ func (c *conn) readRequest() (w *response, err error) { ...@@ -697,6 +697,18 @@ func (c *conn) readRequest() (w *response, err error) {
c.lastMethod = req.Method c.lastMethod = req.Method
c.r.setInfiniteReadLimit() c.r.setInfiniteReadLimit()
hosts, haveHost := req.Header["Host"]
if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) {
return nil, badRequestError("missing required Host header")
}
if len(hosts) > 1 {
return nil, badRequestError("too many Host headers")
}
if len(hosts) == 1 && !validHostHeader(hosts[0]) {
return nil, badRequestError("malformed Host header")
}
delete(req.Header, "Host")
req.RemoteAddr = c.remoteAddr req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState req.TLS = c.tlsState
if body, ok := req.Body.(*body); ok { if body, ok := req.Body.(*body); ok {
...@@ -1334,6 +1346,13 @@ func (c *conn) setState(nc net.Conn, state ConnState) { ...@@ -1334,6 +1346,13 @@ func (c *conn) setState(nc net.Conn, state ConnState) {
} }
} }
// badRequestError is a literal string (used by in the server in HTML,
// unescaped) to tell the user why their request was bad. It should
// be plain text without user info or other embeddded errors.
type badRequestError string
func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
// Serve a new connection. // Serve a new connection.
func (c *conn) serve() { func (c *conn) serve() {
c.remoteAddr = c.rwc.RemoteAddr().String() c.remoteAddr = c.rwc.RemoteAddr().String()
...@@ -1399,7 +1418,11 @@ func (c *conn) serve() { ...@@ -1399,7 +1418,11 @@ func (c *conn) serve() {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() { if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return // don't reply return // don't reply
} }
io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request") var publicErr string
if v, ok := err.(badRequestError); ok {
publicErr = ": " + string(v)
}
io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr)
return return
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment