Commit 11776a39 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add Transport.CancelRequest

Permits all sorts of custom HTTP timeout policies without
adding a new Transport timeout Duration for each combination
of HTTP phases.

This keeps track internally of which TCP connection a given
Request is on, and lets callers forcefully close the TCP
connection for a given request, without actually getting
the net.Conn directly.

Additionally, a future CL will implement res.Body.Close (Issue
3672) in terms of this.

Update #3362
Update #3672

R=golang-dev, rsc, adg
CC=golang-dev
https://golang.org/cl/7372054
parent e97aa82c
...@@ -16,10 +16,16 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn { ...@@ -16,10 +16,16 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn {
return newLoggingConn(baseName, c) return newLoggingConn(baseName, c)
} }
func (t *Transport) NumPendingRequestsForTesting() int {
t.reqMu.Lock()
defer t.reqMu.Unlock()
return len(t.reqConn)
}
func (t *Transport) IdleConnKeysForTesting() (keys []string) { func (t *Transport) IdleConnKeysForTesting() (keys []string) {
keys = make([]string, 0) keys = make([]string, 0)
t.idleLk.Lock() t.idleMu.Lock()
defer t.idleLk.Unlock() defer t.idleMu.Unlock()
if t.idleConn == nil { if t.idleConn == nil {
return return
} }
...@@ -30,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { ...@@ -30,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
} }
func (t *Transport) IdleConnCountForTesting(cacheKey string) int { func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
t.idleLk.Lock() t.idleMu.Lock()
defer t.idleLk.Unlock() defer t.idleMu.Unlock()
if t.idleConn == nil { if t.idleConn == nil {
return 0 return 0
} }
......
...@@ -42,9 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2 ...@@ -42,9 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2
// https, and http proxies (for either http or https with CONNECT). // https, and http proxies (for either http or https with CONNECT).
// Transport can also cache connections for future re-use. // Transport can also cache connections for future re-use.
type Transport struct { type Transport struct {
idleLk sync.Mutex idleMu sync.Mutex
idleConn map[string][]*persistConn idleConn map[string][]*persistConn
altLk sync.RWMutex reqMu sync.Mutex
reqConn map[*Request]*persistConn
altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// TODO: tunable on global max cached connections // TODO: tunable on global max cached connections
...@@ -139,12 +141,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { ...@@ -139,12 +141,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
return nil, errors.New("http: nil Request.Header") return nil, errors.New("http: nil Request.Header")
} }
if req.URL.Scheme != "http" && req.URL.Scheme != "https" { if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
t.altLk.RLock() t.altMu.RLock()
var rt RoundTripper var rt RoundTripper
if t.altProto != nil { if t.altProto != nil {
rt = t.altProto[req.URL.Scheme] rt = t.altProto[req.URL.Scheme]
} }
t.altLk.RUnlock() t.altMu.RUnlock()
if rt == nil { if rt == nil {
return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
} }
...@@ -181,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { ...@@ -181,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
if scheme == "http" || scheme == "https" { if scheme == "http" || scheme == "https" {
panic("protocol " + scheme + " already registered") panic("protocol " + scheme + " already registered")
} }
t.altLk.Lock() t.altMu.Lock()
defer t.altLk.Unlock() defer t.altMu.Unlock()
if t.altProto == nil { if t.altProto == nil {
t.altProto = make(map[string]RoundTripper) t.altProto = make(map[string]RoundTripper)
} }
...@@ -197,10 +199,10 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { ...@@ -197,10 +199,10 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
// a "keep-alive" state. It does not interrupt any connections currently // a "keep-alive" state. It does not interrupt any connections currently
// in use. // in use.
func (t *Transport) CloseIdleConnections() { func (t *Transport) CloseIdleConnections() {
t.idleLk.Lock() t.idleMu.Lock()
m := t.idleConn m := t.idleConn
t.idleConn = nil t.idleConn = nil
t.idleLk.Unlock() t.idleMu.Unlock()
if m == nil { if m == nil {
return return
} }
...@@ -211,6 +213,17 @@ func (t *Transport) CloseIdleConnections() { ...@@ -211,6 +213,17 @@ func (t *Transport) CloseIdleConnections() {
} }
} }
// CancelRequest cancels an in-flight request by closing its
// connection.
func (t *Transport) CancelRequest(req *Request) {
t.reqMu.Lock()
pc := t.reqConn[req]
t.reqMu.Unlock()
if pc != nil {
pc.conn.Close()
}
}
// //
// Private implementation past this point. // Private implementation past this point.
// //
...@@ -266,12 +279,12 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { ...@@ -266,12 +279,12 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
if max == 0 { if max == 0 {
max = DefaultMaxIdleConnsPerHost max = DefaultMaxIdleConnsPerHost
} }
t.idleLk.Lock() t.idleMu.Lock()
if t.idleConn == nil { if t.idleConn == nil {
t.idleConn = make(map[string][]*persistConn) t.idleConn = make(map[string][]*persistConn)
} }
if len(t.idleConn[key]) >= max { if len(t.idleConn[key]) >= max {
t.idleLk.Unlock() t.idleMu.Unlock()
pconn.close() pconn.close()
return false return false
} }
...@@ -281,14 +294,14 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { ...@@ -281,14 +294,14 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
} }
} }
t.idleConn[key] = append(t.idleConn[key], pconn) t.idleConn[key] = append(t.idleConn[key], pconn)
t.idleLk.Unlock() t.idleMu.Unlock()
return true return true
} }
func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
key := cm.String() key := cm.String()
t.idleLk.Lock() t.idleMu.Lock()
defer t.idleLk.Unlock() defer t.idleMu.Unlock()
if t.idleConn == nil { if t.idleConn == nil {
return nil return nil
} }
...@@ -313,6 +326,19 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { ...@@ -313,6 +326,19 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
panic("unreachable") panic("unreachable")
} }
func (t *Transport) setReqConn(r *Request, pc *persistConn) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqConn == nil {
t.reqConn = make(map[*Request]*persistConn)
}
if pc != nil {
t.reqConn[r] = pc
} else {
delete(t.reqConn, r)
}
}
func (t *Transport) dial(network, addr string) (c net.Conn, err error) { func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
if t.Dial != nil { if t.Dial != nil {
return t.Dial(network, addr) return t.Dial(network, addr)
...@@ -662,6 +688,8 @@ func (pc *persistConn) readLoop() { ...@@ -662,6 +688,8 @@ func (pc *persistConn) readLoop() {
alive = <-waitForBodyRead alive = <-waitForBodyRead
} }
pc.t.setReqConn(rc.req, nil)
if !alive { if !alive {
pc.close() pc.close()
} }
...@@ -715,6 +743,7 @@ type writeRequest struct { ...@@ -715,6 +743,7 @@ type writeRequest struct {
} }
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
pc.t.setReqConn(req.Request, pc)
pc.lk.Lock() pc.lk.Lock()
pc.numExpectedResponses++ pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc headerFn := pc.mutateHeaderFunc
...@@ -793,6 +822,9 @@ WaitResponse: ...@@ -793,6 +822,9 @@ WaitResponse:
pc.numExpectedResponses-- pc.numExpectedResponses--
pc.lk.Unlock() pc.lk.Unlock()
if re.err != nil {
pc.t.setReqConn(req.Request, nil)
}
return re.res, re.err return re.res, re.err
} }
......
...@@ -1118,7 +1118,6 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { ...@@ -1118,7 +1118,6 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping timeout test in -short mode") t.Skip("skipping timeout test in -short mode")
} }
const debug = false
mux := NewServeMux() mux := NewServeMux()
mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
...@@ -1161,6 +1160,60 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { ...@@ -1161,6 +1160,60 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
} }
} }
func TestTransportCancelRequest(t *testing.T) {
defer checkLeakedTransports(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Hello")
w.(Flusher).Flush() // send headers and some body
<-unblockc
}))
defer ts.Close()
defer close(unblockc)
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
req, _ := NewRequest("GET", ts.URL, nil)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
go func() {
time.Sleep(1 * time.Second)
tr.CancelRequest(req)
}()
t0 := time.Now()
body, err := ioutil.ReadAll(res.Body)
d := time.Since(t0)
if err == nil {
t.Error("expected an error reading the body")
}
if string(body) != "Hello" {
t.Errorf("Body = %q; want Hello", body)
}
if d < 500*time.Millisecond {
t.Errorf("expected ~1 second delay; got %v", d)
}
// Verify no outstanding requests after readLoop/writeLoop
// goroutines shut down.
for tries := 3; tries > 0; tries-- {
n := tr.NumPendingRequestsForTesting()
if n == 0 {
break
}
time.Sleep(100 * time.Millisecond)
if tries == 1 {
t.Errorf("pending requests = %d; want 0", n)
}
}
}
type fooProto struct{} type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) { func (fooProto) RoundTrip(req *Request) (*Response, error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment