Commit 711bd878 authored by Julien Muchembled's avatar Julien Muchembled Committed by Levin Zimmermann

protocol: switch to msgpack for packet serialization

Not only for performance reasons (at least 3% faster) but also because of
several ugly things in the way packets were defined:
- packet field names, which are only documentary; for roots fields,
  they even just duplicate the packet names
- a lot of repetitions for packet names, and even confusion between the name
  of the packet definition and the name of the actual notify/request packet
- the need to implement field types for anything, like PByte to support new
  compression formats, since PBoolean is not enough

neo/lib/protocol.py is now much smaller.
parent b89b447c
...@@ -13,6 +13,13 @@ ...@@ -13,6 +13,13 @@
############################################################################## ##############################################################################
def patch(): def patch():
# For msgpack & Py2/ZODB5.
try:
from zodbpickle import binary
binary._pack = bytes.__str__
except ImportError:
pass
from hashlib import md5 from hashlib import md5
from ZODB.Connection import Connection from ZODB.Connection import Connection
......
...@@ -181,7 +181,7 @@ class Application(ThreadedApplication): ...@@ -181,7 +181,7 @@ class Application(ThreadedApplication):
with self._connecting_to_master_node: with self._connecting_to_master_node:
result = self.master_conn result = self.master_conn
if result is None: if result is None:
self.new_oid_list = () self.new_oids = ()
result = self.master_conn = self._connectToPrimaryNode() result = self.master_conn = self._connectToPrimaryNode()
return result return result
...@@ -305,15 +305,19 @@ class Application(ThreadedApplication): ...@@ -305,15 +305,19 @@ class Application(ThreadedApplication):
"""Get a new OID.""" """Get a new OID."""
self._oid_lock_acquire() self._oid_lock_acquire()
try: try:
if not self.new_oid_list: for oid in self.new_oids:
break
else:
# Get new oid list from master node # Get new oid list from master node
# we manage a list of oid here to prevent # we manage a list of oid here to prevent
# from asking too many time new oid one by one # from asking too many time new oid one by one
# from master node # from master node
self._askPrimary(Packets.AskNewOIDs(100)) self._askPrimary(Packets.AskNewOIDs(100))
if not self.new_oid_list: for oid in self.new_oids:
break
else:
raise NEOStorageError('new_oid failed') raise NEOStorageError('new_oid failed')
self.last_oid = oid = self.new_oid_list.pop() self.last_oid = oid
return oid return oid
finally: finally:
self._oid_lock_release() self._oid_lock_release()
...@@ -612,7 +616,7 @@ class Application(ThreadedApplication): ...@@ -612,7 +616,7 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, txn_context.cache_dict) str(transaction.description), ext, list(txn_context.cache_dict))
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -697,7 +701,7 @@ class Application(ThreadedApplication): ...@@ -697,7 +701,7 @@ class Application(ThreadedApplication):
else: else:
try: try:
notify(Packets.AbortTransaction(txn_context.ttid, notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.conn_dict)) list(txn_context.conn_dict)))
except ConnectionClosed: except ConnectionClosed:
pass pass
# No need to flush queue, as it will be destroyed on return, # No need to flush queue, as it will be destroyed on return,
...@@ -731,7 +735,8 @@ class Application(ThreadedApplication): ...@@ -731,7 +735,8 @@ class Application(ThreadedApplication):
for oid in checked_list: for oid in checked_list:
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
......
...@@ -163,8 +163,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -163,8 +163,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
oid_list.reverse() self.app.new_oids = iter(oid_list)
self.app.new_oid_list = oid_list
def incompleteTransaction(self, conn, message): def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be" raise NEOStorageError("storage nodes for which vote failed can not be"
......
...@@ -26,7 +26,7 @@ from .exception import NEOStorageError ...@@ -26,7 +26,7 @@ from .exception import NEOStorageError
class _WakeupPacket(object): class _WakeupPacket(object):
handler_method_name = 'pong' handler_method_name = 'pong'
decode = tuple _args = ()
getId = int getId = int
class Transaction(object): class Transaction(object):
......
...@@ -16,12 +16,19 @@ ...@@ -16,12 +16,19 @@
from functools import wraps from functools import wraps
from time import time from time import time
import msgpack
from msgpack.exceptions import UnpackValueError
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
from .util import dummy_read_buffer, ReadBuffer Unpacker
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
...@@ -292,7 +299,7 @@ class ListeningConnection(BaseConnection): ...@@ -292,7 +299,7 @@ class ListeningConnection(BaseConnection):
# message. # message.
else: else:
conn._connected() conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION self.em.addWriter(conn) # for HANDSHAKE_PACKET
def getAddress(self): def getAddress(self):
return self.connector.getAddress() return self.connector.getAddress()
...@@ -311,12 +318,12 @@ class Connection(BaseConnection): ...@@ -311,12 +318,12 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_parser_state = None _total_unpacked = 0
_timeout = None _timeout = None
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer() self.read_buf = Unpacker()
# NOTE cur_id will be set in Server|Client to maintain `cur_id % 2 == const` invariant # NOTE cur_id will be set in Server|Client to maintain `cur_id % 2 == const` invariant
#self.cur_id = 0 #self.cur_id = 0
self.aborted = False self.aborted = False
...@@ -429,41 +436,38 @@ class Connection(BaseConnection): ...@@ -429,41 +436,38 @@ class Connection(BaseConnection):
self._closure() self._closure()
def _parse(self): def _parse(self):
read = self.read_buf.read from .protocol import HANDSHAKE_PACKET, MAGIC_SIZE, Packets
version = read(4) read_buf = self.read_buf
if version is None: handshake = read_buf.read_bytes(len(HANDSHAKE_PACKET))
if handshake != HANDSHAKE_PACKET:
if HANDSHAKE_PACKET.startswith(handshake): # unlikely so tested last
# Not enough data and there's no API to know it in advance.
# Put it back.
read_buf.feed(handshake)
return return
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE, if HANDSHAKE_PACKET.startswith(handshake[:MAGIC_SIZE]):
PACKET_HEADER_FORMAT, Packets)
if version != ENCODED_VERSION:
logging.warning('Protocol version mismatch with %r', self) logging.warning('Protocol version mismatch with %r', self)
else:
logging.debug('Rejecting non-NEO %r', self)
raise ConnectorException raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size read_next = read_buf.next
unpack = PACKET_HEADER_FORMAT.unpack read_pos = read_buf.tell
def parse(): def parse():
state = self._parser_state try:
if state is None: msg_id, msg_type, args = read_next()
header = read(header_size) except StopIteration:
if header is None:
return return
msg_id, msg_type, msg_len = unpack(header) except UnpackValueError as e:
raise PacketMalformedError(str(e))
try: try:
packet_klass = Packets[msg_type] packet_klass = Packets[msg_type]
except KeyError: except KeyError:
raise PacketMalformedError('Unknown packet type') raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE: pos = read_pos()
raise PacketMalformedError('message too big (%d)' % msg_len) packet = packet_klass(*args)
else: packet.setId(msg_id)
msg_id, packet_klass, msg_len = state packet.size = pos - self._total_unpacked
data = read(msg_len) self._total_unpacked = pos
if data is None:
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet return packet
self._parse = parse self._parse = parse
return parse() return parse()
...@@ -517,7 +521,7 @@ class Connection(BaseConnection): ...@@ -517,7 +521,7 @@ class Connection(BaseConnection):
def close(self): def close(self):
if self.connector is None: if self.connector is None:
assert self._on_close is None assert self._on_close is None
assert not self.read_buf assert not self.read_buf.read_bytes(1)
assert not self.isPending() assert not self.isPending()
return return
# process the network events with the last registered handler to # process the network events with the last registered handler to
...@@ -528,7 +532,7 @@ class Connection(BaseConnection): ...@@ -528,7 +532,7 @@ class Connection(BaseConnection):
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
self.read_buf.clear() self.read_buf = dummy_read_buffer
try: try:
if self.connecting: if self.connecting:
handler.connectionFailed(self) handler.connectionFailed(self)
......
...@@ -19,7 +19,7 @@ import ssl ...@@ -19,7 +19,7 @@ import ssl
import errno import errno
from time import time from time import time
from . import logging from . import logging
from .protocol import ENCODED_VERSION from .protocol import HANDSHAKE_PACKET
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
...@@ -74,14 +74,13 @@ class SocketConnector(object): ...@@ -74,14 +74,13 @@ class SocketConnector(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION] self.queued = [HANDSHAKE_PACKET]
self.queue_size = len(ENCODED_VERSION) self.queue_size = len(HANDSHAKE_PACKET)
return self return self
def queue(self, data): def queue(self, data):
was_empty = not self.queued was_empty = not self.queued
self.queued += data self.queued.append(data)
for data in data:
self.queue_size += len(data) self.queue_size += len(data)
return was_empty return was_empty
...@@ -172,7 +171,7 @@ class SocketConnector(object): ...@@ -172,7 +171,7 @@ class SocketConnector(object):
except socket.error, e: except socket.error, e:
self._error('recv', e) self._error('recv', e)
if data: if data:
read_buf.append(data) read_buf.feed(data)
return return
self._error('recv') self._error('recv')
...@@ -283,7 +282,7 @@ class _SSL: ...@@ -283,7 +282,7 @@ class _SSL:
# non-ragged EOF (peer properly closed its side of connection) # non-ragged EOF (peer properly closed its side of connection)
self._error('recv', None) self._error('recv', None)
return return
read_buf.append(data) read_buf.feed(data)
except ssl.SSLWantReadError: except ssl.SSLWantReadError:
pass pass
except socket.error, e: except socket.error, e:
......
...@@ -23,7 +23,7 @@ NOBODY = [] ...@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object): class _ConnectionClosed(object):
handler_method_name = 'connectionClosed' handler_method_name = 'connectionClosed'
decode = tuple _args = ()
class getId(object): class getId(object):
def __eq__(self, other): def __eq__(self, other):
......
...@@ -71,7 +71,7 @@ class EventHandler(object): ...@@ -71,7 +71,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name) method = getattr(self, packet.handler_method_name)
except AttributeError: except AttributeError:
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet._args
method(conn, *args, **kw) method(conn, *args, **kw)
except DelayEvent, e: except DelayEvent, e:
assert not kw, kw assert not kw, kw
...@@ -79,9 +79,6 @@ class EventHandler(object): ...@@ -79,9 +79,6 @@ class EventHandler(object):
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
if not conn.isClosed(): if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message: except NotReadyError, message:
if not conn.isClosed(): if not conn.isClosed():
if not message.args: if not message.args:
......
...@@ -154,7 +154,8 @@ class NEOLogger(Logger): ...@@ -154,7 +154,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False): def _setup(self, filename=None, reset=False):
from . import protocol as p from . import protocol as p
global uuid_str global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str uuid_str = p.uuid_str
if self._db is not None: if self._db is not None:
self._db.close() self._db.close()
...@@ -257,7 +258,7 @@ class NEOLogger(Logger): ...@@ -257,7 +258,7 @@ class NEOLogger(Logger):
pktcls.__name__, peer, r.pkt.decode()) pktcls.__name__, peer, r.pkt.decode())
""" """
if msg is not None: if msg is not None:
msg = buffer(msg) msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)" q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg] x = [r.created, nid, r.msg_id, r.code, peer, msg]
else: else:
...@@ -307,9 +308,14 @@ class NEOLogger(Logger): ...@@ -307,9 +308,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing): def packet(self, connection, packet, outgoing):
#if True or self._db is not None: #if True or self._db is not None:
if self._db is not None: if self._db is not None:
body = packet._body if self._max_packet and self._max_packet < packet.size:
if self._max_packet and self._max_packet < len(body): args = None
body = None else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord( self._queue(PacketRecord(
pkt=packet, pkt=packet,
created=time(), created=time(),
...@@ -318,7 +324,7 @@ class NEOLogger(Logger): ...@@ -318,7 +324,7 @@ class NEOLogger(Logger):
outgoing=outgoing, outgoing=outgoing,
uuid=connection.getUUID(), uuid=connection.getUUID(),
addr=connection.getAddress(), addr=connection.getAddress(),
msg=body)) msg=args))
def node(self, *cluster_nid): def node(self, *cluster_nid):
name = self.name and str(self.name) name = self.name and str(self.name)
......
...@@ -14,27 +14,63 @@ ...@@ -14,27 +14,63 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys import threading
import traceback from functools import partial
from cStringIO import StringIO from msgpack import packb
from struct import Struct
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and # to upgrade other nodes.
# the high order byte 0 is different from TLS Handshake (0x16). PROTOCOL_VERSION = 0
PROTOCOL_VERSION = 6 # By encoding the handshake packet with msgpack, the whole NEO stream can be
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION) # decoded with msgpack. The first byte is 0x92, which is different from TLS
# Handshake (0x16).
HANDSHAKE_PACKET = packb(('NEO', PROTOCOL_VERSION))
# Used to distinguish non-NEO stream from version mismatch.
MAGIC_SIZE = len(HANDSHAKE_PACKET) - len(packb(PROTOCOL_VERSION))
# Avoid memory errors on corrupted data.
MAX_PACKET_SIZE = 0x4000000
PACKET_HEADER_FORMAT = Struct('!LHL')
RESPONSE_MASK = 0x8000 RESPONSE_MASK = 0x8000
# Avoid some memory errors on corrupted data.
# Before we use msgpack, we limited the size of a whole packet. That's not
# possible anymore because the size is not known in advance. Packets bigger
# than the buffer size are possible (e.g. a huge list of small items) and for
# that we could compare the stream position (Unpacker.tell); it's not worth it.
UNPACK_BUFFER_SIZE = 0x4000000
@apply
def Unpacker():
global registerExtType, packb
from msgpack import ExtType, unpackb, Packer, Unpacker
ext_type_dict = []
kw = dict(use_bin_type=True)
pack_ext = Packer(**kw).pack
def registerExtType(getstate, make):
code = len(ext_type_dict)
ext_type_dict.append(lambda data: make(unpackb(data, use_list=False)))
return lambda obj: ExtType(code, pack_ext(getstate(obj)))
iterable_types = set, tuple
def default(obj):
try:
pack = obj._pack
except AttributeError:
assert type(obj) in iterable_types, type(obj)
return list(obj)
return pack()
lock = threading.Lock()
pack = Packer(default, strict_types=True, **kw).pack
def packb(obj):
with lock: # in case that 'default' is called
return pack(obj)
return partial(Unpacker, use_list=False, max_buffer_size=UNPACK_BUFFER_SIZE,
ext_hook=lambda code, data: ext_type_dict[code](data))
class Enum(tuple): class Enum(tuple):
class Item(int): class Item(int):
__slots__ = '_name', '_enum' __slots__ = '_name', '_enum', '_pack'
def __str__(self): def __str__(self):
return self._name return self._name
def __repr__(self): def __repr__(self):
...@@ -49,31 +85,38 @@ class Enum(tuple): ...@@ -49,31 +85,38 @@ class Enum(tuple):
names = func.func_code.co_names names = func.func_code.co_names
self = tuple.__new__(cls, map(cls.Item, xrange(len(names)))) self = tuple.__new__(cls, map(cls.Item, xrange(len(names))))
self._name = func.__name__ self._name = func.__name__
pack = registerExtType(int, self.__getitem__)
for item, name in zip(self, names): for item, name in zip(self, names):
setattr(self, name, item) setattr(self, name, item)
item._name = name item._name = name
item._enum = self item._enum = self
item._pack = (lambda x: lambda: x)(pack(item))
return self return self
def __repr__(self): def __repr__(self):
return "<Enum %s>" % self._name return "<Enum %s>" % self._name
# The order of extension type is important.
# Enum types first, sorted alphabetically.
@Enum @Enum
def ErrorCodes(): def CellStates():
ACK # Write-only cell. Last transactions are missing because storage is/was down
DENIED # for a while, or because it is new for the partition. It usually becomes
NOT_READY # UP_TO_DATE when replication is done.
OID_NOT_FOUND OUT_OF_DATE
TID_NOT_FOUND # Normal state: cell is writable/readable, and it isn't planned to drop it.
OID_DOES_NOT_EXIST UP_TO_DATE
PROTOCOL_ERROR # Same as UP_TO_DATE, except that it will be discarded as soon as another
REPLICATION_ERROR # node finishes to replicate it. It means a partition is moved from 1 node
CHECKING_ERROR # to another. It is also discarded immediately if out-of-date.
BACKEND_NOT_IMPLEMENTED FEEDING
NON_READABLE_CELL # A check revealed that data differs from other replicas. Cell is neither
READ_ONLY_ACCESS # readable nor writable.
INCOMPLETE_TRANSACTION CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
@Enum @Enum
def ClusterStates(): def ClusterStates():
...@@ -108,11 +151,20 @@ def ClusterStates(): ...@@ -108,11 +151,20 @@ def ClusterStates():
STOPPING_BACKUP STOPPING_BACKUP
@Enum @Enum
def NodeTypes(): def ErrorCodes():
MASTER ACK
STORAGE DENIED
CLIENT NOT_READY
ADMIN OID_NOT_FOUND
TID_NOT_FOUND
OID_DOES_NOT_EXIST
PROTOCOL_ERROR
REPLICATION_ERROR
CHECKING_ERROR
BACKEND_NOT_IMPLEMENTED
NON_READABLE_CELL
READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION
@Enum @Enum
def NodeStates(): def NodeStates():
...@@ -122,23 +174,11 @@ def NodeStates(): ...@@ -122,23 +174,11 @@ def NodeStates():
PENDING PENDING
@Enum @Enum
def CellStates(): def NodeTypes():
# Write-only cell. Last transactions are missing because storage is/was down MASTER
# for a while, or because it is new for the partition. It usually becomes STORAGE
# UP_TO_DATE when replication is done. CLIENT
OUT_OF_DATE ADMIN
# Normal state: cell is writable/readable, and it isn't planned to drop it.
UP_TO_DATE
# Same as UP_TO_DATE, except that it will be discarded as soon as another
# node finishes to replicate it. It means a partition is moved from 1 node
# to another. It is also discarded immediately if out-of-date.
FEEDING
# A check revealed that data differs from other replicas. Cell is neither
# readable nor writable.
CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
# used for logging # used for logging
node_state_prefix_dict = { node_state_prefix_dict = {
...@@ -214,45 +254,24 @@ class NonReadableCell(Exception): ...@@ -214,45 +254,24 @@ class NonReadableCell(Exception):
On such event, the client must retry, preferably another cell. On such event, the client must retry, preferably another cell.
""" """
class Packet(object): class Packet(object):
""" """
Base class for any packet definition. The _fmt class attribute must be Base class for any packet definition.
defined for any non-empty packet.
""" """
_ignore_when_closed = False _ignore_when_closed = False
_request = None _request = None
_answer = None _answer = None
_body = None
_code = None _code = None
_fmt = None
_id = None _id = None
allow_dict = False
nodelay = True nodelay = True
poll_thread = False poll_thread = False
def __init__(self, *args): def __init__(self, *args):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
if args: assert self.allow_dict or dict not in map(type, args), args
buf = StringIO() self._args = args
self._fmt.encode(buf.write, args)
self._body = buf.getvalue()
else:
self._body = ''
def decode(self):
assert self._body is not None
if self._fmt is None:
return ()
buf = StringIO(self._body)
try:
return self._fmt.decode(buf.read)
except ParseError, msg:
name = self.__class__.__name__
raise PacketMalformedError("%s fail (%s)" % (name, msg))
def setContent(self, msg_id, body):
""" Register the packet content for future decoding """
self._id = msg_id
self._body = body
def setId(self, value): def setId(self, value):
self._id = value self._id = value
...@@ -261,14 +280,11 @@ class Packet(object): ...@@ -261,14 +280,11 @@ class Packet(object):
assert self._id is not None, "No identifier applied on the packet" assert self._id is not None, "No identifier applied on the packet"
return self._id return self._id
def encode(self): def encode(self, packb=packb):
""" Encode a packet as a string to send it over the network """ """ Encode a packet as a string to send it over the network """
content = self._body r = packb((self._id, self._code, self._args))
return (PACKET_HEADER_FORMAT.pack(self._id, self._code, len(content)), self.size = len(r)
content) return r
def __len__(self):
return PACKET_HEADER_FORMAT.size + len(self._body)
def __repr__(self): def __repr__(self):
return '%s[%r]' % (self.__class__.__name__, self._id) return '%s[%r]' % (self.__class__.__name__, self._id)
...@@ -281,10 +297,10 @@ class Packet(object): ...@@ -281,10 +297,10 @@ class Packet(object):
return self._code == other._code return self._code == other._code
def isError(self): def isError(self):
return isinstance(self, Error) return self._code == RESPONSE_MASK
def isResponse(self): def isResponse(self):
return self._code & RESPONSE_MASK == RESPONSE_MASK return self._code & RESPONSE_MASK
def getAnswerClass(self): def getAnswerClass(self):
return self._answer return self._answer
...@@ -296,719 +312,242 @@ class Packet(object): ...@@ -296,719 +312,242 @@ class Packet(object):
""" """
return self._ignore_when_closed return self._ignore_when_closed
class ParseError(Exception):
"""
An exception that encapsulate another and build the 'path' of the
packet item that generate the error.
"""
def __init__(self, item, trace):
Exception.__init__(self)
self._trace = trace
self._items = [item]
def append(self, item):
self._items.append(item)
def __repr__(self):
chain = '/'.join([item.getName() for item in reversed(self._items)])
return 'at %s:\n%s' % (chain, self._trace)
__str__ = __repr__
# packet parsers
class PItem(object):
"""
Base class for any packet item, _encode and _decode must be overridden
by subclasses.
"""
def __init__(self, name):
self._name = name
def __repr__(self):
return self.__class__.__name__
def getName(self):
return self._name
def _trace(self, method, *args):
try:
return method(*args)
except ParseError, e:
# trace and forward exception
e.append(self)
raise
except Exception:
# original exception, encapsulate it
trace = ''.join(traceback.format_exception(*sys.exc_info())[2:])
raise ParseError(self, trace)
def encode(self, writer, items):
return self._trace(self._encode, writer, items)
def decode(self, reader):
return self._trace(self._decode, reader)
def _encode(self, writer, items):
raise NotImplementedError, self.__class__.__name__
def _decode(self, reader):
raise NotImplementedError, self.__class__.__name__
class PStruct(PItem):
"""
Aggregate other items
"""
def __init__(self, name, *items):
PItem.__init__(self, name)
self._items = items
def _encode(self, writer, items):
assert len(self._items) == len(items), (items, self._items)
for item, value in zip(self._items, items):
item.encode(writer, value)
def _decode(self, reader): class PacketRegistryFactory(dict):
return tuple([item.decode(reader) for item in self._items])
class PStructItem(PItem): def __call__(self, name, base, d):
""" for k, v in d.items():
A single value encoded with struct if isinstance(v, type) and issubclass(v, Packet):
""" v.__name__ = k
def __init__(self, name): v.handler_method_name = k[0].lower() + k[1:]
PItem.__init__(self, name) # this builds a "singleton"
struct = Struct(self._fmt) return type('PacketRegistry', base, d)(self)
self.pack = struct.pack
self.unpack = struct.unpack
self.size = struct.size
def _encode(self, writer, value):
writer(self.pack(value))
def _decode(self, reader):
return self.unpack(reader(self.size))[0]
class PStructItemOrNone(PStructItem):
def _encode(self, writer, value):
return writer(self._None if value is None else self.pack(value))
def _decode(self, reader):
value = reader(self.size)
return None if value == self._None else self.unpack(value)[0]
class POption(PStruct):
def _encode(self, writer, value): def register(self, doc, ignore_when_closed=None, request=False, error=False,
if value is None: _base=(Packet,), **kw):
writer('\0') """ Register a packet in the packet registry """
code = len(self)
if doc is None:
self[code] = None
return # None registered only to skip a code number (for compatibility)
if error and not request:
assert not code
code = RESPONSE_MASK
kw.update(__doc__=doc, _code=code)
packet = type('', _base, kw)
# register the request
self[code] = packet
if request:
if ignore_when_closed is None:
# By default, on a closed connection:
# - request: ignore
# - answer: keep
# - notification: keep
packet._ignore_when_closed = True
else: else:
writer('\1') assert ignore_when_closed is False
PStruct._encode(self, writer, value) if error:
packet._answer = self[RESPONSE_MASK]
def _decode(self, reader):
if '\0\1'.index(reader(1)):
return PStruct._decode(self, reader)
class PList(PStructItem):
"""
A list of homogeneous items
"""
_fmt = '!L'
def __init__(self, name, item):
PStructItem.__init__(self, name)
self._item = item
def _encode(self, writer, items):
writer(self.pack(len(items)))
item = self._item
for value in items:
item.encode(writer, value)
def _decode(self, reader):
length = self.unpack(reader(self.size))[0]
item = self._item
return [item.decode(reader) for _ in xrange(length)]
class PDict(PStructItem):
"""
A dictionary with custom key and value formats
"""
_fmt = '!L'
def __init__(self, name, key, value):
PStructItem.__init__(self, name)
self._key = key
self._value = value
def _encode(self, writer, item):
assert isinstance(item , dict), (type(item), item)
writer(self.pack(len(item)))
key, value = self._key, self._value
for k, v in item.iteritems():
key.encode(writer, k)
value.encode(writer, v)
def _decode(self, reader):
length = self.unpack(reader(self.size))[0]
key, value = self._key, self._value
new_dict = {}
for _ in xrange(length):
k = key.decode(reader)
v = value.decode(reader)
new_dict[k] = v
return new_dict
class PEnum(PStructItem):
"""
Encapsulate an enumeration value
"""
_fmt = 'b'
def __init__(self, name, enum):
PStructItem.__init__(self, name)
self._enum = enum
def _encode(self, writer, item):
if item is None:
item = -1
writer(self.pack(item))
def _decode(self, reader):
code = self.unpack(reader(self.size))[0]
if code == -1:
return None
try:
return self._enum[code]
except KeyError:
enum = self._enum.__class__.__name__
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code)
class PString(PStructItem):
"""
A variable-length string
"""
_fmt = '!L'
def _encode(self, writer, value):
writer(self.pack(len(value)))
writer(value)
def _decode(self, reader):
length = self.unpack(reader(self.size))[0]
return reader(length)
class PAddress(PString):
"""
An host address (IPv4/IPv6)
"""
def __init__(self, name):
PString.__init__(self, name)
self._port = Struct('!H')
def _encode(self, writer, address):
if address:
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else: else:
PString._encode(self, writer, '') # build a class for the answer
code |= RESPONSE_MASK
def _decode(self, reader): kw['_code'] = code
host = PString._decode(self, reader) answer = packet._answer = self[code] = type('', _base, kw)
if host: return packet, answer
p = self._port else:
return host, p.unpack(reader(p.size))[0] assert ignore_when_closed is None
return packet
class PBoolean(PStructItem):
"""
A boolean value, encoded as a single byte
"""
_fmt = '!?'
class PNumber(PStructItem):
"""
A integer number (4-bytes length)
"""
_fmt = '!L'
class PIndex(PStructItem):
"""
A big integer to defined indexes in a huge list.
"""
_fmt = '!Q'
class PPTID(PStructItemOrNone):
"""
A None value means an invalid PTID
"""
_fmt = '!Q'
_None = Struct(_fmt).pack(0)
class PChecksum(PItem):
"""
A hash (SHA1)
"""
def _encode(self, writer, checksum):
assert len(checksum) == 20, (len(checksum), checksum)
writer(checksum)
def _decode(self, reader):
return reader(20)
class PSignedNull(PStructItemOrNone):
_fmt = '!l'
_None = Struct(_fmt).pack(0)
class PUUID(PSignedNull): class Packets(dict):
""" """
An UUID (node identifier, 4-bytes signed integer) Packet registry that checks packet code uniqueness and provides an index
""" """
__metaclass__ = PacketRegistryFactory()
notify = __metaclass__.register
request = partial(notify, request=True)
class PTID(PItem): Error = notify("""
"""
A transaction identifier
"""
def _encode(self, writer, tid):
if tid is None:
tid = INVALID_TID
assert len(tid) == 8, (len(tid), tid)
writer(tid)
def _decode(self, reader):
tid = reader(8)
if tid == INVALID_TID:
tid = None
return tid
# same definition, for now
POID = PTID
class PFloat(PStructItemOrNone):
"""
A float number (8-bytes length)
"""
_fmt = '!d'
_None = '\xff' * 8
# common definitions
PFEmpty = PStruct('no_content')
PFNodeType = PEnum('type', NodeTypes)
PFNodeState = PEnum('state', NodeStates)
PFCellState = PEnum('state', CellStates)
PFNodeList = PList('node_list',
PStruct('node',
PFNodeType,
PAddress('address'),
PUUID('uuid'),
PFNodeState,
PFloat('id_timestamp'),
),
)
PFCellList = PList('cell_list',
PStruct('cell',
PUUID('uuid'),
PFCellState,
),
)
PFRowList = PList('row_list',
PFCellList,
)
PFHistoryList = PList('history_list',
PStruct('history_entry',
PTID('serial'),
PNumber('size'),
),
)
PFUUIDList = PList('uuid_list',
PUUID('uuid'),
)
PFTidList = PList('tid_list',
PTID('tid'),
)
PFOidList = PList('oid_list',
POID('oid'),
)
# packets definition
class Error(Packet):
"""
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
any other message, even if such a message does not expect a reply any other message, even if such a message does not expect a reply
usually. usually.
:nodes: * -> * :nodes: * -> *
""" """, error=True)
_fmt = PStruct('error',
PNumber('code'),
PString('message'),
)
class Ping(Packet): RequestIdentification, AcceptIdentification = request("""
""" Request a node identification. This must be the first packet for any
Empty request used as network barrier. connection.
:nodes: * -> * :nodes: * -> *
""" """, poll_thread=True)
_answer = PFEmpty
class CloseClient(Packet): Ping, Pong = request("""
""" Empty request used as network barrier.
Tell peer that it can close the connection if it has finished with us.
:nodes: * -> * :nodes: * -> *
""" """)
class RequestIdentification(Packet): CloseClient = notify("""
""" Tell peer that it can close the connection if it has finished with us.
Request a node identification. This must be the first packet for any
connection.
:nodes: * -> * :nodes: * -> *
""" """)
poll_thread = True
AskPrimary, AnswerPrimary = request("""
_fmt = PStruct('request_identification',
PFNodeType,
PUUID('uuid'),
PAddress('address'),
PString('name'),
PFloat('id_timestamp'),
# storage:
PList('devpath', PString('devid')),
PList('new_nid', PNumber('offset')),
)
_answer = PStruct('accept_identification',
PFNodeType,
PUUID('my_uuid'),
PUUID('your_uuid'),
)
class PrimaryMaster(Packet):
"""
Ask node identier of the current primary master. Ask node identier of the current primary master.
:nodes: ctl -> A :nodes: ctl -> A
""" """)
_answer = PStruct('answer_primary',
PUUID('primary_uuid'),
)
class NotPrimaryMaster(Packet): NotPrimaryMaster = notify("""
""" Notify peer that I'm not the primary master. Attach any extra
Notify peer that I'm not the primary master. Attach any extra information information to help the peer joining the cluster.
to help the peer joining the cluster.
:nodes: SM -> * :nodes: SM -> *
""" """)
_fmt = PStruct('not_primary_master',
PSignedNull('primary'), NotifyNodeInformation = notify("""
PList('known_master_list', Notify information about one or more nodes.
PAddress('address'),
), :nodes: M -> *
) """)
class Recovery(Packet): AskRecovery, AnswerRecovery = request("""
"""
Ask storage nodes data needed by master to recover. Ask storage nodes data needed by master to recover.
Reused by `neoctl print ids`. Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M :nodes: M -> S; ctl -> A -> M
""" """)
_answer = PStruct('answer_recovery',
PPTID('ptid'),
PTID('backup_tid'),
PTID('truncate_tid'),
)
class LastIDs(Packet): AskLastIDs, AnswerLastIDs = request("""
""" Ask the last OID/TID so that a master can initialize its
Ask the last OID/TID so that a master can initialize its TransactionManager. TransactionManager. Reused by `neoctl print ids`.
Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M :nodes: M -> S; ctl -> A -> M
""" """)
_answer = PStruct('answer_last_ids',
POID('last_oid'),
PTID('last_tid'),
)
class PartitionTable(Packet): AskPartitionTable, AnswerPartitionTable = request("""
"""
Ask storage node the remaining data needed by master to recover. Ask storage node the remaining data needed by master to recover.
:nodes: M -> S :nodes: M -> S
""" """)
_answer = PStruct('answer_partition_table',
PPTID('ptid'),
PNumber('num_replicas'),
PFRowList,
)
class NotifyPartitionTable(Packet): SendPartitionTable = notify("""
""" Send the full partition table to admin/client/storage nodes on
Send the full partition table to admin/client/storage nodes on connection. connection.
:nodes: M -> A, C, S :nodes: M -> A, C, S
""" """)
_fmt = PStruct('send_partition_table',
PPTID('ptid'),
PNumber('num_replicas'),
PFRowList,
)
class PartitionChanges(Packet): NotifyPartitionChanges = notify("""
"""
Notify about changes in the partition table. Notify about changes in the partition table.
:nodes: M -> * :nodes: M -> *
""" """)
_fmt = PStruct('notify_partition_changes',
PPTID('ptid'), StartOperation = notify("""
PNumber('num_replicas'), Tell a storage node to start operation. Before this message,
PList('cell_list', it must only communicate with the primary master.
PStruct('cell',
PNumber('offset'),
PUUID('uuid'),
PFCellState,
),
),
)
class StartOperation(Packet):
"""
Tell a storage node to start operation. Before this message, it must only
communicate with the primary master.
:nodes: M -> S :nodes: M -> S
""" """)
_fmt = PStruct('start_operation',
# XXX: Is this boolean needed ? Maybe this
# can be deduced from cluster state.
PBoolean('backup'),
)
class StopOperation(Packet): StopOperation = notify("""
""" Notify that the cluster is not operational anymore.
Notify that the cluster is not operational anymore. Any operation between Any operation between nodes must be aborted.
nodes must be aborted.
:nodes: M -> S, C :nodes: M -> S, C
""" """)
class UnfinishedTransactions(Packet): AskUnfinishedTransactions, AnswerUnfinishedTransactions = request("""
""" Ask unfinished transactions, which will be replicated
Ask unfinished transactions, which will be replicated when they're finished. when they're finished.
:nodes: S -> M :nodes: S -> M
""" """)
_fmt = PStruct('ask_unfinished_transactions',
PList('row_list', AskLockedTransactions, AnswerLockedTransactions = request("""
PNumber('offset'), Ask locked transactions to replay committed transactions
), that haven't been unlocked.
)
_answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'),
PList('tid_list',
PTID('unfinished_tid'),
),
)
class LockedTransactions(Packet):
"""
Ask locked transactions to replay committed transactions that haven't been
unlocked.
:nodes: M -> S :nodes: M -> S
""" """, allow_dict=True)
_answer = PStruct('answer_locked_transactions',
PDict('tid_dict', AskFinalTID, AnswerFinalTID = request("""
PTID('ttid'),
PTID('tid'),
),
)
class FinalTID(Packet):
"""
Return final tid if ttid has been committed, to recover from certain Return final tid if ttid has been committed, to recover from certain
failures during tpc_finish. failures during tpc_finish.
:nodes: M -> S; C -> M, S :nodes: M -> S; C -> M, S
""" """)
_fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid', ValidateTransaction = notify("""
PTID('tid'),
)
class ValidateTransaction(Packet):
"""
Do replay a committed transaction that was not unlocked. Do replay a committed transaction that was not unlocked.
:nodes: M -> S :nodes: M -> S
""" """)
_fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'),
)
class BeginTransaction(Packet): AskBeginTransaction, AnswerBeginTransaction = request("""
"""
Ask to begin a new transaction. This maps to `tpc_begin`. Ask to begin a new transaction. This maps to `tpc_begin`.
:nodes: C -> M :nodes: C -> M
""" """)
_fmt = PStruct('ask_begin_transaction',
PTID('tid'),
)
_answer = PStruct('answer_begin_transaction', FailedVote = request("""
PTID('tid'),
)
class FailedVote(Packet):
"""
Report storage nodes for which vote failed. Report storage nodes for which vote failed.
True is returned if it's still possible to finish the transaction. True is returned if it's still possible to finish the transaction.
:nodes: C -> M :nodes: C -> M
""" """, error=True)
_fmt = PStruct('failed_vote',
PTID('tid'),
PFUUIDList,
)
_answer = Error AskFinishTransaction, AnswerTransactionFinished = request("""
class FinishTransaction(Packet):
"""
Finish a transaction. Return the TID of the committed transaction. Finish a transaction. Return the TID of the committed transaction.
This maps to `tpc_finish`. This maps to `tpc_finish`.
:nodes: C -> M :nodes: C -> M
""" """, ignore_when_closed=False, poll_thread=True)
poll_thread = True
_fmt = PStruct('ask_finish_transaction',
PTID('tid'),
PFOidList,
PList('checked_list',
POID('oid'),
),
)
_answer = PStruct('answer_information_locked',
PTID('ttid'),
PTID('tid'),
)
class NotifyTransactionFinished(Packet):
"""
Notify that a transaction blocking a replication is now finished.
:nodes: M -> S AskLockInformation, AnswerInformationLocked = request("""
"""
_fmt = PStruct('notify_transaction_finished',
PTID('ttid'),
PTID('max_tid'),
)
class LockInformation(Packet):
"""
Commit a transaction. The new data is read-locked. Commit a transaction. The new data is read-locked.
:nodes: M -> S :nodes: M -> S
""" """, ignore_when_closed=False)
_fmt = PStruct('ask_lock_informations',
PTID('ttid'),
PTID('tid'),
)
_answer = PStruct('answer_information_locked', InvalidateObjects = notify("""
PTID('ttid'),
)
class InvalidateObjects(Packet):
"""
Notify about a new transaction modifying objects, Notify about a new transaction modifying objects,
invalidating client caches. invalidating client caches.
:nodes: M -> C :nodes: M -> C
""" """)
_fmt = PStruct('ask_finish_transaction',
PTID('tid'),
PFOidList,
)
class UnlockInformation(Packet): NotifyUnlockInformation = notify("""
"""
Notify about a successfully committed transaction. The new data can be Notify about a successfully committed transaction. The new data can be
unlocked. unlocked.
:nodes: M -> S :nodes: M -> S
""" """)
_fmt = PStruct('notify_unlock_information',
PTID('ttid'),
)
class GenerateOIDs(Packet): AskNewOIDs, AnswerNewOIDs = request("""
"""
Ask new OIDs to create objects. Ask new OIDs to create objects.
:nodes: C -> M :nodes: C -> M
""" """)
_fmt = PStruct('ask_new_oids',
PNumber('num_oids'),
)
_answer = PStruct('answer_new_oids',
PFOidList,
)
class Deadlock(Packet): NotifyDeadlock = notify("""
""" Ask master to generate a new TTID that will be used by the client to
Ask master to generate a new TTID that will be used by the client to solve solve a deadlock by rebasing the transaction on top of concurrent
a deadlock by rebasing the transaction on top of concurrent changes. changes.
:nodes: S -> M -> C :nodes: S -> M -> C
""" """)
_fmt = PStruct('notify_deadlock',
PTID('ttid'),
PTID('locking_tid'),
)
class RebaseTransaction(Packet): AskRebaseTransaction, AnswerRebaseTransaction = request("""
"""
Rebase a transaction to solve a deadlock. Rebase a transaction to solve a deadlock.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_rebase_transaction',
PTID('ttid'),
PTID('locking_tid'),
)
_answer = PStruct('answer_rebase_transaction',
PFOidList,
)
class RebaseObject(Packet): AskRebaseObject, AnswerRebaseObject = request("""
"""
Rebase an object change to solve a deadlock. Rebase an object change to solve a deadlock.
:nodes: C -> S :nodes: C -> S
...@@ -1017,341 +556,135 @@ class RebaseObject(Packet): ...@@ -1017,341 +556,135 @@ class RebaseObject(Packet):
efficiency, this should be turned into a notification, and the efficiency, this should be turned into a notification, and the
RebaseTransaction should answered once all objects are rebased RebaseTransaction should answered once all objects are rebased
(so that the client can still wait on something). (so that the client can still wait on something).
""" """, data_path=(1, 0, 2, 0))
_fmt = PStruct('ask_rebase_object',
PTID('ttid'), AskStoreObject, AnswerStoreObject = request("""
PTID('oid'),
)
_answer = PStruct('answer_rebase_object',
POption('conflict',
PTID('serial'),
PTID('conflict_serial'),
POption('data',
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
),
)
)
class StoreObject(Packet):
"""
Ask to create/modify an object. This maps to `store`. Ask to create/modify an object. This maps to `store`.
As for IStorage, 'serial' is ZERO_TID for new objects. As for IStorage, 'serial' is ZERO_TID for new objects.
:nodes: C -> S :nodes: C -> S
""" """, data_path=(0, 2))
_fmt = PStruct('ask_store_object',
POID('oid'), AbortTransaction = notify("""
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
PTID('tid'),
)
_answer = PStruct('answer_store_object',
PTID('conflict'),
)
class AbortTransaction(Packet):
"""
Abort a transaction. This maps to `tpc_abort`. Abort a transaction. This maps to `tpc_abort`.
:nodes: C -> S; C -> M -> S :nodes: C -> S; C -> M -> S
""" """)
_fmt = PStruct('abort_transaction',
PTID('tid'),
PFUUIDList, # unused for * -> S
)
class StoreTransaction(Packet): AskStoreTransaction, AnswerStoreTransaction = request("""
"""
Ask to store a transaction. Implies vote. Ask to store a transaction. Implies vote.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_store_transaction',
PTID('tid'), AskVoteTransaction, AnswerVoteTransaction = request("""
PString('user'),
PString('description'),
PString('extension'),
PFOidList,
)
_answer = PFEmpty
class VoteTransaction(Packet):
"""
Ask to vote a transaction. Ask to vote a transaction.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_vote_transaction',
PTID('tid'),
)
_answer = PFEmpty
class GetObject(Packet): AskObject, AnswerObject = request("""
"""
Ask a stored object by its OID, optionally at/before a specific tid. Ask a stored object by its OID, optionally at/before a specific tid.
This maps to `load/loadBefore/loadSerial`. This maps to `load/loadBefore/loadSerial`.
:nodes: C -> S :nodes: C -> S
""" """, data_path=(1, 3))
_fmt = PStruct('ask_object',
POID('oid'),
PTID('at'),
PTID('before'),
)
_answer = PStruct('answer_object',
POID('oid'),
PTID('serial_start'),
PTID('serial_end'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class TIDList(Packet):
"""
Ask for TIDs between a range of offsets. The order of TIDs is descending,
and the range is [first, last). This maps to `undoLog`.
:nodes: C -> S AskTIDs, AnswerTIDs = request("""
""" Ask for TIDs between a range of offsets. The order of TIDs is
_fmt = PStruct('ask_tids', descending, and the range is [first, last). This maps to `undoLog`.
PIndex('first'),
PIndex('last'),
PNumber('partition'),
)
_answer = PStruct('answer_tids',
PFTidList,
)
class TIDListFrom(Packet):
"""
Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
Used by `iterator`.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('tid_list_from',
PTID('min_tid'), AskTransactionInformation, AnswerTransactionInformation = request("""
PTID('max_tid'),
PNumber('length'),
PNumber('partition'),
)
_answer = PStruct('answer_tids',
PFTidList,
)
class TransactionInformation(Packet):
"""
Ask for transaction metadata. Ask for transaction metadata.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_transaction_information',
PTID('tid'), AskObjectHistory, AnswerObjectHistory = request("""
)
_answer = PStruct('answer_transaction_information',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PFOidList,
)
class ObjectHistory(Packet):
"""
Ask history information for a given object. The order of serials is Ask history information for a given object. The order of serials is
descending, and the range is [first, last]. This maps to `history`. descending, and the range is [first, last]. This maps to `history`.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_object_history',
POID('oid'), AskPartitionList, AnswerPartitionList = request("""
PIndex('first'),
PIndex('last'),
)
_answer = PStruct('answer_object_history',
POID('oid'),
PFHistoryList,
)
class PartitionList(Packet):
"""
Ask information about partitions. Ask information about partitions.
:nodes: ctl -> A :nodes: ctl -> A
""" """)
_fmt = PStruct('ask_partition_list',
PNumber('min_offset'), AskNodeList, AnswerNodeList = request("""
PNumber('max_offset'),
PUUID('uuid'),
)
_answer = PStruct('answer_partition_list',
PPTID('ptid'),
PNumber('num_replicas'),
PFRowList,
)
class NodeList(Packet):
"""
Ask information about nodes. Ask information about nodes.
:nodes: ctl -> A :nodes: ctl -> A
""" """)
_fmt = PStruct('ask_node_list',
PFNodeType,
)
_answer = PStruct('answer_node_list',
PFNodeList,
)
class SetNodeState(Packet): SetNodeState = request("""
"""
Change the state of a node. Change the state of a node.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('set_node_state',
PUUID('uuid'),
PFNodeState,
)
_answer = Error
class AddPendingNodes(Packet): AddPendingNodes = request("""
"""
Mark given pending nodes as running, for future inclusion when tweaking Mark given pending nodes as running, for future inclusion when tweaking
the partition table. the partition table.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('add_pending_nodes',
PFUUIDList,
)
_answer = Error
class TweakPartitionTable(Packet): TweakPartitionTable, AnswerTweakPartitionTable = request("""
"""
Ask the master to balance the partition table, optionally excluding Ask the master to balance the partition table, optionally excluding
specific nodes in anticipation of removing them. specific nodes in anticipation of removing them.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """)
_fmt = PStruct('tweak_partition_table',
PBoolean('dry_run'),
PFUUIDList,
)
_answer = PStruct('answer_tweak_partition_table',
PBoolean('changed'),
PFRowList,
)
class NotifyNodeInformation(Packet): SetNumReplicas = request("""
"""
Notify information about one or more nodes.
:nodes: M -> *
"""
_fmt = PStruct('notify_node_informations',
PFloat('id_timestamp'),
PFNodeList,
)
class SetNumReplicas(Packet):
"""
Set the number of replicas. Set the number of replicas.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('set_num_replicas',
PNumber('num_replicas'),
)
_answer = Error
class SetClusterState(Packet): SetClusterState = request("""
"""
Set the cluster state. Set the cluster state.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('set_cluster_state',
PEnum('state', ClusterStates),
)
_answer = Error
class Repair(Packet): Repair = request("""
"""
Ask storage nodes to repair their databases. Ask storage nodes to repair their databases.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """, error=True)
_flags = map(PBoolean, ('dry_run',
# 'prune_orphan' (commented because it's the only option for the moment)
))
_fmt = PStruct('repair',
PFUUIDList,
*_flags)
_answer = Error
class RepairOne(Packet): NotifyRepair = notify("""
"""
Repair is translated to this message, asking a specific storage node to Repair is translated to this message, asking a specific storage node to
repair its database. repair its database.
:nodes: M -> S :nodes: M -> S
""" """)
_fmt = PStruct('repair', *Repair._flags)
class ClusterInformation(Packet): NotifyClusterInformation = notify("""
"""
Notify about a cluster state change. Notify about a cluster state change.
:nodes: M -> * :nodes: M -> *
""" """)
_fmt = PStruct('notify_cluster_information',
PEnum('state', ClusterStates),
)
class ClusterState(Packet): AskClusterState, AnswerClusterState = request("""
"""
Ask the state of the cluster Ask the state of the cluster
:nodes: ctl -> A; A -> M :nodes: ctl -> A; A -> M
""" """)
_answer = PStruct('answer_cluster_state',
PEnum('state', ClusterStates),
)
class ObjectUndoSerial(Packet): AskObjectUndoSerial, AnswerObjectUndoSerial = request("""
""" Ask storage the serial where object data is when undoing given
Ask storage the serial where object data is when undoing given transaction, transaction, for a list of OIDs.
for a list of OIDs.
object_tid_dict has the following format: Answer a dict mapping oids to 3-tuples:
key: oid
value: 3-tuple
current_serial (TID) current_serial (TID)
The latest serial visible to the undoing transaction. The latest serial visible to the undoing transaction.
undo_serial (TID) undo_serial (TID)
...@@ -1360,484 +693,149 @@ class ObjectUndoSerial(Packet): ...@@ -1360,484 +693,149 @@ class ObjectUndoSerial(Packet):
If current_serial's data is current on storage. If current_serial's data is current on storage.
:nodes: C -> S :nodes: C -> S
""" """, allow_dict=True)
_fmt = PStruct('ask_undo_transaction',
PTID('tid'),
PTID('ltid'),
PTID('undone_tid'),
PFOidList,
)
_answer = PStruct('answer_undo_transaction',
PDict('object_tid_dict',
POID('oid'),
PStruct('object_tid_value',
PTID('current_serial'),
PTID('undo_serial'),
PBoolean('is_current'),
),
),
)
class CheckCurrentSerial(Packet):
"""
Check if given serial is current for the given oid, and lock it so that
this state is not altered until transaction ends.
This maps to `checkCurrentSerialInTransaction`.
:nodes: C -> S AskTIDsFrom, AnswerTIDsFrom = request("""
""" Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
_fmt = PStruct('ask_check_current_serial', Used by `iterator`.
PTID('tid'),
POID('oid'),
PTID('serial'),
)
_answer = StoreObject._answer :nodes: C -> S
""")
class Pack(Packet): AskPack, AnswerPack = request("""
"""
Request a pack at given TID. Request a pack at given TID.
:nodes: C -> M -> S :nodes: C -> M -> S
""" """, ignore_when_closed=False)
_fmt = PStruct('ask_pack',
PTID('tid'),
)
_answer = PStruct('answer_pack',
PBoolean('status'),
)
class CheckReplicas(Packet): CheckReplicas = request("""
""" Ask the cluster to search for mismatches between replicas, metadata
Ask the cluster to search for mismatches between replicas, metadata only, only, and optionally within a specific range. Reference nodes can be
and optionally within a specific range. Reference nodes can be specified. specified.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """, error=True, allow_dict=True)
_fmt = PStruct('check_replicas',
PDict('partition_dict', CheckPartition = notify("""
PNumber('partition'),
PUUID('source'),
),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = Error
class CheckPartition(Packet):
"""
Ask a storage node to compare a partition with all other nodes. Ask a storage node to compare a partition with all other nodes.
Like for CheckReplicas, only metadata are checked, optionally within a Like for CheckReplicas, only metadata are checked, optionally within a
specific range. A reference node can be specified. specific range. A reference node can be specified.
:nodes: M -> S :nodes: M -> S
""" """)
_fmt = PStruct('check_partition',
PNumber('partition'), AskCheckTIDRange, AnswerCheckTIDRange = request("""
PStruct('source',
PString('upstream_name'),
PAddress('address'),
),
PTID('min_tid'),
PTID('max_tid'),
)
class CheckTIDRange(Packet):
"""
Ask some stats about a range of transactions. Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and Used to know if there are differences between a replicating node and
reference node. reference node.
:nodes: S -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_check_tid_range',
PNumber('partition'), AskCheckSerialRange, AnswerCheckSerialRange = request("""
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = PStruct('answer_check_tid_range',
PNumber('count'),
PChecksum('checksum'),
PTID('max_tid'),
)
class CheckSerialRange(Packet):
"""
Ask some stats about a range of object history. Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and Used to know if there are differences between a replicating node and
reference node. reference node.
:nodes: S -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_check_serial_range',
PNumber('partition'), NotifyPartitionCorrupted = notify("""
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
)
_answer = PStruct('answer_check_serial_range',
PNumber('count'),
PChecksum('tid_checksum'),
PTID('max_tid'),
PChecksum('oid_checksum'),
POID('max_oid'),
)
class PartitionCorrupted(Packet):
"""
Notify that mismatches were found while check replicas for a partition. Notify that mismatches were found while check replicas for a partition.
:nodes: S -> M :nodes: S -> M
""" """)
_fmt = PStruct('partition_corrupted',
PNumber('partition'), NotifyReady = notify("""
PList('cell_list', Notify that we're ready to serve requests.
PUUID('uuid'),
), :nodes: S -> M
) """)
class LastTransaction(Packet): AskLastTransaction, AnswerLastTransaction = request("""
"""
Ask last committed TID. Ask last committed TID.
:nodes: C -> M; ctl -> A -> M :nodes: C -> M; ctl -> A -> M
""" """, poll_thread=True)
poll_thread = True
_answer = PStruct('answer_last_transaction', AskCheckCurrentSerial, AnswerCheckCurrentSerial = request("""
PTID('tid'), Check if given serial is current for the given oid, and lock it so that
) this state is not altered until transaction ends.
This maps to `checkCurrentSerialInTransaction`.
class NotifyReady(Packet): :nodes: C -> S
""" """)
Notify that we're ready to serve requests.
NotifyTransactionFinished = notify("""
Notify that a transaction blocking a replication is now finished.
:nodes: M -> S
""")
Replicate = notify("""
Notify a storage node to replicate partitions up to given 'tid'
and from given sources.
args: tid, upstream_name, {partition: address}
- upstream_name: replicate from an upstream cluster
- address: address of the source storage node, or None if there's
no new data up to 'tid' for the given partition
:nodes: M -> S
""", allow_dict=True)
NotifyReplicationDone = notify("""
Notify the master node that a partition has been successfully
replicated from a storage to another.
:nodes: S -> M :nodes: S -> M
""" """)
class FetchTransactions(Packet): AskFetchTransactions, AnswerFetchTransactions = request("""
"""
Ask a storage node to send all transaction data we don't have, Ask a storage node to send all transaction data we don't have,
and reply with the list of transactions we should not have. and reply with the list of transactions we should not have.
:nodes: S -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_transaction_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
PFTidList, # already known transactions
)
_answer = PStruct('answer_transaction_list',
PTID('pack_tid'),
PTID('next_tid'),
PFTidList, # transactions to delete
)
class AddTransaction(Packet):
"""
Send metadata of a transaction to a node that do not have them.
:nodes: S -> S AskFetchObjects, AnswerFetchObjects = request("""
"""
nodelay = False
_fmt = PStruct('add_transaction',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PTID('ttid'),
PFOidList,
)
class FetchObjects(Packet):
"""
Ask a storage node to send object records we don't have, Ask a storage node to send object records we don't have,
and reply with the list of records we should not have. and reply with the list of records we should not have.
:nodes: S -> S :nodes: S -> S
""" """, allow_dict=True)
_fmt = PStruct('ask_object_list',
PNumber('partition'), AddTransaction = notify("""
PNumber('length'), Send metadata of a transaction to a node that does not have them.
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
PDict('object_dict', # already known objects
PTID('serial'),
PFOidList,
),
)
_answer = PStruct('answer_object_list',
PTID('pack_tid'),
PTID('next_tid'),
POID('next_oid'),
PDict('object_dict', # objects to delete
PTID('serial'),
PFOidList,
),
)
class AddObject(Packet):
"""
Send an object record to a node that do not have it.
:nodes: S -> S :nodes: S -> S
""" """, nodelay=False)
nodelay = False
_fmt = PStruct('add_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class Replicate(Packet):
"""
Notify a storage node to replicate partitions up to given 'tid'
and from given sources.
- upstream_name: replicate from an upstream cluster AddObject = notify("""
- address: address of the source storage node, or None if there's no new Send an object record to a node that does not have it.
data up to 'tid' for the given partition
:nodes: M -> S :nodes: S -> S
""" """, nodelay=False, data_path=(0, 2))
_fmt = PStruct('replicate',
PTID('tid'),
PString('upstream_name'),
PDict('source_dict',
PNumber('partition'),
PAddress('address'),
)
)
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successfully replicated
from a storage to another.
:nodes: S -> M
"""
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
PTID('tid'),
)
class Truncate(Packet): Truncate = request("""
"""
Request DB to be truncated. Also used to leave backup mode. Request DB to be truncated. Also used to leave backup mode.
:nodes: ctl -> A -> M; M -> S :nodes: ctl -> A -> M; M -> S
""" """, error=True)
_fmt = PStruct('truncate',
PTID('tid'),
)
_answer = Error
class FlushLog(Packet): FlushLog = notify("""
"""
Request all nodes to flush their logs. Request all nodes to flush their logs.
:nodes: ctl -> A -> M -> * :nodes: ctl -> A -> M -> *
""" """)
_next_code = 0 del notify, request
def register(request, ignore_when_closed=None):
""" Register a packet in the packet registry """
global _next_code
code = _next_code
assert code < RESPONSE_MASK
_next_code = code + 1
if request is Error:
code |= RESPONSE_MASK
# register the request
request._code = code
answer = request._answer
if ignore_when_closed is None:
# By default, on a closed connection:
# - request: ignore
# - answer: keep
# - notification: keep
ignore_when_closed = answer is not None
request._ignore_when_closed = ignore_when_closed
if answer in (Error, None):
return request
# build a class for the answer
answer = type('Answer' + request.__name__, (Packet, ), {})
answer._fmt = request._answer
answer.poll_thread = request.poll_thread
answer._request = request
assert answer._code is None, "Answer of %s is already used" % (request, )
answer._code = code | RESPONSE_MASK
request._answer = answer
return request, answer
class Packets(dict):
"""
Packet registry that checks packet code uniqueness and provides an index
"""
def __metaclass__(name, base, d):
# this builds a "singleton"
cls = type('PacketRegistry', base, d)()
for k, v in d.iteritems():
if isinstance(v, type) and issubclass(v, Packet):
v.handler_method_name = k[0].lower() + k[1:]
cls[v._code] = v
return cls
Error = register(
Error)
RequestIdentification, AcceptIdentification = register(
RequestIdentification, ignore_when_closed=True)
Ping, Pong = register(
Ping)
CloseClient = register(
CloseClient)
AskPrimary, AnswerPrimary = register(
PrimaryMaster)
NotPrimaryMaster = register(
NotPrimaryMaster)
NotifyNodeInformation = register(
NotifyNodeInformation)
AskRecovery, AnswerRecovery = register(
Recovery)
AskLastIDs, AnswerLastIDs = register(
LastIDs)
AskPartitionTable, AnswerPartitionTable = register(
PartitionTable)
SendPartitionTable = register(
NotifyPartitionTable)
NotifyPartitionChanges = register(
PartitionChanges)
StartOperation = register(
StartOperation)
StopOperation = register(
StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions)
AskLockedTransactions, AnswerLockedTransactions = register(
LockedTransactions)
AskFinalTID, AnswerFinalTID = register(
FinalTID)
ValidateTransaction = register(
ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction)
FailedVote = register(
FailedVote)
AskFinishTransaction, AnswerTransactionFinished = register(
FinishTransaction, ignore_when_closed=False)
AskLockInformation, AnswerInformationLocked = register(
LockInformation, ignore_when_closed=False)
InvalidateObjects = register(
InvalidateObjects)
NotifyUnlockInformation = register(
UnlockInformation)
AskNewOIDs, AnswerNewOIDs = register(
GenerateOIDs)
NotifyDeadlock = register(
Deadlock)
AskRebaseTransaction, AnswerRebaseTransaction = register(
RebaseTransaction)
AskRebaseObject, AnswerRebaseObject = register(
RebaseObject)
AskStoreObject, AnswerStoreObject = register(
StoreObject)
AbortTransaction = register(
AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register(
GetObject)
AskTIDs, AnswerTIDs = register(
TIDList)
AskTransactionInformation, AnswerTransactionInformation = register(
TransactionInformation)
AskObjectHistory, AnswerObjectHistory = register(
ObjectHistory)
AskPartitionList, AnswerPartitionList = register(
PartitionList)
AskNodeList, AnswerNodeList = register(
NodeList)
SetNodeState = register(
SetNodeState, ignore_when_closed=False)
AddPendingNodes = register(
AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable, AnswerTweakPartitionTable = register(
TweakPartitionTable)
SetNumReplicas = register(
SetNumReplicas, ignore_when_closed=False)
SetClusterState = register(
SetClusterState, ignore_when_closed=False)
Repair = register(
Repair)
NotifyRepair = register(
RepairOne)
NotifyClusterInformation = register(
ClusterInformation)
AskClusterState, AnswerClusterState = register(
ClusterState)
AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial)
AskTIDsFrom, AnswerTIDsFrom = register(
TIDListFrom)
AskPack, AnswerPack = register(
Pack, ignore_when_closed=False)
CheckReplicas = register(
CheckReplicas)
CheckPartition = register(
CheckPartition)
AskCheckTIDRange, AnswerCheckTIDRange = register(
CheckTIDRange)
AskCheckSerialRange, AnswerCheckSerialRange = register(
CheckSerialRange)
NotifyPartitionCorrupted = register(
PartitionCorrupted)
NotifyReady = register(
NotifyReady)
AskLastTransaction, AnswerLastTransaction = register(
LastTransaction)
AskCheckCurrentSerial, AnswerCheckCurrentSerial = register(
CheckCurrentSerial)
NotifyTransactionFinished = register(
NotifyTransactionFinished)
Replicate = register(
Replicate)
NotifyReplicationDone = register(
ReplicationDone)
AskFetchTransactions, AnswerFetchTransactions = register(
FetchTransactions)
AskFetchObjects, AnswerFetchObjects = register(
FetchObjects)
AddTransaction = register(
AddTransaction)
AddObject = register(
AddObject)
Truncate = register(
Truncate)
FlushLog = register(
FlushLog)
def Errors(): def Errors():
registry_dict = {} registry_dict = {}
handler_method_name_dict = {} handler_method_name_dict = {}
Error = Packets.Error
def register_error(code): def register_error(code):
return lambda self, message='': Error(code, message) return lambda self, message='': Error(code, message)
for error in ErrorCodes: for error in ErrorCodes:
...@@ -1856,19 +854,20 @@ from operator import itemgetter ...@@ -1856,19 +854,20 @@ from operator import itemgetter
def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)): def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)):
if node_list: if node_list:
node_list.sort(key=_sort_key)
node_list = [( node_list = [(
uuid_str(uuid), str(node_type), uuid_str(uuid), str(node_type),
('[%s]:%s' if ':' in addr[0] else '%s:%s') ('[%s]:%s' if ':' in addr[0] else '%s:%s')
% addr if addr else '', str(state), % addr if addr else '', str(state),
str(id_timestamp and datetime.utcfromtimestamp(id_timestamp))) str(id_timestamp and datetime.utcfromtimestamp(id_timestamp)))
for node_type, addr, uuid, state, id_timestamp in node_list] for node_type, addr, uuid, state, id_timestamp
in sorted(node_list, key=_sort_key)]
t = ''.join('%%-%us | ' % max(len(x[i]) for x in node_list) t = ''.join('%%-%us | ' % max(len(x[i]) for x in node_list)
for i in xrange(len(node_list[0]) - 1)) for i in xrange(len(node_list[0]) - 1))
return map((prefix + t + '%s').__mod__, node_list) return map((prefix + t + '%s').__mod__, node_list)
return () return ()
NotifyNodeInformation._neolog = staticmethod(lambda timestamp, node_list: Packets.NotifyNodeInformation._neolog = staticmethod(
lambda timestamp, node_list:
((timestamp,), formatNodeList(node_list, ' ! '))) ((timestamp,), formatNodeList(node_list, ' ! ')))
Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,))) Packets.Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,)))
...@@ -166,65 +166,6 @@ def parseMasterList(masters): ...@@ -166,65 +166,6 @@ def parseMasterList(masters):
return map(parseNodeAddress, masters.split()) return map(parseNodeAddress, masters.split())
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
TODO: For better performance, use:
- socket.recv_into (64kiB blocks)
- struct.unpack_from
- and a circular buffer of dynamic size (initial size:
twice the length passed to socket.recv_into ?)
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
self.size += len(data)
self.content.append(data)
def __len__(self):
""" Return the current buffer size """
return self.size
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
to_read = size
# select required chunks
while to_read > 0:
chunk_data = pop_chunk()
to_read -= len(chunk_data)
append_data(chunk_data)
if to_read < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:to_read], last_chunk[to_read:]
self.content.appendleft(let)
chunk_list[-1] = keep
# join all chunks (one copy)
data = ''.join(chunk_list)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
dummy_read_buffer = ReadBuffer()
dummy_read_buffer.append = lambda _: None
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
......
...@@ -585,7 +585,9 @@ class Application(BaseApplication): ...@@ -585,7 +585,9 @@ class Application(BaseApplication):
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def startStorage(self, node): def startStorage(self, node):
node.send(Packets.StartOperation(self.backup_tid)) # XXX: Is this boolean 'backup' field needed ?
# Maybe this can be deduced from cluster state.
node.send(Packets.StartOperation(bool(self.backup_tid)))
uuid = node.getUUID() uuid = node.getUUID()
assert uuid not in self.storage_starting_set assert uuid not in self.storage_starting_set
assert uuid not in self.storage_ready_dict assert uuid not in self.storage_ready_dict
......
...@@ -157,8 +157,30 @@ class Log(object): ...@@ -157,8 +157,30 @@ class Log(object):
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {} x = {}
try:
Unpacker = g['Unpacker']
except KeyError:
unpackb = None
else:
from msgpack import ExtraData, UnpackException
def unpackb(data):
u = Unpacker()
u.feed(data)
data = u.unpack()
if u.read_bytes(1):
raise ExtraData
return data
self.PacketMalformedError = UnpackException
self.unpackb = unpackb
if self._decode > 1: if self._decode > 1:
try:
PStruct = g['PStruct'] PStruct = g['PStruct']
except KeyError:
for p in self.Packets.itervalues():
data_path = getattr(p, 'data_path', (None,))
if p._code >> 15 == data_path[0]:
x[p._code] = data_path[1:]
else:
PBoolean = g['PBoolean'] PBoolean = g['PBoolean']
def hasData(item): def hasData(item):
items = item._items items = item._items
...@@ -215,10 +237,12 @@ class Log(object): ...@@ -215,10 +237,12 @@ class Log(object):
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode: if log or self._decode:
try:
if self.unpackb:
args = self.unpackb(body)
else:
p = p() p = p()
p._id = msg_id
p._body = body p._body = body
try:
args = p.decode() args = p.decode()
except self.PacketMalformedError: except self.PacketMalformedError:
msg.append("Can't decode packet") msg.append("Can't decode packet")
......
...@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager):
return r return r
def loadData(self, data_id): def loadData(self, data_id):
return self.query("SELECT compression, hash, value" compression, checksum, data = self.query(
" FROM data WHERE id=?", (data_id,)).fetchone() "SELECT compression, hash, value FROM data WHERE id=?",
(data_id,)).fetchone()
if checksum:
return compression, str(checksum), str(data)
return compression, checksum, data
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
......
...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid)) p = Errors.TidNotFound('%s does not exist' % dump(tid))
else: else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3], p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0]) bool(t[4]), t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self): def getEventQueue(self):
......
...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler): ...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler):
# Sending such packet does not mark the connection # Sending such packet does not mark the connection
# for writing if there's too little data in the buffer. # for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user, conn.send(Packets.AddTransaction(tid, user,
desc, ext, packed, ttid, oid_list), msg_id) desc, ext, bool(packed), ttid, oid_list), msg_id)
# To avoid delaying several connections simultaneously, # To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different # and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the # parts of the DB at the same time, we ask the
...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler): ...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
if type(oid_set) is list: if type(oid_set) is tuple:
object_dict[serial] = oid_set = set(oid_set) object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set: if oid in oid_set:
oid_set.remove(oid) oid_set.remove(oid)
......
...@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0] ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0] status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack) packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(packet.decode()[0]) self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase): ...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase):
def getFakePacket(self): def getFakePacket(self):
p = Mock({ p = Mock({
'decode': (),
'__repr__': 'Fake Packet', '__repr__': 'Fake Packet',
}) })
p._args = ()
p.handler_method_name = 'fake_method' p.handler_method_name = 'fake_method'
return p return p
...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase): ...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase):
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise PacketMalformedError
conn.mockCalledMethods = {}
def fake(c):
raise PacketMalformedError('message')
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkClosed(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c): def fake(c):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress from neo.lib.util import parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase): ...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase):
self.assertIn(parseNodeAddress('localhost'), local_address(0)) self.assertIn(parseNodeAddress('localhost'), local_address(0))
self.assertIn(parseNodeAddress('localhost:10'), local_address(10)) self.assertIn(parseNodeAddress('localhost:10'), local_address(10))
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1340,7 +1340,7 @@ class Test(NEOThreadedTest): ...@@ -1340,7 +1340,7 @@ class Test(NEOThreadedTest):
# Also check that the master reset the last oid to a correct value. # Also check that the master reset the last oid to a correct value.
t.begin() t.begin()
self.assertEqual(1, u64(c.root()['x']._p_oid)) self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list) self.assertFalse(cluster.client.new_oids)
self.assertEqual(2, u64(cluster.client.new_oid())) self.assertEqual(2, u64(cluster.client.new_oid()))
@with_cluster() @with_cluster()
...@@ -2106,7 +2106,7 @@ class Test(NEOThreadedTest): ...@@ -2106,7 +2106,7 @@ class Test(NEOThreadedTest):
except threading.ThreadError: except threading.ThreadError:
l[j].acquire() l[j].acquire()
threads[j-1].start() threads[j-1].start()
if x != 'StoreTransaction': if x != 'AskStoreTransaction':
try: try:
l[i].acquire() l[i].acquire()
except IndexError: except IndexError:
...@@ -2183,15 +2183,16 @@ class Test(NEOThreadedTest): ...@@ -2183,15 +2183,16 @@ class Test(NEOThreadedTest):
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes, x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1), (1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4, ('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4,
'StoreTransaction', 'RebaseTransaction', 'RebaseTransaction', 'AskStoreTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AskRebaseTransaction', 'AnswerRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'), 'AnswerRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction'),
[4, 6, 2, 6]) [4, 6, 2, 6])
try: try:
x[1].remove(1) x[1].remove(1)
except ValueError: except ValueError:
pass pass
self.assertEqual(x, {0: [2, 'StoreTransaction'], 1: ['tpc_abort']}) self.assertEqual(x, {0: [2, 'AskStoreTransaction'], 1: ['tpc_abort']})
def testCascadedDeadlockAvoidanceWithOneStorage2(self): def testCascadedDeadlockAvoidanceWithOneStorage2(self):
def changes(r1, r2, r3): def changes(r1, r2, r3):
...@@ -2214,8 +2215,8 @@ class Test(NEOThreadedTest): ...@@ -2214,8 +2215,8 @@ class Test(NEOThreadedTest):
(0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, (0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1,
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1), 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1),
('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin', ('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin',
2, 3, 4, 3, 4, 'StoreTransaction', 'RebaseTransaction', 2, 3, 4, 3, 4, 'AskStoreTransaction', 'AskRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'), 'AskRebaseTransaction', 'AnswerRebaseTransaction'),
[1, 7, 9, 0]) [1, 7, 9, 0])
x[0].sort(key=str) x[0].sort(key=str)
try: try:
...@@ -2224,8 +2225,8 @@ class Test(NEOThreadedTest): ...@@ -2224,8 +2225,8 @@ class Test(NEOThreadedTest):
pass pass
self.assertEqual(x, { self.assertEqual(x, {
0: [2, 3, 'AnswerRebaseTransaction', 0: [2, 3, 'AnswerRebaseTransaction',
'RebaseTransaction', 'StoreTransaction'], 'AskRebaseTransaction', 'AskStoreTransaction'],
1: ['AnswerRebaseTransaction','RebaseTransaction', 1: ['AnswerRebaseTransaction','AskRebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'], 'AnswerRebaseTransaction', 'tpc_abort'],
}) })
...@@ -2258,7 +2259,7 @@ class Test(NEOThreadedTest): ...@@ -2258,7 +2259,7 @@ class Test(NEOThreadedTest):
end = self._testComplexDeadlockAvoidanceWithOneStorage(changes, end = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, vote_t2, tic_t1), (0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, vote_t2, tic_t1),
('tpc_begin', 1) * 2, [3, 0, 0, 0], None) ('tpc_begin', 1) * 2, [3, 0, 0, 0], None)
self.assertLessEqual(2, end[0].count('RebaseTransaction')) self.assertLessEqual(2, end[0].count('AskRebaseTransaction'))
def testFailedConflictOnBigValueDuringDeadlockAvoidance(self): def testFailedConflictOnBigValueDuringDeadlockAvoidance(self):
def changes(r1, r2, r3): def changes(r1, r2, r3):
...@@ -2274,10 +2275,10 @@ class Test(NEOThreadedTest): ...@@ -2274,10 +2275,10 @@ class Test(NEOThreadedTest):
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes, x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0), (1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4, ('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4,
'StoreTransaction', 2, 4, 'RebaseTransaction', 'AskStoreTransaction', 2, 4, 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'), 'AnswerRebaseTransaction', 'tpc_abort'),
[5, 1, 0, 2], POSException.ConflictError) [5, 1, 0, 2], POSException.ConflictError)
self.assertEqual(x, {0: ['StoreTransaction']}) self.assertEqual(x, {0: ['AskStoreTransaction']})
@with_cluster(replicas=1, partitions=4) @with_cluster(replicas=1, partitions=4)
def testNotifyReplicated(self, cluster): def testNotifyReplicated(self, cluster):
...@@ -2364,7 +2365,7 @@ class Test(NEOThreadedTest): ...@@ -2364,7 +2365,7 @@ class Test(NEOThreadedTest):
def delayConflict(conn, packet): def delayConflict(conn, packet):
app = self.getConnectionApp(conn) app = self.getConnectionApp(conn)
if (isinstance(packet, Packets.AnswerStoreObject) if (isinstance(packet, Packets.AnswerStoreObject)
and packet.decode()[0]): and packet._args[0]):
conn, = cluster.client.getConnectionList(app) conn, = cluster.client.getConnectionList(app)
kw = conn._handlers._pending[0][0][packet._id][1] kw = conn._handlers._pending[0][0][packet._id][1]
return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop() return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop()
...@@ -2382,8 +2383,9 @@ class Test(NEOThreadedTest): ...@@ -2382,8 +2383,9 @@ class Test(NEOThreadedTest):
self.thread_switcher(threads, self.thread_switcher(threads,
(1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0, (1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0,
t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2), t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin', 2, 1, 1, ('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin',
3, 3, 4, 4, 3, 1, 'RebaseTransaction', 'RebaseTransaction', 2, 1, 1, 3, 3, 4, 4, 3, 1,
'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2
)) as end: )) as end:
delay = f.delayAskFetchTransactions() delay = f.delayAskFetchTransactions()
...@@ -2395,11 +2397,11 @@ class Test(NEOThreadedTest): ...@@ -2395,11 +2397,11 @@ class Test(NEOThreadedTest):
t4.begin() t4.begin()
self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd']) self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd'])
self.assertEqual([2, 2], map(end.pop(2).count, self.assertEqual([2, 2], map(end.pop(2).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['AskRebaseTransaction', 'AnswerRebaseTransaction']))
self.assertEqual(end, { self.assertEqual(end, {
0: [1, 'StoreTransaction'], 0: [1, 'AskStoreTransaction'],
1: ['StoreTransaction'], 1: ['AskStoreTransaction'],
3: [4, 'StoreTransaction'], 3: [4, 'AskStoreTransaction'],
}) })
self.assertFalse(s1.dm.getOrphanList()) self.assertFalse(s1.dm.getOrphanList())
...@@ -2435,7 +2437,8 @@ class Test(NEOThreadedTest): ...@@ -2435,7 +2437,8 @@ class Test(NEOThreadedTest):
self.thread_switcher((thread,), self.thread_switcher((thread,),
(1, 0, 1, 1, t2_b, 0, 0, 1, t2_vote, 0, 0), (1, 0, 1, 1, t2_b, 0, 0, 1, t2_vote, 0, 0),
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, ('tpc_begin', 'tpc_begin', 1, 1, 2, 2,
'RebaseTransaction', 'RebaseTransaction', 'StoreTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AskStoreTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
)) as end: )) as end:
delay = f.delayAskFetchTransactions() delay = f.delayAskFetchTransactions()
...@@ -2498,9 +2501,10 @@ class Test(NEOThreadedTest): ...@@ -2498,9 +2501,10 @@ class Test(NEOThreadedTest):
with self.thread_switcher((commit23,), with self.thread_switcher((commit23,),
(1, 1, 0, 0, t1_rebase, 0, 0, 0, 1, 1, 1, 1, 0), (1, 1, 0, 0, t1_rebase, 0, 0, 0, 1, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 0, 1, 0, ('tpc_begin', 'tpc_begin', 0, 1, 0,
'RebaseTransaction', 'RebaseTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction', 'tpc_begin', 1, 'tpc_abort')) as end: 'AskStoreTransaction', 'tpc_begin', 1, 'tpc_abort',
)) as end:
self.assertRaises(POSException.ConflictError, t1.commit) self.assertRaises(POSException.ConflictError, t1.commit)
commit23.join() commit23.join()
self.assertEqual(end, {0: ['tpc_abort']}) self.assertEqual(end, {0: ['tpc_abort']})
...@@ -2587,9 +2591,9 @@ class Test(NEOThreadedTest): ...@@ -2587,9 +2591,9 @@ class Test(NEOThreadedTest):
self.thread_switcher((commit2,), self.thread_switcher((commit2,),
(1, 1, 0, 0, t1_b, t1_resolve, 0, 0, 0, 0, 1, t2_vote, t1_end), (1, 1, 0, 0, t1_b, t1_resolve, 0, 0, 0, 0, 1, t2_vote, t1_end),
('tpc_begin', 'tpc_begin', 2, 1, 2, 1, 1, ('tpc_begin', 'tpc_begin', 2, 1, 2, 1, 1,
'RebaseTransaction', 'RebaseTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end: 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
commit2.join() commit2.join()
t1.begin() t1.begin()
...@@ -2597,7 +2601,7 @@ class Test(NEOThreadedTest): ...@@ -2597,7 +2601,7 @@ class Test(NEOThreadedTest):
self.assertEqual(r['a'].value, 9) self.assertEqual(r['a'].value, 9)
self.assertEqual(r['b'].value, 6) self.assertEqual(r['b'].value, 6)
t1 = end.pop(0) t1 = end.pop(0)
self.assertEqual(t1.pop(), 'StoreTransaction') self.assertEqual(t1.pop(), 'AskStoreTransaction')
self.assertEqual(sorted(t1), [1, 2]) self.assertEqual(sorted(t1), [1, 2])
self.assertFalse(end) self.assertFalse(end)
self.assertPartitionTable(cluster, 'UU|UU') self.assertPartitionTable(cluster, 'UU|UU')
...@@ -2699,9 +2703,9 @@ class Test(NEOThreadedTest): ...@@ -2699,9 +2703,9 @@ class Test(NEOThreadedTest):
with Patch(cluster.client, _loadFromStorage=load) as p, \ with Patch(cluster.client, _loadFromStorage=load) as p, \
self.thread_switcher((commit2,), self.thread_switcher((commit2,),
(1, 0, tic1, 0, t1_resolve, 1, t2_begin, 0, 1, 1, 0), (1, 0, tic1, 0, t1_resolve, 1, t2_begin, 0, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 1, 1, 'StoreTransaction', ('tpc_begin', 'tpc_begin', 1, 1, 1, 'AskStoreTransaction',
'tpc_begin', 'RebaseTransaction', 'RebaseTransaction', 1, 'tpc_begin', 'AskRebaseTransaction', 'AskRebaseTransaction',
'StoreTransaction')) as end: 1, 'AskStoreTransaction')) as end:
self.assertRaisesRegexp(NEOStorageError, self.assertRaisesRegexp(NEOStorageError,
'^partition 0 not fully write-locked$', '^partition 0 not fully write-locked$',
t1.commit) t1.commit)
...@@ -2754,13 +2758,14 @@ class Test(NEOThreadedTest): ...@@ -2754,13 +2758,14 @@ class Test(NEOThreadedTest):
f.remove(delayFinish) f.remove(delayFinish)
with self.thread_switcher((commit2,), with self.thread_switcher((commit2,),
(1, 0, 0, 1, t2_b, 0, t1_resolve), (1, 0, 0, 1, t2_b, 0, t1_resolve),
('tpc_begin', 'tpc_begin', 0, 2, 2, 'StoreTransaction')) as end: ('tpc_begin', 'tpc_begin', 0, 2, 2, 'AskStoreTransaction')
) as end:
t1.commit() t1.commit()
commit2.join() commit2.join()
t1.begin() t1.begin()
self.assertEqual(c1.root()['b'].value, 6) self.assertEqual(c1.root()['b'].value, 6)
self.assertPartitionTable(cluster, 'UU|UU') self.assertPartitionTable(cluster, 'UU|UU')
self.assertEqual(end, {0: [2, 2, 'StoreTransaction']}) self.assertEqual(end, {0: [2, 2, 'AskStoreTransaction']})
self.assertFalse(s1.dm.getOrphanList()) self.assertFalse(s1.dm.getOrphanList())
@with_cluster(storage_count=2, partitions=2) @with_cluster(storage_count=2, partitions=2)
...@@ -2783,19 +2788,19 @@ class Test(NEOThreadedTest): ...@@ -2783,19 +2788,19 @@ class Test(NEOThreadedTest):
yield 1 yield 1
self.tic() self.tic()
with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1), with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'RebaseTransaction', ('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'AskRebaseTransaction',
2, 'AnswerRebaseTransaction')) as end: 2, 'AnswerRebaseTransaction')) as end:
t1.commit() t1.commit()
t.join() t.join()
t2.begin() t2.begin()
self.assertEqual([6, 9, 6], [r[x].value for x in 'abc']) self.assertEqual([6, 9, 6], [r[x].value for x in 'abc'])
self.assertEqual([2, 2], map(end.pop(1).count, self.assertEqual([2, 2], map(end.pop(1).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['AskRebaseTransaction', 'AnswerRebaseTransaction']))
# Rarely, there's an extra deadlock for t1: # Rarely, there's an extra deadlock for t1:
# 0: ['AnswerRebaseTransaction', 'RebaseTransaction', # 0: ['AnswerRebaseTransaction', 'AskRebaseTransaction',
# 'RebaseTransaction', 'AnswerRebaseTransaction', # 'AskRebaseTransaction', 'AnswerRebaseTransaction',
# 'AnswerRebaseTransaction', 2, 3, 1, # 'AnswerRebaseTransaction', 2, 3, 1,
# 'StoreTransaction', 'VoteTransaction'] # 'AskStoreTransaction', 'VoteTransaction']
self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction') self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction')
self.assertFalse(end) self.assertFalse(end)
...@@ -2825,13 +2830,13 @@ class Test(NEOThreadedTest): ...@@ -2825,13 +2830,13 @@ class Test(NEOThreadedTest):
threads = map(self.newPausedThread, (t2.commit, t3.commit)) threads = map(self.newPausedThread, (t2.commit, t3.commit))
with self.thread_switcher(threads, (1, 2, 0, 1, 2, 1, 0, 2, 0, 1, 2), with self.thread_switcher(threads, (1, 2, 0, 1, 2, 1, 0, 2, 0, 1, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 1, 2, 3, 4, 4, 4, ('tpc_begin', 'tpc_begin', 'tpc_begin', 1, 2, 3, 4, 4, 4,
'RebaseTransaction', 'StoreTransaction')) as end: 'AskRebaseTransaction', 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
for t in threads: for t in threads:
t.join() t.join()
self.assertEqual(end, { self.assertEqual(end, {
0: ['AnswerRebaseTransaction', 'StoreTransaction'], 0: ['AnswerRebaseTransaction', 'AskStoreTransaction'],
2: ['StoreTransaction']}) 2: ['AskStoreTransaction']})
@with_cluster(replicas=1) @with_cluster(replicas=1)
def testConflictAfterDeadlockWithSlowReplica1(self, cluster, def testConflictAfterDeadlockWithSlowReplica1(self, cluster,
...@@ -2874,16 +2879,16 @@ class Test(NEOThreadedTest): ...@@ -2874,16 +2879,16 @@ class Test(NEOThreadedTest):
order[-1] = t1_resolve order[-1] = t1_resolve
delay = f.delayAskStoreObject() delay = f.delayAskStoreObject()
with self.thread_switcher((t,), order, with self.thread_switcher((t,), order,
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'RebaseTransaction', ('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'AskRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction', 'AskRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end: 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
t.join() t.join()
self.assertNotIn(delay, f) self.assertNotIn(delay, f)
t2.begin() t2.begin()
end[0].sort(key=str) end[0].sort(key=str)
self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction', self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction',
'StoreTransaction']}) 'AskStoreTransaction']})
self.assertEqual([4, 2], [r[x].value for x in 'ab']) self.assertEqual([4, 2], [r[x].value for x in 'ab'])
def testConflictAfterDeadlockWithSlowReplica2(self): def testConflictAfterDeadlockWithSlowReplica2(self):
...@@ -2934,7 +2939,7 @@ class Test(NEOThreadedTest): ...@@ -2934,7 +2939,7 @@ class Test(NEOThreadedTest):
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification) isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.STORAGE) and packet._args[0] == NodeTypes.STORAGE)
self.tic() self.tic()
m2.start() m2.start()
self.tic() self.tic()
...@@ -2974,7 +2979,7 @@ class Test(NEOThreadedTest): ...@@ -2974,7 +2979,7 @@ class Test(NEOThreadedTest):
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification) isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.MASTER) and packet._args[0] == NodeTypes.MASTER)
cluster.start(recovering=True) cluster.start(recovering=True)
neoctl = cluster.neoctl neoctl = cluster.neoctl
getClusterState = neoctl.getClusterState getClusterState = neoctl.getClusterState
......
...@@ -113,7 +113,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -113,7 +113,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(3) importZODB(3)
def delaySecondary(conn, packet): def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate): if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode() tid, upstream_name, source_dict = packet._args
return not upstream_name and all(source_dict.itervalues()) return not upstream_name and all(source_dict.itervalues())
# U -> B propagation # U -> B propagation
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5, with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
...@@ -513,7 +513,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -513,7 +513,7 @@ class ReplicationTests(NEOThreadedTest):
""" """
def delayAskFetch(conn, packet): def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \ packet._args[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, num_replicas, cell_list): def changePartitionTable(orig, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
...@@ -768,7 +768,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -768,7 +768,7 @@ class ReplicationTests(NEOThreadedTest):
def logReplication(conn, packet): def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions, if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)): Packets.AskFetchObjects)):
ask.append(packet.decode()[2:]) ask.append(packet._args[2:])
def getTIDList(): def getTIDList():
return [t.tid for t in c.db().storage.iterator()] return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
...@@ -869,7 +869,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -869,7 +869,7 @@ class ReplicationTests(NEOThreadedTest):
return True return True
elif not isinstance(packet, Packets.AskFetchTransactions): elif not isinstance(packet, Packets.AskFetchTransactions):
return return
ask.append(packet.decode()) ask.append(packet._args)
conn, = upstream.master.getConnectionList(backup.master) conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator, with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset): _nextPartitionSortKey=lambda orig, self, offset: offset):
...@@ -930,11 +930,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -930,11 +930,11 @@ class ReplicationTests(NEOThreadedTest):
@f.add @f.add
def delayReplicate(conn, packet): def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions): if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2]) trans.append(packet._args[2])
elif isinstance(packet, Packets.AskFetchObjects): elif isinstance(packet, Packets.AskFetchObjects):
if obj: if obj:
return True return True
obj.append(packet.decode()[2]) obj.append(packet._args[2])
s2.start() s2.start()
self.tic() self.tic()
cluster.neoctl.enableStorageList([s2.uuid]) cluster.neoctl.enableStorageList([s2.uuid])
...@@ -1021,7 +1021,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1021,7 +1021,7 @@ class ReplicationTests(NEOThreadedTest):
def expected(changed): def expected(changed):
s0 = 1, CellStates.UP_TO_DATE s0 = 1, CellStates.UP_TO_DATE
s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE
return changed, 3 * [[s0, (2, s)], [s0, (3, s)]] return changed, 3 * ((s0, (2, s)), (s0, (3, s)))
for dry_run in True, False: for dry_run in True, False:
self.assertEqual(expected(True), self.assertEqual(expected(True),
cluster.neoctl.tweakPartitionTable(drop_list, dry_run)) cluster.neoctl.tweakPartitionTable(drop_list, dry_run))
......
...@@ -53,7 +53,7 @@ extras_require = { ...@@ -53,7 +53,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'], 'storage-importer': zodb_require + ['setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
...@@ -109,6 +109,7 @@ setup( ...@@ -109,6 +109,7 @@ setup(
], ],
}, },
install_requires = [ install_requires = [
'msgpack>=0.5.6',
'python-dateutil', # neolog --from 'python-dateutil', # neolog --from
], ],
extras_require = extras_require, extras_require = extras_require,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment