Commit cc2c5fc3 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: don't re-use Transport connections if we've seen an EOF

This the second part of making persistent HTTPS connections to
certain servers (notably Amazon) robust.

See the story in part 1: https://golang.org/cl/76400046/

This is the http Transport change that notes whether our
net.Conn.Read has ever seen an EOF. If it has, then we use
that as an additional signal to not re-use that connection (in
addition to the HTTP response headers)

Fixes #3514

LGTM=rsc
R=agl, rsc
CC=golang-codereviews
https://golang.org/cl/79240044
parent f61f18d6
...@@ -588,7 +588,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { ...@@ -588,7 +588,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
pconn.conn = tlsConn pconn.conn = tlsConn
} }
pconn.br = bufio.NewReader(pconn.conn) pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
pconn.bw = bufio.NewWriter(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop() go pconn.readLoop()
go pconn.writeLoop() go pconn.writeLoop()
...@@ -723,6 +723,7 @@ type persistConn struct { ...@@ -723,6 +723,7 @@ type persistConn struct {
tlsState *tls.ConnectionState tlsState *tls.ConnectionState
closed bool // whether conn has been closed closed bool // whether conn has been closed
br *bufio.Reader // from conn br *bufio.Reader // from conn
sawEOF bool // whether we've seen EOF from conn; owned by readLoop
bw *bufio.Writer // to conn bw *bufio.Writer // to conn
reqch chan requestAndChan // written by roundTrip; read by readLoop reqch chan requestAndChan // written by roundTrip; read by readLoop
writech chan writeRequest // written by roundTrip; read by writeLoop writech chan writeRequest // written by roundTrip; read by writeLoop
...@@ -841,6 +842,9 @@ func (pc *persistConn) readLoop() { ...@@ -841,6 +842,9 @@ func (pc *persistConn) readLoop() {
if err != nil { if err != nil {
alive1 = false alive1 = false
} }
if alive1 && pc.sawEOF {
alive1 = false
}
if alive1 && !pc.t.putIdleConn(pc) { if alive1 && !pc.t.putIdleConn(pc) {
alive1 = false alive1 = false
} }
...@@ -1134,3 +1138,16 @@ type tlsHandshakeTimeoutError struct{} ...@@ -1134,3 +1138,16 @@ type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true } func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true } func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
type noteEOFReader struct {
r io.Reader
sawEOF *bool
}
func (nr noteEOFReader) Read(p []byte) (n int, err error) {
n, err = nr.r.Read(p)
if err == io.EOF {
*nr.sawEOF = true
}
return
}
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"crypto/rand" "crypto/rand"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -1836,6 +1837,71 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { ...@@ -1836,6 +1837,71 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) {
} }
} }
// Trying to repro golang.org/issue/3514
func TestTLSServerClosesConnection(t *testing.T) {
defer afterTest(t)
closedc := make(chan bool, 1)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
conn, _, _ := w.(Hijacker).Hijack()
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
conn.Close()
closedc <- true
return
}
fmt.Fprintf(w, "hello")
}))
defer ts.Close()
tr := &Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
var nSuccess = 0
var errs []error
const trials = 20
for i := 0; i < trials; i++ {
tr.CloseIdleConnections()
res, err := client.Get(ts.URL + "/keep-alive-then-die")
if err != nil {
t.Fatal(err)
}
<-closedc
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(slurp) != "foo" {
t.Errorf("Got %q, want foo", slurp)
}
// Now try again and see if we successfully
// pick a new connection.
res, err = client.Get(ts.URL + "/")
if err != nil {
errs = append(errs, err)
continue
}
slurp, err = ioutil.ReadAll(res.Body)
if err != nil {
errs = append(errs, err)
continue
}
nSuccess++
}
if nSuccess > 0 {
t.Logf("successes = %d of %d", nSuccess, trials)
} else {
t.Errorf("All runs failed:")
}
for _, err := range errs {
t.Logf(" err: %v", err)
}
}
func newLocalListener(t *testing.T) net.Listener { func newLocalListener(t *testing.T) net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment