Commit 9d0bf97a authored by Julien Muchembled's avatar Julien Muchembled

protocol: switch to msgpack for packet serialization

Not only for performance reasons (at least 3% faster) but also because of
several ugly things in the way packets were defined:
- packet field names, which are only documentary; for roots fields,
  they even just duplicate the packet names
- a lot of repetitions for packet names, and even confusion between the name
  of the packet definition and the name of the actual notify/request packet
- the need to implement field types for anything, like PByte to support new
  compression formats, since PBoolean is not enough

neo/lib/protocol.py is now much smaller.
parent 6332112c
......@@ -13,6 +13,13 @@
##############################################################################
def patch():
# For msgpack & Py2/ZODB5.
try:
from zodbpickle import binary
binary._pack = bytes.__str__
except ImportError:
pass
from hashlib import md5
from ZODB.Connection import Connection
......
......@@ -80,7 +80,7 @@ class Application(ThreadedApplication):
self._cache = ClientCache() if cache_size is None else \
ClientCache(max_size=cache_size)
self._loading_oid = None
self.new_oid_list = ()
self.new_oids = ()
self.last_oid = '\0' * 8
self.storage_event_handler = storage.StorageEventHandler(self)
self.storage_bootstrap_handler = storage.StorageBootstrapHandler(self)
......@@ -187,7 +187,7 @@ class Application(ThreadedApplication):
with self._connecting_to_master_node:
result = self.master_conn
if result is None:
self.new_oid_list = ()
self.new_oids = ()
result = self.master_conn = self._connectToPrimaryNode()
return result
......@@ -311,15 +311,19 @@ class Application(ThreadedApplication):
"""Get a new OID."""
self._oid_lock_acquire()
try:
if not self.new_oid_list:
for oid in self.new_oids:
break
else:
# Get new oid list from master node
# we manage a list of oid here to prevent
# from asking too many time new oid one by one
# from master node
self._askPrimary(Packets.AskNewOIDs(100))
if not self.new_oid_list:
for oid in self.new_oids:
break
else:
raise NEOStorageError('new_oid failed')
self.last_oid = oid = self.new_oid_list.pop()
self.last_oid = oid
return oid
finally:
self._oid_lock_release()
......@@ -611,7 +615,7 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, txn_context.cache_dict)
str(transaction.description), ext, list(txn_context.cache_dict))
queue = txn_context.queue
conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata.
......@@ -696,7 +700,7 @@ class Application(ThreadedApplication):
else:
try:
notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.conn_dict))
list(txn_context.conn_dict)))
except ConnectionClosed:
pass
# We don't need to flush queue, as it won't be reused by future
......@@ -735,7 +739,8 @@ class Application(ThreadedApplication):
for oid in checked_list:
del cache_dict[oid]
ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list)
p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid
......
......@@ -160,8 +160,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list):
oid_list.reverse()
self.app.new_oid_list = oid_list
self.app.new_oids = iter(oid_list)
def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be"
......
......@@ -26,7 +26,7 @@ from .exception import NEOStorageError
class _WakeupPacket(object):
handler_method_name = 'pong'
decode = tuple
_args = ()
getId = int
class Transaction(object):
......
......@@ -16,12 +16,19 @@
from functools import wraps
from time import time
import msgpack
from msgpack.exceptions import UnpackValueError
from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets
from .util import dummy_read_buffer, ReadBuffer
from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
Unpacker
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception):
pass
......@@ -291,7 +298,7 @@ class ListeningConnection(BaseConnection):
# message.
else:
conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION
self.em.addWriter(conn) # for HANDSHAKE_PACKET
def getAddress(self):
return self.connector.getAddress()
......@@ -310,12 +317,12 @@ class Connection(BaseConnection):
client = False
server = False
peer_id = None
_parser_state = None
_total_unpacked = 0
_timeout = None
def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer()
self.read_buf = Unpacker()
self.cur_id = 0
self.aborted = False
self.uuid = None
......@@ -425,42 +432,39 @@ class Connection(BaseConnection):
self._closure()
def _parse(self):
read = self.read_buf.read
version = read(4)
if version is None:
return
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE,
PACKET_HEADER_FORMAT, Packets)
if version != ENCODED_VERSION:
logging.warning('Protocol version mismatch with %r', self)
from .protocol import HANDSHAKE_PACKET, MAGIC_SIZE, Packets
read_buf = self.read_buf
handshake = read_buf.read_bytes(len(HANDSHAKE_PACKET))
if handshake != HANDSHAKE_PACKET:
if HANDSHAKE_PACKET.startswith(handshake): # unlikely so tested last
# Not enough data and there's no API to know it in advance.
# Put it back.
read_buf.feed(handshake)
return
if HANDSHAKE_PACKET.startswith(handshake[:MAGIC_SIZE]):
logging.warning('Protocol version mismatch with %r', self)
else:
logging.debug('Rejecting non-NEO %r', self)
raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size
unpack = PACKET_HEADER_FORMAT.unpack
read_next = read_buf.next
read_pos = read_buf.tell
def parse():
state = self._parser_state
if state is None:
header = read(header_size)
if header is None:
return
msg_id, msg_type, msg_len = unpack(header)
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
else:
msg_id, packet_klass, msg_len = state
data = read(msg_len)
if data is None:
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
try:
msg_id, msg_type, args = read_next()
except StopIteration:
return
except UnpackValueError as e:
raise PacketMalformedError(str(e))
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
pos = read_pos()
packet = packet_klass(*args)
packet.setId(msg_id)
packet.size = pos - self._total_unpacked
self._total_unpacked = pos
return packet
self._parse = parse
return parse()
......@@ -513,7 +517,7 @@ class Connection(BaseConnection):
def close(self):
if self.connector is None:
assert self._on_close is None
assert not self.read_buf
assert not self.read_buf.read_bytes(1)
assert not self.isPending()
return
# process the network events with the last registered handler to
......@@ -524,7 +528,7 @@ class Connection(BaseConnection):
if self._on_close is not None:
self._on_close()
self._on_close = None
self.read_buf.clear()
self.read_buf = dummy_read_buffer
try:
if self.connecting:
handler.connectionFailed(self)
......
......@@ -19,7 +19,7 @@ import ssl
import errno
from time import time
from . import logging
from .protocol import ENCODED_VERSION
from .protocol import HANDSHAKE_PACKET
# Global connector registry.
# Fill by calling registerConnectorHandler.
......@@ -74,15 +74,14 @@ class SocketConnector(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION]
self.queue_size = len(ENCODED_VERSION)
self.queued = [HANDSHAKE_PACKET]
self.queue_size = len(HANDSHAKE_PACKET)
return self
def queue(self, data):
was_empty = not self.queued
self.queued += data
for data in data:
self.queue_size += len(data)
self.queued.append(data)
self.queue_size += len(data)
return was_empty
def _error(self, op, exc=None):
......@@ -172,7 +171,7 @@ class SocketConnector(object):
except socket.error, e:
self._error('recv', e)
if data:
read_buf.append(data)
read_buf.feed(data)
return
self._error('recv')
......@@ -278,7 +277,7 @@ class _SSL:
def receive(self, read_buf):
try:
while 1:
read_buf.append(self.socket.recv(4096))
read_buf.feed(self.socket.recv(4096))
except ssl.SSLWantReadError:
pass
except socket.error, e:
......
......@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object):
handler_method_name = 'connectionClosed'
decode = tuple
_args = ()
class getId(object):
def __eq__(self, other):
......
......@@ -71,7 +71,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name)
except AttributeError:
raise UnexpectedPacketError('no handler found')
args = packet.decode() or ()
args = packet._args
method(conn, *args, **kw)
except DelayEvent, e:
assert not kw, kw
......@@ -79,9 +79,6 @@ class EventHandler(object):
except UnexpectedPacketError, e:
if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message:
if not conn.isClosed():
if not message.args:
......
......@@ -152,7 +152,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False):
from . import protocol as p
global uuid_str
global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str
if self._db is not None:
self._db.close()
......@@ -250,7 +251,7 @@ class NEOLogger(Logger):
'>' if r.outgoing else '<', uuid_str(r.uuid), ip, port)
msg = r.msg
if msg is not None:
msg = buffer(msg)
msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg]
else:
......@@ -299,9 +300,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing):
if self._db is not None:
body = packet._body
if self._max_packet and self._max_packet < len(body):
body = None
if self._max_packet and self._max_packet < packet.size:
args = None
else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord(
created=time(),
msg_id=packet._id,
......@@ -309,7 +315,7 @@ class NEOLogger(Logger):
outgoing=outgoing,
uuid=connection.getUUID(),
addr=connection.getAddress(),
msg=body))
msg=args))
def node(self, *cluster_nid):
name = self.name and str(self.name)
......
This diff is collapsed.
......@@ -166,65 +166,6 @@ def parseMasterList(masters):
return map(parseNodeAddress, masters.split())
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
TODO: For better performance, use:
- socket.recv_into (64kiB blocks)
- struct.unpack_from
- and a circular buffer of dynamic size (initial size:
twice the length passed to socket.recv_into ?)
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
self.size += len(data)
self.content.append(data)
def __len__(self):
""" Return the current buffer size """
return self.size
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
to_read = size
# select required chunks
while to_read > 0:
chunk_data = pop_chunk()
to_read -= len(chunk_data)
append_data(chunk_data)
if to_read < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:to_read], last_chunk[to_read:]
self.content.appendleft(let)
chunk_list[-1] = keep
# join all chunks (one copy)
data = ''.join(chunk_list)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
dummy_read_buffer = ReadBuffer()
dummy_read_buffer.append = lambda _: None
class cached_property(object):
"""
A property that is only computed once per instance and then replaces itself
......
......@@ -584,7 +584,9 @@ class Application(BaseApplication):
self.tm.executeQueuedEvents()
def startStorage(self, node):
node.send(Packets.StartOperation(self.backup_tid))
# XXX: Is this boolean 'backup' field needed ?
# Maybe this can be deduced from cluster state.
node.send(Packets.StartOperation(bool(self.backup_tid)))
uuid = node.getUUID()
assert uuid not in self.storage_starting_set
if uuid not in self.storage_ready_dict:
......
......@@ -157,27 +157,49 @@ class Log(object):
for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x])
x = {}
try:
Unpacker = g['Unpacker']
except KeyError:
unpackb = None
else:
from msgpack import ExtraData, UnpackException
def unpackb(data):
u = Unpacker()
u.feed(data)
data = u.unpack()
if u.read_bytes(1):
raise ExtraData
return data
self.PacketMalformedError = UnpackException
self.unpackb = unpackb
if self._decode > 1:
PStruct = g['PStruct']
PBoolean = g['PBoolean']
def hasData(item):
items = item._items
for i, item in enumerate(items):
if isinstance(item, PStruct):
j = hasData(item)
if j:
return (i,) + j
elif (isinstance(item, PBoolean)
and item._name == 'compression'
and i + 2 < len(items)
and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
try:
PStruct = g['PStruct']
except KeyError:
for p in self.Packets.itervalues():
data_path = getattr(p, 'data_path', (None,))
if p._code >> 15 == data_path[0]:
x[p._code] = data_path[1:]
else:
PBoolean = g['PBoolean']
def hasData(item):
items = item._items
for i, item in enumerate(items):
if isinstance(item, PStruct):
j = hasData(item)
if j:
return (i,) + j
elif (isinstance(item, PBoolean)
and item._name == 'compression'
and i + 2 < len(items)
and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
self._getDataPath = x.get
try:
......@@ -215,11 +237,13 @@ class Log(object):
if body is not None:
log = getattr(p, '_neolog', None)
if log or self._decode:
p = p()
p._id = msg_id
p._body = body
try:
args = p.decode()
if self.unpackb:
args = self.unpackb(body)
else:
p = p()
p._body = body
args = p.decode()
except self.PacketMalformedError:
msg.append("Can't decode packet")
else:
......
......@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager):
return r
def loadData(self, data_id):
return self.query("SELECT compression, hash, value"
" FROM data WHERE id=?", (data_id,)).fetchone()
compression, checksum, data = self.query(
"SELECT compression, hash, value FROM data WHERE id=?",
(data_id,)).fetchone()
if checksum:
return compression, str(checksum), str(data)
return compression, checksum, data
def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getReadablePartition(oid)
......
......@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
bool(t[4]), t[0])
conn.answer(p)
def getEventQueue(self):
......
......@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler):
# Sending such packet does not mark the connection
# for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user,
desc, ext, packed, ttid, oid_list), msg_id)
desc, ext, bool(packed), ttid, oid_list), msg_id)
# To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the
......@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid in object_list:
oid_set = object_dict.get(serial)
if oid_set:
if type(oid_set) is list:
if type(oid_set) is tuple:
object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set:
oid_set.remove(oid)
......
......@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0]
ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id)
......@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0]
status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status)
if __name__ == '__main__':
......
......@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id
self.assertTrue(packet.decode()[0])
self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None)
if __name__ == '__main__':
......
......@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase):
def getFakePacket(self):
p = Mock({
'decode': (),
'__repr__': 'Fake Packet',
})
p._args = ()
p.handler_method_name = 'fake_method'
return p
......@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase):
self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn)
self.checkAborted(conn)
# raise PacketMalformedError
conn.mockCalledMethods = {}
def fake(c):
raise PacketMalformedError('message')
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkClosed(conn)
# raise NotReadyError
conn.mockCalledMethods = {}
def fake(c):
......
......@@ -17,7 +17,7 @@
import unittest
import socket
from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress
from neo.lib.util import parseNodeAddress
class UtilTests(NeoUnitTestBase):
......@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase):
self.assertIn(parseNodeAddress('localhost'), local_address(0))
self.assertIn(parseNodeAddress('localhost:10'), local_address(10))
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
if __name__ == "__main__":
unittest.main()
This diff is collapsed.
......@@ -106,7 +106,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(3)
def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode()
tid, upstream_name, source_dict = packet._args
return not upstream_name and all(source_dict.itervalues())
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream) as backup:
......@@ -446,7 +446,7 @@ class ReplicationTests(NEOThreadedTest):
"""
def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \
packet._args[0] == offset and \
conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
......@@ -701,7 +701,7 @@ class ReplicationTests(NEOThreadedTest):
def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)):
ask.append(packet.decode()[2:])
ask.append(packet._args[2:])
def getTIDList():
return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list
......@@ -802,7 +802,7 @@ class ReplicationTests(NEOThreadedTest):
return True
elif not isinstance(packet, Packets.AskFetchTransactions):
return
ask.append(packet.decode())
ask.append(packet._args)
conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset):
......@@ -863,11 +863,11 @@ class ReplicationTests(NEOThreadedTest):
@f.add
def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2])
trans.append(packet._args[2])
elif isinstance(packet, Packets.AskFetchObjects):
if obj:
return True
obj.append(packet.decode()[2])
obj.append(packet._args[2])
s2.start()
self.tic()
cluster.neoctl.enableStorageList([s2.uuid])
......@@ -954,7 +954,7 @@ class ReplicationTests(NEOThreadedTest):
def expected(changed):
s0 = 1, CellStates.UP_TO_DATE
s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE
return changed, 3 * [[s0, (2, s)], [s0, (3, s)]]
return changed, 3 * ((s0, (2, s)), (s0, (3, s)))
for dry_run in True, False:
self.assertEqual(expected(True),
cluster.neoctl.tweakPartitionTable(drop_list, dry_run))
......
......@@ -53,7 +53,7 @@ extras_require = {
'master': [],
'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'],
'storage-importer': zodb_require + ['setproctitle'],
}
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)]
......@@ -108,6 +108,7 @@ setup(
],
},
install_requires = [
'msgpack>=0.5.6',
'python-dateutil', # neolog --from
],
extras_require = extras_require,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment