Commit b9ad2787 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http: RoundTrippers shouldn't mutate Request

Fixes #2146

R=rsc
CC=golang-dev
https://golang.org/cl/5284041
parent 236aff31
...@@ -56,9 +56,10 @@ type RoundTripper interface { ...@@ -56,9 +56,10 @@ type RoundTripper interface {
// higher-level protocol details such as redirects, // higher-level protocol details such as redirects,
// authentication, or cookies. // authentication, or cookies.
// //
// RoundTrip may modify the request. The request Headers field is // RoundTrip should not modify the request, except for
// guaranteed to be initialized. // consuming the Body. The request's URL and Header fields
RoundTrip(req *Request) (resp *Response, err os.Error) // are guaranteed to be initialized.
RoundTrip(*Request) (*Response, os.Error)
} }
// Given a string of the form "host", "host:port", or "[ipv6::address]:port", // Given a string of the form "host", "host:port", or "[ipv6::address]:port",
...@@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) { ...@@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) {
if t == nil { if t == nil {
t = DefaultTransport t = DefaultTransport
if t == nil { if t == nil {
err = os.NewError("no http.Client.Transport or http.DefaultTransport") err = os.NewError("http: no Client.Transport or DefaultTransport")
return return
} }
} }
if req.URL == nil {
return nil, os.NewError("http: nil Request.URL")
}
// Most the callers of send (Get, Post, et al) don't need // Most the callers of send (Get, Post, et al) don't need
// Headers, leaving it uninitialized. We guarantee to the // Headers, leaving it uninitialized. We guarantee to the
// Transport that this has been initialized, though. // Transport that this has been initialized, though.
......
...@@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package" ...@@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package"
// hasn't been set to "identity", Write adds "Transfer-Encoding: // hasn't been set to "identity", Write adds "Transfer-Encoding:
// chunked" to the header. Body is closed after it is sent. // chunked" to the header. Body is closed after it is sent.
func (req *Request) Write(w io.Writer) os.Error { func (req *Request) Write(w io.Writer) os.Error {
return req.write(w, false) return req.write(w, false, nil)
} }
// WriteProxy is like Write but writes the request in the form // WriteProxy is like Write but writes the request in the form
...@@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error { ...@@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error {
// either case, WriteProxy also writes a Host header, using either // either case, WriteProxy also writes a Host header, using either
// req.Host or req.URL.Host. // req.Host or req.URL.Host.
func (req *Request) WriteProxy(w io.Writer) os.Error { func (req *Request) WriteProxy(w io.Writer) os.Error {
return req.write(w, true) return req.write(w, true, nil)
} }
func (req *Request) dumpWrite(w io.Writer) os.Error { func (req *Request) dumpWrite(w io.Writer) os.Error {
...@@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error { ...@@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error {
return nil return nil
} }
func (req *Request) write(w io.Writer, usingProxy bool) os.Error { // extraHeaders may be nil
func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error {
host := req.Host host := req.Host
if host == "" { if host == "" {
if req.URL == nil { if req.URL == nil {
...@@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { ...@@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
return err return err
} }
if extraHeaders != nil {
err = extraHeaders.Write(bw)
if err != nil {
return err
}
}
io.WriteString(bw, "\r\n") io.WriteString(bw, "\r\n")
// Write body and trailer // Write body and trailer
......
...@@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) { ...@@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) {
} }
} }
// transportRequest is a wrapper around a *Request that adds
// optional extra headers to write.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
}
func (tr *transportRequest) extraHeaders() Header {
if tr.extra == nil {
tr.extra = make(Header)
}
return tr.extra
}
// RoundTrip implements the RoundTripper interface. // RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
if req.URL == nil { if req.URL == nil {
return nil, os.NewError("http: nil Request.URL") return nil, os.NewError("http: nil Request.URL")
} }
if req.Header == nil {
return nil, os.NewError("http: nil Request.Header")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" { if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
t.lk.Lock() t.lk.Lock()
var rt RoundTripper var rt RoundTripper
...@@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { ...@@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
} }
return rt.RoundTrip(req) return rt.RoundTrip(req)
} }
treq := &transportRequest{Request: req}
cm, err := t.connectMethodForRequest(req) cm, err := t.connectMethodForRequest(treq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { ...@@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
return nil, err return nil, err
} }
return pconn.roundTrip(req) return pconn.roundTrip(treq)
} }
// RegisterProtocol registers a new protocol with scheme. // RegisterProtocol registers a new protocol with scheme.
...@@ -185,14 +202,14 @@ func getenvEitherCase(k string) string { ...@@ -185,14 +202,14 @@ func getenvEitherCase(k string) string {
return os.Getenv(strings.ToLower(k)) return os.Getenv(strings.ToLower(k))
} }
func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) {
cm := &connectMethod{ cm := &connectMethod{
targetScheme: req.URL.Scheme, targetScheme: treq.URL.Scheme,
targetAddr: canonicalAddr(req.URL), targetAddr: canonicalAddr(treq.URL),
} }
if t.Proxy != nil { if t.Proxy != nil {
var err os.Error var err os.Error
cm.proxyURL, err = t.Proxy(req) cm.proxyURL, err = t.Proxy(treq.Request)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { ...@@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
conn: conn, conn: conn,
reqch: make(chan requestAndChan, 50), reqch: make(chan requestAndChan, 50),
} }
newClientConnFunc := NewClientConn
switch { switch {
case cm.proxyURL == nil: case cm.proxyURL == nil:
// Do nothing. // Do nothing.
case cm.targetScheme == "http": case cm.targetScheme == "http":
newClientConnFunc = NewProxyClientConn pconn.isProxy = true
if pa != "" { if pa != "" {
pconn.mutateRequestFunc = func(req *Request) { pconn.mutateHeaderFunc = func(h Header) {
if req.Header == nil { h.Set("Proxy-Authorization", pa)
req.Header = make(Header)
}
req.Header.Set("Proxy-Authorization", pa)
} }
} }
case cm.targetScheme == "https": case cm.targetScheme == "https":
...@@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { ...@@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
} }
pconn.br = bufio.NewReader(pconn.conn) pconn.br = bufio.NewReader(pconn.conn)
pconn.cc = newClientConnFunc(conn, pconn.br) pconn.cc = NewClientConn(conn, pconn.br)
go pconn.readLoop() go pconn.readLoop()
return pconn, nil return pconn, nil
} }
...@@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string { ...@@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string {
return h return h
} }
type readResult struct {
res *Response // either res or err will be set
err os.Error
}
type writeRequest struct {
// Set by client (in pc.roundTrip)
req *Request
resch chan *readResult
// Set by writeLoop if an error writing headers.
writeErr os.Error
}
// persistConn wraps a connection, usually a persistent one // persistConn wraps a connection, usually a persistent one
// (but may be used for non-keep-alive requests as well) // (but may be used for non-keep-alive requests as well)
type persistConn struct { type persistConn struct {
t *Transport t *Transport
cacheKey string // its connectMethod.String() cacheKey string // its connectMethod.String()
conn net.Conn conn net.Conn
cc *ClientConn cc *ClientConn
br *bufio.Reader br *bufio.Reader
reqch chan requestAndChan // written by roundTrip(); read by readLoop() reqch chan requestAndChan // written by roundTrip(); read by readLoop()
mutateRequestFunc func(*Request) // nil or func to modify each outbound request isProxy bool
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified)
mutateHeaderFunc func(Header)
lk sync.Mutex // guards numExpectedResponses and broken lk sync.Mutex // guards numExpectedResponses and broken
numExpectedResponses int numExpectedResponses int
...@@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() { ...@@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() {
if err != nil || resp.ContentLength == 0 { if err != nil || resp.ContentLength == 0 {
return resp, err return resp, err
} }
if rc.addedGzip {
forReq.Header.Del("Accept-Encoding")
}
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length") resp.Header.Del("Content-Length")
...@@ -604,9 +605,9 @@ type requestAndChan struct { ...@@ -604,9 +605,9 @@ type requestAndChan struct {
addedGzip bool addedGzip bool
} }
func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) {
if pc.mutateRequestFunc != nil { if pc.mutateHeaderFunc != nil {
pc.mutateRequestFunc(req) pc.mutateHeaderFunc(req.extraHeaders())
} }
// Ask for a compressed version if the caller didn't set their // Ask for a compressed version if the caller didn't set their
...@@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { ...@@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
requestedGzip := false requestedGzip := false
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
// Request gzip only, not deflate. Deflate is ambiguous and // Request gzip only, not deflate. Deflate is ambiguous and
// as universally supported anyway. // not as universally supported anyway.
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38 // See: http://www.gzip.org/zlib/zlib_faq.html#faq38
requestedGzip = true requestedGzip = true
req.Header.Set("Accept-Encoding", "gzip") req.extraHeaders().Set("Accept-Encoding", "gzip")
} }
pc.lk.Lock() pc.lk.Lock()
pc.numExpectedResponses++ pc.numExpectedResponses++
pc.lk.Unlock() pc.lk.Unlock()
err = pc.cc.Write(req) pc.cc.writeReq = func(r *Request, w io.Writer) os.Error {
return r.write(w, pc.isProxy, req.extra)
}
err = pc.cc.Write(req.Request)
if err != nil { if err != nil {
pc.close() pc.close()
return return
} }
ch := make(chan responseAndError, 1) ch := make(chan responseAndError, 1)
pc.reqch <- requestAndChan{req, ch, requestedGzip} pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
re := <-ch re := <-ch
pc.lk.Lock() pc.lk.Lock()
pc.numExpectedResponses-- pc.numExpectedResponses--
...@@ -648,7 +653,7 @@ func (pc *persistConn) close() { ...@@ -648,7 +653,7 @@ func (pc *persistConn) close() {
pc.broken = true pc.broken = true
pc.cc.Close() pc.cc.Close()
pc.conn.Close() pc.conn.Close()
pc.mutateRequestFunc = nil pc.mutateHeaderFunc = nil
} }
var portMap = map[string]string{ var portMap = map[string]string{
......
...@@ -372,7 +372,8 @@ var roundTripTests = []struct { ...@@ -372,7 +372,8 @@ var roundTripTests = []struct {
// Requests with other accept-encoding should pass through unmodified // Requests with other accept-encoding should pass through unmodified
{"foo", "foo", false}, {"foo", "foo", false},
// Requests with accept-encoding == gzip should be passed through // Requests with accept-encoding == gzip should be passed through
{"gzip", "gzip", true}} {"gzip", "gzip", true},
}
// Test that the modification made to the Request by the RoundTripper is cleaned up // Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) { func TestRoundTripGzip(t *testing.T) {
...@@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) { ...@@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding") accept := req.Header.Get("Accept-Encoding")
if expect := req.FormValue("expect_accept"); accept != expect { if expect := req.FormValue("expect_accept"); accept != expect {
t.Errorf("Accept-Encoding = %q, want %q", accept, expect) t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
req.FormValue("testnum"), accept, expect)
} }
if accept == "gzip" { if accept == "gzip" {
rw.Header().Set("Content-Encoding", "gzip") rw.Header().Set("Content-Encoding", "gzip")
...@@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) { ...@@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) {
for i, test := range roundTripTests { for i, test := range roundTripTests {
// Test basic request (no accept-encoding) // Test basic request (no accept-encoding)
req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil) req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
req.Header.Set("Accept-Encoding", test.accept) if test.accept != "" {
req.Header.Set("Accept-Encoding", test.accept)
}
res, err := DefaultTransport.RoundTrip(req) res, err := DefaultTransport.RoundTrip(req)
var body []byte var body []byte
if test.compressed { if test.compressed {
...@@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) { ...@@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) {
} }
if err != nil { if err != nil {
t.Errorf("%d. Error: %q", i, err) t.Errorf("%d. Error: %q", i, err)
} else { continue
if g, e := string(body), responseBody; g != e { }
t.Errorf("%d. body = %q; want %q", i, g, e) if g, e := string(body), responseBody; g != e {
} t.Errorf("%d. body = %q; want %q", i, g, e)
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { }
t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e) if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
} t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { }
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
} t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment