Commit 281088b1 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Server.ErrorLog; log and test TLS handshake errors

Fixes #7291

LGTM=agl
R=golang-codereviews, agl
CC=agl, golang-codereviews
https://golang.org/cl/70250044
parent 91e36811
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
. "net/http" . "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -23,6 +24,7 @@ import ( ...@@ -23,6 +24,7 @@ import (
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time"
) )
var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
...@@ -54,6 +56,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { ...@@ -54,6 +56,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
} }
} }
type chanWriter chan string
func (w chanWriter) Write(p []byte) (n int, err error) {
w <- string(p)
return len(p), nil
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewServer(robotsTxtHandler) ts := httptest.NewServer(robotsTxtHandler)
...@@ -564,6 +573,8 @@ func TestClientInsecureTransport(t *testing.T) { ...@@ -564,6 +573,8 @@ func TestClientInsecureTransport(t *testing.T) {
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello")) w.Write([]byte("Hello"))
})) }))
errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ErrorLog = log.New(errc, "", 0)
defer ts.Close() defer ts.Close()
// TODO(bradfitz): add tests for skipping hostname checks too? // TODO(bradfitz): add tests for skipping hostname checks too?
...@@ -585,6 +596,16 @@ func TestClientInsecureTransport(t *testing.T) { ...@@ -585,6 +596,16 @@ func TestClientInsecureTransport(t *testing.T) {
res.Body.Close() res.Body.Close()
} }
} }
select {
case v := <-errc:
if !strings.Contains(v, "bad certificate") {
t.Errorf("expected an error log message containing 'bad certificate'; got %q", v)
}
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for logged error")
}
} }
func TestClientErrorWithRequestURI(t *testing.T) { func TestClientErrorWithRequestURI(t *testing.T) {
...@@ -635,6 +656,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { ...@@ -635,6 +656,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close() defer ts.Close()
errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ErrorLog = log.New(errc, "", 0)
trans := newTLSTransport(t, ts) trans := newTLSTransport(t, ts)
trans.TLSClientConfig.ServerName = "badserver" trans.TLSClientConfig.ServerName = "badserver"
...@@ -646,6 +669,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { ...@@ -646,6 +669,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
} }
select {
case v := <-errc:
if !strings.Contains(v, "bad certificate") {
t.Errorf("expected an error log message containing 'bad certificate'; got %q", v)
}
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for logged error")
}
} }
// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
......
...@@ -851,7 +851,9 @@ func TestTLSHandshakeTimeout(t *testing.T) { ...@@ -851,7 +851,9 @@ func TestTLSHandshakeTimeout(t *testing.T) {
} }
defer afterTest(t) defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ReadTimeout = 250 * time.Millisecond ts.Config.ReadTimeout = 250 * time.Millisecond
ts.Config.ErrorLog = log.New(errc, "", 0)
ts.StartTLS() ts.StartTLS()
defer ts.Close() defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String()) conn, err := net.Dial("tcp", ts.Listener.Addr().String())
...@@ -866,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) { ...@@ -866,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) {
t.Errorf("Read = %d, %v; want an error and no bytes", n, err) t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
} }
}) })
select {
case v := <-errc:
if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
t.Errorf("expected a TLS handshake timeout error; got %q", v)
}
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for logged error")
}
} }
func TestTLSServer(t *testing.T) { func TestTLSServer(t *testing.T) {
...@@ -878,6 +888,7 @@ func TestTLSServer(t *testing.T) { ...@@ -878,6 +888,7 @@ func TestTLSServer(t *testing.T) {
} }
} }
})) }))
ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
defer ts.Close() defer ts.Close()
// Connect an idle TCP connection to this server before we run // Connect an idle TCP connection to this server before we run
......
...@@ -615,11 +615,11 @@ const maxPostHandlerReadBytes = 256 << 10 ...@@ -615,11 +615,11 @@ const maxPostHandlerReadBytes = 256 << 10
func (w *response) WriteHeader(code int) { func (w *response) WriteHeader(code int) {
if w.conn.hijacked() { if w.conn.hijacked() {
log.Print("http: response.WriteHeader on hijacked connection") w.conn.server.logf("http: response.WriteHeader on hijacked connection")
return return
} }
if w.wroteHeader { if w.wroteHeader {
log.Print("http: multiple response.WriteHeader calls") w.conn.server.logf("http: multiple response.WriteHeader calls")
return return
} }
w.wroteHeader = true w.wroteHeader = true
...@@ -634,7 +634,7 @@ func (w *response) WriteHeader(code int) { ...@@ -634,7 +634,7 @@ func (w *response) WriteHeader(code int) {
if err == nil && v >= 0 { if err == nil && v >= 0 {
w.contentLength = v w.contentLength = v
} else { } else {
log.Printf("http: invalid Content-Length of %q", cl) w.conn.server.logf("http: invalid Content-Length of %q", cl)
w.handlerHeader.Del("Content-Length") w.handlerHeader.Del("Content-Length")
} }
} }
...@@ -817,7 +817,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -817,7 +817,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
if hasCL && hasTE && te != "identity" { if hasCL && hasTE && te != "identity" {
// TODO: return an error if WriteHeader gets a return parameter // TODO: return an error if WriteHeader gets a return parameter
// For now just ignore the Content-Length. // For now just ignore the Content-Length.
log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
te, w.contentLength) te, w.contentLength)
delHeader("Content-Length") delHeader("Content-Length")
hasCL = false hasCL = false
...@@ -963,7 +963,7 @@ func (w *response) WriteString(data string) (n int, err error) { ...@@ -963,7 +963,7 @@ func (w *response) WriteString(data string) (n int, err error) {
// either dataB or dataS is non-zero. // either dataB or dataS is non-zero.
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
if w.conn.hijacked() { if w.conn.hijacked() {
log.Print("http: response.Write on hijacked connection") w.conn.server.logf("http: response.Write on hijacked connection")
return 0, ErrHijacked return 0, ErrHijacked
} }
if !w.wroteHeader { if !w.wroteHeader {
...@@ -1096,7 +1096,7 @@ func (c *conn) serve() { ...@@ -1096,7 +1096,7 @@ func (c *conn) serve() {
const size = 64 << 10 const size = 64 << 10
buf := make([]byte, size) buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)] buf = buf[:runtime.Stack(buf, false)]
log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
} }
if !c.hijacked() { if !c.hijacked() {
c.close() c.close()
...@@ -1112,6 +1112,7 @@ func (c *conn) serve() { ...@@ -1112,6 +1112,7 @@ func (c *conn) serve() {
c.rwc.SetWriteDeadline(time.Now().Add(d)) c.rwc.SetWriteDeadline(time.Now().Add(d))
} }
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return return
} }
c.tlsState = new(tls.ConnectionState) c.tlsState = new(tls.ConnectionState)
...@@ -1604,6 +1605,12 @@ type Server struct { ...@@ -1604,6 +1605,12 @@ type Server struct {
// ConnState type and associated constants for details. // ConnState type and associated constants for details.
ConnState func(net.Conn, ConnState) ConnState func(net.Conn, ConnState)
// ErrorLog specifies an optional logger for errors accepting
// connections and unexpected behavior from handlers.
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger
disableKeepAlives int32 // accessed atomically. disableKeepAlives int32 // accessed atomically.
} }
...@@ -1704,7 +1711,7 @@ func (srv *Server) Serve(l net.Listener) error { ...@@ -1704,7 +1711,7 @@ func (srv *Server) Serve(l net.Listener) error {
if max := 1 * time.Second; tempDelay > max { if max := 1 * time.Second; tempDelay > max {
tempDelay = max tempDelay = max
} }
log.Printf("http: Accept error: %v; retrying in %v", e, tempDelay) srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay) time.Sleep(tempDelay)
continue continue
} }
...@@ -1735,6 +1742,14 @@ func (s *Server) SetKeepAlivesEnabled(v bool) { ...@@ -1735,6 +1742,14 @@ func (s *Server) SetKeepAlivesEnabled(v bool) {
} }
} }
func (s *Server) logf(format string, args ...interface{}) {
if s.ErrorLog != nil {
s.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
// ListenAndServe listens on the TCP network address addr // ListenAndServe listens on the TCP network address addr
// and then calls Serve with handler to handle requests // and then calls Serve with handler to handle requests
// on incoming connections. Handler is typically nil, // on incoming connections. Handler is typically nil,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment