Commit 3eaaed50 authored by ChaiShushan's avatar ChaiShushan Committed by Rob Pike

net/rpc: fix RegisterName rejects "." character.

Fixes #5617.

R=r, rsc
CC=gobot, golang-dev
https://golang.org/cl/10370043
parent b78aaec2
...@@ -560,20 +560,23 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt ...@@ -560,20 +560,23 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
// we can still recover and move on to the next request. // we can still recover and move on to the next request.
keepReading = true keepReading = true
serviceMethod := strings.Split(req.ServiceMethod, ".") dot := strings.LastIndex(req.ServiceMethod, ".")
if len(serviceMethod) != 2 { if dot < 0 {
err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
return return
} }
serviceName := req.ServiceMethod[:dot]
methodName := req.ServiceMethod[dot+1:]
// Look up the request. // Look up the request.
server.mu.RLock() server.mu.RLock()
service = server.serviceMap[serviceMethod[0]] service = server.serviceMap[serviceName]
server.mu.RUnlock() server.mu.RUnlock()
if service == nil { if service == nil {
err = errors.New("rpc: can't find service " + req.ServiceMethod) err = errors.New("rpc: can't find service " + req.ServiceMethod)
return return
} }
mtype = service.method[serviceMethod[1]] mtype = service.method[methodName]
if mtype == nil { if mtype == nil {
err = errors.New("rpc: can't find method " + req.ServiceMethod) err = errors.New("rpc: can't find method " + req.ServiceMethod)
} }
......
...@@ -84,6 +84,7 @@ func listenTCP() (net.Listener, string) { ...@@ -84,6 +84,7 @@ func listenTCP() (net.Listener, string) {
func startServer() { func startServer() {
Register(new(Arith)) Register(new(Arith))
RegisterName("net.rpc.Arith", new(Arith))
var l net.Listener var l net.Listener
l, serverAddr = listenTCP() l, serverAddr = listenTCP()
...@@ -97,6 +98,7 @@ func startServer() { ...@@ -97,6 +98,7 @@ func startServer() {
func startNewServer() { func startNewServer() {
newServer = NewServer() newServer = NewServer()
newServer.Register(new(Arith)) newServer.Register(new(Arith))
newServer.RegisterName("net.rpc.Arith", new(Arith))
var l net.Listener var l net.Listener
l, newServerAddr = listenTCP() l, newServerAddr = listenTCP()
...@@ -234,6 +236,17 @@ func testRPC(t *testing.T, addr string) { ...@@ -234,6 +236,17 @@ func testRPC(t *testing.T, addr string) {
if reply.C != args.A*args.B { if reply.C != args.A*args.B {
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
} }
// ServiceName contain "." character
args = &Args{7, 8}
reply = new(Reply)
err = client.Call("net.rpc.Arith.Add", args, reply)
if err != nil {
t.Errorf("Add: expected no error but got string %q", err.Error())
}
if reply.C != args.A+args.B {
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
}
} }
func TestHTTP(t *testing.T) { func TestHTTP(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment