Commit 67b8bf3e authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: add optional Server.ConnState callback

Update #4674

This allows for all sorts of graceful shutdown policies,
without picking a policy (e.g. lameduck period) and without
adding lots of locking to the server core. That policy and
locking can be implemented outside of net/http now.

LGTM=adg
R=golang-codereviews, josharian, r, adg, dvyukov
CC=golang-codereviews
https://golang.org/cl/69260044
parent d9c6ae6a
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"testing" "testing"
...@@ -2243,6 +2244,120 @@ func TestAppendTime(t *testing.T) { ...@@ -2243,6 +2244,120 @@ func TestAppendTime(t *testing.T) {
} }
} }
func TestServerConnState(t *testing.T) {
defer afterTest(t)
handler := map[string]func(w ResponseWriter, r *Request){
"/": func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Hello.")
},
"/close": func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
fmt.Fprintf(w, "Hello.")
},
"/hijack": func(w ResponseWriter, r *Request) {
c, _, _ := w.(Hijacker).Hijack()
c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
c.Close()
},
}
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
handler[r.URL.Path](w, r)
}))
defer ts.Close()
type connIDAndState struct {
connID int
state ConnState
}
var mu sync.Mutex // guard stateLog and connID
var stateLog []connIDAndState
var connID = map[net.Conn]int{}
ts.Config.ConnState = func(c net.Conn, state ConnState) {
if c == nil {
t.Error("nil conn seen in state %s", state)
return
}
mu.Lock()
defer mu.Unlock()
id, ok := connID[c]
if !ok {
id = len(connID) + 1
connID[c] = id
}
stateLog = append(stateLog, connIDAndState{id, state})
}
ts.Start()
mustGet(t, ts.URL+"/")
mustGet(t, ts.URL+"/close")
mustGet(t, ts.URL+"/")
mustGet(t, ts.URL+"/", "Connection", "close")
mustGet(t, ts.URL+"/hijack")
want := []connIDAndState{
{1, StateNew},
{1, StateActive},
{1, StateIdle},
{1, StateActive},
{1, StateClosed},
{2, StateNew},
{2, StateActive},
{2, StateIdle},
{2, StateActive},
{2, StateClosed},
{3, StateNew},
{3, StateActive},
{3, StateHijacked},
}
logString := func(l []connIDAndState) string {
var b bytes.Buffer
for _, cs := range l {
fmt.Fprintf(&b, "[%d %s] ", cs.connID, cs.state)
}
return b.String()
}
for i := 0; i < 5; i++ {
time.Sleep(time.Duration(i) * 50 * time.Millisecond)
mu.Lock()
match := reflect.DeepEqual(stateLog, want)
mu.Unlock()
if match {
return
}
}
mu.Lock()
t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want))
mu.Unlock()
}
func mustGet(t *testing.T, url string, headers ...string) {
req, err := NewRequest("GET", url, nil)
if err != nil {
t.Fatal(err)
}
for len(headers) > 0 {
req.Header.Add(headers[0], headers[1])
headers = headers[2:]
}
res, err := DefaultClient.Do(req)
if err != nil {
t.Errorf("Error fetching %s: %v", url, err)
return
}
_, err = ioutil.ReadAll(res.Body)
defer res.Body.Close()
if err != nil {
t.Errorf("Error reading %s: %v", url, err)
}
}
func BenchmarkClientServer(b *testing.B) { func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.StopTimer() b.StopTimer()
......
...@@ -1079,8 +1079,16 @@ func validNPN(proto string) bool { ...@@ -1079,8 +1079,16 @@ func validNPN(proto string) bool {
return true return true
} }
func (c *conn) setState(nc net.Conn, state ConnState) {
if hook := c.server.ConnState; hook != nil {
hook(nc, state)
}
}
// Serve a new connection. // Serve a new connection.
func (c *conn) serve() { func (c *conn) serve() {
origConn := c.rwc // copy it before it's set nil on Close or Hijack
c.setState(origConn, StateNew)
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
const size = 64 << 10 const size = 64 << 10
...@@ -1090,6 +1098,7 @@ func (c *conn) serve() { ...@@ -1090,6 +1098,7 @@ func (c *conn) serve() {
} }
if !c.hijacked() { if !c.hijacked() {
c.close() c.close()
c.setState(origConn, StateClosed)
} }
}() }()
...@@ -1116,6 +1125,10 @@ func (c *conn) serve() { ...@@ -1116,6 +1125,10 @@ func (c *conn) serve() {
for { for {
w, err := c.readRequest() w, err := c.readRequest()
// TODO(bradfitz): could push this StateActive
// earlier, but in practice header will be all in one
// packet/Read:
c.setState(c.rwc, StateActive)
if err != nil { if err != nil {
if err == errTooLarge { if err == errTooLarge {
// Their HTTP client may or may not be // Their HTTP client may or may not be
...@@ -1161,6 +1174,7 @@ func (c *conn) serve() { ...@@ -1161,6 +1174,7 @@ func (c *conn) serve() {
// in parallel even if their responses need to be serialized. // in parallel even if their responses need to be serialized.
serverHandler{c.server}.ServeHTTP(w, w.req) serverHandler{c.server}.ServeHTTP(w, w.req)
if c.hijacked() { if c.hijacked() {
c.setState(origConn, StateHijacked)
return return
} }
w.finishRequest() w.finishRequest()
...@@ -1170,6 +1184,7 @@ func (c *conn) serve() { ...@@ -1170,6 +1184,7 @@ func (c *conn) serve() {
} }
break break
} }
c.setState(c.rwc, StateIdle)
} }
} }
...@@ -1580,6 +1595,58 @@ type Server struct { ...@@ -1580,6 +1595,58 @@ type Server struct {
// and RemoteAddr if not already set. The connection is // and RemoteAddr if not already set. The connection is
// automatically closed when the function returns. // automatically closed when the function returns.
TLSNextProto map[string]func(*Server, *tls.Conn, Handler) TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
// ConnState specifies an optional callback function that is
// called when a client connection changes state. See the
// ConnState type and associated constants for details.
ConnState func(net.Conn, ConnState)
}
// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
const (
// StateNew represents a new connection that is expected to
// send a request immediately. Connections begin at this
// state and then transition to either StateActive or
// StateClosed.
StateNew ConnState = iota
// StateActive represents a connection that has read 1 or more
// bytes of a request. The Server.ConnState hook for
// StateActive fires before the request has entered a handler
// and doesn't fire again until the request has been
// handled. After the request is handled, the state
// transitions to StateClosed, StateHijacked, or StateIdle.
StateActive
// StateIdle represents a connection that has finished
// handling a request and is in the keep-alive state, waiting
// for a new request. Connections transition from StateIdle
// to either StateActive or StateClosed.
StateIdle
// StateHijacked represents a hijacked connection.
// This is a terminal state. It does not transition to StateClosed.
StateHijacked
// StateClosed represents a closed connection.
// This is a terminal state. Hijacked connections do not
// transition to StateClosed.
StateClosed
)
var stateName = map[ConnState]string{
StateNew: "new",
StateActive: "active",
StateIdle: "idle",
StateHijacked: "hijacked",
StateClosed: "closed",
}
func (c ConnState) String() string {
return stateName[c]
} }
// serverHandler delegates to either the server's Handler or // serverHandler delegates to either the server's Handler or
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment