Commit a73d8f5a authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make Transport send WebSocket upgrade requests over HTTP/1

WebSockets requires HTTP/1 in practice (no spec or implementations
work over HTTP/2), so if we get an HTTP request that looks like it's
trying to initiate WebSockets, use HTTP/1, like browsers do.

This is part of a series of commits to make WebSockets work over
httputil.ReverseProxy. See #26937.

Updates #26937

Change-Id: I6ad3df9b0a21fddf62fa7d9cacef48e7d5d9585b
Reviewed-on: https://go-review.googlesource.com/c/137437
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarDmitri Shuralyov <dmitshur@golang.org>
parent 3aa3c052
...@@ -252,7 +252,7 @@ type slurpResult struct { ...@@ -252,7 +252,7 @@ type slurpResult struct {
func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
if res.Proto == wantProto { if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
} else { } else {
t.Errorf("got %q response; want %q", res.Proto, wantProto) t.Errorf("got %q response; want %q", res.Proto, wantProto)
...@@ -1546,3 +1546,25 @@ func TestBidiStreamReverseProxy(t *testing.T) { ...@@ -1546,3 +1546,25 @@ func TestBidiStreamReverseProxy(t *testing.T) {
} }
} }
// Always use HTTP/1.1 for WebSocket upgrades.
func TestH12_WebSocketUpgrade(t *testing.T) {
h12Compare{
Handler: func(w ResponseWriter, r *Request) {
h := w.Header()
h.Set("Foo", "bar")
},
ReqFunc: func(c *Client, url string) (*Response, error) {
req, _ := NewRequest("GET", url, nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "WebSocket")
return c.Do(req)
},
EarlyCheckResponse: func(proto string, res *Response) {
if res.Proto != "HTTP/1.1" {
t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
}
res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
},
}.run(t)
}
...@@ -155,7 +155,7 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string { ...@@ -155,7 +155,7 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string {
func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
t.idleMu.Lock() t.idleMu.Lock()
defer t.idleMu.Unlock() defer t.idleMu.Unlock()
key := connectMethodKey{"", scheme, addr} key := connectMethodKey{"", scheme, addr, false}
cacheKey := key.String() cacheKey := key.String()
for k, conns := range t.idleConn { for k, conns := range t.idleConn {
if k.String() == cacheKey { if k.String() == cacheKey {
...@@ -178,12 +178,12 @@ func (t *Transport) IsIdleForTesting() bool { ...@@ -178,12 +178,12 @@ func (t *Transport) IsIdleForTesting() bool {
} }
func (t *Transport) RequestIdleConnChForTesting() { func (t *Transport) RequestIdleConnChForTesting() {
t.getIdleConnCh(connectMethod{nil, "http", "example.com"}) t.getIdleConnCh(connectMethod{nil, "http", "example.com", false})
} }
func (t *Transport) PutIdleTestConn(scheme, addr string) bool { func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
c, _ := net.Pipe() c, _ := net.Pipe()
key := connectMethodKey{"", scheme, addr} key := connectMethodKey{"", scheme, addr, false}
select { select {
case <-t.incHostConnCount(key): case <-t.incHostConnCount(key):
default: default:
......
...@@ -35,7 +35,7 @@ func TestCacheKeys(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestCacheKeys(t *testing.T) {
} }
proxy = u proxy = u
} }
cm := connectMethod{proxy, tt.scheme, tt.addr} cm := connectMethod{proxy, tt.scheme, tt.addr, false}
if got := cm.key().String(); got != tt.key { if got := cm.key().String(); got != tt.key {
t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key) t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key)
} }
......
...@@ -1371,3 +1371,10 @@ func requestMethodUsuallyLacksBody(method string) bool { ...@@ -1371,3 +1371,10 @@ func requestMethodUsuallyLacksBody(method string) bool {
} }
return false return false
} }
// requiresHTTP1 reports whether this request requires being sent on
// an HTTP/1 connection.
func (r *Request) requiresHTTP1() bool {
return hasToken(r.Header.Get("Connection"), "upgrade") &&
strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
...@@ -382,6 +382,19 @@ func (tr *transportRequest) setError(err error) { ...@@ -382,6 +382,19 @@ func (tr *transportRequest) setError(err error) {
tr.mu.Unlock() tr.mu.Unlock()
} }
// useRegisteredProtocol reports whether an alternate protocol (as reqistered
// with Transport.RegisterProtocol) should be respected for this request.
func (t *Transport) useRegisteredProtocol(req *Request) bool {
if req.URL.Scheme == "https" && req.requiresHTTP1() {
// If this request requires HTTP/1, don't use the
// "https" alternate protocol, which is used by the
// HTTP/2 code to take over requests if there's an
// existing cached HTTP/2 connection.
return false
}
return true
}
// roundTrip implements a RoundTripper over HTTP. // roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) { func (t *Transport) roundTrip(req *Request) (*Response, error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
...@@ -411,10 +424,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { ...@@ -411,10 +424,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
} }
} }
altProto, _ := t.altProto.Load().(map[string]RoundTripper) if t.useRegisteredProtocol(req) {
if altRT := altProto[scheme]; altRT != nil { altProto, _ := t.altProto.Load().(map[string]RoundTripper)
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { if altRT := altProto[scheme]; altRT != nil {
return resp, err if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
return resp, err
}
} }
} }
if !isHTTP { if !isHTTP {
...@@ -653,6 +668,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM ...@@ -653,6 +668,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM
} }
} }
} }
cm.onlyH1 = treq.requiresHTTP1()
return cm, err return cm, err
} }
...@@ -1155,6 +1171,9 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro ...@@ -1155,6 +1171,9 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
if cfg.ServerName == "" { if cfg.ServerName == "" {
cfg.ServerName = name cfg.ServerName = name
} }
if pconn.cacheKey.onlyH1 {
cfg.NextProtos = nil
}
plainConn := pconn.conn plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg) tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2) errc := make(chan error, 2)
...@@ -1361,10 +1380,11 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) { ...@@ -1361,10 +1380,11 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
// //
// A connect method may be of the following types: // A connect method may be of the following types:
// //
// Cache key form Description // connectMethod.key().String() Description
// ----------------- ------------------------- // ------------------------------ -------------------------
// |http|foo.com http directly to server, no proxy // |http|foo.com http directly to server, no proxy
// |https|foo.com https directly to server, no proxy // |https|foo.com https directly to server, no proxy
// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy
// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com // http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
// http://proxy.com|http http to proxy, http to anywhere after that // http://proxy.com|http http to proxy, http to anywhere after that
// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com // socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com
...@@ -1379,6 +1399,7 @@ type connectMethod struct { ...@@ -1379,6 +1399,7 @@ type connectMethod struct {
// then targetAddr is not included in the connect method key, because the socket can // then targetAddr is not included in the connect method key, because the socket can
// be reused for different targetAddr values. // be reused for different targetAddr values.
targetAddr string targetAddr string
onlyH1 bool // whether to disable HTTP/2 and force HTTP/1
} }
func (cm *connectMethod) key() connectMethodKey { func (cm *connectMethod) key() connectMethodKey {
...@@ -1394,6 +1415,7 @@ func (cm *connectMethod) key() connectMethodKey { ...@@ -1394,6 +1415,7 @@ func (cm *connectMethod) key() connectMethodKey {
proxy: proxyStr, proxy: proxyStr,
scheme: cm.targetScheme, scheme: cm.targetScheme,
addr: targetAddr, addr: targetAddr,
onlyH1: cm.onlyH1,
} }
} }
...@@ -1428,11 +1450,16 @@ func (cm *connectMethod) tlsHost() string { ...@@ -1428,11 +1450,16 @@ func (cm *connectMethod) tlsHost() string {
// a URL. // a URL.
type connectMethodKey struct { type connectMethodKey struct {
proxy, scheme, addr string proxy, scheme, addr string
onlyH1 bool
} }
func (k connectMethodKey) String() string { func (k connectMethodKey) String() string {
// Only used by tests. // Only used by tests.
return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr) var h1 string
if k.onlyH1 {
h1 = ",h1"
}
return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr)
} }
// persistConn wraps a connection, usually a persistent one // persistConn wraps a connection, usually a persistent one
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment