Commit a3156aaa authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http/httptest: change Server to use http.Server.ConnState for accounting

With this CL, httptest.Server now uses connection-level accounting of
outstanding requests instead of ServeHTTP-level accounting. This is
more robust and results in a non-racy shutdown.

This is much easier now that net/http.Server has the ConnState hook.

Fixes #12789
Fixes #12781

Change-Id: I098cf334a6494316acb66cd07df90766df41764b
Reviewed-on: https://go-review.googlesource.com/15151Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 684218e1
......@@ -7,13 +7,17 @@
package httptest
import (
"bytes"
"crypto/tls"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"runtime"
"sync"
"time"
)
// A Server is an HTTP server listening on a system-chosen port on the
......@@ -34,24 +38,10 @@ type Server struct {
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
}
// historyListener keeps track of all connections that it's ever
// accepted.
type historyListener struct {
net.Listener
sync.Mutex // protects history
history []net.Conn
}
func (hs *historyListener) Accept() (c net.Conn, err error) {
c, err = hs.Listener.Accept()
if err == nil {
hs.Lock()
hs.history = append(hs.history, c)
hs.Unlock()
}
return
mu sync.Mutex // guards closed and conns
closed bool
conns map[net.Conn]http.ConnState // except terminal states
}
func newLocalListener() net.Listener {
......@@ -103,10 +93,9 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
s.wrap()
s.goServe()
if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
......@@ -134,23 +123,10 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
tlsListener := tls.NewListener(s.Listener, s.TLS)
s.Listener = &historyListener{Listener: tlsListener}
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
}
func (s *Server) wrapHandler() {
h := s.Config.Handler
if h == nil {
h = http.DefaultServeMux
}
s.Config.Handler = &waitGroupHandler{
s: s,
h: h,
}
s.wrap()
s.goServe()
}
// NewTLSServer starts and returns a new Server using TLS.
......@@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server {
return ts
}
type closeIdleTransport interface {
CloseIdleConnections()
}
// Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() {
s.Listener.Close()
s.wg.Wait()
s.CloseClientConnections()
if t, ok := http.DefaultTransport.(*http.Transport); ok {
s.mu.Lock()
if !s.closed {
s.closed = true
s.Listener.Close()
s.Config.SetKeepAlivesEnabled(false)
for c, st := range s.conns {
if st == http.StateIdle {
s.closeConn(c)
}
}
// If this server doesn't shut down in 5 seconds, tell the user why.
t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
defer t.Stop()
}
s.mu.Unlock()
// Not part of httptest.Server's correctness, but assume most
// users of httptest.Server will be using the standard
// transport, so help them out and close any idle connections for them.
if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
s.wg.Wait()
}
// CloseClientConnections closes any currently open HTTP connections
// to the test Server.
func (s *Server) logCloseHangDebugInfo() {
s.mu.Lock()
defer s.mu.Unlock()
var buf bytes.Buffer
buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
for c, st := range s.conns {
fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
}
log.Print(buf.String())
}
// CloseClientConnections closes any open HTTP connections to the test Server.
func (s *Server) CloseClientConnections() {
hl, ok := s.Listener.(*historyListener)
if !ok {
return
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
s.closeConn(c)
}
hl.Lock()
for _, conn := range hl.history {
conn.Close()
}
func (s *Server) goServe() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.Config.Serve(s.Listener)
}()
}
// wrap installs the connection state-tracking hook to know which
// connections are idle.
func (s *Server) wrap() {
oldHook := s.Config.ConnState
s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()
switch cs {
case http.StateNew:
s.wg.Add(1)
if _, exists := s.conns[c]; exists {
panic("invalid state transition")
}
if s.conns == nil {
s.conns = make(map[net.Conn]http.ConnState)
}
s.conns[c] = cs
if s.closed {
// Probably just a socket-late-binding dial from
// the default transport that lost the race (and
// thus this connection is now idle and will
// never be used).
s.closeConn(c)
}
case http.StateActive:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateNew && oldState != http.StateIdle {
panic("invalid state transition")
}
s.conns[c] = cs
}
case http.StateIdle:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateActive {
panic("invalid state transition")
}
s.conns[c] = cs
}
if s.closed {
s.closeConn(c)
}
case http.StateHijacked, http.StateClosed:
s.forgetConn(c)
}
if oldHook != nil {
oldHook(c, cs)
}
}
hl.Unlock()
}
// waitGroupHandler wraps a handler, incrementing and decrementing a
// sync.WaitGroup on each request, to enable Server.Close to block
// until outstanding requests are finished.
type waitGroupHandler struct {
s *Server
h http.Handler // non-nil
// closeConn closes c. Except on plan9, which is special. See comment below.
// s.mu must be held.
func (s *Server) closeConn(c net.Conn) {
if runtime.GOOS == "plan9" {
// Go's Plan 9 net package isn't great at unblocking reads when
// their underlying TCP connections are closed. Don't trust
// that that the ConnState state machine will get to
// StateClosed. Instead, just go there directly. Plan 9 may leak
// resources if the syscall doesn't end up returning. Oh well.
s.forgetConn(c)
}
go c.Close()
}
func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.s.wg.Add(1)
defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
h.h.ServeHTTP(w, r)
// forgetConn removes c from the set of tracked conns and decrements it from the
// waitgroup, unless it was previously removed.
// s.mu must be held.
func (s *Server) forgetConn(c net.Conn) {
if _, ok := s.conns[c]; ok {
delete(s.conns, c)
s.wg.Done()
}
}
// localhostCert is a PEM-encoded TLS cert with SAN IPs
......
......@@ -27,3 +27,30 @@ func TestServer(t *testing.T) {
t.Errorf("got %q, want hello", string(got))
}
}
// Issue 12781
func TestGetAfterClose(t *testing.T) {
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Fatalf("got %q, want hello", string(got))
}
ts.Close()
res, err = http.Get(ts.URL)
if err == nil {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment