Commit 8f0bfc5a authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http/httptest: make Server.Close wait for outstanding requests to finish

Might fix issue 3050

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/5708066
parent 357b257c
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"sync"
) )
// A Server is an HTTP server listening on a system-chosen port on the // A Server is an HTTP server listening on a system-chosen port on the
...@@ -25,6 +26,10 @@ type Server struct { ...@@ -25,6 +26,10 @@ type Server struct {
// Config may be changed after calling NewUnstartedServer and // Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS. // before Start or StartTLS.
Config *http.Server Config *http.Server
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
} }
// historyListener keeps track of all connections that it's ever // historyListener keeps track of all connections that it's ever
...@@ -93,6 +98,7 @@ func (s *Server) Start() { ...@@ -93,6 +98,7 @@ func (s *Server) Start() {
} }
s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)} s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)}
s.URL = "http://" + s.Listener.Addr().String() s.URL = "http://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener) go s.Config.Serve(s.Listener)
if *serve != "" { if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
...@@ -118,9 +124,21 @@ func (s *Server) StartTLS() { ...@@ -118,9 +124,21 @@ func (s *Server) StartTLS() {
s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
s.URL = "https://" + s.Listener.Addr().String() s.URL = "https://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener) go s.Config.Serve(s.Listener)
} }
func (s *Server) wrapHandler() {
h := s.Config.Handler
if h == nil {
h = http.DefaultServeMux
}
s.Config.Handler = &waitGroupHandler{
s: s,
h: h,
}
}
// NewTLSServer starts and returns a new Server using TLS. // NewTLSServer starts and returns a new Server using TLS.
// The caller should call Close when finished, to shut it down. // The caller should call Close when finished, to shut it down.
func NewTLSServer(handler http.Handler) *Server { func NewTLSServer(handler http.Handler) *Server {
...@@ -129,9 +147,11 @@ func NewTLSServer(handler http.Handler) *Server { ...@@ -129,9 +147,11 @@ func NewTLSServer(handler http.Handler) *Server {
return ts return ts
} }
// Close shuts down the server. // Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() { func (s *Server) Close() {
s.Listener.Close() s.Listener.Close()
s.wg.Wait()
} }
// CloseClientConnections closes any currently open HTTP connections // CloseClientConnections closes any currently open HTTP connections
...@@ -146,6 +166,20 @@ func (s *Server) CloseClientConnections() { ...@@ -146,6 +166,20 @@ func (s *Server) CloseClientConnections() {
} }
} }
// waitGroupHandler wraps a handler, incrementing and decrementing a
// sync.WaitGroup on each request, to enable Server.Close to block
// until outstanding requests are finished.
type waitGroupHandler struct {
s *Server
h http.Handler // non-nil
}
func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.s.wg.Add(1)
defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
h.h.ServeHTTP(w, r)
}
// localhostCert is a PEM-encoded TLS cert with SAN DNS names // localhostCert is a PEM-encoded TLS cert with SAN DNS names
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
// of ASN.1 time). // of ASN.1 time).
......
...@@ -129,9 +129,10 @@ func TestSniffWriteSize(t *testing.T) { ...@@ -129,9 +129,10 @@ func TestSniffWriteSize(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} {
_, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size)) res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size))
if err != nil { if err != nil {
t.Fatalf("size %d: %v", size, err) t.Fatalf("size %d: %v", size, err)
} }
res.Body.Close()
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment