Commit 740a209a authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: don't crash if Server.Server is called with non-comparable Listener

Fixes #24812

Change-Id: If8d496d61b1120233e44c72d854e80cb06bab970
Reviewed-on: https://go-review.googlesource.com/106657
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
parent 7a0bf943
......@@ -5843,6 +5843,19 @@ func TestServerValidatesMethod(t *testing.T) {
}
}
// Listener for TestServerListenNotComparableListener.
type eofListenerNotComparable []int
func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
func (eofListenerNotComparable) Addr() net.Addr { return nil }
func (eofListenerNotComparable) Close() error { return nil }
// Issue 24812: don't crash on non-comparable Listener
func TestServerListenNotComparableListener(t *testing.T) {
var s Server
s.Serve(make(eofListenerNotComparable, 1)) // used to panic
}
func BenchmarkResponseStatusLine(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
......
......@@ -2485,7 +2485,7 @@ type Server struct {
nextProtoErr error // result of http2.ConfigureServer if used
mu sync.Mutex
listeners map[net.Listener]struct{}
listeners map[*net.Listener]struct{}
activeConn map[*conn]struct{}
doneChan chan struct{}
onShutdown []func()
......@@ -2621,7 +2621,7 @@ func (s *Server) closeIdleConns() bool {
func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
if cerr := ln.Close(); cerr != nil && err == nil {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
delete(s.listeners, ln)
......@@ -2765,8 +2765,8 @@ func (srv *Server) Serve(l net.Listener) error {
return err
}
srv.trackListener(l, true)
defer srv.trackListener(l, false)
srv.trackListener(&l, true)
defer srv.trackListener(&l, false)
baseCtx := context.Background() // base is always background, per Issue 16220
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
......@@ -2843,11 +2843,19 @@ func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
return srv.Serve(tlsListener)
}
func (s *Server) trackListener(ln net.Listener, add bool) {
// trackListener adds or removes a net.Listener to the set of tracked
// listeners.
//
// We store a pointer to interface in the map set, in case the
// net.Listener is not comparable. This is safe because we only call
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
func (s *Server) trackListener(ln *net.Listener, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
s.listeners = make(map[net.Listener]struct{})
s.listeners = make(map[*net.Listener]struct{})
}
if add {
// If the *Server is being reused after a previous
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment