Commit 5375c712 authored by Emmanuel T Odeke's avatar Emmanuel T Odeke Committed by Emmanuel Odeke

net/http/httptest: add EnableHTTP2 to Server

Adds a knob EnableHTTP2, that enables an unstarted
Server and its respective client to speak HTTP/2,
but only after StartTLS has been invoked.

Fixes #34939

Change-Id: I287c568b8708a4d3c03e7d9eca7c323b8f4c65b6
Reviewed-on: https://go-review.googlesource.com/c/go/+/201557
Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent dc3b7883
...@@ -55,6 +55,28 @@ func ExampleServer() { ...@@ -55,6 +55,28 @@ func ExampleServer() {
// Output: Hello, client // Output: Hello, client
} }
func ExampleServer_hTTP2() {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %s", r.Proto)
}))
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()
res, err := ts.Client().Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
// Output: Hello, HTTP/2.0
}
func ExampleNewTLSServer() { func ExampleNewTLSServer() {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client") fmt.Fprintln(w, "Hello, client")
......
...@@ -27,6 +27,11 @@ type Server struct { ...@@ -27,6 +27,11 @@ type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash URL string // base URL of form http://ipaddr:port with no trailing slash
Listener net.Listener Listener net.Listener
// EnableHTTP2 controls whether HTTP/2 is enabled
// on the server. It must be set between calling
// NewUnstartedServer and calling Server.StartTLS.
EnableHTTP2 bool
// TLS is the optional TLS configuration, populated with a new config // TLS is the optional TLS configuration, populated with a new config
// after TLS is started. If set on an unstarted server before StartTLS // after TLS is started. If set on an unstarted server before StartTLS
// is called, existing fields are copied into the new config. // is called, existing fields are copied into the new config.
...@@ -151,7 +156,11 @@ func (s *Server) StartTLS() { ...@@ -151,7 +156,11 @@ func (s *Server) StartTLS() {
s.TLS = new(tls.Config) s.TLS = new(tls.Config)
} }
if s.TLS.NextProtos == nil { if s.TLS.NextProtos == nil {
s.TLS.NextProtos = []string{"http/1.1"} nextProtos := []string{"http/1.1"}
if s.EnableHTTP2 {
nextProtos = []string{"h2"}
}
s.TLS.NextProtos = nextProtos
} }
if len(s.TLS.Certificates) == 0 { if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert} s.TLS.Certificates = []tls.Certificate{cert}
...@@ -166,6 +175,7 @@ func (s *Server) StartTLS() { ...@@ -166,6 +175,7 @@ func (s *Server) StartTLS() {
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
RootCAs: certpool, RootCAs: certpool,
}, },
ForceAttemptHTTP2: s.EnableHTTP2,
} }
s.Listener = tls.NewListener(s.Listener, s.TLS) s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String() s.URL = "https://" + s.Listener.Addr().String()
......
...@@ -202,3 +202,39 @@ func TestServerZeroValueClose(t *testing.T) { ...@@ -202,3 +202,39 @@ func TestServerZeroValueClose(t *testing.T) {
ts.Close() // tests that it doesn't panic ts.Close() // tests that it doesn't panic
} }
func TestTLSServerWithHTTP2(t *testing.T) {
modes := []struct {
name string
wantProto string
}{
{"http1", "HTTP/1.1"},
{"http2", "HTTP/2.0"},
}
for _, tt := range modes {
t.Run(tt.name, func(t *testing.T) {
cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Proto", r.Proto)
}))
switch tt.name {
case "http2":
cst.EnableHTTP2 = true
cst.StartTLS()
default:
cst.Start()
}
defer cst.Close()
res, err := cst.Client().Get(cst.URL)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
}
})
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment