connection.py 17.1 KB
Newer Older
1 2 3 4
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Corporation and Contributors.
# All Rights Reserved.
Guido van Rossum's avatar
Guido van Rossum committed
5
#
6 7 8 9 10 11
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
Guido van Rossum's avatar
Guido van Rossum committed
12
#
13 14
##############################################################################
import asyncore
15
import errno
16
import select
17 18 19 20 21
import sys
import threading
import types

import ThreadedAsync
22
from ZEO.zrpc import smac
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
from ZEO.zrpc.error import ZRPCError, DisconnectedError, DecodingError
from ZEO.zrpc.log import log, short_repr
from ZEO.zrpc.marshal import Marshaller
from ZEO.zrpc.trigger import trigger
import zLOG
from ZODB import POSException

REPLY = ".reply" # message name used for replies
ASYNC = 1

class Delay:
    """Used to delay response to client for synchronous calls

    When a synchronous call is made and the original handler returns
    without handling the call, it returns a Delay object that prevents
    the mainloop from sending a response.
    """

41
    def set_sender(self, msgid, send_reply, return_error):
42 43
        self.msgid = msgid
        self.send_reply = send_reply
44
        self.return_error = return_error
45 46 47 48

    def reply(self, obj):
        self.send_reply(self.msgid, obj)

49 50 51 52
    def error(self, exc_info):
        log("Error raised in delayed method", zLOG.ERROR, error=exc_info)
        self.return_error(self.msgid, 0, *exc_info[:2])

53 54 55 56 57
class MTDelay(Delay):

    def __init__(self):
        self.ready = threading.Event()

58 59
    def set_sender(self, msgid, send_reply, return_error):
        Delay.set_sender(self, msgid, send_reply, return_error)
60 61 62 63 64 65
        self.ready.set()

    def reply(self, obj):
        self.ready.wait()
        Delay.reply(self, obj)

66 67 68 69
    def error(self, exc_info):
        self.ready.wait()
        Delay.error(self, exc_info)

70 71 72 73
class Connection(smac.SizedMessageAsyncConnection):
    """Dispatcher for RPC on object on both sides of socket.

    The connection supports synchronous calls, which expect a return,
74
    and asynchronous calls, which do not.
75 76

    It uses the Marshaller class to handle encoding and decoding of
77
    method calls and arguments.  Marshaller uses pickle to encode
78 79
    arbitrary Python objects.  The code here doesn't ever see the wire
    format.
80 81 82 83 84 85 86

    A Connection is designed for use in a multithreaded application,
    where a synchronous call must block until a response is ready.

    A socket connection between a client and a server allows either
    side to invoke methods on the other side.  The processes on each
    end of the socket use a Connection object to manage communication.
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

    The Connection deals with decoded RPC messages.  They are
    represented as four-tuples containing: msgid, flags, method name,
    and a tuple of method arguments.

    The msgid starts at zero and is incremented by one each time a
    method call message is sent.  Each side of the connection has a
    separate msgid state.

    When one side of the connection (the client) calls a method, it
    sends a message with a new msgid.  The other side (the server),
    replies with a message that has the same msgid, the string
    ".reply" (the global variable REPLY) as the method name, and the
    actual return value in the args position.  Note that each side of
    the Connection can initiate a call, in which case it will be the
    client for that particular call.

    The protocol also supports asynchronous calls.  The client does
    not wait for a return value for an asynchronous call.  The only
    defined flag is ASYNC.  If a method call message has the ASYNC
    flag set, the server will raise an exception.

    If a method call raises an exception, the exception is propagated
    back to the client via the REPLY message.  The client side will
    raise any exception it receives instead of returning the value to
    the caller.
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
    """

    __super_init = smac.SizedMessageAsyncConnection.__init__
    __super_close = smac.SizedMessageAsyncConnection.close

    protocol_version = "Z200"

    def __init__(self, sock, addr, obj=None):
        self.obj = None
        self.marshal = Marshaller()
        self.closed = 0
        self.msgid = 0
        self.__super_init(sock, addr)
        # A Connection either uses asyncore directly or relies on an
        # asyncore mainloop running in a separate thread.  If
        # thr_async is true, then the mainloop is running in a
        # separate thread.  If thr_async is true, then the asyncore
        # trigger (self.trigger) is used to notify that thread of
        # activity on the current thread.
        self.thr_async = 0
        self.trigger = None
        self._prepare_async()
        self._map = {self._fileno: self}
136 137
        # __msgid_lock guards access to msgid
        self.__msgid_lock = threading.Lock()
138
        # __replies_cond is used to block when a synchronous call is
139
        # waiting for a response
140 141
        self.__replies_cond = threading.Condition()
        self.__replies = {}
142 143 144 145 146 147 148 149 150
        self.register_object(obj)
        self.handshake()

    def __repr__(self):
        return "<%s %s>" % (self.__class__.__name__, self.addr)

    def close(self):
        if self.closed:
            return
151
        self._map.clear()
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        self.closed = 1
        self.close_trigger()
        self.__super_close()

    def close_trigger(self):
        # overridden by ManagedConnection
        if self.trigger is not None:
            self.trigger.close()

    def register_object(self, obj):
        """Register obj as the true object to invoke methods on"""
        self.obj = obj

    def handshake(self):
        # When a connection is created the first message sent is a
        # 4-byte protocol version.  This mechanism should allow the
        # protocol to evolve over time, and let servers handle clients
        # using multiple versions of the protocol.

171
        # The mechanism replaces the message_input() method for the
172 173 174 175 176 177 178 179 180 181
        # first message received.

        # The client sends the protocol version it is using.
        self._message_input = self.message_input
        self.message_input = self.recv_handshake
        self.message_output(self.protocol_version)

    def recv_handshake(self, message):
        if message == self.protocol_version:
            self.message_input = self._message_input
182 183 184
        else:
            log("recv_handshake: bad handshake %s" % repr(message),
                level=zLOG.ERROR)
185
        # otherwise do something else...
Guido van Rossum's avatar
Guido van Rossum committed
186

187 188 189 190 191 192 193 194 195 196 197 198 199
    def message_input(self, message):
        """Decoding an incoming message and dispatch it"""
        # XXX Not sure what to do with errors that reach this level.
        # Need to catch ZRPCErrors in handle_reply() and
        # handle_request() so that they get back to the client.
        try:
            msgid, flags, name, args = self.marshal.decode(message)
        except DecodingError, msg:
            return self.return_error(None, None, DecodingError, msg)

        if __debug__:
            log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
                                              short_repr(args)),
200
                level=zLOG.TRACE)
201 202 203 204 205 206 207
        if name == REPLY:
            self.handle_reply(msgid, flags, args)
        else:
            self.handle_request(msgid, flags, name, args)

    def handle_reply(self, msgid, flags, args):
        if __debug__:
208
            log("recv reply: %s, %s, %s" % (msgid, flags, short_repr(args)),
209
                level=zLOG.DEBUG)
210 211 212 213 214 215
        self.__replies_cond.acquire()
        try:
            self.__replies[msgid] = flags, args
            self.__replies_cond.notifyAll()
        finally:
            self.__replies_cond.release()
216 217 218 219 220 221

    def handle_request(self, msgid, flags, name, args):
        if not self.check_method(name):
            msg = "Invalid method name: %s on %s" % (name, repr(self.obj))
            raise ZRPCError(msg)
        if __debug__:
222
            log("calling %s%s" % (name, short_repr(args)), level=zLOG.BLATHER)
223 224 225 226

        meth = getattr(self.obj, name)
        try:
            ret = meth(*args)
227 228
        except (SystemExit, KeyboardInterrupt):
            raise
229
        except Exception, msg:
230
            error = sys.exc_info()
231 232 233 234
            # XXX Since we're just passing this on to the caller, and
            # there are several cases where this happens during the
            # normal course of action, shouldn't this be logged at the
            # INFO level?
235
            log("%s() raised exception: %s" % (name, msg), zLOG.INFO,
236 237 238
                error=error)
            error = error[:2]
            return self.return_error(msgid, flags, *error)
239 240 241 242 243 244 245

        if flags & ASYNC:
            if ret is not None:
                raise ZRPCError("async method %s returned value %s" %
                                (name, repr(ret)))
        else:
            if __debug__:
246
                log("%s returns %s" % (name, short_repr(ret)), zLOG.DEBUG)
247
            if isinstance(ret, Delay):
248
                ret.set_sender(msgid, self.send_reply, self.return_error)
249 250 251 252
            else:
                self.send_reply(msgid, ret)

    def handle_error(self):
253 254 255
        if sys.exc_info()[0] == SystemExit:
            raise sys.exc_info()
        self.log_error("Error caught in asyncore")
256 257 258 259 260 261 262 263 264 265 266 267 268 269
        self.close()

    def log_error(self, msg="No error message supplied"):
        log(msg, zLOG.ERROR, error=sys.exc_info())

    def check_method(self, name):
        # XXX Is this sufficient "security" for now?
        if name.startswith('_'):
            return None
        return hasattr(self.obj, name)

    def send_reply(self, msgid, ret):
        msg = self.marshal.encode(msgid, 0, REPLY, ret)
        self.message_output(msg)
270
        self.poll()
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287

    def return_error(self, msgid, flags, err_type, err_value):
        if flags is None:
            self.log_error("Exception raised during decoding")
            return
        if flags & ASYNC:
            self.log_error("Asynchronous call raised exception: %s" % self)
            return
        if type(err_value) is not types.InstanceType:
            err_value = err_type, err_value

        try:
            msg = self.marshal.encode(msgid, 0, REPLY, (err_type, err_value))
        except self.marshal.errors:
            err = ZRPCError("Couldn't pickle error %s" % `err_value`)
            msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
        self.message_output(msg)
288
        self.poll()
289 290 291 292

    # The next two public methods (call and callAsync) are used by
    # clients to invoke methods on remote objects

293 294 295 296 297 298 299 300 301
    def send_call(self, method, args, flags):
        # send a message and return its msgid
        self.__msgid_lock.acquire()
        try:
            msgid = self.msgid
            self.msgid = self.msgid + 1
        finally:
            self.__msgid_lock.release()
        if __debug__:
302 303
            log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
                zLOG.TRACE)
304 305 306 307
        buf = self.marshal.encode(msgid, flags, method, args)
        self.message_output(buf)
        return msgid

308
    def call(self, method, *args):
309
        self.__replies_cond.acquire()
310
        try:
311 312 313 314 315 316 317 318
            while self.__replies and not self.closed:
                log("waiting for previous call to finish %s" %
                    repr(self.__replies.values()[0]))
                self.__replies_cond.wait(30)
            if self.closed:
                raise DisconnectedError()
            msgid = self.send_call(method, args, 0)
            self.__replies[msgid] = None
319
        finally:
320 321
            self.__replies_cond.release()
        r_flags, r_args = self.wait(msgid)
322
        if (isinstance(r_args, types.TupleType)
323
            and type(r_args[0]) == types.ClassType
324 325 326 327 328
            and issubclass(r_args[0], Exception)):
            inst = r_args[1]
            raise inst # error raised by server
        else:
            return r_args
329 330 331

    def callAsync(self, method, *args):
        if self.closed:
332
            raise DisconnectedError()
333
        self.send_call(method, args, ASYNC)
334
        self.poll()
335 336 337 338 339 340 341 342 343 344 345 346 347 348

    # handle IO, possibly in async mode

    def _prepare_async(self):
        self.thr_async = 0
        ThreadedAsync.register_loop_callback(self.set_async)
        # XXX If we are not in async mode, this will cause dead
        # Connections to be leaked.

    def set_async(self, map):
        self.trigger = trigger()
        self.thr_async = 1

    def is_async(self):
349
        # overridden for ManagedConnection
350 351 352 353 354
        if self.thr_async:
            return 1
        else:
            return 0

355
    def wait(self, msgid):
356
        """Invoke asyncore mainloop and wait for reply."""
357
        if __debug__:
358
            log("wait() async=%d" % self.is_async(), level=zLOG.TRACE)
359 360
        if self.is_async():
            self.trigger.pull_trigger()
361 362 363 364
            
        self.__replies_cond.acquire()
        try:
            while 1:
365 366
                if self.closed:
                    raise DisconnectedError()
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
                reply = self.__replies.get(msgid)
                if reply is not None:
                    del self.__replies[msgid]
                    assert len(self.__replies) == 0
                    self.__replies_cond.notifyAll()
                    return reply
                if self.is_async():
                    self.__replies_cond.wait(10.0)
                else:
                    self.__replies_cond.release()
                    try:
                        try:
                            asyncore.poll(10.0, self._map)
                        except select.error, err:
                            log("Closing.  asyncore.poll() raised %s." % err,
                                level=zLOG.BLATHER)
                            self.close()
                    finally:
                        self.__replies_cond.acquire()
        finally:
            self.__replies_cond.release()
388

389
    def poll(self):
390
        """Invoke asyncore mainloop to get pending message out."""
391
        if __debug__:
392
            log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE)
393 394 395 396 397
        if self.is_async():
            self.trigger.pull_trigger()
        else:
            asyncore.poll(0.0, self._map)

398 399 400 401 402 403 404
    def pending(self):
        """Invoke mainloop until any pending messages are handled."""
        if __debug__:
            log("pending(), async=%d" % self.is_async(), level=zLOG.TRACE)
        if self.is_async():
            return
        # Inline the asyncore poll3 function to know whether any input
405
        # was actually read.  Repeat until no input is ready.
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
        # XXX This only does reads.
        poll = select.poll()
        poll.register(self._fileno, select.POLLIN)
        # put dummy value in r so we enter the while loop the first time
        r = [(self._fileno, None)]
        while r:
            try:
                r = poll.poll()
            except select.error, err:
                if err[0] == errno.EINTR:
                    continue
                else:
                    raise
            if r:
                try:
                    self.handle_read_event()
                except asyncore.ExitNow:
                    raise
                else:
                    self.handle_error()
                    

428 429 430 431 432 433 434 435 436 437 438 439 440 441
class ServerConnection(Connection):
    """Connection on the server side"""

    # The server side does not send a protocol message.  Instead, it
    # adapts to whatever the client sends it.

class ManagedServerConnection(ServerConnection):
    """A connection that notifies its ConnectionManager of closing"""
    __super_init = Connection.__init__
    __super_close = Connection.close

    def __init__(self, sock, addr, obj, mgr):
        self.__mgr = mgr
        self.__super_init(sock, addr, obj)
442
        self.obj.notifyConnected(self)
443 444

    def close(self):
445
        self.obj.notifyDisconnected()
446
        self.__super_close()
447
        self.__mgr.close_conn(self)
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483

class ManagedConnection(Connection):
    """A connection that notifies its ConnectionManager of closing.

    A managed connection also defers the ThreadedAsync work to its
    manager.
    """
    __super_init = Connection.__init__
    __super_close = Connection.close

    def __init__(self, sock, addr, obj, mgr):
        self.__mgr = mgr
        self.__super_init(sock, addr, obj)
        self.check_mgr_async()

    def close_trigger(self):
        # the manager should actually close the trigger
        del self.trigger

    def set_async(self, map):
        pass

    def _prepare_async(self):
        # Don't do the register_loop_callback that the superclass does
        pass

    def check_mgr_async(self):
        if not self.thr_async and self.__mgr.thr_async:
            assert self.__mgr.trigger is not None, \
                   "manager (%s) has no trigger" % self.__mgr
            self.thr_async = 1
            self.trigger = self.__mgr.trigger
            return 1
        return 0

    def is_async(self):
484
        # XXX could the check_mgr_async() be avoided on each test?
485 486 487 488 489 490
        if self.thr_async:
            return 1
        return self.check_mgr_async()

    def close(self):
        self.__super_close()
491
        self.__mgr.notify_closed()