Commit 281088b1 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Server.ErrorLog; log and test TLS handshake errors

Fixes #7291

LGTM=agl
R=golang-codereviews, agl
CC=agl, golang-codereviews
https://golang.org/cl/70250044
parent 91e36811
......@@ -15,6 +15,7 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
. "net/http"
"net/http/httptest"
......@@ -23,6 +24,7 @@ import (
"strings"
"sync"
"testing"
"time"
)
var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
......@@ -54,6 +56,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
}
}
type chanWriter chan string
func (w chanWriter) Write(p []byte) (n int, err error) {
w <- string(p)
return len(p), nil
}
func TestClient(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(robotsTxtHandler)
......@@ -564,6 +573,8 @@ func TestClientInsecureTransport(t *testing.T) {
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
}))
errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ErrorLog = log.New(errc, "", 0)
defer ts.Close()
// TODO(bradfitz): add tests for skipping hostname checks too?
......@@ -585,6 +596,16 @@ func TestClientInsecureTransport(t *testing.T) {
res.Body.Close()
}
}
select {
case v := <-errc:
if !strings.Contains(v, "bad certificate") {
t.Errorf("expected an error log message containing 'bad certificate'; got %q", v)
}
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for logged error")
}
}
func TestClientErrorWithRequestURI(t *testing.T) {
......@@ -635,6 +656,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
defer afterTest(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ErrorLog = log.New(errc, "", 0)
trans := newTLSTransport(t, ts)
trans.TLSClientConfig.ServerName = "badserver"
......@@ -646,6 +669,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
}
select {
case v := <-errc:
if !strings.Contains(v, "bad certificate") {
t.Errorf("expected an error log message containing 'bad certificate'; got %q", v)
}
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for logged error")
}
}
// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
......
......@@ -851,7 +851,9 @@ func TestTLSHandshakeTimeout(t *testing.T) {
}
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ReadTimeout = 250 * time.Millisecond
ts.Config.ErrorLog = log.New(errc, "", 0)
ts.StartTLS()
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
......@@ -866,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) {
t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
}
})
select {
case v := <-errc:
if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
t.Errorf("expected a TLS handshake timeout error; got %q", v)
}
case <-time.After(5 * time.Second):
t.Errorf("timeout waiting for logged error")
}
}
func TestTLSServer(t *testing.T) {
......@@ -878,6 +888,7 @@ func TestTLSServer(t *testing.T) {
}
}
}))
ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
defer ts.Close()
// Connect an idle TCP connection to this server before we run
......
......@@ -615,11 +615,11 @@ const maxPostHandlerReadBytes = 256 << 10
func (w *response) WriteHeader(code int) {
if w.conn.hijacked() {
log.Print("http: response.WriteHeader on hijacked connection")
w.conn.server.logf("http: response.WriteHeader on hijacked connection")
return
}
if w.wroteHeader {
log.Print("http: multiple response.WriteHeader calls")
w.conn.server.logf("http: multiple response.WriteHeader calls")
return
}
w.wroteHeader = true
......@@ -634,7 +634,7 @@ func (w *response) WriteHeader(code int) {
if err == nil && v >= 0 {
w.contentLength = v
} else {
log.Printf("http: invalid Content-Length of %q", cl)
w.conn.server.logf("http: invalid Content-Length of %q", cl)
w.handlerHeader.Del("Content-Length")
}
}
......@@ -817,7 +817,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
if hasCL && hasTE && te != "identity" {
// TODO: return an error if WriteHeader gets a return parameter
// For now just ignore the Content-Length.
log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
te, w.contentLength)
delHeader("Content-Length")
hasCL = false
......@@ -963,7 +963,7 @@ func (w *response) WriteString(data string) (n int, err error) {
// either dataB or dataS is non-zero.
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
if w.conn.hijacked() {
log.Print("http: response.Write on hijacked connection")
w.conn.server.logf("http: response.Write on hijacked connection")
return 0, ErrHijacked
}
if !w.wroteHeader {
......@@ -1096,7 +1096,7 @@ func (c *conn) serve() {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
if !c.hijacked() {
c.close()
......@@ -1112,6 +1112,7 @@ func (c *conn) serve() {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}
if err := tlsConn.Handshake(); err != nil {
c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
c.tlsState = new(tls.ConnectionState)
......@@ -1604,6 +1605,12 @@ type Server struct {
// ConnState type and associated constants for details.
ConnState func(net.Conn, ConnState)
// ErrorLog specifies an optional logger for errors accepting
// connections and unexpected behavior from handlers.
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger
disableKeepAlives int32 // accessed atomically.
}
......@@ -1704,7 +1711,7 @@ func (srv *Server) Serve(l net.Listener) error {
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Printf("http: Accept error: %v; retrying in %v", e, tempDelay)
srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
......@@ -1735,6 +1742,14 @@ func (s *Server) SetKeepAlivesEnabled(v bool) {
}
}
func (s *Server) logf(format string, args ...interface{}) {
if s.ErrorLog != nil {
s.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
// ListenAndServe listens on the TCP network address addr
// and then calls Serve with handler to handle requests
// on incoming connections. Handler is typically nil,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment