Commit 1abf3aa5 authored by Leo Antunes's avatar Leo Antunes Committed by Brad Fitzpatrick

net: add KeepAlive field to ListenConfig

This commit adds a KeepAlive field to ListenConfig and uses it
analogously to Dialer.KeepAlive to set TCP KeepAlives per default on
Accept()

Fixes #23378

Change-Id: I57eaf9508c979e7f0e2b8c5dd8e8901f6eb27fd6
GitHub-Last-Rev: e9e035d53ee8aa3d899d12db08b293f599daecb6
GitHub-Pull-Request: golang/go#31242
Reviewed-on: https://go-review.googlesource.com/c/go/+/170678
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 964fe4b8
...@@ -596,6 +596,14 @@ type ListenConfig struct { ...@@ -596,6 +596,14 @@ type ListenConfig struct {
// necessarily the ones passed to Listen. For example, passing "tcp" to // necessarily the ones passed to Listen. For example, passing "tcp" to
// Listen will cause the Control function to be called with "tcp4" or "tcp6". // Listen will cause the Control function to be called with "tcp4" or "tcp6".
Control func(network, address string, c syscall.RawConn) error Control func(network, address string, c syscall.RawConn) error
// KeepAlive specifies the keep-alive period for network
// connections accepted by this listener.
// If zero, keep-alives are enabled if supported by the protocol
// and operating system. Network protocols or operating systems
// that do not support keep-alives ignore this field.
// If negative, keep-alives are disabled.
KeepAlive time.Duration
} }
// Listen announces on the local network address. // Listen announces on the local network address.
......
...@@ -127,7 +127,7 @@ func fileListener(f *os.File) (Listener, error) { ...@@ -127,7 +127,7 @@ func fileListener(f *os.File) (Listener, error) {
return nil, errors.New("file does not represent a listener") return nil, errors.New("file does not represent a listener")
} }
return &TCPListener{fd}, nil return &TCPListener{fd: fd}, nil
} }
func filePacketConn(f *os.File) (PacketConn, error) { func filePacketConn(f *os.File) (PacketConn, error) {
......
...@@ -93,7 +93,7 @@ func fileListener(f *os.File) (Listener, error) { ...@@ -93,7 +93,7 @@ func fileListener(f *os.File) (Listener, error) {
} }
switch laddr := fd.laddr.(type) { switch laddr := fd.laddr.(type) {
case *TCPAddr: case *TCPAddr:
return &TCPListener{fd}, nil return &TCPListener{fd: fd}, nil
case *UnixAddr: case *UnixAddr:
return &UnixListener{fd: fd, path: laddr.Name, unlink: false}, nil return &UnixListener{fd: fd, path: laddr.Name, unlink: false}, nil
} }
......
...@@ -2792,7 +2792,7 @@ func (srv *Server) ListenAndServe() error { ...@@ -2792,7 +2792,7 @@ func (srv *Server) ListenAndServe() error {
if err != nil { if err != nil {
return err return err
} }
return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) return srv.Serve(ln)
} }
var testHookServerServe func(*Server, net.Listener) // used if non-nil var testHookServerServe func(*Server, net.Listener) // used if non-nil
...@@ -3076,7 +3076,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { ...@@ -3076,7 +3076,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
defer ln.Close() defer ln.Close()
return srv.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, certFile, keyFile) return srv.ServeTLS(ln, certFile, keyFile)
} }
// setupHTTP2_ServeTLS conditionally configures HTTP/2 on // setupHTTP2_ServeTLS conditionally configures HTTP/2 on
...@@ -3269,24 +3269,6 @@ func (tw *timeoutWriter) writeHeader(code int) { ...@@ -3269,24 +3269,6 @@ func (tw *timeoutWriter) writeHeader(code int) {
tw.code = code tw.code = code
} }
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// onceCloseListener wraps a net.Listener, protecting it from // onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls. // multiple Close calls.
type onceCloseListener struct { type onceCloseListener struct {
......
...@@ -224,6 +224,7 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) { ...@@ -224,6 +224,7 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
// use variables of type Listener instead of assuming TCP. // use variables of type Listener instead of assuming TCP.
type TCPListener struct { type TCPListener struct {
fd *netFD fd *netFD
lc ListenConfig
} }
// SyscallConn returns a raw network connection. // SyscallConn returns a raw network connection.
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"context" "context"
"io" "io"
"os" "os"
"time"
) )
func (c *TCPConn) readFrom(r io.Reader) (int64, error) { func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
...@@ -44,7 +45,16 @@ func (ln *TCPListener) accept() (*TCPConn, error) { ...@@ -44,7 +45,16 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newTCPConn(fd), nil tc := newTCPConn(fd)
if ln.lc.KeepAlive >= 0 {
setKeepAlive(fd, true)
ka := ln.lc.KeepAlive
if ln.lc.KeepAlive == 0 {
ka = 3 * time.Minute
}
setKeepAlivePeriod(fd, ka)
}
return tc, nil
} }
func (ln *TCPListener) close() error { func (ln *TCPListener) close() error {
...@@ -74,5 +84,5 @@ func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListe ...@@ -74,5 +84,5 @@ func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListe
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &TCPListener{fd}, nil return &TCPListener{fd: fd, lc: sl.ListenConfig}, nil
} }
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"io" "io"
"os" "os"
"syscall" "syscall"
"time"
) )
func sockaddrToTCP(sa syscall.Sockaddr) Addr { func sockaddrToTCP(sa syscall.Sockaddr) Addr {
...@@ -140,7 +141,16 @@ func (ln *TCPListener) accept() (*TCPConn, error) { ...@@ -140,7 +141,16 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newTCPConn(fd), nil tc := newTCPConn(fd)
if ln.lc.KeepAlive >= 0 {
setKeepAlive(fd, true)
ka := ln.lc.KeepAlive
if ln.lc.KeepAlive == 0 {
ka = 3 * time.Minute
}
setKeepAlivePeriod(fd, ka)
}
return tc, nil
} }
func (ln *TCPListener) close() error { func (ln *TCPListener) close() error {
...@@ -160,5 +170,5 @@ func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListe ...@@ -160,5 +170,5 @@ func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListe
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &TCPListener{fd}, nil return &TCPListener{fd: fd, lc: sl.ListenConfig}, nil
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment