Commit 7835c09e authored by Vincent Pelletier's avatar Vincent Pelletier

Make undo implementation work with replication.

The problem with previous implementation was that each storage locally
decided what undo actually did to data. This causes problems when a
storage doesn't have a complete view of past transaction but accepts write
queries, ie when it replicates.
This implementation reduces the decision to a readable subset of storage
nodes (which are hence not replicating), and then sends that decision to
all storage nodes, hence fixing the issue.
Also, DatabaseManager.storeTransaction now consistently expects object's
value_serial to be packed, not an integer.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2285 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent af92b534
...@@ -114,8 +114,7 @@ class ThreadContext(object): ...@@ -114,8 +114,7 @@ class ThreadContext(object):
'node_tids': {}, 'node_tids': {},
'node_ready': False, 'node_ready': False,
'asked_object': 0, 'asked_object': 0,
'undo_conflict_oid_list': [], 'undo_object_tid_dict': {},
'undo_error_oid_list': [],
'involved_nodes': set(), 'involved_nodes': set(),
} }
...@@ -905,55 +904,77 @@ class Application(object): ...@@ -905,55 +904,77 @@ class Application(object):
raise NEOStorageError('undo failed') raise NEOStorageError('undo failed')
tid = self.local_var.tid tid = self.local_var.tid
oid_list = self.local_var.txn_info['oids']
undo_conflict_oid_list = self.local_var.undo_conflict_oid_list = [] # Regroup objects per partition, to ask a minimum set of storage.
undo_error_oid_list = self.local_var.undo_error_oid_list = [] partition_oid_dict = {}
ask_undo_transaction = Packets.AskUndoTransaction(tid, undone_tid) pt = self._getPartitionTable()
getConnForNode = self.cp.getConnForNode getPartitionFromIndex = pt.getPartitionFromIndex
for oid in oid_list:
partition = pt.getPartitionFromIndex(oid)
try:
oid_list = partition_oid_dict[partition]
except KeyError:
oid_list = partition_oid_dict[partition] = []
oid_list.append(oid)
# Ask storage the undo serial (serial at which object's previous data
# is)
getCellList = pt.getCellList
getCellSortKey = self.cp.getCellSortKey
queue = self.local_var.queue queue = self.local_var.queue
for storage_node in self.nm.getStorageList(): undo_object_tid_dict = self.local_var.undo_object_tid_dict = {}
storage_conn = getConnForNode(storage_node) for partition, oid_list in partition_oid_dict.iteritems():
storage_conn.ask(ask_undo_transaction, queue=queue) cell_list = getCellList(partition, readable=True)
# Wait for all AnswerUndoTransaction. shuffle(cell_list)
cell_list.sort(key=getCellSortKey)
storage_conn = getConnForCell(cell_list[0])
storage_conn.ask(Packets.AskObjectUndoSerial(tid, undone_tid,
oid_list), queue=queue)
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError,
# meaning that objects in transaction's oid_list do not exist any
# longer. This is the symptom of a pack, so forbid undoing transaction
# when it happens, but sill keep waiting for answers.
failed = False
while True:
try:
self.waitResponses() self.waitResponses()
except NEOStorageNotFoundError:
failed = True
else:
break
if failed:
raise UndoError('non-undoable transaction')
# Don't do any handling for "live" conflicts, raise # Send undo data to all storage nodes.
if undo_conflict_oid_list: for oid in oid_list:
raise ConflictError(oid=undo_conflict_oid_list[0], serials=(tid, current_serial, undo_serial, is_current = undo_object_tid_dict[oid]
undone_tid), data=None) if is_current:
data = None
# Try to resolve undo conflicts else:
for oid in undo_error_oid_list: # Serial being undone is not the latest version for this
def loadBefore(oid, tid): # object. This is an undo conflict, try to resolve it.
try: try:
result = self._load(oid, tid=tid)
except NEOStorageNotFoundError:
raise UndoError("Object not found while resolving undo " \
"conflict")
return result[:2]
# Load the latest version we are supposed to see # Load the latest version we are supposed to see
data, data_tid = loadBefore(oid, tid) data = self.loadSerial(oid, current_serial)
# Load the version we were undoing to # Load the version we were undoing to
undo_data, _ = loadBefore(oid, undone_tid) undo_data = self.loadSerial(oid, undo_serial)
except NEOStorageNotFoundError:
raise UndoError('Object not found while resolving undo '
'conflict')
# Resolve conflict # Resolve conflict
try: try:
new_data = tryToResolveConflict(oid, data_tid, undone_tid, data = tryToResolveConflict(oid, current_serial,
undo_data, data) undone_tid, undo_data, data)
except ConflictError: except ConflictError:
new_data = None data = None
if new_data is None: if data is None:
raise UndoError('Some data were modified by a later ' \ raise UndoError('Some data were modified by a later ' \
'transaction', oid) 'transaction', oid)
else: undo_serial = None
self._store(oid, data_tid, new_data) self._store(oid, current_serial, data, undo_serial)
return tid, oid_list
oid_list = self.local_var.txn_info['oids']
# Consistency checking: all oids of the transaction must have been
# reported as undone
data_dict = self.local_var.data_dict
for oid in oid_list:
assert oid in data_dict, repr(oid)
return self.local_var.tid, oid_list
def _insertMetadata(self, txn_info, extension): def _insertMetadata(self, txn_info, extension):
for k, v in loads(extension).items(): for k, v in loads(extension).items():
......
...@@ -127,14 +127,8 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -127,14 +127,8 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list self.app.local_var.node_tids[conn.getUUID()] = tid_list
def answerUndoTransaction(self, conn, oid_list, error_oid_list, def answerObjectUndoSerial(self, conn, object_tid_dict):
conflict_oid_list): self.app.local_var.undo_object_tid_dict.update(object_tid_dict)
local_var = self.app.local_var
local_var.undo_conflict_oid_list.extend(conflict_oid_list)
local_var.undo_error_oid_list.extend(error_oid_list)
data_dict = local_var.data_dict
for oid in oid_list:
data_dict[oid] = ''
def answerHasLock(self, conn, oid, status): def answerHasLock(self, conn, oid, status):
if status == LockState.GRANTED_TO_OTHER: if status == LockState.GRANTED_TO_OTHER:
......
...@@ -335,10 +335,10 @@ class EventHandler(object): ...@@ -335,10 +335,10 @@ class EventHandler(object):
def notifyReplicationDone(self, conn, offset): def notifyReplicationDone(self, conn, offset):
raise UnexpectedPacketError raise UnexpectedPacketError
def askUndoTransaction(self, conn, tid, undone_tid): def askObjectUndoSerial(self, conn, tid, undone_tid, oid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerUndoTransaction(self, conn, oid_list, error_oid_list, conflict_oid_list): def answerObjectUndoSerial(self, conn, object_tid_dict):
raise UnexpectedPacketError raise UnexpectedPacketError
def askHasLock(self, conn, tid, oid): def askHasLock(self, conn, tid, oid):
...@@ -456,8 +456,8 @@ class EventHandler(object): ...@@ -456,8 +456,8 @@ class EventHandler(object):
d[Packets.NotifyClusterInformation] = self.notifyClusterInformation d[Packets.NotifyClusterInformation] = self.notifyClusterInformation
d[Packets.NotifyLastOID] = self.notifyLastOID d[Packets.NotifyLastOID] = self.notifyLastOID
d[Packets.NotifyReplicationDone] = self.notifyReplicationDone d[Packets.NotifyReplicationDone] = self.notifyReplicationDone
d[Packets.AskUndoTransaction] = self.askUndoTransaction d[Packets.AskObjectUndoSerial] = self.askObjectUndoSerial
d[Packets.AnswerUndoTransaction] = self.answerUndoTransaction d[Packets.AnswerObjectUndoSerial] = self.answerObjectUndoSerial
d[Packets.AskHasLock] = self.askHasLock d[Packets.AskHasLock] = self.askHasLock
d[Packets.AnswerHasLock] = self.answerHasLock d[Packets.AnswerHasLock] = self.answerHasLock
......
...@@ -1528,55 +1528,76 @@ class NotifyLastOID(Packet): ...@@ -1528,55 +1528,76 @@ class NotifyLastOID(Packet):
(loid, ) = unpack('8s', body) (loid, ) = unpack('8s', body)
return (loid, ) return (loid, )
class AskUndoTransaction(Packet): class AskObjectUndoSerial(Packet):
""" """
Ask storage to undo given transaction Ask storage the serial where object data is when undoing given transaction,
for a list of OIDs.
C -> S C -> S
""" """
def _encode(self, tid, undone_tid): _header_format = '!8s8sL'
return _encodeTID(tid) + _encodeTID(undone_tid)
def _decode(self, body): def _encode(self, tid, undone_tid, oid_list):
tid = _decodeTID(body[:8]) body = StringIO()
undone_tid = _decodeTID(body[8:]) write = body.write
return (tid, undone_tid) write(pack(self._header_format, tid, undone_tid, len(oid_list)))
for oid in oid_list:
write(oid)
return body.getvalue()
class AnswerUndoTransaction(Packet): def _decode(self, body):
""" body = StringIO(body)
Answer an undo request, telling if undo could be done, with an oid list. read = body.read
If undo failed, the list contains oid(s) causing problems. tid, undone_tid, oid_list_len = unpack(self._header_format,
If undo succeeded; the list contains all undone oids for given storage. read(self._header_len))
oid_list = [read(8) for _ in xrange(oid_list_len)]
return tid, undone_tid, oid_list
class AnswerObjectUndoSerial(Packet):
"""
Answer serials at which object data is when undoing a given transaction.
object_tid_dict has the following format:
key: oid
value: 3-tuple
current_serial (TID)
The latest serial visible to the undoing transaction.
undo_serial (TID)
Where undone data is (tid at which data is before given undo).
is_current (bool)
If current_serial's data is current on storage.
S -> C S -> C
""" """
_header_format = '!LLL' _header_format = '!L'
_list_entry_format = '!8s8s8sB'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, oid_list, error_oid_list, conflict_oid_list): def _encode(self, object_tid_dict):
body = StringIO() body = StringIO()
write = body.write write = body.write
oid_list_list = [oid_list, error_oid_list, conflict_oid_list] write(pack(self._header_format, len(object_tid_dict)))
write(pack(self._header_format, *[len(x) for x in oid_list_list])) list_entry_format = self._list_entry_format
for oid_list in oid_list_list: for oid, (current_serial, undo_serial, is_current) in \
for oid in oid_list: object_tid_dict.iteritems():
write(oid) if undo_serial is None:
undo_serial = ZERO_TID
write(pack(list_entry_format, oid, current_serial, undo_serial,
is_current))
return body.getvalue() return body.getvalue()
def _decode(self, body): def _decode(self, body):
body = StringIO(body) body = StringIO(body)
read = body.read read = body.read
oid_list_len, error_oid_list_len, conflict_oid_list_len = unpack( object_tid_dict = {}
self._header_format, read(self._header_len)) list_entry_format = self._list_entry_format
oid_list = [] list_entry_len = self._list_entry_len
error_oid_list = [] object_tid_len = unpack(self._header_format, read(self._header_len))[0]
conflict_oid_list = [] for _ in xrange(object_tid_len):
for some_list, some_list_len in ( oid, current_serial, undo_serial, is_current = unpack(
(oid_list, oid_list_len), list_entry_format, read(list_entry_len))
(error_oid_list, error_oid_list_len), if undo_serial == ZERO_TID:
(conflict_oid_list, conflict_oid_list_len), undo_serial = None
): object_tid_dict[oid] = (current_serial, undo_serial,
append = some_list.append bool(is_current))
for _ in xrange(some_list_len): return (object_tid_dict, )
append(read(OID_LEN))
return (oid_list, error_oid_list, conflict_oid_list)
class AskHasLock(Packet): class AskHasLock(Packet):
""" """
...@@ -1821,10 +1842,10 @@ class PacketRegistry(dict): ...@@ -1821,10 +1842,10 @@ class PacketRegistry(dict):
AnswerClusterState) AnswerClusterState)
NotifyLastOID = register(0x0030, NotifyLastOID) NotifyLastOID = register(0x0030, NotifyLastOID)
NotifyReplicationDone = register(0x0031, NotifyReplicationDone) NotifyReplicationDone = register(0x0031, NotifyReplicationDone)
AskUndoTransaction, AnswerUndoTransaction = register( AskObjectUndoSerial, AnswerObjectUndoSerial = register(
0x0033, 0x0033,
AskUndoTransaction, AskObjectUndoSerial,
AnswerUndoTransaction) AnswerObjectUndoSerial)
AskHasLock, AnswerHasLock = register( AskHasLock, AnswerHasLock = register(
0x0034, 0x0034,
AskHasLock, AskHasLock,
......
...@@ -234,13 +234,32 @@ class DatabaseManager(object): ...@@ -234,13 +234,32 @@ class DatabaseManager(object):
pack state (True for packed).""" pack state (True for packed)."""
raise NotImplementedError raise NotImplementedError
def getTransactionUndoData(self, tid, undone_tid, def findUndoTID(self, oid, tid, undone_tid, transaction_object):
getObjectFromTransaction): """
"""Undo transaction with "undone_tid" tid. "tid" is the tid of the oid
transaction in which the undo happens. Object OID
getObjectFromTransaction is a callback allowing to find object data tid
stored to this storage in the same transaction (it is useful for Transation doing the undo
example when undoing twice in the same transaction). undone_tid
Transaction to undo
transaction_object
Object data from memory, if it was modified by running
transaction.
None if is was not modified by running transaction.
Returns a 3-tuple:
current_tid (p64)
TID of most recent version of the object client's transaction can
see. This is used later to detect current conflicts (eg, another
client modifying the same object in parallel)
data_tid (int)
TID containing (without indirection) the data prior to undone
transaction.
None if object doesn't exist prior to transaction being undone
(its creation is being undone).
is_current (bool)
False if object was modified by later transaction (ie, data_tid is
not current), True otherwise.
""" """
raise NotImplementedError raise NotImplementedError
......
...@@ -453,7 +453,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -453,7 +453,7 @@ class MySQLDatabaseManager(DatabaseManager):
if value_serial is None: if value_serial is None:
value_serial = 'NULL' value_serial = 'NULL'
else: else:
value_serial = '%d' % (value_serial, ) value_serial = '%d' % (u64(value_serial), )
q("""REPLACE INTO %s VALUES (%d, %d, %s, %s, %s, %s)""" \ q("""REPLACE INTO %s VALUES (%d, %d, %s, %s, %s, %s)""" \
% (obj_table, oid, tid, compression, checksum, data, % (obj_table, oid, tid, compression, checksum, data,
value_serial)) value_serial))
...@@ -506,71 +506,33 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -506,71 +506,33 @@ class MySQLDatabaseManager(DatabaseManager):
result = self._getDataTIDFromData(oid, result) result = self._getDataTIDFromData(oid, result)
return result return result
def _findUndoTID(self, oid, tid, undone_tid, transaction_object): def findUndoTID(self, oid, tid, undone_tid, transaction_object):
""" u64 = util.u64
oid, undone_tid (ints) p64 = util.p64
Object to undo for given transaction oid = u64(oid)
tid (int) tid = u64(tid)
Client's transaction (he can't see objects past this value). undone_tid = u64(undone_tid)
Return a 2-tuple:
current_tid (p64)
TID of most recent version of the object client's transaction can
see. This is used later to detect current conflicts (eg, another
client modifying the same object in parallel)
data_tid (int)
TID containing (without indirection) the data prior to undone
transaction.
-1 if object was modified by later transaction.
None if object doesn't exist prior to transaction being undone
(its creation is being undone).
"""
_getDataTID = self._getDataTID _getDataTID = self._getDataTID
if transaction_object is not None: if transaction_object is not None:
# transaction_object: toid, tcompression, tchecksum, tdata, tvalue_serial = \
# oid, compression, ... transaction_object
# Expected value:
# serial, next_serial, compression, ...
current_tid, current_data_tid = self._getDataTIDFromData(oid, current_tid, current_data_tid = self._getDataTIDFromData(oid,
(tid, None) + transaction_object[1:]) (tid, None, tcompression, tchecksum, tdata,
u64(tvalue_serial)))
else: else:
current_tid, current_data_tid = _getDataTID(oid, before_tid=tid) current_tid, current_data_tid = _getDataTID(oid, before_tid=tid)
assert current_tid is not None, (oid, tid, transaction_object) if current_tid is None:
return (None, None, False)
found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid) found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid)
assert found_undone_tid is not None, (oid, undone_tid) assert found_undone_tid is not None, (oid, undone_tid)
if undone_data_tid not in (current_data_tid, tid): is_current = undone_data_tid in (current_data_tid, tid)
# data from the transaction we want to undo is modified by a later
# transaction. It is up to the client node to decide what to do
# (undo error of conflict resolution).
data_tid = -1
else:
# Load object data as it was before given transaction. # Load object data as it was before given transaction.
# It can be None, in which case it means we are undoing object # It can be None, in which case it means we are undoing object
# creation. # creation.
_, data_tid = _getDataTID(oid, before_tid=undone_tid) _, data_tid = _getDataTID(oid, before_tid=undone_tid)
return util.p64(current_tid), data_tid if data_tid is not None:
data_tid = p64(data_tid)
def getTransactionUndoData(self, tid, undone_tid, return p64(current_tid), data_tid, is_current
getObjectFromTransaction):
q = self.query
p64 = util.p64
u64 = util.u64
_findUndoTID = self._findUndoTID
p_tid = tid
tid = u64(tid)
undone_tid = u64(undone_tid)
if undone_tid > tid:
# Replace with an exception reaching client (TIDNotFound)
raise ValueError, 'Can\'t undo in future: %d > %d' % (
undone_tid, tid)
result = {}
for (oid, ) in q("""SELECT oid FROM obj WHERE serial = %d""" % (
undone_tid, )):
p_oid = p64(oid)
result[p_oid] = _findUndoTID(oid, tid, undone_tid,
getObjectFromTransaction(p_tid, p_oid))
return result
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from neo import logging from neo import logging
from neo import protocol from neo import protocol
from neo.util import dump from neo.util import dump
from neo.protocol import Packets, LockState from neo.protocol import Packets, LockState, Errors
from neo.storage.handlers import BaseClientAndStorageOperationHandler from neo.storage.handlers import BaseClientAndStorageOperationHandler
from neo.storage.transactions import ConflictError, DelayedError from neo.storage.transactions import ConflictError, DelayedError
import time import time
...@@ -78,6 +78,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -78,6 +78,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
compression, checksum, data, data_serial, tid): compression, checksum, data, data_serial, tid):
# register the transaction # register the transaction
self.app.tm.register(conn.getUUID(), tid) self.app.tm.register(conn.getUUID(), tid)
if data_serial is not None:
assert data == '', repr(data)
# Change data to None here, to do it only once, even if store gets
# delayed.
data = None
self._askStoreObject(conn, oid, serial, compression, checksum, data, self._askStoreObject(conn, oid, serial, compression, checksum, data,
data_serial, tid, time.time()) data_serial, tid, time.time())
...@@ -97,41 +102,21 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -97,41 +102,21 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list)) conn.answer(Packets.AnswerTIDs(tid_list))
def askUndoTransaction(self, conn, tid, undone_tid): def askObjectUndoSerial(self, conn, tid, undone_tid, oid_list):
app = self.app app = self.app
tm = app.tm findUndoTID = app.dm.findUndoTID
storeObject = tm.storeObject getObjectFromTransaction = app.tm.getObjectFromTransaction
uuid = conn.getUUID() object_tid_dict = {}
oid_list = [] for oid in oid_list:
error_oid_list = [] current_serial, undo_serial, is_current = findUndoTID(oid, tid,
conflict_oid_list = [] undone_tid, getObjectFromTransaction(tid, oid))
if current_serial is None:
undo_tid_dict = app.dm.getTransactionUndoData(tid, undone_tid, p = Errors.OidNotFound(dump(oid))
tm.getObjectFromTransaction) break
for oid, (current_serial, undone_value_serial) in \ object_tid_dict[oid] = (current_serial, undo_serial, is_current)
undo_tid_dict.iteritems():
if undone_value_serial == -1:
# Some data were modified by a later transaction
# This must be propagated to client, who will
# attempt a conflict resolution, and store resolved
# data.
to_append_list = error_oid_list
else:
try:
self.app.tm.register(uuid, tid)
storeObject(tid, current_serial, oid, None,
None, None, undone_value_serial)
except ConflictError:
to_append_list = conflict_oid_list
except DelayedError:
app.queueEvent(self.askUndoTransaction, conn, tid,
undone_tid)
return
else: else:
to_append_list = oid_list p = Packets.AnswerObjectUndoSerial(object_tid_dict)
to_append_list.append(oid) conn.answer(p)
conn.answer(Packets.AnswerUndoTransaction(oid_list, error_oid_list,
conflict_oid_list))
def askHasLock(self, conn, tid, oid): def askHasLock(self, conn, tid, oid):
locking_tid = self.app.tm.getLockingTID(oid) locking_tid = self.app.tm.getLockingTID(oid)
......
This diff is collapsed.
...@@ -240,28 +240,25 @@ class StorageAnswerHandlerTests(NeoTestBase): ...@@ -240,28 +240,25 @@ class StorageAnswerHandlerTests(NeoTestBase):
self.assertTrue(uuid in self.app.local_var.node_tids) self.assertTrue(uuid in self.app.local_var.node_tids)
self.assertEqual(self.app.local_var.node_tids[uuid], tid_list) self.assertEqual(self.app.local_var.node_tids[uuid], tid_list)
def test_answerUndoTransaction(self): def test_answerObjectUndoSerial(self):
local_var = self.app.local_var uuid = self.getNewUUID()
undo_conflict_oid_list = local_var.undo_conflict_oid_list = [] conn = self.getFakeConnection(uuid=uuid)
undo_error_oid_list = local_var.undo_error_oid_list = [] oid1 = self.getOID(1)
data_dict = local_var.data_dict = {} oid2 = self.getOID(2)
conn = None # Nothing is done on connection in this handler tid0 = self.getNextTID()
tid1 = self.getNextTID()
# Nothing undone, check nothing changed tid2 = self.getNextTID()
self.handler.answerUndoTransaction(conn, [], [], []) tid3 = self.getNextTID()
self.assertEqual(undo_conflict_oid_list, []) self.app.local_var.undo_object_tid_dict = undo_dict = {
self.assertEqual(undo_error_oid_list, []) oid1: [tid0, tid1],
self.assertEqual(data_dict, {}) }
self.handler.answerObjectUndoSerial(conn, {
# One OID for each case, check they are inserted in expected local_var oid2: [tid2, tid3],
# entries. })
oid_1 = self.getOID(0) self.assertEqual(undo_dict, {
oid_2 = self.getOID(1) oid1: [tid0, tid1],
oid_3 = self.getOID(2) oid2: [tid2, tid3],
self.handler.answerUndoTransaction(conn, [oid_1], [oid_2], [oid_3]) })
self.assertEqual(undo_conflict_oid_list, [oid_3])
self.assertEqual(undo_error_oid_list, [oid_2])
self.assertEqual(data_dict, {oid_1: ''})
def test_answerHasLock(self): def test_answerHasLock(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest import unittest
from mock import Mock from mock import Mock, ReturnValues
from collections import deque from collections import deque
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.storage.app import Application from neo.storage.app import Application
...@@ -215,11 +215,27 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -215,11 +215,27 @@ class StorageClientHandlerTests(NeoTestBase):
conn = self._getConnection(uuid=uuid) conn = self._getConnection(uuid=uuid)
tid = self.getNextTID() tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject() oid, serial, comp, checksum, data = self._getObject()
self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, None, tid)
self._checkStoreObjectCalled(tid, serial, oid, comp,
checksum, data, None)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 0)
self.assertEqual(poid, oid)
self.assertEqual(pserial, serial)
def test_askStoreObjectWithDataTID(self):
# same as test_askStoreObject1, but with a non-None data_tid value
uuid = self.getNewUUID()
conn = self._getConnection(uuid=uuid)
tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID() data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, checksum, self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, data_tid, tid) '', data_tid, tid)
self._checkStoreObjectCalled(tid, serial, oid, comp, self._checkStoreObjectCalled(tid, serial, oid, comp,
checksum, data, data_tid) checksum, None, data_tid)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn, pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True) decode=True)
self.assertEqual(pconflicting, 0) self.assertEqual(pconflicting, 0)
...@@ -236,9 +252,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -236,9 +252,8 @@ class StorageClientHandlerTests(NeoTestBase):
raise ConflictError(locking_tid) raise ConflictError(locking_tid)
self.app.tm.storeObject = fakeStoreObject self.app.tm.storeObject = fakeStoreObject
oid, serial, comp, checksum, data = self._getObject() oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, checksum, self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, data_tid, tid) data, None, tid)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn, pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True) decode=True)
self.assertEqual(pconflicting, 1) self.assertEqual(pconflicting, 1)
...@@ -253,44 +268,22 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -253,44 +268,22 @@ class StorageClientHandlerTests(NeoTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(tid)
def test_askUndoTransaction(self): def test_askObjectUndoSerial(self):
conn = self._getConnection() uuid = self.getNewUUID()
conn = self._getConnection(uuid=uuid)
tid = self.getNextTID() tid = self.getNextTID()
undone_tid = self.getNextTID() undone_tid = self.getNextTID()
oid_1 = self.getNextTID() # Keep 2 entries here, so we check findUndoTID is called only once.
oid_2 = self.getNextTID() oid_list = [self.getOID(1), self.getOID(2)]
oid_3 = self.getNextTID() obj2_data = [] # Marker
oid_4 = self.getNextTID() self.app.tm = Mock({
def getTransactionUndoData(tid, undone_tid, getObjectFromTransaction): 'getObjectFromTransaction': None,
return { })
oid_1: (1, 1), self.app.dm = Mock({
oid_2: (1, -1), 'findUndoTID': ReturnValues((None, None, False), )
oid_3: (1, 2), })
oid_4: (1, 3), self.operation.askObjectUndoSerial(conn, tid, undone_tid, oid_list)
} self.checkErrorPacket(conn)
self.app.dm.getTransactionUndoData = getTransactionUndoData
original_storeObject = self.app.tm.storeObject
def storeObject(tid, serial, oid, *args, **kw):
if oid == oid_3:
raise ConflictError(0)
elif oid == oid_4 and delay_store:
raise DelayedError
return original_storeObject(tid, serial, oid, *args, **kw)
self.app.tm.storeObject = storeObject
# Check if delaying a store (of oid_4) is supported
delay_store = True
self.operation.askUndoTransaction(conn, tid, undone_tid)
self.checkNoPacketSent(conn)
delay_store = False
self.operation.askUndoTransaction(conn, tid, undone_tid)
oid_list_1, oid_list_2, oid_list_3 = self.checkAnswerPacket(conn,
Packets.AnswerUndoTransaction, decode=True)
# Compare sets as order doens't matter here.
self.assertEqual(set(oid_list_1), set([oid_1, oid_4]))
self.assertEqual(oid_list_2, [oid_2])
self.assertEqual(oid_list_3, [oid_3])
def test_askHasLock(self): def test_askHasLock(self):
tid_1 = self.getNextTID() tid_1 = self.getNextTID()
......
...@@ -605,13 +605,13 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -605,13 +605,13 @@ class StorageMySQSLdbTests(NeoTestBase):
db.storeTransaction( db.storeTransaction(
tid1, ( tid1, (
(oid1, 0, 0, 'foo', None), (oid1, 0, 0, 'foo', None),
(oid2, None, None, None, u64(tid0)), (oid2, None, None, None, tid0),
(oid3, None, None, None, u64(tid2)), (oid3, None, None, None, tid2),
), None, temporary=False) ), None, temporary=False)
db.storeTransaction( db.storeTransaction(
tid2, ( tid2, (
(oid1, None, None, None, u64(tid1)), (oid1, None, None, None, tid1),
(oid2, None, None, None, u64(tid1)), (oid2, None, None, None, tid1),
(oid3, 0, 0, 'bar', None), (oid3, 0, 0, 'bar', None),
), None, temporary=False) ), None, temporary=False)
...@@ -689,7 +689,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -689,7 +689,7 @@ class StorageMySQSLdbTests(NeoTestBase):
), None, temporary=False) ), None, temporary=False)
db.storeTransaction( db.storeTransaction(
tid2, ( tid2, (
(oid1, None, None, None, u64(tid1)), (oid1, None, None, None, tid1),
), None, temporary=False) ), None, temporary=False)
self.assertEqual( self.assertEqual(
...@@ -713,7 +713,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -713,7 +713,7 @@ class StorageMySQSLdbTests(NeoTestBase):
), None, temporary=False) ), None, temporary=False)
db.storeTransaction( db.storeTransaction(
tid2, ( tid2, (
(oid1, None, None, None, u64(tid1)), (oid1, None, None, None, tid1),
), None, temporary=False) ), None, temporary=False)
self.assertEqual( self.assertEqual(
...@@ -723,7 +723,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -723,7 +723,7 @@ class StorageMySQSLdbTests(NeoTestBase):
db._getDataTID(u64(oid1), tid=u64(tid2)), db._getDataTID(u64(oid1), tid=u64(tid2)),
(u64(tid2), u64(tid1))) (u64(tid2), u64(tid1)))
def test__findUndoTID(self): def test_findUndoTID(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
tid1 = self.getNextTID() tid1 = self.getNextTID()
...@@ -740,8 +740,8 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -740,8 +740,8 @@ class StorageMySQSLdbTests(NeoTestBase):
# Result: current tid is tid1, data_tid is None (undoing object # Result: current tid is tid1, data_tid is None (undoing object
# creation) # creation)
self.assertEqual( self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), None), db.findUndoTID(oid1, tid4, tid1, None),
(tid1, None)) (tid1, None, True))
# Store a new transaction # Store a new transaction
db.storeTransaction( db.storeTransaction(
...@@ -752,14 +752,14 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -752,14 +752,14 @@ class StorageMySQSLdbTests(NeoTestBase):
# Undoing oid1 tid2, OK: tid2 is latest # Undoing oid1 tid2, OK: tid2 is latest
# Result: current tid is tid2, data_tid is tid1 # Result: current tid is tid2, data_tid is tid1
self.assertEqual( self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid2), None), db.findUndoTID(oid1, tid4, tid2, None),
(tid2, u64(tid1))) (tid2, tid1, True))
# Undoing oid1 tid1, Error: tid2 is latest # Undoing oid1 tid1, Error: tid2 is latest
# Result: current tid is tid2, data_tid is -1 # Result: current tid is tid2, data_tid is -1
self.assertEqual( self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), None), db.findUndoTID(oid1, tid4, tid1, None),
(tid2, -1)) (tid2, None, False))
# Undoing oid1 tid1 with tid2 being undone in same transaction, # Undoing oid1 tid1 with tid2 being undone in same transaction,
# OK: tid1 is latest # OK: tid1 is latest
...@@ -768,71 +768,22 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -768,71 +768,22 @@ class StorageMySQSLdbTests(NeoTestBase):
# Explanation of transaction_object: oid1, no data but a data serial # Explanation of transaction_object: oid1, no data but a data serial
# to tid1 # to tid1
self.assertEqual( self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), db.findUndoTID(oid1, tid4, tid1,
(u64(oid1), None, None, None, u64(tid1))), (u64(oid1), None, None, None, tid1)),
(tid4, None)) (tid4, None, True))
# Store a new transaction # Store a new transaction
db.storeTransaction( db.storeTransaction(
tid3, ( tid3, (
(oid1, None, None, None, u64(tid1)), (oid1, None, None, None, tid1),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid1, OK: tid3 is latest with tid1 data # Undoing oid1 tid1, OK: tid3 is latest with tid1 data
# Result: current tid is tid2, data_tid is None (undoing object # Result: current tid is tid2, data_tid is None (undoing object
# creation) # creation)
self.assertEqual( self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), None), db.findUndoTID(oid1, tid4, tid1, None),
(tid3, None)) (tid3, None, True))
def test_getTransactionUndoData(self):
db = self.db
db.setup(reset=True)
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
tid4 = self.getNextTID()
tid5 = self.getNextTID()
assert tid1 < tid2 < tid3 < tid4 < tid5
oid1 = self.getOID(1)
oid2 = self.getOID(2)
oid3 = self.getOID(3)
oid4 = self.getOID(4)
oid5 = self.getOID(5)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo1', None),
(oid2, 0, 0, 'foo2', None),
(oid3, 0, 0, 'foo3', None),
(oid4, 0, 0, 'foo5', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, 0, 0, 'bar1', None),
(oid2, None, None, None, None),
(oid3, 0, 0, 'bar3', None),
), None, temporary=False)
db.storeTransaction(
tid3, (
(oid3, 0, 0, 'baz3', None),
(oid5, 0, 0, 'foo6', None),
), None, temporary=False)
def getObjectFromTransaction(tid, oid):
return None
self.assertEqual(
db.getTransactionUndoData(tid4, tid2, getObjectFromTransaction),
{
oid1: (tid2, u64(tid1)), # can be undone
oid2: (tid2, u64(tid1)), # can be undone (creation redo)
oid3: (tid3, -1), # cannot be undone
# oid4 & oid5: not present because not ins undone transaction
})
# Cannot undo future transaction
self.assertRaises(ValueError, db.getTransactionUndoData, tid4, tid5,
getObjectFromTransaction)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -482,23 +482,29 @@ class ProtocolTests(NeoTestBase): ...@@ -482,23 +482,29 @@ class ProtocolTests(NeoTestBase):
p_offset = p.decode()[0] p_offset = p.decode()[0]
self.assertEqual(p_offset, offset) self.assertEqual(p_offset, offset)
def test_askUndoTransaction(self): def test_askObjectUndoSerial(self):
tid = self.getNextTID() tid = self.getNextTID()
undone_tid = self.getNextTID() undone_tid = self.getNextTID()
p = Packets.AskUndoTransaction(tid, undone_tid) oid_list = [self.getOID(x) for x in xrange(4)]
p_tid, p_undone_tid = p.decode() p = Packets.AskObjectUndoSerial(tid, undone_tid, oid_list)
self.assertEqual(p_tid, tid) ptid, pundone_tid, poid_list = p.decode()
self.assertEqual(p_undone_tid, undone_tid) self.assertEqual(tid, ptid)
self.assertEqual(undone_tid, pundone_tid)
def test_answerUndoTransaction(self): self.assertEqual(oid_list, poid_list)
oid_list_1 = [self.getNextTID()]
oid_list_2 = [self.getNextTID(), self.getNextTID()] def test_answerObjectUndoSerial(self):
oid_list_3 = [self.getNextTID(), self.getNextTID(), self.getNextTID()] oid1 = self.getNextTID()
p = Packets.AnswerUndoTransaction(oid_list_1, oid_list_2, oid_list_3) oid2 = self.getNextTID()
p_oid_list_1, p_oid_list_2, p_oid_list_3 = p.decode() tid1 = self.getNextTID()
self.assertEqual(p_oid_list_1, oid_list_1) tid2 = self.getNextTID()
self.assertEqual(p_oid_list_2, oid_list_2) tid3 = self.getNextTID()
self.assertEqual(p_oid_list_3, oid_list_3) object_tid_dict = {
oid1: (tid1, tid2, True),
oid2: (tid3, None, False),
}
p = Packets.AnswerObjectUndoSerial(object_tid_dict)
pobject_tid_dict = p.decode()[0]
self.assertEqual(object_tid_dict, pobject_tid_dict)
def test_NotifyLastOID(self): def test_NotifyLastOID(self):
oid = self.getOID(1) oid = self.getOID(1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment