Commit 4f6d8a59 authored by Russ Cox's avatar Russ Cox

net/rpc: wait for responses to be written before closing Codec

If there are no more requests being made, wait to shut down
the response-writing codec until the pending requests are all
answered.

Fixes #17239.

Change-Id: Ie62c63ada536171df4e70b73c95f98f778069972
Reviewed-on: https://go-review.googlesource.com/79515
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarRob Pike <r@golang.org>
parent 70ee9b4a
...@@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) { ...@@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) {
return n return n
} }
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
if wg != nil {
defer wg.Done()
}
mtype.Lock() mtype.Lock()
mtype.numCalls++ mtype.numCalls++
mtype.Unlock() mtype.Unlock()
...@@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { ...@@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
// decode requests and encode responses. // decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) { func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex) sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for { for {
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil { if err != nil {
...@@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) { ...@@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) {
} }
continue continue
} }
go service.call(server, sending, mtype, req, argv, replyv, codec) wg.Add(1)
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
} }
// We've seen that there are no more requests.
// Wait for responses to be sent before closing codec.
wg.Wait()
codec.Close() codec.Close()
} }
...@@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error { ...@@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error {
} }
return err return err
} }
service.call(server, sending, mtype, req, argv, replyv, codec) service.call(server, sending, nil, mtype, req, argv, replyv, codec)
return nil return nil
} }
......
...@@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error { ...@@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
panic("ERROR") panic("ERROR")
} }
func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
time.Sleep(time.Duration(args.A) * time.Millisecond)
return nil
}
type hidden int type hidden int
func (t *hidden) Exported(args Args, reply *Reply) error { func (t *hidden) Exported(args Args, reply *Reply) error {
...@@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) { ...@@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) {
newServer.Accept(l) newServer.Accept(l)
} }
func TestShutdown(t *testing.T) {
var l net.Listener
l, _ = listenTCP()
ch := make(chan net.Conn, 1)
go func() {
defer l.Close()
c, err := l.Accept()
if err != nil {
t.Error(err)
}
ch <- c
}()
c, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
c1 := <-ch
if c1 == nil {
t.Fatal(err)
}
newServer := NewServer()
newServer.Register(new(Arith))
go newServer.ServeConn(c1)
args := &Args{7, 8}
reply := new(Reply)
client := NewClient(c)
err = client.Call("Arith.Add", args, reply)
if err != nil {
t.Fatal(err)
}
// On an unloaded system 10ms is usually enough to fail 100% of the time
// with a broken server. On a loaded system, a broken server might incorrectly
// be reported as passing, but we're OK with that kind of flakiness.
// If the code is correct, this test will never fail, regardless of timeout.
args.A = 10 // 10 ms
done := make(chan *Call, 1)
call := client.Go("Arith.SleepMilli", args, reply, done)
c.(*net.TCPConn).CloseWrite()
<-done
if call.Error != nil {
t.Fatal(err)
}
}
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
once.Do(startServer) once.Do(startServer)
client, err := dial() client, err := dial()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment