Commit b31fed14 authored by Jim Fulton's avatar Jim Fulton

Implemented msgpack as an optional ZEO message encoding with basic tests.

parent c3183420
...@@ -289,6 +289,13 @@ client-conflict-resolution ...@@ -289,6 +289,13 @@ client-conflict-resolution
Flag indicating that clients should perform conflict Flag indicating that clients should perform conflict
resolution. This option defaults to false. resolution. This option defaults to false.
msgpack
Use msgpack to serialize and de-serialize ZEO protocol messages.
An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).
Server SSL configuration Server SSL configuration
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -36,7 +36,7 @@ install_requires = [ ...@@ -36,7 +36,7 @@ install_requires = [
'zope.interface', 'zope.interface',
] ]
tests_require = ['zope.testing', 'manuel', 'random2', 'mock'] tests_require = ['zope.testing', 'manuel', 'random2', 'mock', 'msgpack-python']
if sys.version_info[:2] < (3, ): if sys.version_info[:2] < (3, ):
install_requires.extend(('futures', 'trollius')) install_requires.extend(('futures', 'trollius'))
...@@ -128,7 +128,11 @@ setup(name="ZEO", ...@@ -128,7 +128,11 @@ setup(name="ZEO",
classifiers = classifiers, classifiers = classifiers,
test_suite="__main__.alltests", # to support "setup.py test" test_suite="__main__.alltests", # to support "setup.py test"
tests_require = tests_require, tests_require = tests_require,
extras_require = dict(test=tests_require, uvloop=['uvloop >=0.5.1']), extras_require = dict(
test=tests_require,
uvloop=['uvloop >=0.5.1'],
msgpack=['msgpack-python'],
),
install_requires = install_requires, install_requires = install_requires,
zip_safe = False, zip_safe = False,
entry_points = """ entry_points = """
......
...@@ -663,6 +663,7 @@ class StorageServer: ...@@ -663,6 +663,7 @@ class StorageServer:
ssl=None, ssl=None,
client_conflict_resolution=False, client_conflict_resolution=False,
Acceptor=Acceptor, Acceptor=Acceptor,
msgpack=False,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -757,7 +758,7 @@ class StorageServer: ...@@ -757,7 +758,7 @@ class StorageServer:
self.client_conflict_resolution = client_conflict_resolution self.client_conflict_resolution = client_conflict_resolution
if addr is not None: if addr is not None:
self.acceptor = Acceptor(self, addr, ssl) self.acceptor = Acceptor(self, addr, ssl, msgpack)
if isinstance(addr, tuple) and addr[0]: if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr self.addr = self.acceptor.addr
else: else:
......
...@@ -10,8 +10,6 @@ import socket ...@@ -10,8 +10,6 @@ import socket
from struct import unpack from struct import unpack
import sys import sys
from .marshal import encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6 INET_FAMILIES = socket.AF_INET, socket.AF_INET6
...@@ -129,13 +127,13 @@ class Protocol(asyncio.Protocol): ...@@ -129,13 +127,13 @@ class Protocol(asyncio.Protocol):
self.getting_size = True self.getting_size = True
self.message_received(collected) self.message_received(collected)
except Exception: except Exception:
#import traceback; traceback.print_exc()
logger.exception("data_received %s %s %s", logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size) self.want, self.got, self.getting_size)
def first_message_received(self, protocol_version): def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__ # Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on del self.message_received # use default handler from here on
self.encode = encoder()
self.finish_connect(protocol_version) self.finish_connect(protocol_version)
def call_async(self, method, args): def call_async(self, method, args):
......
...@@ -13,7 +13,7 @@ import ZEO.interfaces ...@@ -13,7 +13,7 @@ import ZEO.interfaces
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import decode from .marshal import encoder, decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,7 +63,7 @@ class Protocol(base.Protocol): ...@@ -63,7 +63,7 @@ class Protocol(base.Protocol):
# One place where special care was required was in cache setup on # One place where special care was required was in cache setup on
# connect. See finish connect below. # connect. See finish connect below.
protocols = b'Z309', b'Z310', b'Z3101', b'Z4', b'Z5' protocols = b'309', b'310', b'3101', b'4', b'5'
def __init__(self, loop, def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1, addr, client, storage_key, read_only, connect_poll=1,
...@@ -150,6 +150,8 @@ class Protocol(base.Protocol): ...@@ -150,6 +150,8 @@ class Protocol(base.Protocol):
# We have to be careful processing the futures, because # We have to be careful processing the futures, because
# exception callbacks might modufy them. # exception callbacks might modufy them.
for f in self.pop_futures(): for f in self.pop_futures():
if isinstance(f, tuple):
continue
f.set_exception(ClientDisconnected(exc or 'connection lost')) f.set_exception(ClientDisconnected(exc or 'connection lost'))
self.closed = True self.closed = True
self.client.disconnected(self) self.client.disconnected(self)
...@@ -165,13 +167,17 @@ class Protocol(base.Protocol): ...@@ -165,13 +167,17 @@ class Protocol(base.Protocol):
# lastTid before processing (and possibly missing) subsequent # lastTid before processing (and possibly missing) subsequent
# invalidations. # invalidations.
self.protocol_version = min(protocol_version, self.protocols[-1]) version = min(protocol_version[1:], self.protocols[-1])
if version not in self.protocols:
if self.protocol_version not in self.protocols:
self.client.register_failed( self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version)) self, ZEO.Exceptions.ProtocolError(protocol_version))
return return
self.protocol_version = protocol_version[:1] + version
self.encode = encoder(protocol_version)
self.decode = decoder(protocol_version)
self.heartbeat_bytes = self.encode(-1, 0, '.reply', None)
self._write(self.protocol_version) self._write(self.protocol_version)
credentials = (self.credentials,) if self.credentials else () credentials = (self.credentials,) if self.credentials else ()
...@@ -199,9 +205,12 @@ class Protocol(base.Protocol): ...@@ -199,9 +205,12 @@ class Protocol(base.Protocol):
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
msgid, async, name, args = decode(data) msgid, async, name, args = self.decode(data)
if name == '.reply': if name == '.reply':
future = self.futures.pop(msgid) future = self.futures.pop(msgid)
if isinstance(future, tuple):
future = self.futures.pop(future)
if (async): # ZEO 5 exception if (async): # ZEO 5 exception
class_, args = args class_, args = args
factory = exc_factories.get(class_) factory = exc_factories.get(class_)
...@@ -245,13 +254,15 @@ class Protocol(base.Protocol): ...@@ -245,13 +254,15 @@ class Protocol(base.Protocol):
def load_before(self, oid, tid): def load_before(self, oid, tid):
# Special-case loadBefore, so we collapse outstanding requests # Special-case loadBefore, so we collapse outstanding requests
message_id = (oid, tid) oid_tid = (oid, tid)
future = self.futures.get(message_id) future = self.futures.get(oid_tid)
if future is None: if future is None:
future = asyncio.Future(loop=self.loop) future = asyncio.Future(loop=self.loop)
self.futures[message_id] = future self.futures[oid_tid] = future
self.message_id += 1
self.futures[self.message_id] = oid_tid
self._write( self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid))) self.encode(self.message_id, False, 'loadBefore', (oid, tid)))
return future return future
# Methods called by the server. # Methods called by the server.
...@@ -267,7 +278,7 @@ class Protocol(base.Protocol): ...@@ -267,7 +278,7 @@ class Protocol(base.Protocol):
def heartbeat(self, write=True): def heartbeat(self, write=True):
if write: if write:
self._write(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.') self._write(self.heartbeat_bytes)
self.heartbeat_handle = self.loop.call_later( self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat) self.heartbeat_interval, self.heartbeat)
......
...@@ -26,10 +26,18 @@ from ..shortrepr import short_repr ...@@ -26,10 +26,18 @@ from ..shortrepr import short_repr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def encoder(): def encoder(protocol):
"""Return a non-thread-safe encoder """Return a non-thread-safe encoder
""" """
if protocol[:1] == b'M':
from msgpack import packb
def encode(*args):
return packb(args, use_bin_type=True)
return encode
else:
assert protocol[:1] == b'Z'
if PY3 or PYPY: if PY3 or PYPY:
f = BytesIO() f = BytesIO()
getvalue = f.getvalue getvalue = f.getvalue
...@@ -54,9 +62,20 @@ def encoder(): ...@@ -54,9 +62,20 @@ def encoder():
def encode(*args): def encode(*args):
return encoder()(*args) return encoder(b'Z')(*args)
def decode(msg): def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8')
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode
def pickle_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global unpickler.find_global = find_global
...@@ -71,7 +90,14 @@ def decode(msg): ...@@ -71,7 +90,14 @@ def decode(msg):
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_decode(msg): def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
else:
assert protocol[:1] == b'Z'
return pickle_server_decode
def pickle_server_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global unpickler.find_global = server_find_global
......
...@@ -76,13 +76,14 @@ class Acceptor(asyncore.dispatcher): ...@@ -76,13 +76,14 @@ class Acceptor(asyncore.dispatcher):
And creates a separate thread for each. And creates a separate thread for each.
""" """
def __init__(self, storage_server, addr, ssl): def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server self.storage_server = storage_server
self.addr = addr self.addr = addr
self.__socket_map = {} self.__socket_map = {}
asyncore.dispatcher.__init__(self, map=self.__socket_map) asyncore.dispatcher.__init__(self, map=self.__socket_map)
self.ssl_context = ssl self.ssl_context = ssl
self.msgpack = msgpack
self._open_socket() self._open_socket()
def _open_socket(self): def _open_socket(self):
...@@ -165,7 +166,7 @@ class Acceptor(asyncore.dispatcher): ...@@ -165,7 +166,7 @@ class Acceptor(asyncore.dispatcher):
def run(): def run():
loop = new_event_loop() loop = new_event_loop()
zs = self.storage_server.create_client_handler() zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(loop, self.addr, zs) protocol = ServerProtocol(loop, self.addr, zs, self.msgpack)
protocol.stop = loop.stop protocol.stop = loop.stop
if self.ssl_context is None: if self.ssl_context is None:
......
...@@ -11,13 +11,13 @@ from ..shortrepr import short_repr ...@@ -11,13 +11,13 @@ from ..shortrepr import short_repr
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import server_decode from .marshal import server_decoder, encoder
class ServerProtocol(base.Protocol): class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface """asyncio low-level ZEO server interface
""" """
protocols = (b'Z5', ) protocols = (b'5', )
name = 'server protocol' name = 'server protocol'
methods = set(('register', )) methods = set(('register', ))
...@@ -26,12 +26,16 @@ class ServerProtocol(base.Protocol): ...@@ -26,12 +26,16 @@ class ServerProtocol(base.Protocol):
ZODB.POSException.POSKeyError, ZODB.POSException.POSKeyError,
) )
def __init__(self, loop, addr, zeo_storage): def __init__(self, loop, addr, zeo_storage, msgpack):
"""Create a server's client interface """Create a server's client interface
""" """
super(ServerProtocol, self).__init__(loop, addr) super(ServerProtocol, self).__init__(loop, addr)
self.zeo_storage = zeo_storage self.zeo_storage = zeo_storage
self.announce_protocol = (
(b'M' if msgpack else b'Z') + best_protocol_version
)
closed = False closed = False
def close(self): def close(self):
logger.debug("Closing server protocol") logger.debug("Closing server protocol")
...@@ -44,7 +48,7 @@ class ServerProtocol(base.Protocol): ...@@ -44,7 +48,7 @@ class ServerProtocol(base.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
self.connected = True self.connected = True
super(ServerProtocol, self).connection_made(transport) super(ServerProtocol, self).connection_made(transport)
self._write(best_protocol_version) self._write(self.announce_protocol)
def connection_lost(self, exc): def connection_lost(self, exc):
self.connected = False self.connected = False
...@@ -61,10 +65,13 @@ class ServerProtocol(base.Protocol): ...@@ -61,10 +65,13 @@ class ServerProtocol(base.Protocol):
self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii")) self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii"))
self.close() self.close()
else: else:
if protocol_version in self.protocols: version = protocol_version[1:]
if version in self.protocols:
logger.info("received handshake %r" % logger.info("received handshake %r" %
str(protocol_version.decode('ascii'))) str(protocol_version.decode('ascii')))
self.protocol_version = protocol_version self.protocol_version = protocol_version
self.encode = encoder(protocol_version)
self.decode = server_decoder(protocol_version)
self.zeo_storage.notify_connected(self) self.zeo_storage.notify_connected(self)
else: else:
logger.error("bad handshake %s" % short_repr(protocol_version)) logger.error("bad handshake %s" % short_repr(protocol_version))
...@@ -79,7 +86,7 @@ class ServerProtocol(base.Protocol): ...@@ -79,7 +86,7 @@ class ServerProtocol(base.Protocol):
def message_received(self, message): def message_received(self, message):
try: try:
message_id, async, name, args = server_decode(message) message_id, async, name, args = self.decode(message)
except Exception: except Exception:
logger.exception("Can't deserialize message") logger.exception("Can't deserialize message")
self.close() self.close()
...@@ -144,8 +151,8 @@ best_protocol_version = os.environ.get( ...@@ -144,8 +151,8 @@ best_protocol_version = os.environ.get(
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8') ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage): def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage) protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket) cr = loop.create_connection((lambda : protocol), sock=socket)
asyncio.async(cr, loop=loop) asyncio.async(cr, loop=loop)
...@@ -213,10 +220,11 @@ class MTDelay(Delay): ...@@ -213,10 +220,11 @@ class MTDelay(Delay):
class Acceptor(object): class Acceptor(object):
def __init__(self, storage_server, addr, ssl): def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server self.storage_server = storage_server
self.addr = addr self.addr = addr
self.ssl_context = ssl self.ssl_context = ssl
self.msgpack = msgpack
self.event_loop = loop = new_event_loop() self.event_loop = loop = new_event_loop()
if isinstance(addr, tuple): if isinstance(addr, tuple):
...@@ -243,7 +251,8 @@ class Acceptor(object): ...@@ -243,7 +251,8 @@ class Acceptor(object):
try: try:
logger.debug("Accepted connection") logger.debug("Accepted connection")
zs = self.storage_server.create_client_handler() zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(self.event_loop, self.addr, zs) protocol = ServerProtocol(
self.event_loop, self.addr, zs, self.msgpack)
except Exception: except Exception:
logger.exception("Failure in protocol factory") logger.exception("Failure in protocol factory")
......
This diff is collapsed.
...@@ -100,6 +100,7 @@ class ZEOOptionsMixin: ...@@ -100,6 +100,7 @@ class ZEOOptionsMixin:
self.add("client_conflict_resolution", self.add("client_conflict_resolution",
"zeo.client_conflict_resolution", "zeo.client_conflict_resolution",
default=0) default=0)
self.add("msgpack", "zeo.msgpack", default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size", self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100) default=100)
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
...@@ -342,6 +343,7 @@ def create_server(storages, options): ...@@ -342,6 +343,7 @@ def create_server(storages, options):
storages, storages,
read_only = options.read_only, read_only = options.read_only,
client_conflict_resolution=options.client_conflict_resolution, client_conflict_resolution=options.client_conflict_resolution,
msgpack=options.msgpack,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
......
...@@ -115,6 +115,16 @@ ...@@ -115,6 +115,16 @@
</description> </description>
</key> </key>
<key name="msgpack" datatype="boolean" required="no" default="false">
<description>
Use msgpack to serialize and de-serialize ZEO protocol messages.
An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).
</description>
</key>
</sectiontype> </sectiontype>
</component> </component>
...@@ -17,7 +17,7 @@ Let's start a Z4 server ...@@ -17,7 +17,7 @@ Let's start a Z4 server
... ''' ... '''
>>> addr, stop = start_server( >>> addr, stop = start_server(
... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z4') ... storage_conf, dict(invalidation_queue_size=5), protocol=b'4')
A current client should be able to connect to a old server: A current client should be able to connect to a old server:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment