Commit d24a9785 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http: configurable and default request header size limit

This addresses the biggest DoS in issue 2093

R=golang-dev, dsymonds
CC=golang-dev
https://golang.org/cl/4841050
parent 29df7bb7
...@@ -873,6 +873,28 @@ func TestStripPrefix(t *testing.T) { ...@@ -873,6 +873,28 @@ func TestStripPrefix(t *testing.T) {
} }
} }
func TestRequestLimit(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
t.Fatalf("didn't expect to get request in Handler")
}))
defer ts.Close()
req, _ := NewRequest("GET", ts.URL, nil)
var bytesPerHeader = len("header12345: val12345\r\n")
for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
}
res, err := DefaultClient.Do(req)
if err != nil {
// Some HTTP clients may fail on this undefined behavior (server replying and
// closing the connection while the request is still being written), but
// we do support it (at least currently), so we expect a response below.
t.Fatalf("Do: %v", err)
}
if res.StatusCode != 400 {
t.Fatalf("expected 400 response status; got: %d %s", res.StatusCode, res.Status)
}
}
type errorListener struct { type errorListener struct {
errs []os.Error errs []os.Error
} }
......
...@@ -94,9 +94,10 @@ type Hijacker interface { ...@@ -94,9 +94,10 @@ type Hijacker interface {
// A conn represents the server side of an HTTP connection. // A conn represents the server side of an HTTP connection.
type conn struct { type conn struct {
remoteAddr string // network address of remote side remoteAddr string // network address of remote side
handler Handler // request handler server *Server // the Server on which the connection arrived
rwc net.Conn // i/o connection rwc net.Conn // i/o connection
buf *bufio.ReadWriter // buffered rwc lr *io.LimitedReader // io.LimitReader(rwc)
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc
hijacked bool // connection has been hijacked by handler hijacked bool // connection has been hijacked by handler
tlsState *tls.ConnectionState // or nil when not using TLS tlsState *tls.ConnectionState // or nil when not using TLS
body []byte body []byte
...@@ -143,14 +144,18 @@ func (r *response) ReadFrom(src io.Reader) (n int64, err os.Error) { ...@@ -143,14 +144,18 @@ func (r *response) ReadFrom(src io.Reader) (n int64, err os.Error) {
return io.Copy(writerOnly{r}, src) return io.Copy(writerOnly{r}, src)
} }
// noLimit is an effective infinite upper bound for io.LimitedReader
const noLimit int64 = (1 << 63) - 1
// Create new connection from rwc. // Create new connection from rwc.
func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { func (srv *Server) newConn(rwc net.Conn) (c *conn, err os.Error) {
c = new(conn) c = new(conn)
c.remoteAddr = rwc.RemoteAddr().String() c.remoteAddr = rwc.RemoteAddr().String()
c.handler = handler c.server = srv
c.rwc = rwc c.rwc = rwc
c.body = make([]byte, sniffLen) c.body = make([]byte, sniffLen)
br := bufio.NewReader(rwc) c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
br := bufio.NewReader(c.lr)
bw := bufio.NewWriter(rwc) bw := bufio.NewWriter(rwc)
c.buf = bufio.NewReadWriter(br, bw) c.buf = bufio.NewReadWriter(br, bw)
...@@ -163,6 +168,18 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { ...@@ -163,6 +168,18 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) {
return c, nil return c, nil
} }
// DefaultMaxHeaderBytes is the maximum permitted size of the headers
// in an HTTP request.
// This can be overridden by setting Server.MaxHeaderBytes.
const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
func (srv *Server) maxHeaderBytes() int {
if srv.MaxHeaderBytes > 0 {
return srv.MaxHeaderBytes
}
return DefaultMaxHeaderBytes
}
// wrapper around io.ReaderCloser which on first read, sends an // wrapper around io.ReaderCloser which on first read, sends an
// HTTP/1.1 100 Continue header // HTTP/1.1 100 Continue header
type expectContinueReader struct { type expectContinueReader struct {
...@@ -194,15 +211,22 @@ func (ecr *expectContinueReader) Close() os.Error { ...@@ -194,15 +211,22 @@ func (ecr *expectContinueReader) Close() os.Error {
// It is like time.RFC1123 but hard codes GMT as the time zone. // It is like time.RFC1123 but hard codes GMT as the time zone.
const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
var errTooLarge = os.NewError("http: request too large")
// Read next request from connection. // Read next request from connection.
func (c *conn) readRequest() (w *response, err os.Error) { func (c *conn) readRequest() (w *response, err os.Error) {
if c.hijacked { if c.hijacked {
return nil, ErrHijacked return nil, ErrHijacked
} }
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
var req *Request var req *Request
if req, err = ReadRequest(c.buf.Reader); err != nil { if req, err = ReadRequest(c.buf.Reader); err != nil {
if c.lr.N == 0 {
return nil, errTooLarge
}
return nil, err return nil, err
} }
c.lr.N = noLimit
req.RemoteAddr = c.remoteAddr req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState req.TLS = c.tlsState
...@@ -567,6 +591,14 @@ func (c *conn) serve() { ...@@ -567,6 +591,14 @@ func (c *conn) serve() {
for { for {
w, err := c.readRequest() w, err := c.readRequest()
if err != nil { if err != nil {
if err == errTooLarge {
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
fmt.Fprintf(c.rwc, "HTTP/1.1 400 Request Too Large\r\n\r\n")
}
break break
} }
...@@ -603,12 +635,17 @@ func (c *conn) serve() { ...@@ -603,12 +635,17 @@ func (c *conn) serve() {
break break
} }
handler := c.server.Handler
if handler == nil {
handler = DefaultServeMux
}
// HTTP cannot have multiple simultaneous active requests.[*] // HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another, // Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine. // so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process // [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized. // in parallel even if their responses need to be serialized.
c.handler.ServeHTTP(w, w.req) handler.ServeHTTP(w, w.req)
if c.hijacked { if c.hijacked {
return return
} }
...@@ -906,10 +943,11 @@ func Serve(l net.Listener, handler Handler) os.Error { ...@@ -906,10 +943,11 @@ func Serve(l net.Listener, handler Handler) os.Error {
// A Server defines parameters for running an HTTP server. // A Server defines parameters for running an HTTP server.
type Server struct { type Server struct {
Addr string // TCP address to listen on, ":http" if empty Addr string // TCP address to listen on, ":http" if empty
Handler Handler // handler to invoke, http.DefaultServeMux if nil Handler Handler // handler to invoke, http.DefaultServeMux if nil
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0
} }
// ListenAndServe listens on the TCP network address srv.Addr and then // ListenAndServe listens on the TCP network address srv.Addr and then
...@@ -932,10 +970,6 @@ func (srv *Server) ListenAndServe() os.Error { ...@@ -932,10 +970,6 @@ func (srv *Server) ListenAndServe() os.Error {
// then call srv.Handler to reply to them. // then call srv.Handler to reply to them.
func (srv *Server) Serve(l net.Listener) os.Error { func (srv *Server) Serve(l net.Listener) os.Error {
defer l.Close() defer l.Close()
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
for { for {
rw, e := l.Accept() rw, e := l.Accept()
if e != nil { if e != nil {
...@@ -951,7 +985,7 @@ func (srv *Server) Serve(l net.Listener) os.Error { ...@@ -951,7 +985,7 @@ func (srv *Server) Serve(l net.Listener) os.Error {
if srv.WriteTimeout != 0 { if srv.WriteTimeout != 0 {
rw.SetWriteTimeout(srv.WriteTimeout) rw.SetWriteTimeout(srv.WriteTimeout)
} }
c, err := newConn(rw, handler) c, err := srv.newConn(rw)
if err != nil { if err != nil {
continue continue
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment