Commit ea534b05 authored by Julien Muchembled's avatar Julien Muchembled

protocol: switch to msgpack for packet serialization

Not only for performance reasons (which is significant in the case of
replication; tools/matrix is ~3% faster) but also because of several ugly
things in the way packets were defined:
- packet field names, which are only documentary; for roots fields,
  they even just duplicate the packet names
- a lot of repetitions for packet names, and even confusion between the name
  of the packet definition and the name of the actual notify/request packet
- the need to implement field types for anything, like PByte to support new
  compression formats, since PBoolean is not enough

neo/lib/protocol.py is now much smaller.
parent 38e98a12
...@@ -80,7 +80,7 @@ class Application(ThreadedApplication): ...@@ -80,7 +80,7 @@ class Application(ThreadedApplication):
self._cache = ClientCache() if cache_size is None else \ self._cache = ClientCache() if cache_size is None else \
ClientCache(max_size=cache_size) ClientCache(max_size=cache_size)
self._loading_oid = None self._loading_oid = None
self.new_oid_list = () self.new_oids = ()
Please register or sign in to reply
self.last_oid = '\0' * 8 self.last_oid = '\0' * 8
self.storage_event_handler = storage.StorageEventHandler(self) self.storage_event_handler = storage.StorageEventHandler(self)
self.storage_bootstrap_handler = storage.StorageBootstrapHandler(self) self.storage_bootstrap_handler = storage.StorageBootstrapHandler(self)
...@@ -187,7 +187,7 @@ class Application(ThreadedApplication): ...@@ -187,7 +187,7 @@ class Application(ThreadedApplication):
with self._connecting_to_master_node: with self._connecting_to_master_node:
result = self.master_conn result = self.master_conn
if result is None: if result is None:
self.new_oid_list = () self.new_oids = ()
result = self.master_conn = self._connectToPrimaryNode() result = self.master_conn = self._connectToPrimaryNode()
return result return result
...@@ -312,15 +312,19 @@ class Application(ThreadedApplication): ...@@ -312,15 +312,19 @@ class Application(ThreadedApplication):
"""Get a new OID.""" """Get a new OID."""
self._oid_lock_acquire() self._oid_lock_acquire()
try: try:
if not self.new_oid_list: for oid in self.new_oids:
break
else:
# Get new oid list from master node # Get new oid list from master node
# we manage a list of oid here to prevent # we manage a list of oid here to prevent
# from asking too many time new oid one by one # from asking too many time new oid one by one
# from master node # from master node
self._askPrimary(Packets.AskNewOIDs(100)) self._askPrimary(Packets.AskNewOIDs(100))
if not self.new_oid_list: for oid in self.new_oids:
break
else:
raise NEOStorageError('new_oid failed') raise NEOStorageError('new_oid failed')
self.last_oid = oid = self.new_oid_list.pop() self.last_oid = oid
return oid return oid
finally: finally:
self._oid_lock_release() self._oid_lock_release()
...@@ -612,7 +616,7 @@ class Application(ThreadedApplication): ...@@ -612,7 +616,7 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, txn_context.cache_dict) str(transaction.description), ext, list(txn_context.cache_dict))
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -697,7 +701,7 @@ class Application(ThreadedApplication): ...@@ -697,7 +701,7 @@ class Application(ThreadedApplication):
else: else:
try: try:
notify(Packets.AbortTransaction(txn_context.ttid, notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.conn_dict)) list(txn_context.conn_dict)))
  • FWIW, the python coding style for multiline arguments is AFAIK either: all parameters on the right of the parenthesis, or all indented. Here the first argument follows the former, and the second follows the latter.

  • Let's relax rules about positioning arguments. There's nothing that's not ugly in some cases and I don't like modifying an extra line (i.e. breaking git-blame) for so little/questionable benefit. I'm also not fond of isolating each argument on its own line as it tends to explode the number of lines, and I see much fewer code on screen without scrolling.

Please register or sign in to reply
except ConnectionClosed: except ConnectionClosed:
pass pass
# We don't need to flush queue, as it won't be reused by future # We don't need to flush queue, as it won't be reused by future
...@@ -736,7 +740,8 @@ class Application(ThreadedApplication): ...@@ -736,7 +740,8 @@ class Application(ThreadedApplication):
for oid in checked_list: for oid in checked_list:
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
......
...@@ -164,8 +164,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -164,8 +164,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
oid_list.reverse() self.app.new_oids = iter(oid_list)
self.app.new_oid_list = oid_list
def incompleteTransaction(self, conn, message): def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be" raise NEOStorageError("storage nodes for which vote failed can not be"
......
...@@ -26,7 +26,7 @@ from .exception import NEOStorageError ...@@ -26,7 +26,7 @@ from .exception import NEOStorageError
class _WakeupPacket(object): class _WakeupPacket(object):
handler_method_name = 'pong' handler_method_name = 'pong'
decode = tuple _args = ()
getId = int getId = int
class Transaction(object): class Transaction(object):
......
...@@ -16,12 +16,19 @@ ...@@ -16,12 +16,19 @@
from functools import wraps from functools import wraps
from time import time from time import time
import msgpack
from msgpack.exceptions import UnpackValueError
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
from .util import dummy_read_buffer, ReadBuffer Unpacker
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
...@@ -310,12 +317,12 @@ class Connection(BaseConnection): ...@@ -310,12 +317,12 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_parser_state = None _total_unpacked = 0
_timeout = None _timeout = None
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer() self.read_buf = Unpacker()
self.cur_id = 0 self.cur_id = 0
self.aborted = False self.aborted = False
self.uuid = None self.uuid = None
...@@ -425,42 +432,36 @@ class Connection(BaseConnection): ...@@ -425,42 +432,36 @@ class Connection(BaseConnection):
self._closure() self._closure()
def _parse(self): def _parse(self):
read = self.read_buf.read from .protocol import ENCODED_VERSION, Packets
version = read(4) read_buf = self.read_buf
if version is None: version = read_buf.read_bytes(4)
return
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE,
PACKET_HEADER_FORMAT, Packets)
if version != ENCODED_VERSION: if version != ENCODED_VERSION:
if len(version) < 4: # unlikely so tested last
# Not enough data and there's no API to know it in advance.
# Put it back.
read_buf.feed(version)
return
logging.warning('Protocol version mismatch with %r', self) logging.warning('Protocol version mismatch with %r', self)
raise ConnectorException raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size read_next = read_buf.next
unpack = PACKET_HEADER_FORMAT.unpack read_pos = read_buf.tell
def parse(): def parse():
state = self._parser_state try:
if state is None: msg_id, msg_type, args = read_next()
header = read(header_size) except StopIteration:
if header is None: return
return except UnpackValueError as e:
msg_id, msg_type, msg_len = unpack(header) raise PacketMalformedError(str(e))
try: try:
packet_klass = Packets[msg_type] packet_klass = Packets[msg_type]
except KeyError: except KeyError:
raise PacketMalformedError('Unknown packet type') raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE: pos = read_pos()
  • Are we still able to enforce packet size with msgpack ? The intent of this check is to prevent remote-controlled memory explosion, and AFAICS each string exchanged may take 4GB as far as msgpack is concerned.

  • No (see added comment in neo.lib.protocol).

    prevent remote-controlled memory explosion

    Yes, if that's unintentional (bug). I mean that NEO was designed from the beginning as a single application that spans over several processes/hosts: each process fully trusts each other. Given the complexity of what we want to do, it looks impossible to open a NEO cluster to anyone. That probably even applies to all ZODB implementation because, for example, you couldn't prevent a malicious process to block all writes by not calling tpc_finish.

    From there, the check on packet size was only an assertion.

    It's ok to add checks when doable easily, but that should not block anything.

  • NEO was designed from the beginning as a single application

    Yes, and it's fine for me once the NEO protocol nature of both ends is established. But NEO must not die if an admin does a typo in a port number of an unrelated program which would happen to speak the same next-layer generic protocol as NEO (so TCP, with or without TLS, and now msgpack).

    So NEO protocol identification packet, at the very least, should not be allowed to blow memory up.

    Maybe this packet is already correctly structured, I do not know. Enforcing systematic max packet size avoids having to check, and is a generally good idea for anything network-related (hence my hope that msgpack has such feature).

  • AFAICS each string exchanged may take 4GB as far as msgpack is concerned

    I already use msgpack with max_buffer_size=64MB, so each item can't exceed this size (cf a comment in protocol.py). Which means that you can still reach 4GB by, for example, constructing a list of 1024 strings of 4MB. You can't just cause a DoS by only changing a few bytes: the following data must be valid msgpack data and for how long you want to DoS.

    Yes, and it's fine for me once the NEO protocol nature of both ends is established.

    Currently, it's only the first 4 bytes, for the version.

    So NEO protocol identification packet, at the very least, should not be allowed to blow memory up.

    The RequestIdentification has never been mandatory at the connection level. neoctl does not use it when connecting to an admin node and that was a problem when the version was checked via this packet. So in theory, enforcing a limit for the size of this packet is not enough.

    To sum up, the only way I see to improve things is to have more than the 4-bytes version at the beginning: some NEO signature. But that does not look useful to me.

  • I already use msgpack with max_buffer_size=64MB, so each item can't exceed this size

    This is good.

    Which means that you can still reach 4GB by, for example, constructing a list of 1024 strings of 4MB.

    Does msgpack validate received data against a schema while parsing ? Like "I'm nost supposed to receive a list at this place in this type of message".

    If it does, then if care is taken to not have a list of strings early in the protocol it's fine for me.

    If it doesn't, and sender can emit whatever, then I think we need to treat the initial handshake specially, outside of msgpack, with strict & defensive parsing: needs to send X bytes in a single network frame, not more, not less, and not any of the 99.99...% of values possible for these bytes. As you say below, we expect a version on 4 bytes, then as long as we do not accept many thousands of possible versions as valid at any point, then it should be fine. Adding a fixed prefix (ex: NEO\x00) would weed out other protocols which would use a similar pattern.

    You can't just cause a DoS by only changing a few bytes: the following data must be valid msgpack data and for how long you want to DoS.

    Again, I'm not talking about malicious DoS, but about "I crossed the streams !" accidents must not result in a trainwreck, database data loss, NEO process crash, etc.

    The RequestIdentification has never been mandatory at the connection level. neoctl does not use it when connecting to an admin node and that was a problem when the version was checked via this packet. So in theory, enforcing a limit for the size of this packet is not enough.

    Is it ? I seem to remember a time where connections received a minimal handler on establishment until identification happened, and only then switched to the full handler. This was IIRC needed at least on master and storage nodes as more than one type of node connect, and they need to be distinguished so the proper handler is chosen. Am I making memories up ?

    To sum up, the only way I see to improve things is to have more than the 4-bytes version at the beginning: some NEO signature.

    I agree.

    But that does not look useful to me.

    I disagree.

    Also, it should be rather easy to implement: I believe that now that SSL support is implemented the connection is being polled until SSL handshake is finished. This should be possible to extend to also cover emission & reception of such mandatory pair of packets with minimal code footprint, and a half-RTT of latency (which is as low as it gets) added on connection establishment (which should be a rare event, so pointless to micro-optimise for).

  • As you say below, we expect a version on 4 bytes, then as long as we do not accept many thousands of possible versions as valid at any point, then it should be fine.

    There's no plan to accept other values than the one defined in the code.

    You can't just cause a DoS by only changing a few bytes: the following data must be valid msgpack data and for how long you want to DoS.

    Again, I'm not talking about malicious DoS, but about "I crossed the streams !" accidents must not result in a trainwreck, database data loss, NEO process crash, etc.

    I wasn't talking about malicious DoS.

    Does msgpack validate received data against a schema while parsing ?

    It doesn't.

    Is it ? I seem to remember a time where connections received a minimal handler on establishment until identification happened, and only then switched to the full handler. This was IIRC needed at least on master and storage nodes as more than one type of node connect, and they need to be distinguished so the proper handler is chosen. Am I making memories up ?

    See a60e36e8, also the fact that there's no node type for neoctl (to fill the first field of RequestIdentification).

    I'm not saying a RequestIdentification is not mandatory for nodes that have a node type defined. What I wrote is that it's not mandatory at the connection level (neo.lib.connection.py).

Please register or sign in to reply
raise PacketMalformedError('message too big (%d)' % msg_len) packet = packet_klass(*args)
else: packet.setId(msg_id)
msg_id, packet_klass, msg_len = state packet.size = pos - self._total_unpacked
data = read(msg_len) self._total_unpacked = pos
if data is None: return packet
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
self._parse = parse self._parse = parse
return parse() return parse()
...@@ -513,7 +514,7 @@ class Connection(BaseConnection): ...@@ -513,7 +514,7 @@ class Connection(BaseConnection):
def close(self): def close(self):
if self.connector is None: if self.connector is None:
assert self._on_close is None assert self._on_close is None
assert not self.read_buf assert not self.read_buf.read_bytes(1)
assert not self.isPending() assert not self.isPending()
return return
# process the network events with the last registered handler to # process the network events with the last registered handler to
...@@ -524,7 +525,7 @@ class Connection(BaseConnection): ...@@ -524,7 +525,7 @@ class Connection(BaseConnection):
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
self.read_buf.clear() self.read_buf = dummy_read_buffer
try: try:
if self.connecting: if self.connecting:
handler.connectionFailed(self) handler.connectionFailed(self)
......
...@@ -80,9 +80,8 @@ class SocketConnector(object): ...@@ -80,9 +80,8 @@ class SocketConnector(object):
def queue(self, data): def queue(self, data):
was_empty = not self.queued was_empty = not self.queued
self.queued += data self.queued.append(data)
for data in data: self.queue_size += len(data)
self.queue_size += len(data)
return was_empty return was_empty
def _error(self, op, exc=None): def _error(self, op, exc=None):
...@@ -172,7 +171,7 @@ class SocketConnector(object): ...@@ -172,7 +171,7 @@ class SocketConnector(object):
except socket.error, e: except socket.error, e:
self._error('recv', e) self._error('recv', e)
if data: if data:
read_buf.append(data) read_buf.feed(data)
return return
self._error('recv') self._error('recv')
...@@ -278,7 +277,7 @@ class _SSL: ...@@ -278,7 +277,7 @@ class _SSL:
def receive(self, read_buf): def receive(self, read_buf):
try: try:
while 1: while 1:
read_buf.append(self.socket.recv(4096)) read_buf.feed(self.socket.recv(4096))
except ssl.SSLWantReadError: except ssl.SSLWantReadError:
pass pass
except socket.error, e: except socket.error, e:
......
...@@ -23,7 +23,7 @@ NOBODY = [] ...@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object): class _ConnectionClosed(object):
handler_method_name = 'connectionClosed' handler_method_name = 'connectionClosed'
decode = tuple _args = ()
class getId(object): class getId(object):
def __eq__(self, other): def __eq__(self, other):
......
...@@ -68,7 +68,7 @@ class EventHandler(object): ...@@ -68,7 +68,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name) method = getattr(self, packet.handler_method_name)
except AttributeError: except AttributeError:
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet._args
method(conn, *args, **kw) method(conn, *args, **kw)
except DelayEvent, e: except DelayEvent, e:
assert not kw, kw assert not kw, kw
...@@ -76,9 +76,6 @@ class EventHandler(object): ...@@ -76,9 +76,6 @@ class EventHandler(object):
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
if not conn.isClosed(): if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message: except NotReadyError, message:
if not conn.isClosed(): if not conn.isClosed():
if not message.args: if not message.args:
......
...@@ -152,7 +152,8 @@ class NEOLogger(Logger): ...@@ -152,7 +152,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False): def _setup(self, filename=None, reset=False):
from . import protocol as p from . import protocol as p
global uuid_str global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str uuid_str = p.uuid_str
if self._db is not None: if self._db is not None:
self._db.close() self._db.close()
...@@ -250,7 +251,7 @@ class NEOLogger(Logger): ...@@ -250,7 +251,7 @@ class NEOLogger(Logger):
'>' if r.outgoing else '<', uuid_str(r.uuid), ip, port) '>' if r.outgoing else '<', uuid_str(r.uuid), ip, port)
msg = r.msg msg = r.msg
if msg is not None: if msg is not None:
msg = buffer(msg) msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)" q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg] x = [r.created, nid, r.msg_id, r.code, peer, msg]
else: else:
...@@ -299,9 +300,14 @@ class NEOLogger(Logger): ...@@ -299,9 +300,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing): def packet(self, connection, packet, outgoing):
if self._db is not None: if self._db is not None:
body = packet._body if self._max_packet and self._max_packet < packet.size:
if self._max_packet and self._max_packet < len(body): args = None
body = None else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord( self._queue(PacketRecord(
created=time(), created=time(),
msg_id=packet._id, msg_id=packet._id,
...@@ -309,7 +315,7 @@ class NEOLogger(Logger): ...@@ -309,7 +315,7 @@ class NEOLogger(Logger):
outgoing=outgoing, outgoing=outgoing,
uuid=connection.getUUID(), uuid=connection.getUUID(),
addr=connection.getAddress(), addr=connection.getAddress(),
msg=body)) msg=args))
def node(self, *cluster_nid): def node(self, *cluster_nid):
name = self.name and str(self.name) name = self.name and str(self.name)
......
This diff is collapsed.
...@@ -166,65 +166,6 @@ def parseMasterList(masters): ...@@ -166,65 +166,6 @@ def parseMasterList(masters):
return map(parseNodeAddress, masters.split()) return map(parseNodeAddress, masters.split())
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
TODO: For better performance, use:
- socket.recv_into (64kiB blocks)
- struct.unpack_from
- and a circular buffer of dynamic size (initial size:
twice the length passed to socket.recv_into ?)
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
self.size += len(data)
self.content.append(data)
def __len__(self):
""" Return the current buffer size """
return self.size
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
to_read = size
# select required chunks
while to_read > 0:
chunk_data = pop_chunk()
to_read -= len(chunk_data)
append_data(chunk_data)
if to_read < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:to_read], last_chunk[to_read:]
self.content.appendleft(let)
chunk_list[-1] = keep
# join all chunks (one copy)
data = ''.join(chunk_list)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
dummy_read_buffer = ReadBuffer()
dummy_read_buffer.append = lambda _: None
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
......
...@@ -578,7 +578,9 @@ class Application(BaseApplication): ...@@ -578,7 +578,9 @@ class Application(BaseApplication):
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def startStorage(self, node): def startStorage(self, node):
node.send(Packets.StartOperation(self.backup_tid)) # XXX: Is this boolean 'backup' field needed ?
# Maybe this can be deduced from cluster state.
node.send(Packets.StartOperation(bool(self.backup_tid)))
uuid = node.getUUID() uuid = node.getUUID()
assert uuid not in self.storage_starting_set assert uuid not in self.storage_starting_set
if uuid not in self.storage_ready_dict: if uuid not in self.storage_ready_dict:
......
...@@ -157,27 +157,49 @@ class Log(object): ...@@ -157,27 +157,49 @@ class Log(object):
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {} x = {}
try:
Unpacker = g['Unpacker']
except KeyError:
unpackb = None
else:
from msgpack import ExtraData, UnpackException
def unpackb(data):
u = Unpacker()
u.feed(data)
data = u.unpack()
if u.read_bytes(1):
raise ExtraData
return data
self.PacketMalformedError = UnpackException
self.unpackb = unpackb
if self._decode > 1: if self._decode > 1:
PStruct = g['PStruct'] try:
PBoolean = g['PBoolean'] PStruct = g['PStruct']
def hasData(item): except KeyError:
items = item._items for p in self.Packets.itervalues():
for i, item in enumerate(items): data_path = getattr(p, 'data_path', (None,))
if isinstance(item, PStruct): if p._code >> 15 == data_path[0]:
j = hasData(item) x[p._code] = data_path[1:]
if j: else:
return (i,) + j PBoolean = g['PBoolean']
elif (isinstance(item, PBoolean) def hasData(item):
and item._name == 'compression' items = item._items
and i + 2 < len(items) for i, item in enumerate(items):
and items[i+2]._name == 'data'): if isinstance(item, PStruct):
return i, j = hasData(item)
for p in self.Packets.itervalues(): if j:
if p._fmt is not None: return (i,) + j
path = hasData(p._fmt) elif (isinstance(item, PBoolean)
if path: and item._name == 'compression'
assert not hasattr(p, '_neolog'), p and i + 2 < len(items)
x[p._code] = path and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
self._getDataPath = x.get self._getDataPath = x.get
try: try:
...@@ -215,11 +237,13 @@ class Log(object): ...@@ -215,11 +237,13 @@ class Log(object):
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode: if log or self._decode:
p = p()
p._id = msg_id
p._body = body
try: try:
args = p.decode() if self.unpackb:
args = self.unpackb(body)
else:
p = p()
p._body = body
args = p.decode()
except self.PacketMalformedError: except self.PacketMalformedError:
msg.append("Can't decode packet") msg.append("Can't decode packet")
else: else:
......
...@@ -451,8 +451,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -451,8 +451,12 @@ class SQLiteDatabaseManager(DatabaseManager):
return r return r
def loadData(self, data_id): def loadData(self, data_id):
return self.query("SELECT compression, hash, value" compression, checksum, data = self.query(
" FROM data WHERE id=?", (data_id,)).fetchone() "SELECT compression, hash, value FROM data WHERE id=?",
(data_id,)).fetchone()
if checksum:
return compression, str(checksum), str(data)
return compression, checksum, data
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
......
...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler): ...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
if type(oid_set) is list: if type(oid_set) is tuple:
object_dict[serial] = oid_set = set(oid_set) object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set: if oid in oid_set:
oid_set.remove(oid) oid_set.remove(oid)
......
...@@ -73,7 +73,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -73,7 +73,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0] ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -85,7 +85,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -85,7 +85,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0] status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -73,7 +73,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -73,7 +73,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack) packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(packet.decode()[0]) self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase): ...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase):
def getFakePacket(self): def getFakePacket(self):
p = Mock({ p = Mock({
'decode': (),
'__repr__': 'Fake Packet', '__repr__': 'Fake Packet',
}) })
p._args = ()
p.handler_method_name = 'fake_method' p.handler_method_name = 'fake_method'
return p return p
...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase): ...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase):
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise PacketMalformedError
conn.mockCalledMethods = {}
def fake(c):
raise PacketMalformedError('message')
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkClosed(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c): def fake(c):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress from neo.lib.util import parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase): ...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase):
self.assertIn(parseNodeAddress('localhost'), local_address(0)) self.assertIn(parseNodeAddress('localhost'), local_address(0))
self.assertIn(parseNodeAddress('localhost:10'), local_address(10)) self.assertIn(parseNodeAddress('localhost:10'), local_address(10))
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
This diff is collapsed.
...@@ -103,7 +103,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -103,7 +103,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(3) importZODB(3)
def delaySecondary(conn, packet): def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate): if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode() tid, upstream_name, source_dict = packet._args
return not upstream_name and all(source_dict.itervalues()) return not upstream_name and all(source_dict.itervalues())
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5, with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream) as backup: upstream=upstream) as backup:
...@@ -443,7 +443,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -443,7 +443,7 @@ class ReplicationTests(NEOThreadedTest):
""" """
def delayAskFetch(conn, packet): def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \ packet._args[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, cell_list): def changePartitionTable(orig, ptid, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
...@@ -695,7 +695,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -695,7 +695,7 @@ class ReplicationTests(NEOThreadedTest):
def logReplication(conn, packet): def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions, if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)): Packets.AskFetchObjects)):
ask.append(packet.decode()[2:]) ask.append(packet._args[2:])
def getTIDList(): def getTIDList():
return [t.tid for t in c.db().storage.iterator()] return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
...@@ -796,7 +796,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -796,7 +796,7 @@ class ReplicationTests(NEOThreadedTest):
return True return True
elif not isinstance(packet, Packets.AskFetchTransactions): elif not isinstance(packet, Packets.AskFetchTransactions):
return return
ask.append(packet.decode()) ask.append(packet._args)
conn, = upstream.master.getConnectionList(backup.master) conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator, with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset): _nextPartitionSortKey=lambda orig, self, offset: offset):
...@@ -857,11 +857,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -857,11 +857,11 @@ class ReplicationTests(NEOThreadedTest):
@f.add @f.add
def delayReplicate(conn, packet): def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions): if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2]) trans.append(packet._args[2])
elif isinstance(packet, Packets.AskFetchObjects): elif isinstance(packet, Packets.AskFetchObjects):
if obj: if obj:
return True return True
obj.append(packet.decode()[2]) obj.append(packet._args[2])
s2.start() s2.start()
self.tic() self.tic()
cluster.neoctl.enableStorageList([s2.uuid]) cluster.neoctl.enableStorageList([s2.uuid])
......
...@@ -53,7 +53,7 @@ extras_require = { ...@@ -53,7 +53,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'], 'storage-importer': zodb_require + ['setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
...@@ -108,6 +108,7 @@ setup( ...@@ -108,6 +108,7 @@ setup(
], ],
}, },
install_requires = [ install_requires = [
'msgpack>=0.5.6',
'python-dateutil', # neolog --from 'python-dateutil', # neolog --from
], ],
extras_require = extras_require, extras_require = extras_require,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment