Commit a3156aaa authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http/httptest: change Server to use http.Server.ConnState for accounting

With this CL, httptest.Server now uses connection-level accounting of
outstanding requests instead of ServeHTTP-level accounting. This is
more robust and results in a non-racy shutdown.

This is much easier now that net/http.Server has the ConnState hook.

Fixes #12789
Fixes #12781

Change-Id: I098cf334a6494316acb66cd07df90766df41764b
Reviewed-on: https://go-review.googlesource.com/15151Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 684218e1
...@@ -7,13 +7,17 @@ ...@@ -7,13 +7,17 @@
package httptest package httptest
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt" "fmt"
"log"
"net" "net"
"net/http" "net/http"
"os" "os"
"runtime"
"sync" "sync"
"time"
) )
// A Server is an HTTP server listening on a system-chosen port on the // A Server is an HTTP server listening on a system-chosen port on the
...@@ -34,24 +38,10 @@ type Server struct { ...@@ -34,24 +38,10 @@ type Server struct {
// wg counts the number of outstanding HTTP requests on this server. // wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished. // Close blocks until all requests are finished.
wg sync.WaitGroup wg sync.WaitGroup
}
// historyListener keeps track of all connections that it's ever
// accepted.
type historyListener struct {
net.Listener
sync.Mutex // protects history
history []net.Conn
}
func (hs *historyListener) Accept() (c net.Conn, err error) { mu sync.Mutex // guards closed and conns
c, err = hs.Listener.Accept() closed bool
if err == nil { conns map[net.Conn]http.ConnState // except terminal states
hs.Lock()
hs.history = append(hs.history, c)
hs.Unlock()
}
return
} }
func newLocalListener() net.Listener { func newLocalListener() net.Listener {
...@@ -103,10 +93,9 @@ func (s *Server) Start() { ...@@ -103,10 +93,9 @@ func (s *Server) Start() {
if s.URL != "" { if s.URL != "" {
panic("Server already started") panic("Server already started")
} }
s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String() s.URL = "http://" + s.Listener.Addr().String()
s.wrapHandler() s.wrap()
go s.Config.Serve(s.Listener) s.goServe()
if *serve != "" { if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {} select {}
...@@ -134,23 +123,10 @@ func (s *Server) StartTLS() { ...@@ -134,23 +123,10 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 { if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert} s.TLS.Certificates = []tls.Certificate{cert}
} }
tlsListener := tls.NewListener(s.Listener, s.TLS) s.Listener = tls.NewListener(s.Listener, s.TLS)
s.Listener = &historyListener{Listener: tlsListener}
s.URL = "https://" + s.Listener.Addr().String() s.URL = "https://" + s.Listener.Addr().String()
s.wrapHandler() s.wrap()
go s.Config.Serve(s.Listener) s.goServe()
}
func (s *Server) wrapHandler() {
h := s.Config.Handler
if h == nil {
h = http.DefaultServeMux
}
s.Config.Handler = &waitGroupHandler{
s: s,
h: h,
}
} }
// NewTLSServer starts and returns a new Server using TLS. // NewTLSServer starts and returns a new Server using TLS.
...@@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server { ...@@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server {
return ts return ts
} }
type closeIdleTransport interface {
CloseIdleConnections()
}
// Close shuts down the server and blocks until all outstanding // Close shuts down the server and blocks until all outstanding
// requests on this server have completed. // requests on this server have completed.
func (s *Server) Close() { func (s *Server) Close() {
s.Listener.Close() s.mu.Lock()
s.wg.Wait() if !s.closed {
s.CloseClientConnections() s.closed = true
if t, ok := http.DefaultTransport.(*http.Transport); ok { s.Listener.Close()
s.Config.SetKeepAlivesEnabled(false)
for c, st := range s.conns {
if st == http.StateIdle {
s.closeConn(c)
}
}
// If this server doesn't shut down in 5 seconds, tell the user why.
t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
defer t.Stop()
}
s.mu.Unlock()
// Not part of httptest.Server's correctness, but assume most
// users of httptest.Server will be using the standard
// transport, so help them out and close any idle connections for them.
if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
t.CloseIdleConnections() t.CloseIdleConnections()
} }
s.wg.Wait()
} }
// CloseClientConnections closes any currently open HTTP connections func (s *Server) logCloseHangDebugInfo() {
// to the test Server. s.mu.Lock()
defer s.mu.Unlock()
var buf bytes.Buffer
buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
for c, st := range s.conns {
fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
}
log.Print(buf.String())
}
// CloseClientConnections closes any open HTTP connections to the test Server.
func (s *Server) CloseClientConnections() { func (s *Server) CloseClientConnections() {
hl, ok := s.Listener.(*historyListener) s.mu.Lock()
if !ok { defer s.mu.Unlock()
return for c := range s.conns {
s.closeConn(c)
} }
hl.Lock() }
for _, conn := range hl.history {
conn.Close() func (s *Server) goServe() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.Config.Serve(s.Listener)
}()
}
// wrap installs the connection state-tracking hook to know which
// connections are idle.
func (s *Server) wrap() {
oldHook := s.Config.ConnState
s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()
switch cs {
case http.StateNew:
s.wg.Add(1)
if _, exists := s.conns[c]; exists {
panic("invalid state transition")
}
if s.conns == nil {
s.conns = make(map[net.Conn]http.ConnState)
}
s.conns[c] = cs
if s.closed {
// Probably just a socket-late-binding dial from
// the default transport that lost the race (and
// thus this connection is now idle and will
// never be used).
s.closeConn(c)
}
case http.StateActive:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateNew && oldState != http.StateIdle {
panic("invalid state transition")
}
s.conns[c] = cs
}
case http.StateIdle:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateActive {
panic("invalid state transition")
}
s.conns[c] = cs
}
if s.closed {
s.closeConn(c)
}
case http.StateHijacked, http.StateClosed:
s.forgetConn(c)
}
if oldHook != nil {
oldHook(c, cs)
}
} }
hl.Unlock()
} }
// waitGroupHandler wraps a handler, incrementing and decrementing a // closeConn closes c. Except on plan9, which is special. See comment below.
// sync.WaitGroup on each request, to enable Server.Close to block // s.mu must be held.
// until outstanding requests are finished. func (s *Server) closeConn(c net.Conn) {
type waitGroupHandler struct { if runtime.GOOS == "plan9" {
s *Server // Go's Plan 9 net package isn't great at unblocking reads when
h http.Handler // non-nil // their underlying TCP connections are closed. Don't trust
// that that the ConnState state machine will get to
// StateClosed. Instead, just go there directly. Plan 9 may leak
// resources if the syscall doesn't end up returning. Oh well.
s.forgetConn(c)
}
go c.Close()
} }
func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // forgetConn removes c from the set of tracked conns and decrements it from the
h.s.wg.Add(1) // waitgroup, unless it was previously removed.
defer h.s.wg.Done() // a defer, in case ServeHTTP below panics // s.mu must be held.
h.h.ServeHTTP(w, r) func (s *Server) forgetConn(c net.Conn) {
if _, ok := s.conns[c]; ok {
delete(s.conns, c)
s.wg.Done()
}
} }
// localhostCert is a PEM-encoded TLS cert with SAN IPs // localhostCert is a PEM-encoded TLS cert with SAN IPs
......
...@@ -27,3 +27,30 @@ func TestServer(t *testing.T) { ...@@ -27,3 +27,30 @@ func TestServer(t *testing.T) {
t.Errorf("got %q, want hello", string(got)) t.Errorf("got %q, want hello", string(got))
} }
} }
// Issue 12781
func TestGetAfterClose(t *testing.T) {
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Fatalf("got %q, want hello", string(got))
}
ts.Close()
res, err = http.Get(ts.URL)
if err == nil {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment