Commit e0e81079 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Move _afterLock to app and move logic from handlers to transaction manager.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2563 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 8978b179
...@@ -53,7 +53,7 @@ class Application(object): ...@@ -53,7 +53,7 @@ class Application(object):
# Internal attributes. # Internal attributes.
self.em = EventManager() self.em = EventManager()
self.nm = NodeManager() self.nm = NodeManager()
self.tm = TransactionManager() self.tm = TransactionManager(self.onTransactionCommitted)
self.name = config.getCluster() self.name = config.getCluster()
self.server = config.getBind() self.server = config.getBind()
...@@ -562,6 +562,32 @@ class Application(object): ...@@ -562,6 +562,32 @@ class Application(object):
neo.logging.info('Accept a storage %s (%s)' % (dump(uuid), state)) neo.logging.info('Accept a storage %s (%s)' % (dump(uuid), state))
return (uuid, node, state, handler, node_ctor) return (uuid, node, state, handler, node_ctor)
def onTransactionCommitted(self, tid, txn):
# I have received all the lock answers now:
# - send a Notify Transaction Finished to the initiated client node
# - Invalidate Objects to the other client nodes
ttid = txn.getTTID()
transaction_node = txn.getNode()
invalidate_objects = Packets.InvalidateObjects(tid, txn.getOIDList())
transaction_finished = Packets.AnswerTransactionFinished(ttid, tid)
for client_node in self.nm.getClientList(only_identified=True):
c = client_node.getConnection()
if client_node is transaction_node:
c.answer(transaction_finished, msg_id=txn.getMessageId())
else:
c.notify(invalidate_objects)
# Unlock Information to relevant storage nodes.
notify_unlock = Packets.NotifyUnlockInformation(ttid)
getByUUID = self.nm.getByUUID
for storage_uuid in txn.getUUIDList():
getByUUID(storage_uuid).getConnection().notify(notify_unlock)
# remove transaction from manager
self.tm.remove(tid)
self.setLastTransaction(tid)
self.executeQueuedEvent()
def getLastTransaction(self): def getLastTransaction(self):
return self.last_transaction return self.last_transaction
......
...@@ -48,13 +48,8 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -48,13 +48,8 @@ class StorageServiceHandler(BaseServiceHandler):
# this is intentionaly placed after the raise because the last cell in a # this is intentionaly placed after the raise because the last cell in a
# partition must not oudated to allows a cluster restart. # partition must not oudated to allows a cluster restart.
self.app.outdateAndBroadcastPartition() self.app.outdateAndBroadcastPartition()
uuid = conn.getUUID() self.app.tm.forget(conn.getUUID())
for tid, transaction in self.app.tm.items(): if self.app.packing is not None:
# if a transaction is known, this means that it's being committed
if transaction.forget(uuid):
self._afterLock(tid)
packing = self.app.packing
if packing is not None:
self.answerPack(conn, False) self.answerPack(conn, False)
def askLastIDs(self, conn): def askLastIDs(self, conn):
...@@ -68,9 +63,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -68,9 +63,7 @@ class StorageServiceHandler(BaseServiceHandler):
conn.answer(p) conn.answer(p)
def answerInformationLocked(self, conn, tid): def answerInformationLocked(self, conn, tid):
uuid = conn.getUUID() tm = self.app.tm
app = self.app
tm = app.tm
# If the given transaction ID is later than the last TID, the peer # If the given transaction ID is later than the last TID, the peer
# is crazy. # is crazy.
...@@ -78,38 +71,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -78,38 +71,7 @@ class StorageServiceHandler(BaseServiceHandler):
raise ProtocolError('TID too big') raise ProtocolError('TID too big')
# transaction locked on this storage node # transaction locked on this storage node
if tm.lock(tid, uuid): tm.lock(tid, conn.getUUID())
self._afterLock(tid)
def _afterLock(self, tid):
# I have received all the lock answers now:
# - send a Notify Transaction Finished to the initiated client node
# - Invalidate Objects to the other client nodes
app = self.app
tm = app.tm
t = tm[tid]
ttid = t.getTTID()
nm = app.nm
transaction_node = t.getNode()
invalidate_objects = Packets.InvalidateObjects(tid, t.getOIDList())
answer_transaction_finished = Packets.AnswerTransactionFinished(ttid,
tid)
for client_node in nm.getClientList(only_identified=True):
c = client_node.getConnection()
if client_node is transaction_node:
c.answer(answer_transaction_finished, msg_id=t.getMessageId())
else:
c.notify(invalidate_objects)
# - Unlock Information to relevant storage nodes.
notify_unlock = Packets.NotifyUnlockInformation(ttid)
for storage_uuid in t.getUUIDList():
nm.getByUUID(storage_uuid).getConnection().notify(notify_unlock)
# remove transaction from manager
tm.remove(tid)
app.setLastTransaction(tid)
app.executeQueuedEvent()
def notifyReplicationDone(self, conn, offset): def notifyReplicationDone(self, conn, offset):
node = self.app.nm.getByUUID(conn.getUUID()) node = self.app.nm.getByUUID(conn.getUUID())
......
...@@ -193,17 +193,19 @@ class TransactionManager(object): ...@@ -193,17 +193,19 @@ class TransactionManager(object):
_next_ttid = 0 _next_ttid = 0
def __init__(self): def __init__(self, on_commit):
# tid -> transaction # tid -> transaction
self._tid_dict = {} self._tid_dict = {}
# node -> transactions mapping # node -> transactions mapping
self._node_dict = {} self._node_dict = {}
self._last_oid = None self._last_oid = None
self._on_commit = on_commit
def __getitem__(self, tid): def __getitem__(self, tid):
""" """
Return the transaction object for this TID Return the transaction object for this TID
""" """
# XXX: used by unit tests only
return self._tid_dict[tid] return self._tid_dict[tid]
def __contains__(self, tid): def __contains__(self, tid):
...@@ -213,6 +215,7 @@ class TransactionManager(object): ...@@ -213,6 +215,7 @@ class TransactionManager(object):
return tid in self._tid_dict return tid in self._tid_dict
def items(self): def items(self):
# XXX: used by unit tests only
return self._tid_dict.items() return self._tid_dict.items()
def getNextOIDList(self, num_oids): def getNextOIDList(self, num_oids):
...@@ -386,7 +389,19 @@ class TransactionManager(object): ...@@ -386,7 +389,19 @@ class TransactionManager(object):
Returns True if all are now locked Returns True if all are now locked
""" """
assert tid in self._tid_dict, "Transaction not started" assert tid in self._tid_dict, "Transaction not started"
return self._tid_dict[tid].lock(uuid) txn = self._tid_dict[tid]
if txn.lock(uuid):
# all storage are locked
self._on_commit(tid, txn)
def forget(self, uuid):
"""
A storage node has been lost, don't expect a reply from it for
current transactions
"""
for tid, txn in self._tid_dict.items():
if txn.forget(uuid):
self._on_commit(tid, txn)
def abortFor(self, node): def abortFor(self, node):
""" """
......
...@@ -61,7 +61,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -61,7 +61,8 @@ class testTransactionManager(NeoUnitTestBase):
uuid_list = (uuid1, uuid2) = self.makeUUID(1), self.makeUUID(2) uuid_list = (uuid1, uuid2) = self.makeUUID(1), self.makeUUID(2)
client_uuid = self.makeUUID(3) client_uuid = self.makeUUID(3)
# create transaction manager # create transaction manager
txnman = TransactionManager() callback = Mock()
txnman = TransactionManager(on_commit=callback)
self.assertFalse(txnman.hasPending()) self.assertFalse(txnman.hasPending())
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
# begin the transaction # begin the transaction
...@@ -78,8 +79,10 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -78,8 +79,10 @@ class testTransactionManager(NeoUnitTestBase):
self.assertEqual(txn.getUUIDList(), list(uuid_list)) self.assertEqual(txn.getUUIDList(), list(uuid_list))
self.assertEqual(txn.getOIDList(), list(oid_list)) self.assertEqual(txn.getOIDList(), list(oid_list))
# lock nodes # lock nodes
self.assertFalse(txnman.lock(tid, uuid1)) txnman.lock(tid, uuid1)
self.assertTrue(txnman.lock(tid, uuid2)) self.assertEqual(len(callback.getNamedCalls('__call__')), 0)
txnman.lock(tid, uuid2)
self.assertEqual(len(callback.getNamedCalls('__call__')), 1)
# transaction finished # transaction finished
txnman.remove(tid) txnman.remove(tid)
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
...@@ -91,7 +94,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -91,7 +94,7 @@ class testTransactionManager(NeoUnitTestBase):
storage_1_uuid = self.makeUUID(1) storage_1_uuid = self.makeUUID(1)
storage_2_uuid = self.makeUUID(2) storage_2_uuid = self.makeUUID(2)
client_uuid = self.makeUUID(3) client_uuid = self.makeUUID(3)
txnman = TransactionManager() txnman = TransactionManager(lambda tid, txn: None)
# register 4 transactions made by two nodes # register 4 transactions made by two nodes
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
ttid1 = txnman.begin(client_uuid) ttid1 = txnman.begin(client_uuid)
...@@ -108,7 +111,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -108,7 +111,7 @@ class testTransactionManager(NeoUnitTestBase):
txnman.begin(client_uuid, self.getNextTID()) txnman.begin(client_uuid, self.getNextTID())
def test_getNextOIDList(self): def test_getNextOIDList(self):
txnman = TransactionManager() txnman = TransactionManager(lambda tid, txn: None)
# must raise as we don"t have one # must raise as we don"t have one
self.assertEqual(txnman.getLastOID(), None) self.assertEqual(txnman.getLastOID(), None)
self.assertRaises(RuntimeError, txnman.getNextOIDList, 1) self.assertRaises(RuntimeError, txnman.getNextOIDList, 1)
...@@ -129,7 +132,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -129,7 +132,7 @@ class testTransactionManager(NeoUnitTestBase):
oid_list = [self.makeOID(1), ] oid_list = [self.makeOID(1), ]
client_uuid = self.makeUUID(3) client_uuid = self.makeUUID(3)
tm = TransactionManager() tm = TransactionManager(lambda tid, txn: None)
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
...@@ -206,7 +209,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -206,7 +209,7 @@ class testTransactionManager(NeoUnitTestBase):
Note: this implementation might change later, to allow more paralelism. Note: this implementation might change later, to allow more paralelism.
""" """
client_uuid = self.makeUUID(3) client_uuid = self.makeUUID(3)
tm = TransactionManager() tm = TransactionManager(lambda tid, txn: None)
# With a requested TID, lock spans from begin to remove # With a requested TID, lock spans from begin to remove
ttid1 = self.getNextTID() ttid1 = self.getNextTID()
ttid2 = self.getNextTID() ttid2 = self.getNextTID()
...@@ -227,7 +230,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -227,7 +230,7 @@ class testTransactionManager(NeoUnitTestBase):
def testClientDisconectsAfterBegin(self): def testClientDisconectsAfterBegin(self):
client1_uuid = self.makeUUID(1) client1_uuid = self.makeUUID(1)
client2_uuid = self.makeUUID(2) client2_uuid = self.makeUUID(2)
tm = TransactionManager() tm = TransactionManager(lambda tid, txn: None)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tm.begin(client1_uuid, tid1) tm.begin(client1_uuid, tid1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment