Commit 9e7b30b4 authored by Michael Fraenkel's avatar Michael Fraenkel Committed by Tom Bergan

net/http: Set a timeout on Request.Context when using TimeoutHandler

In TimeoutHandler, use a request whose context has been configured with
the handler's timeout

Fixes #20712

Change-Id: Ie670148f85fdad46841ff29232042309e15665ae
Reviewed-on: https://go-review.googlesource.com/46412
Run-TryBot: Tom Bergan <tombergan@google.com>
Reviewed-by: default avatarTom Bergan <tombergan@google.com>
parent 2b079c3c
...@@ -63,9 +63,14 @@ func SetPendingDialHooks(before, after func()) { ...@@ -63,9 +63,14 @@ func SetPendingDialHooks(before, after func()) {
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-ch
cancel()
}()
return &timeoutHandler{ return &timeoutHandler{
handler: handler, handler: handler,
testTimeout: ch, testContext: ctx,
// (no body) // (no body)
} }
} }
......
...@@ -3032,9 +3032,9 @@ type timeoutHandler struct { ...@@ -3032,9 +3032,9 @@ type timeoutHandler struct {
body string body string
dt time.Duration dt time.Duration
// When set, no timer will be created and this channel will // When set, no context will be created and this context will
// be used instead. // be used instead.
testTimeout <-chan time.Time testContext context.Context
} }
func (h *timeoutHandler) errorBody() string { func (h *timeoutHandler) errorBody() string {
...@@ -3045,12 +3045,13 @@ func (h *timeoutHandler) errorBody() string { ...@@ -3045,12 +3045,13 @@ func (h *timeoutHandler) errorBody() string {
} }
func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
var t *time.Timer ctx := h.testContext
timeout := h.testTimeout if ctx == nil {
if timeout == nil { var cancelCtx context.CancelFunc
t = time.NewTimer(h.dt) ctx, cancelCtx = context.WithTimeout(r.Context(), h.dt)
timeout = t.C defer cancelCtx()
} }
r = r.WithContext(ctx)
done := make(chan struct{}) done := make(chan struct{})
tw := &timeoutWriter{ tw := &timeoutWriter{
w: w, w: w,
...@@ -3073,10 +3074,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { ...@@ -3073,10 +3074,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
} }
w.WriteHeader(tw.code) w.WriteHeader(tw.code)
w.Write(tw.wbuf.Bytes()) w.Write(tw.wbuf.Bytes())
if t != nil { case <-ctx.Done():
t.Stop()
}
case <-timeout:
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
w.WriteHeader(StatusServiceUnavailable) w.WriteHeader(StatusServiceUnavailable)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment