Commit 76214fbb authored by Vincent Pelletier's avatar Vincent Pelletier

Shorten 2PC lock to cover only tpc_finish.

This allows parallel execution of tpc_begin, stores & related conflict
resolution and tpc_vote for different transactions.
This requires an extension to ZODB allowing to keep TID secret until
tpc_finish (ie, so that it doesn't require tpc_vote to return tid for each
stored object).

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2534 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 76a6eef2
...@@ -882,8 +882,6 @@ class Application(object): ...@@ -882,8 +882,6 @@ class Application(object):
raise NEOStorageError('tpc_store failed') raise NEOStorageError('tpc_store failed')
elif oid in resolved_oid_set: elif oid in resolved_oid_set:
append((oid, ResolvedSerial)) append((oid, ResolvedSerial))
else:
append((oid, tid))
return result return result
@profiler_decorator @profiler_decorator
...@@ -975,16 +973,17 @@ class Application(object): ...@@ -975,16 +973,17 @@ class Application(object):
self.tpc_vote(transaction, tryToResolveConflict) self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire() self._load_lock_acquire()
try: try:
# Call finish on master
oid_list = local_var.data_list
p = Packets.AskFinishTransaction(local_var.tid, oid_list)
self._askPrimary(p)
# From now on, self.local_var.tid holds the "real" TID.
tid = local_var.tid tid = local_var.tid
# Call function given by ZODB # Call function given by ZODB
if f is not None: if f is not None:
f(tid) f(tid)
# Call finish on master
oid_list = local_var.data_list
p = Packets.AskFinishTransaction(tid, oid_list)
self._askPrimary(p)
# Update cache # Update cache
self._cache_lock_acquire() self._cache_lock_acquire()
try: try:
...@@ -1280,26 +1279,20 @@ class Application(object): ...@@ -1280,26 +1279,20 @@ class Application(object):
@profiler_decorator @profiler_decorator
def importFrom(self, source, start, stop, tryToResolveConflict): def importFrom(self, source, start, stop, tryToResolveConflict):
serials = {} serials = {}
def updateLastSerial(oid, result):
if result:
if isinstance(result, str):
assert oid is not None
serials[oid] = result
else:
for oid, serial in result:
assert isinstance(serial, str), serial
serials[oid] = serial
transaction_iter = source.iterator(start, stop) transaction_iter = source.iterator(start, stop)
for transaction in transaction_iter: for transaction in transaction_iter:
self.tpc_begin(transaction, transaction.tid, transaction.status) tid = transaction.tid
self.tpc_begin(transaction, tid, transaction.status)
for r in transaction: for r in transaction:
pre = serials.get(r.oid, None) oid = r.oid
pre = serials.get(oid, None)
# TODO: bypass conflict resolution, locks... # TODO: bypass conflict resolution, locks...
result = self.store(r.oid, pre, r.data, r.version, transaction) self.store(oid, pre, r.data, r.version, transaction)
updateLastSerial(r.oid, result) serials[oid] = tid
updateLastSerial(None, self.tpc_vote(transaction, conflicted = self.tpc_vote(transaction, tryToResolveConflict)
tryToResolveConflict)) assert not conflicted, conflicted
self.tpc_finish(transaction, tryToResolveConflict) real_tid = self.tpc_finish(transaction, tryToResolveConflict)
assert real_tid == tid, (real_tid, tid)
transaction_iter.close() transaction_iter.close()
def iterator(self, start=None, stop=None): def iterator(self, start=None, stop=None):
......
...@@ -163,9 +163,10 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -163,9 +163,10 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
self.app.new_oid_list = oid_list self.app.new_oid_list = oid_list
def answerTransactionFinished(self, conn, tid): def answerTransactionFinished(self, conn, ttid, tid):
if tid != self.app.getTID(): if ttid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started') raise NEOStorageError('Wrong TID, transaction not started')
self.app.setTID(tid)
def answerPack(self, conn, status): def answerPack(self, conn, status):
if not status: if not status:
......
...@@ -211,10 +211,10 @@ class EventHandler(object): ...@@ -211,10 +211,10 @@ class EventHandler(object):
def askFinishTransaction(self, conn, tid, oid_list): def askFinishTransaction(self, conn, tid, oid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerTransactionFinished(self, conn, tid): def answerTransactionFinished(self, conn, ttid, tid):
raise UnexpectedPacketError raise UnexpectedPacketError
def askLockInformation(self, conn, tid, oid_list): def askLockInformation(self, conn, ttid, tid, oid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerInformationLocked(self, conn, tid): def answerInformationLocked(self, conn, tid):
......
...@@ -61,13 +61,13 @@ class ClientServiceHandler(MasterHandler): ...@@ -61,13 +61,13 @@ class ClientServiceHandler(MasterHandler):
conn.answer(Packets.AnswerNewOIDs(app.tm.getNextOIDList(num_oids))) conn.answer(Packets.AnswerNewOIDs(app.tm.getNextOIDList(num_oids)))
app.broadcastLastOID() app.broadcastLastOID()
def askFinishTransaction(self, conn, tid, oid_list): def askFinishTransaction(self, conn, ttid, oid_list):
app = self.app app = self.app
# Collect partitions related to this transaction. # Collect partitions related to this transaction.
getPartition = app.pt.getPartition getPartition = app.pt.getPartition
partition_set = set() partition_set = set()
partition_set.add(getPartition(tid)) partition_set.add(getPartition(ttid))
partition_set.update((getPartition(oid) for oid in oid_list)) partition_set.update((getPartition(oid) for oid in oid_list))
# Collect the UUIDs of nodes related to this transaction. # Collect the UUIDs of nodes related to this transaction.
...@@ -79,6 +79,22 @@ class ClientServiceHandler(MasterHandler): ...@@ -79,6 +79,22 @@ class ClientServiceHandler(MasterHandler):
if cell.getNodeState() != NodeStates.HIDDEN) if cell.getNodeState() != NodeStates.HIDDEN)
if isStorageReady(uuid))) if isStorageReady(uuid)))
if not uuid_set:
raise ProtocolError('No storage node ready for transaction')
identified_node_list = app.nm.getIdentifiedList(pool_set=uuid_set)
usable_uuid_set = set((x.getUUID() for x in identified_node_list))
partitions = app.pt.getPartitions()
peer_id = conn.getPeerId()
node = app.nm.getByUUID(conn.getUUID())
try:
tid = app.tm.prepare(node, ttid, partitions, oid_list,
usable_uuid_set, peer_id)
except DelayedError:
app.queueEvent(self.askFinishTransaction, conn, ttid,
oid_list)
return
# check if greater and foreign OID was stored # check if greater and foreign OID was stored
if app.tm.updateLastOID(oid_list): if app.tm.updateLastOID(oid_list):
app.broadcastLastOID() app.broadcastLastOID()
...@@ -86,14 +102,9 @@ class ClientServiceHandler(MasterHandler): ...@@ -86,14 +102,9 @@ class ClientServiceHandler(MasterHandler):
# Request locking data. # Request locking data.
# build a new set as we may not send the message to all nodes as some # build a new set as we may not send the message to all nodes as some
# might be not reachable at that time # might be not reachable at that time
p = Packets.AskLockInformation(tid, oid_list) p = Packets.AskLockInformation(ttid, tid, oid_list)
used_uuid_set = set() for node in identified_node_list:
for node in app.nm.getIdentifiedList(pool_set=uuid_set):
node.ask(p, timeout=60) node.ask(p, timeout=60)
used_uuid_set.add(node.getUUID())
node = app.nm.getByUUID(conn.getUUID())
app.tm.prepare(node, tid, oid_list, used_uuid_set, conn.getPeerId())
def askPack(self, conn, tid): def askPack(self, conn, tid):
app = self.app app = self.app
......
...@@ -88,10 +88,12 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -88,10 +88,12 @@ class StorageServiceHandler(BaseServiceHandler):
app = self.app app = self.app
tm = app.tm tm = app.tm
t = tm[tid] t = tm[tid]
ttid = t.getTTID()
nm = app.nm nm = app.nm
transaction_node = t.getNode() transaction_node = t.getNode()
invalidate_objects = Packets.InvalidateObjects(tid, t.getOIDList()) invalidate_objects = Packets.InvalidateObjects(tid, t.getOIDList())
answer_transaction_finished = Packets.AnswerTransactionFinished(tid) answer_transaction_finished = Packets.AnswerTransactionFinished(ttid,
tid)
for client_node in nm.getClientList(only_identified=True): for client_node in nm.getClientList(only_identified=True):
c = client_node.getConnection() c = client_node.getConnection()
if client_node is transaction_node: if client_node is transaction_node:
...@@ -100,7 +102,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -100,7 +102,7 @@ class StorageServiceHandler(BaseServiceHandler):
c.notify(invalidate_objects) c.notify(invalidate_objects)
# - Unlock Information to relevant storage nodes. # - Unlock Information to relevant storage nodes.
notify_unlock = Packets.NotifyUnlockInformation(tid) notify_unlock = Packets.NotifyUnlockInformation(ttid)
for storage_uuid in t.getUUIDList(): for storage_uuid in t.getUUIDList():
nm.getByUUID(storage_uuid).getConnection().notify(notify_unlock) nm.getByUUID(storage_uuid).getConnection().notify(notify_unlock)
......
...@@ -19,7 +19,7 @@ from time import time, gmtime ...@@ -19,7 +19,7 @@ from time import time, gmtime
from struct import pack, unpack from struct import pack, unpack
from neo.protocol import ZERO_TID from neo.protocol import ZERO_TID
from datetime import timedelta, datetime from datetime import timedelta, datetime
from neo.util import dump from neo.util import dump, u64, p64
import neo import neo
TID_LOW_OVERFLOW = 2**32 TID_LOW_OVERFLOW = 2**32
...@@ -92,11 +92,12 @@ class Transaction(object): ...@@ -92,11 +92,12 @@ class Transaction(object):
A pending transaction A pending transaction
""" """
def __init__(self, node, tid, oid_list, uuid_list, msg_id): def __init__(self, node, ttid, tid, oid_list, uuid_list, msg_id):
""" """
Prepare the transaction, set OIDs and UUIDs related to it Prepare the transaction, set OIDs and UUIDs related to it
""" """
self._node = node self._node = node
self._ttid = ttid
self._tid = tid self._tid = tid
self._oid_list = oid_list self._oid_list = oid_list
self._msg_id = msg_id self._msg_id = msg_id
...@@ -122,6 +123,12 @@ class Transaction(object): ...@@ -122,6 +123,12 @@ class Transaction(object):
""" """
return self._node return self._node
def getTTID(self):
"""
Return the temporary transaction ID.
"""
return self._ttid
def getTID(self): def getTID(self):
""" """
Return the transaction ID Return the transaction ID
...@@ -184,6 +191,8 @@ class TransactionManager(object): ...@@ -184,6 +191,8 @@ class TransactionManager(object):
# We don't need to use a real lock, as we are mono-threaded. # We don't need to use a real lock, as we are mono-threaded.
_locked = None _locked = None
_next_ttid = 0
def __init__(self): def __init__(self):
# tid -> transaction # tid -> transaction
self._tid_dict = {} self._tid_dict = {}
...@@ -232,8 +241,18 @@ class TransactionManager(object): ...@@ -232,8 +241,18 @@ class TransactionManager(object):
def getLastOID(self): def getLastOID(self):
return self._last_oid return self._last_oid
def _nextTID(self): def _nextTID(self, ttid, divisor):
""" Compute the next TID based on the current time and check collisions """ """
Compute the next TID based on the current time and check collisions.
Also, adjust it so that
tid % divisor == ttid % divisor
while preserving
min_tid < tid
When constraints allow, prefer decreasing generated TID, to avoid
fast-forwarding to future dates.
"""
assert isinstance(ttid, basestring), repr(ttid)
assert isinstance(divisor, (int, long)), repr(divisor)
tm = time() tm = time()
gmt = gmtime(tm) gmt = gmtime(tm)
tid = packTID(( tid = packTID((
...@@ -241,8 +260,28 @@ class TransactionManager(object): ...@@ -241,8 +260,28 @@ class TransactionManager(object):
gmt.tm_min), gmt.tm_min),
int((gmt.tm_sec % 60 + (tm - int(tm))) / SECOND_PER_TID_LOW) int((gmt.tm_sec % 60 + (tm - int(tm))) / SECOND_PER_TID_LOW)
)) ))
if tid <= self._last_tid: min_tid = self._last_tid
tid = addTID(self._last_tid, 1) if tid <= min_tid:
tid = addTID(min_tid, 1)
# We know we won't have room to adjust by decreasing.
try_decrease = False
else:
try_decrease = True
ref_remainder = u64(ttid) % divisor
remainder = u64(tid) % divisor
if ref_remainder != remainder:
if try_decrease:
new_tid = addTID(tid, ref_remainder - divisor - remainder)
assert u64(new_tid) % divisor == ref_remainder, (dump(new_tid),
ref_remainder)
if new_tid <= min_tid:
new_tid = addTID(new_tid, divisor)
else:
if ref_remainder > remainder:
ref_remainder += divisor
new_tid = addTID(tid, ref_remainder - remainder)
assert min_tid < new_tid, (dump(min_tid), dump(tid), dump(new_tid))
tid = new_tid
self._last_tid = tid self._last_tid = tid
return self._last_tid return self._last_tid
...@@ -258,6 +297,14 @@ class TransactionManager(object): ...@@ -258,6 +297,14 @@ class TransactionManager(object):
""" """
self._last_tid = max(self._last_tid, tid) self._last_tid = max(self._last_tid, tid)
def getTTID(self):
"""
Generate a temporary TID, to be used only during a single node's
2PC.
"""
self._next_ttid += 1
return p64(self._next_ttid)
def reset(self): def reset(self):
""" """
Discard all manager content Discard all manager content
...@@ -282,28 +329,47 @@ class TransactionManager(object): ...@@ -282,28 +329,47 @@ class TransactionManager(object):
""" """
Generate a new TID Generate a new TID
""" """
if self._locked is not None:
raise DelayedError()
if tid is None: if tid is None:
tid = self._nextTID() # No TID requested, generate a temporary one
self._locked = tid tid = self.getTTID()
else:
# TID requested, take commit lock immediately
if self._locked is not None:
raise DelayedError()
self._locked = tid
return tid return tid
def prepare(self, node, tid, oid_list, uuid_list, msg_id): def prepare(self, node, ttid, divisor, oid_list, uuid_list, msg_id):
""" """
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
locked = self._locked
if locked == ttid:
# Transaction requested some TID upon begin, and it owns the commit
# lock since then.
tid = ttid
else:
# Otherwise, acquire lock and allocate a new TID.
if locked is not None:
raise DelayedError()
tid = self._nextTID(ttid, divisor)
self._locked = tid
self.setLastTID(tid) self.setLastTID(tid)
txn = Transaction(node, tid, oid_list, uuid_list, msg_id) txn = Transaction(node, ttid, tid, oid_list, uuid_list, msg_id)
self._tid_dict[tid] = txn self._tid_dict[tid] = txn
self._node_dict.setdefault(node, {})[tid] = txn self._node_dict.setdefault(node, {})[tid] = txn
return tid
def remove(self, tid): def remove(self, tid):
""" """
Remove a transaction, commited or aborted Remove a transaction, commited or aborted
""" """
assert self._locked == tid, (self._locked, tid) if tid == self._locked:
self._locked = None # If TID has the lock, release it.
# It might legitimately not have the lock (ex: a transaction
# aborting, which didn't request a TID upon begin)
self._locked = None
tid_dict = self._tid_dict tid_dict = self._tid_dict
if tid in tid_dict: if tid in tid_dict:
# ...and tried to finish # ...and tried to finish
......
...@@ -789,30 +789,30 @@ class AnswerTransactionFinished(Packet): ...@@ -789,30 +789,30 @@ class AnswerTransactionFinished(Packet):
""" """
Answer when a transaction is finished. PM -> C. Answer when a transaction is finished. PM -> C.
""" """
def _encode(self, tid): def _encode(self, ttid, tid):
return _encodeTID(tid) return _encodeTID(ttid) + _encodeTID(tid)
def _decode(self, body): def _decode(self, body):
(tid, ) = unpack('8s', body) (ttid, tid) = unpack('8s8s', body)
return (_decodeTID(tid), ) return (_decodeTID(ttid), _decodeTID(tid))
class AskLockInformation(Packet): class AskLockInformation(Packet):
""" """
Lock information on a transaction. PM -> S. Lock information on a transaction. PM -> S.
""" """
# XXX: Identical to InvalidateObjects and AskFinishTransaction # XXX: Identical to InvalidateObjects and AskFinishTransaction
_header_format = '!8sL' _header_format = '!8s8sL'
_list_entry_format = '8s' _list_entry_format = '8s'
_list_entry_len = calcsize(_list_entry_format) _list_entry_len = calcsize(_list_entry_format)
def _encode(self, tid, oid_list): def _encode(self, ttid, tid, oid_list):
body = [pack(self._header_format, tid, len(oid_list))] body = [pack(self._header_format, ttid, tid, len(oid_list))]
body.extend(oid_list) body.extend(oid_list)
return ''.join(body) return ''.join(body)
def _decode(self, body): def _decode(self, body):
offset = self._header_len offset = self._header_len
(tid, n) = unpack(self._header_format, body[:offset]) (ttid, tid, n) = unpack(self._header_format, body[:offset])
oid_list = [] oid_list = []
list_entry_format = self._list_entry_format list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len list_entry_len = self._list_entry_len
...@@ -821,7 +821,7 @@ class AskLockInformation(Packet): ...@@ -821,7 +821,7 @@ class AskLockInformation(Packet):
oid = unpack(list_entry_format, body[offset:next_offset])[0] oid = unpack(list_entry_format, body[offset:next_offset])[0]
offset = next_offset offset = next_offset
oid_list.append(oid) oid_list.append(oid)
return (tid, oid_list) return (ttid, tid, oid_list)
class AnswerInformationLocked(Packet): class AnswerInformationLocked(Packet):
""" """
......
...@@ -53,10 +53,10 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -53,10 +53,10 @@ class MasterOperationHandler(BaseMasterHandler):
elif state == CellStates.OUT_OF_DATE: elif state == CellStates.OUT_OF_DATE:
app.replicator.addPartition(offset) app.replicator.addPartition(offset)
def askLockInformation(self, conn, tid, oid_list): def askLockInformation(self, conn, ttid, tid, oid_list):
if not tid in self.app.tm: if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction') raise ProtocolError('Unknown transaction')
self.app.tm.lock(tid, oid_list) self.app.tm.lock(ttid, tid, oid_list)
if not conn.isClosed(): if not conn.isClosed():
conn.answer(Packets.AnswerInformationLocked(tid)) conn.answer(Packets.AnswerInformationLocked(tid))
......
...@@ -44,10 +44,11 @@ class Transaction(object): ...@@ -44,10 +44,11 @@ class Transaction(object):
""" """
Container for a pending transaction Container for a pending transaction
""" """
_tid = None
def __init__(self, uuid, tid): def __init__(self, uuid, ttid):
self._uuid = uuid self._uuid = uuid
self._tid = tid self._ttid = ttid
self._object_dict = {} self._object_dict = {}
self._transaction = None self._transaction = None
self._locked = False self._locked = False
...@@ -55,8 +56,9 @@ class Transaction(object): ...@@ -55,8 +56,9 @@ class Transaction(object):
self._checked_set = set() self._checked_set = set()
def __repr__(self): def __repr__(self):
return "<%s(tid=%r, uuid=%r, locked=%r, age=%.2fs)> at %x" % ( return "<%s(ttid=%r, tid=%r, uuid=%r, locked=%r, age=%.2fs)> at %x" % (
self.__class__.__name__, self.__class__.__name__,
dump(self._ttid),
dump(self._tid), dump(self._tid),
dump(self._uuid), dump(self._uuid),
self.isLocked(), self.isLocked(),
...@@ -67,6 +69,14 @@ class Transaction(object): ...@@ -67,6 +69,14 @@ class Transaction(object):
def addCheckedObject(self, oid): def addCheckedObject(self, oid):
self._checked_set.add(oid) self._checked_set.add(oid)
def getTTID(self):
return self._ttid
def setTID(self, tid):
assert self._tid is None, dump(self._tid)
assert tid is not None
self._tid = tid
def getTID(self): def getTID(self):
return self._tid return self._tid
...@@ -158,20 +168,21 @@ class TransactionManager(object): ...@@ -158,20 +168,21 @@ class TransactionManager(object):
self._load_lock_dict.clear() self._load_lock_dict.clear()
self._uuid_dict.clear() self._uuid_dict.clear()
def lock(self, tid, oid_list): def lock(self, ttid, tid, oid_list):
""" """
Lock a transaction Lock a transaction
""" """
transaction = self._transaction_dict[tid] transaction = self._transaction_dict[ttid]
# remember that the transaction has been locked # remember that the transaction has been locked
transaction.lock() transaction.lock()
for oid in transaction.getOIDList(): for oid in transaction.getOIDList():
self._load_lock_dict[oid] = tid self._load_lock_dict[oid] = ttid
# check every object that should be locked # check every object that should be locked
uuid = transaction.getUUID() uuid = transaction.getUUID()
is_assigned = self._app.pt.isAssigned is_assigned = self._app.pt.isAssigned
for oid in oid_list: for oid in oid_list:
if is_assigned(oid, uuid) and self._load_lock_dict.get(oid) != tid: if is_assigned(oid, uuid) and \
self._load_lock_dict.get(oid) != ttid:
raise ValueError, 'Some locks are not held' raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList() object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on # txn_info is None is the transaction information is not stored on
...@@ -179,13 +190,18 @@ class TransactionManager(object): ...@@ -179,13 +190,18 @@ class TransactionManager(object):
txn_info = transaction.getTransactionInformations() txn_info = transaction.getTransactionInformations()
# store data from memory to temporary table # store data from memory to temporary table
self._app.dm.storeTransaction(tid, object_list, txn_info) self._app.dm.storeTransaction(tid, object_list, txn_info)
# ...and remember its definitive TID
transaction.setTID(tid)
def getTIDFromTTID(self, ttid):
return self._transaction_dict[ttid].getTID()
def unlock(self, tid): def unlock(self, ttid):
""" """
Unlock transaction Unlock transaction
""" """
self._app.dm.finishTransaction(tid) self._app.dm.finishTransaction(self.getTIDFromTTID(ttid))
self.abort(tid, even_if_locked=True) self.abort(ttid, even_if_locked=True)
def storeTransaction(self, tid, oid_list, user, desc, ext, packed): def storeTransaction(self, tid, oid_list, user, desc, ext, packed):
""" """
...@@ -283,8 +299,8 @@ class TransactionManager(object): ...@@ -283,8 +299,8 @@ class TransactionManager(object):
Abort any non-locked transaction of a node Abort any non-locked transaction of a node
""" """
# abort any non-locked transaction of this node # abort any non-locked transaction of this node
for tid in [x.getTID() for x in self._uuid_dict.get(uuid, [])]: for ttid in [x.getTTID() for x in self._uuid_dict.get(uuid, [])]:
self.abort(tid) self.abort(ttid)
# cleanup _uuid_dict if no transaction remains for this node # cleanup _uuid_dict if no transaction remains for this node
transaction_set = self._uuid_dict.get(uuid) transaction_set = self._uuid_dict.get(uuid)
if transaction_set is not None and not transaction_set: if transaction_set is not None and not transaction_set:
......
...@@ -705,7 +705,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -705,7 +705,7 @@ class ClientApplicationTests(NeoUnitTestBase):
def hook(tid): def hook(tid):
self.f_called = True self.f_called = True
self.f_called_with_tid = tid self.f_called_with_tid = tid
packet = Packets.AnswerTransactionFinished(INVALID_TID) packet = Packets.AnswerTransactionFinished(INVALID_TID, INVALID_TID)
packet.setId(0) packet.setId(0)
app.master_conn = Mock({ app.master_conn = Mock({
'getNextId': 1, 'getNextId': 1,
...@@ -722,8 +722,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -722,8 +722,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.local_var.txn_voted = True app.local_var.txn_voted = True
self.assertRaises(NEOStorageError, app.tpc_finish, txn, self.assertRaises(NEOStorageError, app.tpc_finish, txn,
dummy_tryToResolveConflict, hook) dummy_tryToResolveConflict, hook)
self.assertTrue(self.f_called) self.assertFalse(self.f_called)
self.assertEquals(self.f_called_with_tid, tid)
self.assertEqual(self.vote_params, None) self.assertEqual(self.vote_params, None)
self.checkAskFinishTransaction(app.master_conn) self.checkAskFinishTransaction(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn) self.checkDispatcherRegisterCalled(app, app.master_conn)
...@@ -732,7 +731,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -732,7 +731,7 @@ class ClientApplicationTests(NeoUnitTestBase):
self.f_called = False self.f_called = False
self.assertRaises(NEOStorageError, app.tpc_finish, txn, self.assertRaises(NEOStorageError, app.tpc_finish, txn,
dummy_tryToResolveConflict, hook) dummy_tryToResolveConflict, hook)
self.assertTrue(self.f_called) self.assertFalse(self.f_called)
self.assertTrue(self.vote_params[0] is txn) self.assertTrue(self.vote_params[0] is txn)
self.assertTrue(self.vote_params[1] is dummy_tryToResolveConflict) self.assertTrue(self.vote_params[1] is dummy_tryToResolveConflict)
app.tpc_vote = tpc_vote app.tpc_vote = tpc_vote
...@@ -741,14 +740,15 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -741,14 +740,15 @@ class ClientApplicationTests(NeoUnitTestBase):
# transaction is finished # transaction is finished
app = self.getApp() app = self.getApp()
tid = self.makeTID() tid = self.makeTID()
ttid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid app.local_var.txn, app.local_var.tid = txn, ttid
self.f_called = False self.f_called = False
self.f_called_with_tid = None self.f_called_with_tid = None
def hook(tid): def hook(tid):
self.f_called = True self.f_called = True
self.f_called_with_tid = tid self.f_called_with_tid = tid
packet = Packets.AnswerTransactionFinished(tid) packet = Packets.AnswerTransactionFinished(ttid, tid)
packet.setId(0) packet.setId(0)
app.master_conn = Mock({ app.master_conn = Mock({
'getNextId': 1, 'getNextId': 1,
......
...@@ -248,17 +248,18 @@ class MasterAnswersHandlerTests(MasterHandlerTests): ...@@ -248,17 +248,18 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
def test_answerTransactionFinished(self): def test_answerTransactionFinished(self):
conn = self.getConnection() conn = self.getConnection()
tid1 = self.getNextTID() ttid1 = self.getNextTID()
tid2 = self.getNextTID(tid1) ttid2 = self.getNextTID(ttid1)
tid2 = self.getNextTID(ttid2)
# wrong TID # wrong TID
self.app = Mock({'getTID': tid1}) self.app = Mock({'getTID': ttid1})
self.assertRaises(NEOStorageError, self.assertRaises(NEOStorageError,
self.handler.answerTransactionFinished, self.handler.answerTransactionFinished,
conn, tid2) conn, ttid2, tid2)
# matching TID # matching TID
app = Mock({'getTID': tid2}) app = Mock({'getTID': ttid2})
handler = PrimaryAnswersHandler(app=app) handler = PrimaryAnswersHandler(app=app)
handler.answerTransactionFinished(conn, tid2) handler.answerTransactionFinished(conn, ttid2, tid2)
def test_answerPack(self): def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False) self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
......
...@@ -65,13 +65,33 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -65,13 +65,33 @@ class MasterClientHandlerTests(NeoUnitTestBase):
# Tests # Tests
def test_07_askBeginTransaction(self): def test_07_askBeginTransaction(self):
tid1 = self.getNextTID()
tid2 = self.getNextTID()
service = self.service service = self.service
ltid = self.app.tm.getLastTID() tm_org = self.app.tm
self.app.tm = tm = Mock({
'begin': '\x00\x00\x00\x00\x00\x00\x00\x01',
})
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
service.askBeginTransaction(conn, None) service.askBeginTransaction(conn, None)
self.assertTrue(ltid < self.app.tm.getLastTID()) calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(None)
# Client asks for a TID
self.app.tm = tm_org
service.askBeginTransaction(conn, tid1)
# If asking again for a TID, call is queued
call_marker = []
def queueEvent(*args, **kw):
call_marker.append((args, kw))
self.app.queueEvent = queueEvent
service.askBeginTransaction(conn, tid2)
self.assertEqual(len(call_marker), 1)
args, kw = call_marker[0]
self.assertEqual(kw, {})
self.assertEqual(args, (service.askBeginTransaction, conn, tid2))
def test_08_askNewOIDs(self): def test_08_askNewOIDs(self):
service = self.service service = self.service
...@@ -96,11 +116,19 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -96,11 +116,19 @@ class MasterClientHandlerTests(NeoUnitTestBase):
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
storage_uuid = self.identifyToMasterNode() storage_uuid = self.identifyToMasterNode()
storage_conn = self.getFakeConnection(storage_uuid, self.storage_address) storage_conn = self.getFakeConnection(storage_uuid, self.storage_address)
storage2_uuid = self.identifyToMasterNode()
storage2_conn = self.getFakeConnection(storage2_uuid,
(self.storage_address[0], self.storage_address[1] + 1))
self.app.setStorageReady(storage2_uuid)
self.assertNotEquals(uuid, client_uuid) self.assertNotEquals(uuid, client_uuid)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.pt = Mock({ self.app.pt = Mock({
'getPartition': 0, 'getPartition': 0,
'getCellList': [Mock({'getUUID': storage_uuid})], 'getCellList': [
Mock({'getUUID': storage_uuid}),
Mock({'getUUID': storage2_uuid}),
],
'getPartitions': 2,
}) })
service.askBeginTransaction(conn, None) service.askBeginTransaction(conn, None)
oid_list = [] oid_list = []
...@@ -111,6 +139,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -111,6 +139,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.assertFalse(self.app.isStorageReady(storage_uuid)) self.assertFalse(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, tid, oid_list) service.askFinishTransaction(conn, tid, oid_list)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
self.app.tm.abortFor(self.app.nm.getByUUID(client_uuid))
# ...but AskLockInformation is sent if it is ready # ...but AskLockInformation is sent if it is ready
self.app.setStorageReady(storage_uuid) self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid)) self.assertTrue(self.app.isStorageReady(storage_uuid))
...@@ -118,8 +147,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -118,8 +147,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.checkAskLockInformation(storage_conn) self.checkAskLockInformation(storage_conn)
self.assertEquals(len(self.app.tm.getPendingList()), 1) self.assertEquals(len(self.app.tm.getPendingList()), 1)
apptid = self.app.tm.getPendingList()[0] apptid = self.app.tm.getPendingList()[0]
self.assertEquals(tid, apptid) txn = self.app.tm[apptid]
txn = self.app.tm[tid]
self.assertEquals(len(txn.getOIDList()), 0) self.assertEquals(len(txn.getOIDList()), 0)
self.assertEquals(len(txn.getUUIDList()), 1) self.assertEquals(len(txn.getUUIDList()), 1)
......
...@@ -101,8 +101,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -101,8 +101,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
oid_list = self.getOID(), self.getOID() oid_list = self.getOID(), self.getOID()
msg_id = 1 msg_id = 1
# register a transaction # register a transaction
tid = self.app.tm.begin() ttid = self.app.tm.begin()
self.app.tm.prepare(client_1, tid, oid_list, uuid_list, msg_id) tid = self.app.tm.prepare(client_1, ttid, 1, oid_list, uuid_list,
msg_id)
self.assertTrue(tid in self.app.tm) self.assertTrue(tid in self.app.tm)
# the first storage acknowledge the lock # the first storage acknowledge the lock
self.service.answerInformationLocked(storage_conn_1, tid) self.service.answerInformationLocked(storage_conn_1, tid)
...@@ -145,18 +146,13 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -145,18 +146,13 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# create some transaction # create some transaction
node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port) port=self.client_port)
def create_transaction(index): self.app.tm.prepare(node, self.getNextTID(), 1,
tid = self.getNextTID() [self.getOID(1)], [node.getUUID()], 1)
oid_list = [self.getOID(index)]
self.app.tm.prepare(node, tid, oid_list, [node.getUUID()], index)
create_transaction(1)
create_transaction(2)
create_transaction(3)
conn = self.getFakeConnection(node.getUUID(), self.storage_address) conn = self.getFakeConnection(node.getUUID(), self.storage_address)
service.askUnfinishedTransactions(conn) service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn) packet = self.checkAnswerUnfinishedTransactions(conn)
(tid_list, ) = packet.decode() (tid_list, ) = packet.decode()
self.assertEqual(len(tid_list), 3) self.assertEqual(len(tid_list), 1)
def _testWithMethod(self, method, state): def _testWithMethod(self, method, state):
# define two nodes # define two nodes
...@@ -212,10 +208,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -212,10 +208,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
tid1 = tm.begin() ttid1 = tm.begin()
tm.prepare(client1, tid1, oid_list, tid1 = tm.prepare(client1, ttid1, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_1) [node1.getUUID(), node2.getUUID()], msg_id_1)
tm.lock(tid1, node2.getUUID()) tm.lock(tid1, node2.getUUID())
self.checkNoPacketSent(cconn1)
# Storage 1 dies # Storage 1 dies
node1.setTemporarilyDown() node1.setTemporarilyDown()
self.service.nodeLost(conn1, node1) self.service.nodeLost(conn1, node1)
...@@ -229,8 +226,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -229,8 +226,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
tid2 = tm.begin() ttid2 = tm.begin()
tm.prepare(client2, tid2, oid_list, tid2 = tm.prepare(client2, ttid2, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_2) [node1.getUUID(), node2.getUUID()], msg_id_2)
# T2: pending locking answer, client keeps waiting # T2: pending locking answer, client keeps waiting
self.checkNoPacketSent(cconn2, check_notify=False) self.checkNoPacketSent(cconn2, check_notify=False)
...@@ -238,8 +235,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -238,8 +235,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
tid3 = tm.begin() ttid3 = tm.begin()
tm.prepare(client3, tid3, oid_list, tid3 = tm.prepare(client3, ttid3, 1, oid_list,
[node2.getUUID(), ], msg_id_3) [node2.getUUID(), ], msg_id_3)
# T3: action not significant to this transacion, so no response # T3: action not significant to this transacion, so no response
self.checkNoPacketSent(cconn3, check_notify=False) self.checkNoPacketSent(cconn3, check_notify=False)
......
...@@ -39,11 +39,12 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -39,11 +39,12 @@ class testTransactionManager(NeoUnitTestBase):
# test data # test data
node = Mock({'__repr__': 'Node'}) node = Mock({'__repr__': 'Node'})
tid = self.makeTID(1) tid = self.makeTID(1)
ttid = self.makeTID(2)
oid_list = (oid1, oid2) = [self.makeOID(1), self.makeOID(2)] oid_list = (oid1, oid2) = [self.makeOID(1), self.makeOID(2)]
uuid_list = (uuid1, uuid2) = [self.makeUUID(1), self.makeUUID(2)] uuid_list = (uuid1, uuid2) = [self.makeUUID(1), self.makeUUID(2)]
msg_id = 1 msg_id = 1
# create transaction object # create transaction object
txn = Transaction(node, tid, oid_list, uuid_list, msg_id) txn = Transaction(node, ttid, tid, oid_list, uuid_list, msg_id)
self.assertEqual(txn.getUUIDList(), uuid_list) self.assertEqual(txn.getUUIDList(), uuid_list)
self.assertEqual(txn.getOIDList(), oid_list) self.assertEqual(txn.getOIDList(), oid_list)
# lock nodes one by one # lock nodes one by one
...@@ -63,12 +64,12 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -63,12 +64,12 @@ class testTransactionManager(NeoUnitTestBase):
self.assertFalse(txnman.hasPending()) self.assertFalse(txnman.hasPending())
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
# begin the transaction # begin the transaction
tid = txnman.begin() ttid = txnman.begin()
self.assertTrue(tid is not None) self.assertTrue(ttid is not None)
self.assertFalse(txnman.hasPending()) self.assertFalse(txnman.hasPending())
self.assertEqual(len(txnman.getPendingList()), 0) self.assertEqual(len(txnman.getPendingList()), 0)
# prepare the transaction # prepare the transaction
txnman.prepare(node, tid, oid_list, uuid_list, msg_id) tid = txnman.prepare(node, ttid, 1, oid_list, uuid_list, msg_id)
self.assertTrue(txnman.hasPending()) self.assertTrue(txnman.hasPending())
self.assertEqual(txnman.getPendingList()[0], tid) self.assertEqual(txnman.getPendingList()[0], tid)
self.assertEqual(txnman[tid].getTID(), tid) self.assertEqual(txnman[tid].getTID(), tid)
...@@ -91,8 +92,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -91,8 +92,8 @@ class testTransactionManager(NeoUnitTestBase):
txnman = TransactionManager() txnman = TransactionManager()
# register 4 transactions made by two nodes # register 4 transactions made by two nodes
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
tid1 = txnman.begin() ttid1 = txnman.begin()
txnman.prepare(node1, tid1, oid_list, [storage_1_uuid], 1) tid1 = txnman.prepare(node1, ttid1, 1, oid_list, [storage_1_uuid], 1)
self.assertEqual(txnman.getPendingList(), [tid1]) self.assertEqual(txnman.getPendingList(), [tid1])
# abort transactions of another node, transaction stays # abort transactions of another node, transaction stays
txnman.abortFor(node2) txnman.abortFor(node2)
...@@ -101,8 +102,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -101,8 +102,8 @@ class testTransactionManager(NeoUnitTestBase):
txnman.abortFor(node1) txnman.abortFor(node1)
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
self.assertFalse(txnman.hasPending()) self.assertFalse(txnman.hasPending())
# ...and we can start another transaction # ...and the lock is available
tid2 = txnman.begin() txnman.begin(self.getNextTID())
def test_getNextOIDList(self): def test_getNextOIDList(self):
txnman = TransactionManager() txnman = TransactionManager()
...@@ -117,29 +118,6 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -117,29 +118,6 @@ class testTransactionManager(NeoUnitTestBase):
for i, oid in zip(xrange(len(oid_list)), oid_list): for i, oid in zip(xrange(len(oid_list)), oid_list):
self.assertEqual(oid, self.getOID(i+2)) self.assertEqual(oid, self.getOID(i+2))
def test_getNextTID(self):
txnman = TransactionManager()
# no previous TID
self.assertEqual(txnman.getLastTID(), ZERO_TID)
# first transaction
node1 = Mock({'__hash__': 1})
tid1 = txnman.begin()
self.assertTrue(tid1 is not None)
self.assertEqual(txnman.getLastTID(), tid1)
# set a new last TID
ntid = pack('!Q', unpack('!Q', tid1)[0] + 10)
txnman.setLastTID(ntid)
self.assertEqual(txnman.getLastTID(), ntid)
self.assertTrue(ntid > tid1)
# If a new TID is generated, DelayedError is raised
self.assertRaises(DelayedError, txnman.begin)
txnman.remove(tid1)
# new trancation
node2 = Mock({'__hash__': 2})
tid2 = txnman.begin()
self.assertTrue(tid2 is not None)
self.assertTrue(tid2 > ntid > tid1)
def test_forget(self): def test_forget(self):
client1 = Mock({'__hash__': 1}) client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2}) client2 = Mock({'__hash__': 2})
...@@ -152,8 +130,9 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -152,8 +130,9 @@ class testTransactionManager(NeoUnitTestBase):
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
tid1 = tm.begin() ttid1 = tm.begin()
tm.prepare(client1, tid1, oid_list, [storage_1_uuid, storage_2_uuid], msg_id_1) tid1 = tm.prepare(client1, ttid1, 1, oid_list,
[storage_1_uuid, storage_2_uuid], msg_id_1)
tm.lock(tid1, storage_2_uuid) tm.lock(tid1, storage_2_uuid)
t1 = tm[tid1] t1 = tm[tid1]
self.assertFalse(t1.locked()) self.assertFalse(t1.locked())
...@@ -165,8 +144,9 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -165,8 +144,9 @@ class testTransactionManager(NeoUnitTestBase):
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
tid2 = tm.begin() ttid2 = tm.begin()
tm.prepare(client2, tid2, oid_list, [storage_1_uuid, storage_2_uuid], msg_id_2) tid2 = tm.prepare(client2, ttid2, 1, oid_list,
[storage_1_uuid, storage_2_uuid], msg_id_2)
t2 = tm[tid2] t2 = tm[tid2]
self.assertFalse(t2.locked()) self.assertFalse(t2.locked())
# Storage 1 dies: # Storage 1 dies:
...@@ -178,8 +158,9 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -178,8 +158,9 @@ class testTransactionManager(NeoUnitTestBase):
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
tid3 = tm.begin() ttid3 = tm.begin()
tm.prepare(client3, tid3, oid_list, [storage_2_uuid, ], msg_id_3) tid3 = tm.prepare(client3, ttid3, 1, oid_list, [storage_2_uuid, ],
msg_id_3)
t3 = tm[tid3] t3 = tm[tid3]
self.assertFalse(t3.locked()) self.assertFalse(t3.locked())
# Storage 1 dies: # Storage 1 dies:
...@@ -222,12 +203,21 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -222,12 +203,21 @@ class testTransactionManager(NeoUnitTestBase):
Note: this implementation might change later, to allow more paralelism. Note: this implementation might change later, to allow more paralelism.
""" """
tm = TransactionManager() tm = TransactionManager()
tid1 = tm.begin() # With a requested TID, lock spans from begin to remove
# Further calls fail with DelayedError ttid1 = self.getNextTID()
self.assertRaises(DelayedError, tm.begin) ttid2 = self.getNextTID()
# ...until tid1 gets removed tid1 = tm.begin(ttid1)
self.assertEqual(tid1, ttid1)
self.assertRaises(DelayedError, tm.begin, ttid2)
tm.remove(tid1) tm.remove(tid1)
tid2 = tm.begin() tm.remove(tm.begin(ttid2))
# Without a requested TID, lock spans from prepare to remove only
ttid3 = tm.begin()
ttid4 = tm.begin() # Doesn't raise
tid4 = tm.prepare(None, ttid4, 1, [], [], 0)
self.assertRaises(DelayedError, tm.prepare, None, ttid3, 1, [], [], 0)
tm.remove(tid4)
tm.prepare(None, ttid3, 1, [], [], 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -132,20 +132,22 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -132,20 +132,22 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
conn = self._getConnection() conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID() tid = self.getNextTID()
ttid = self.getNextTID()
handler = self.operation handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn, tid, self.assertRaises(ProtocolError, handler.askLockInformation, conn,
oid_list) ttid, tid, oid_list)
def test_askLockInformation2(self): def test_askLockInformation2(self):
""" Lock transaction """ """ Lock transaction """
self.app.tm = Mock({'__contains__': True}) self.app.tm = Mock({'__contains__': True})
conn = self._getConnection() conn = self._getConnection()
tid = self.getNextTID() tid = self.getNextTID()
ttid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, tid, oid_list) self.operation.askLockInformation(conn, ttid, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock') calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, oid_list) calls[0].checkArgs(ttid, tid, oid_list)
self.checkAnswerInformationLocked(conn) self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self): def test_notifyUnlockInformation1(self):
......
This diff is collapsed.
...@@ -278,18 +278,22 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -278,18 +278,22 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
def test_37_answerTransactionFinished(self): def test_37_answerTransactionFinished(self):
ttid = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = Packets.AnswerTransactionFinished(tid) p = Packets.AnswerTransactionFinished(ttid, tid)
ptid = p.decode()[0] pttid, ptid = p.decode()
self.assertEqual(pttid, ttid)
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_38_askLockInformation(self): def test_38_askLockInformation(self):
oid1 = self.getNextTID() oid1 = self.getNextTID()
oid2 = self.getNextTID() oid2 = self.getNextTID()
oid_list = [oid1, oid2] oid_list = [oid1, oid2]
ttid = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = Packets.AskLockInformation(tid, oid_list) p = Packets.AskLockInformation(ttid, tid, oid_list)
ptid, p_oid_list = p.decode() pttid, ptid, p_oid_list = p.decode()
self.assertEqual(pttid, ttid)
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertEqual(oid_list, p_oid_list) self.assertEqual(oid_list, p_oid_list)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment