Commit 3aa3c052 authored by Aleksandr Razumov's avatar Aleksandr Razumov Committed by Brad Fitzpatrick

net/http: rewind request body unconditionally

When http2 fails with ErrNoCachedConn the request is retried with body
that has already been read.

Fixes #25009

Change-Id: I51ed5c8cf469dd8b17c73fff6140ab80162bf267
Reviewed-on: https://go-review.googlesource.com/c/131755
Run-TryBot: Iskander Sharipov <iskander.sharipov@intel.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 0906d648
...@@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { ...@@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
} }
testHookRoundTripRetried() testHookRoundTripRetried()
// Rewind the body if we're able to. (HTTP/2 does this itself so we only // Rewind the body if we're able to.
// need to do it for HTTP/1.1 connections.) if req.GetBody != nil {
if req.GetBody != nil && pconn.alt == nil {
newReq := *req newReq := *req
var err error var err error
newReq.Body, err = req.GetBody() newReq.Body, err = req.GetBody()
......
...@@ -7,8 +7,13 @@ ...@@ -7,8 +7,13 @@
package http package http
import ( import (
"bytes"
"crypto/tls"
"errors" "errors"
"io"
"io/ioutil"
"net" "net"
"net/http/internal"
"strings" "strings"
"testing" "testing"
) )
...@@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) { ...@@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) {
} }
} }
} }
type roundTripFunc func(r *Request) (*Response, error)
func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
return f(r)
}
// Issue 25009
func TestTransportBodyAltRewind(t *testing.T) {
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
t.Fatal(err)
}
ln := newLocalListener(t)
defer ln.Close()
go func() {
tln := tls.NewListener(ln, &tls.Config{
NextProtos: []string{"foo"},
Certificates: []tls.Certificate{cert},
})
for i := 0; i < 2; i++ {
sc, err := tln.Accept()
if err != nil {
t.Error(err)
return
}
if err := sc.(*tls.Conn).Handshake(); err != nil {
t.Error(err)
return
}
sc.Close()
}
}()
addr := ln.Addr().String()
req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
roundTripped := false
tr := &Transport{
DisableKeepAlives: true,
TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
"foo": func(authority string, c *tls.Conn) RoundTripper {
return roundTripFunc(func(r *Request) (*Response, error) {
n, _ := io.Copy(ioutil.Discard, r.Body)
if n == 0 {
t.Error("body length is zero")
}
if roundTripped {
return &Response{
Body: NoBody,
StatusCode: 200,
}, nil
}
roundTripped = true
return nil, http2noCachedConnError{}
})
},
},
DialTLS: func(_, _ string) (net.Conn, error) {
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"foo"},
})
if err != nil {
return nil, err
}
if err := tc.Handshake(); err != nil {
return nil, err
}
return tc, nil
},
}
c := &Client{Transport: tr}
_, err = c.Do(req)
if err != nil {
t.Error(err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment