Commit fa23a700 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick Committed by Russ Cox

http: fix two Transport gzip+persist crashes

There were a couple issues:

-- HEAD requests were attempting to be ungzipped,
   despite having no content.  That was fixed in
   the previous patch version, but ultimately was
   fixed as a result of other refactoring:

-- persist.go's ClientConn "lastbody" field was
   remembering the wrong body, since we were
   mucking with it later. Instead, ditch
   ClientConn's readRes func field and add a new
   method passing it in, so we can use a closure
   and do all our bodyEOFSignal + gunzip stuff
   in one place, simplifying a lot of code and
   not requiring messing with ClientConn's innards.

-- closing the gzip reader didn't consume its
   contents.  if the caller wasn't done reading
   all the response body and ClientConn closed it
   (thinking it'd move past those bytes in the
   TCP stream), it actually wouldn't.  so introduce
   a new wrapper just for gzip reader to have its
   Close method do an ioutil.Discard on its body
   first, before the close.

Fixes #1725
Fixes #1804

R=rsc, eivind
CC=golang-dev
https://golang.org/cl/4523058
parent d080a1cf
...@@ -222,7 +222,6 @@ type ClientConn struct { ...@@ -222,7 +222,6 @@ type ClientConn struct {
pipe textproto.Pipeline pipe textproto.Pipeline
writeReq func(*Request, io.Writer) os.Error writeReq func(*Request, io.Writer) os.Error
readRes func(buf *bufio.Reader, method string) (*Response, os.Error)
} }
// NewClientConn returns a new ClientConn reading and writing c. If r is not // NewClientConn returns a new ClientConn reading and writing c. If r is not
...@@ -236,7 +235,6 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { ...@@ -236,7 +235,6 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
r: r, r: r,
pipereq: make(map[*Request]uint), pipereq: make(map[*Request]uint),
writeReq: (*Request).Write, writeReq: (*Request).Write,
readRes: ReadResponse,
} }
} }
...@@ -339,8 +337,13 @@ func (cc *ClientConn) Pending() int { ...@@ -339,8 +337,13 @@ func (cc *ClientConn) Pending() int {
// returned together with an ErrPersistEOF, which means that the remote // returned together with an ErrPersistEOF, which means that the remote
// requested that this be the last request serviced. Read can be called // requested that this be the last request serviced. Read can be called
// concurrently with Write, but not with another Read. // concurrently with Write, but not with another Read.
func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { func (cc *ClientConn) Read(req *Request) (*Response, os.Error) {
return cc.readUsing(req, ReadResponse)
}
// readUsing is the implementation of Read with a replaceable
// ReadResponse-like function, used by the Transport.
func (cc *ClientConn) readUsing(req *Request, readRes func(buf *bufio.Reader, method string) (*Response, os.Error)) (resp *Response, err os.Error) {
// Retrieve the pipeline ID of this request/response pair // Retrieve the pipeline ID of this request/response pair
cc.lk.Lock() cc.lk.Lock()
id, ok := cc.pipereq[req] id, ok := cc.pipereq[req]
...@@ -383,7 +386,7 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) { ...@@ -383,7 +386,7 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) {
} }
} }
resp, err = cc.readRes(r, req.Method) resp, err = readRes(r, req.Method)
cc.lk.Lock() cc.lk.Lock()
defer cc.lk.Unlock() defer cc.lk.Unlock()
if err != nil { if err != nil {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"os" "os"
...@@ -285,7 +286,6 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { ...@@ -285,7 +286,6 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
pconn.br = bufio.NewReader(pconn.conn) pconn.br = bufio.NewReader(pconn.conn)
pconn.cc = newClientConnFunc(conn, pconn.br) pconn.cc = newClientConnFunc(conn, pconn.br)
pconn.cc.readRes = readResponseWithEOFSignal
go pconn.readLoop() go pconn.readLoop()
return pconn, nil return pconn, nil
} }
...@@ -447,7 +447,25 @@ func (pc *persistConn) readLoop() { ...@@ -447,7 +447,25 @@ func (pc *persistConn) readLoop() {
} }
rc := <-pc.reqch rc := <-pc.reqch
resp, err := pc.cc.Read(rc.req) resp, err := pc.cc.readUsing(rc.req, func(buf *bufio.Reader, reqMethod string) (*Response, os.Error) {
resp, err := ReadResponse(buf, reqMethod)
if err != nil || resp.ContentLength == 0 {
return resp, err
}
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
pc.close()
return nil, err
}
resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body}
}
resp.Body = &bodyEOFSignal{body: resp.Body}
return resp, err
})
if err == ErrPersistEOF { if err == ErrPersistEOF {
// Succeeded, but we can't send any more // Succeeded, but we can't send any more
...@@ -502,6 +520,11 @@ type responseAndError struct { ...@@ -502,6 +520,11 @@ type responseAndError struct {
type requestAndChan struct { type requestAndChan struct {
req *Request req *Request
ch chan responseAndError ch chan responseAndError
// did the Transport (as opposed to the client code) add an
// Accept-Encoding gzip header? only if it we set it do
// we transparently decode the gzip.
addedGzip bool
} }
func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
...@@ -533,25 +556,12 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { ...@@ -533,25 +556,12 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
} }
ch := make(chan responseAndError, 1) ch := make(chan responseAndError, 1)
pc.reqch <- requestAndChan{req, ch} pc.reqch <- requestAndChan{req, ch, requestedGzip}
re := <-ch re := <-ch
pc.lk.Lock() pc.lk.Lock()
pc.numExpectedResponses-- pc.numExpectedResponses--
pc.lk.Unlock() pc.lk.Unlock()
if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" {
re.res.Header.Del("Content-Encoding")
re.res.Header.Del("Content-Length")
re.res.ContentLength = -1
esb := re.res.Body.(*bodyEOFSignal)
gzReader, err := gzip.NewReader(esb.body)
if err != nil {
pc.close()
return nil, err
}
esb.body = &readFirstCloseBoth{gzReader, esb.body}
}
return re.res, re.err return re.res, re.err
} }
...@@ -583,16 +593,6 @@ func responseIsKeepAlive(res *Response) bool { ...@@ -583,16 +593,6 @@ func responseIsKeepAlive(res *Response) bool {
return false return false
} }
// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces
// the response body with a bodyEOFSignal-wrapped version.
func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) {
resp, err = ReadResponse(r, requestMethod)
if err == nil && resp.ContentLength != 0 {
resp.Body = &bodyEOFSignal{body: resp.Body}
}
return
}
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
// once, right before the final Read() or Close() call returns, but after // once, right before the final Read() or Close() call returns, but after
// EOF has been seen. // EOF has been seen.
...@@ -615,6 +615,9 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { ...@@ -615,6 +615,9 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) {
} }
func (es *bodyEOFSignal) Close() (err os.Error) { func (es *bodyEOFSignal) Close() (err os.Error) {
if es.isClosed {
return nil
}
es.isClosed = true es.isClosed = true
err = es.body.Close() err = es.body.Close()
if err == nil && es.fn != nil { if err == nil && es.fn != nil {
...@@ -639,3 +642,13 @@ func (r *readFirstCloseBoth) Close() os.Error { ...@@ -639,3 +642,13 @@ func (r *readFirstCloseBoth) Close() os.Error {
} }
return nil return nil
} }
// discardOnCloseReadCloser consumes all its input on Close.
type discardOnCloseReadCloser struct {
io.ReadCloser
}
func (d *discardOnCloseReadCloser) Close() os.Error {
io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed
return d.ReadCloser.Close()
}
...@@ -394,6 +394,9 @@ func TestTransportGzip(t *testing.T) { ...@@ -394,6 +394,9 @@ func TestTransportGzip(t *testing.T) {
t.Errorf("Accept-Encoding = %q, want %q", g, e) t.Errorf("Accept-Encoding = %q, want %q", g, e)
} }
rw.Header().Set("Content-Encoding", "gzip") rw.Header().Set("Content-Encoding", "gzip")
if req.Method == "HEAD" {
return
}
var w io.Writer = rw var w io.Writer = rw
var buf bytes.Buffer var buf bytes.Buffer
...@@ -463,6 +466,16 @@ func TestTransportGzip(t *testing.T) { ...@@ -463,6 +466,16 @@ func TestTransportGzip(t *testing.T) {
t.Errorf("expected Read error after Close; got %d, %v", n, err) t.Errorf("expected Read error after Close; got %d, %v", n, err)
} }
} }
// And a HEAD request too, because they're always weird.
c := &Client{Transport: &Transport{}}
res, err := c.Head(ts.URL)
if err != nil {
t.Fatalf("Head: %v", err)
}
if res.StatusCode != 200 {
t.Errorf("Head status=%d; want=200", res.StatusCode)
}
} }
// TestTransportGzipRecursive sends a gzip quine and checks that the // TestTransportGzipRecursive sends a gzip quine and checks that the
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment