Commit 5b588e66 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make http2 Transport send Content Length

Updates x/net/http2 to git rev 5c0dae8 for https://golang.org/cl/18709

Fixes #14003

Change-Id: I8bc205d6d089107b017e3458bbc7e05f6d0cae60
Reviewed-on: https://go-review.googlesource.com/18730Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 49234ee2
...@@ -421,6 +421,35 @@ func TestH12_ServerEmptyContentLength(t *testing.T) { ...@@ -421,6 +421,35 @@ func TestH12_ServerEmptyContentLength(t *testing.T) {
}.run(t) }.run(t)
} }
func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
}
func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0)
}
func TestH12_RequestContentLength_Unknown(t *testing.T) {
h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
}
func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
h12Compare{
Handler: func(w ResponseWriter, r *Request) {
w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
},
ReqFunc: func(c *Client, url string) (*Response, error) {
return c.Post(url, "text/plain", bodyfn())
},
CheckResponse: func(proto string, res *Response) {
if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
t.Errorf("Proto %q got length %q; want %q", proto, got, want)
}
},
}.run(t)
}
// Tests that closing the Request.Cancel channel also while still // Tests that closing the Request.Cancel channel also while still
// reading the response body. Issue 13159. // reading the response body. Issue 13159.
func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
......
...@@ -4779,6 +4779,25 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { ...@@ -4779,6 +4779,25 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
} }
hasTrailers := trailers != "" hasTrailers := trailers != ""
var body io.Reader = req.Body
contentLen := req.ContentLength
if req.Body != nil && contentLen == 0 {
// Test to see if it's actually zero or just unset.
var buf [1]byte
n, rerr := io.ReadFull(body, buf[:])
if rerr != nil && rerr != io.EOF {
contentLen = -1
body = http2errorReader{rerr}
} else if n == 1 {
contentLen = -1
body = io.MultiReader(bytes.NewReader(buf[:]), body)
} else {
body = nil
}
}
cc.mu.Lock() cc.mu.Lock()
if cc.closed || !cc.canTakeNewRequestLocked() { if cc.closed || !cc.canTakeNewRequestLocked() {
cc.mu.Unlock() cc.mu.Unlock()
...@@ -4787,7 +4806,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { ...@@ -4787,7 +4806,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
cs := cc.newStream() cs := cc.newStream()
cs.req = req cs.req = req
hasBody := req.Body != nil hasBody := body != nil
if !cc.t.disableCompression() && if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Accept-Encoding") == "" &&
...@@ -4797,7 +4816,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { ...@@ -4797,7 +4816,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
cs.requestedGzip = true cs.requestedGzip = true
} }
hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers) hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
cc.wmu.Lock() cc.wmu.Lock()
endStream := !hasBody && !hasTrailers endStream := !hasBody && !hasTrailers
werr := cc.writeHeaders(cs.ID, endStream, hdrs) werr := cc.writeHeaders(cs.ID, endStream, hdrs)
...@@ -4817,7 +4836,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { ...@@ -4817,7 +4836,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
if hasBody { if hasBody {
bodyCopyErrc = make(chan error, 1) bodyCopyErrc = make(chan error, 1)
go func() { go func() {
bodyCopyErrc <- cs.writeRequestBody(req.Body) bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
}() }()
} }
...@@ -4901,7 +4920,7 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs [] ...@@ -4901,7 +4920,7 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []
// It doesn't escape to callers. // It doesn't escape to callers.
var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write") var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write")
func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
cc := cs.cc cc := cs.cc
sentEnd := false sentEnd := false
buf := cc.frameScratchBuffer() buf := cc.frameScratchBuffer()
...@@ -4909,7 +4928,7 @@ func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { ...@@ -4909,7 +4928,7 @@ func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) {
defer func() { defer func() {
cerr := body.Close() cerr := bodyCloser.Close()
if err == nil { if err == nil {
err = cerr err = cerr
} }
...@@ -5016,7 +5035,7 @@ type http2badStringError struct { ...@@ -5016,7 +5035,7 @@ type http2badStringError struct {
func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
// requires cc.mu be held. // requires cc.mu be held.
func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string) []byte { func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) []byte {
cc.hbuf.Reset() cc.hbuf.Reset()
host := req.Host host := req.Host
...@@ -5037,7 +5056,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail ...@@ -5037,7 +5056,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
var didUA bool var didUA bool
for k, vv := range req.Header { for k, vv := range req.Header {
lowKey := strings.ToLower(k) lowKey := strings.ToLower(k)
if lowKey == "host" { if lowKey == "host" || lowKey == "content-length" {
continue continue
} }
if lowKey == "user-agent" { if lowKey == "user-agent" {
...@@ -5055,6 +5074,9 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail ...@@ -5055,6 +5074,9 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
cc.writeHeader(lowKey, v) cc.writeHeader(lowKey, v)
} }
} }
if contentLength >= 0 {
cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader { if addGzipHeader {
cc.writeHeader("accept-encoding", "gzip") cc.writeHeader("accept-encoding", "gzip")
} }
...@@ -5745,6 +5767,10 @@ func (gz *http2gzipReader) Close() error { ...@@ -5745,6 +5767,10 @@ func (gz *http2gzipReader) Close() error {
return gz.body.Close() return gz.body.Close()
} }
type http2errorReader struct{ err error }
func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err }
// writeFramer is implemented by any type that is used to write frames. // writeFramer is implemented by any type that is used to write frames.
type http2writeFramer interface { type http2writeFramer interface {
writeFrame(http2writeContext) error writeFrame(http2writeContext) error
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment