Commit fbf4dd91 authored by Johan Brandhorst's avatar Johan Brandhorst Committed by Brad Fitzpatrick

net/http/httptest: add Client and Certificate methods to Server

Adds a function for easily accessing the x509.Certificate
of a Server, if there is one. Also adds a helper function
for getting a http.Client suitable for use with the server.

This makes the steps required to test a httptest
TLS server simpler.

Fixes #18411

Change-Id: I2e78fe1e54e31bed9c641be2d9a099f698c7bbde
Reviewed-on: https://go-review.googlesource.com/34639Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 02e36f8c
...@@ -411,7 +411,7 @@ var pkgDeps = map[string][]string{ ...@@ -411,7 +411,7 @@ var pkgDeps = map[string][]string{
"net/http/cgi": {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"}, "net/http/cgi": {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"},
"net/http/cookiejar": {"L4", "NET", "net/http"}, "net/http/cookiejar": {"L4", "NET", "net/http"},
"net/http/fcgi": {"L4", "NET", "OS", "net/http", "net/http/cgi"}, "net/http/fcgi": {"L4", "NET", "OS", "net/http", "net/http/cgi"},
"net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal"}, "net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509"},
"net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"}, "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"}, "net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
"net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"}, "net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"},
......
...@@ -54,3 +54,25 @@ func ExampleServer() { ...@@ -54,3 +54,25 @@ func ExampleServer() {
fmt.Printf("%s", greeting) fmt.Printf("%s", greeting)
// Output: Hello, client // Output: Hello, client
} }
func ExampleNewTLSServer() {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
// Output: Hello, client
}
...@@ -9,6 +9,7 @@ package httptest ...@@ -9,6 +9,7 @@ package httptest
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509"
"flag" "flag"
"fmt" "fmt"
"log" "log"
...@@ -35,6 +36,9 @@ type Server struct { ...@@ -35,6 +36,9 @@ type Server struct {
// before Start or StartTLS. // before Start or StartTLS.
Config *http.Server Config *http.Server
// certificate is a parsed version of the TLS config certificate, if present.
certificate *x509.Certificate
// wg counts the number of outstanding HTTP requests on this server. // wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished. // Close blocks until all requests are finished.
wg sync.WaitGroup wg sync.WaitGroup
...@@ -42,6 +46,10 @@ type Server struct { ...@@ -42,6 +46,10 @@ type Server struct {
mu sync.Mutex // guards closed and conns mu sync.Mutex // guards closed and conns
closed bool closed bool
conns map[net.Conn]http.ConnState // except terminal states conns map[net.Conn]http.ConnState // except terminal states
// client is configured for use with the server.
// Its transport is automatically closed when Close is called.
client *http.Client
} }
func newLocalListener() net.Listener { func newLocalListener() net.Listener {
...@@ -85,6 +93,7 @@ func NewUnstartedServer(handler http.Handler) *Server { ...@@ -85,6 +93,7 @@ func NewUnstartedServer(handler http.Handler) *Server {
return &Server{ return &Server{
Listener: newLocalListener(), Listener: newLocalListener(),
Config: &http.Server{Handler: handler}, Config: &http.Server{Handler: handler},
client: &http.Client{},
} }
} }
...@@ -124,6 +133,17 @@ func (s *Server) StartTLS() { ...@@ -124,6 +133,17 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 { if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert} s.TLS.Certificates = []tls.Certificate{cert}
} }
s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
certpool := x509.NewCertPool()
certpool.AddCert(s.certificate)
s.client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
}
s.Listener = tls.NewListener(s.Listener, s.TLS) s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String() s.URL = "https://" + s.Listener.Addr().String()
s.wrap() s.wrap()
...@@ -186,6 +206,11 @@ func (s *Server) Close() { ...@@ -186,6 +206,11 @@ func (s *Server) Close() {
t.CloseIdleConnections() t.CloseIdleConnections()
} }
// Also close the client idle connections.
if t, ok := s.client.Transport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
s.wg.Wait() s.wg.Wait()
} }
...@@ -228,6 +253,19 @@ func (s *Server) CloseClientConnections() { ...@@ -228,6 +253,19 @@ func (s *Server) CloseClientConnections() {
} }
} }
// Certificate returns the certificate used by the server, or nil if
// the server doesn't use TLS.
func (s *Server) Certificate() *x509.Certificate {
return s.certificate
}
// Client returns an HTTP client configured for making requests to the server.
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on Server.Close.
func (s *Server) Client() *http.Client {
return s.client
}
func (s *Server) goServe() { func (s *Server) goServe() {
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
......
...@@ -22,6 +22,7 @@ func TestServer(t *testing.T) { ...@@ -22,6 +22,7 @@ func TestServer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
got, err := ioutil.ReadAll(res.Body) got, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -98,3 +99,25 @@ func TestServerCloseClientConnections(t *testing.T) { ...@@ -98,3 +99,25 @@ func TestServerCloseClientConnections(t *testing.T) {
t.Fatalf("Unexpected response: %#v", res) t.Fatalf("Unexpected response: %#v", res)
} }
} }
// Tests that the Server.Client method works and returns an http.Client that can hit
// NewTLSServer without cert warnings.
func TestServerClient(t *testing.T) {
ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got))
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment