Commit ba8f014f authored by Vincent Pelletier's avatar Vincent Pelletier

Pipeline "store" action on client side.

Storage.store calls can be pipelined when implementation can take advantage of
it (as Zeo does). This allows reducing the impact of (network-induced, mainly)
latency by sending all objects to storages without waiting for storage answer.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1788 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent ce162b67
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
from ZODB import BaseStorage, ConflictResolution, POSException from ZODB import BaseStorage, ConflictResolution, POSException
from neo.client.app import Application from neo.client.app import Application
from neo.client.exception import NEOStorageConflictError, \ from neo.client.exception import NEOStorageNotFoundError
NEOStorageNotFoundError
class Storage(BaseStorage.BaseStorage, class Storage(BaseStorage.BaseStorage,
ConflictResolution.ConflictResolvingStorage): ConflictResolution.ConflictResolvingStorage):
...@@ -61,7 +60,8 @@ class Storage(BaseStorage.BaseStorage, ...@@ -61,7 +60,8 @@ class Storage(BaseStorage.BaseStorage,
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
if self._is_read_only: if self._is_read_only:
raise POSException.ReadOnlyError() raise POSException.ReadOnlyError()
return self.app.tpc_vote(transaction=transaction) return self.app.tpc_vote(transaction=transaction,
tryToResolveConflict=self.tryToResolveConflict)
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
if self._is_read_only: if self._is_read_only:
...@@ -72,30 +72,11 @@ class Storage(BaseStorage.BaseStorage, ...@@ -72,30 +72,11 @@ class Storage(BaseStorage.BaseStorage,
return self.app.tpc_finish(transaction=transaction, f=f) return self.app.tpc_finish(transaction=transaction, f=f)
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
app = self.app
if self._is_read_only: if self._is_read_only:
raise POSException.ReadOnlyError() raise POSException.ReadOnlyError()
try: return self.app.store(oid=oid, serial=serial,
return app.store(oid = oid, serial = serial, data=data, version=version, transaction=transaction,
data = data, version = version, tryToResolveConflict=self.tryToResolveConflict)
transaction = transaction)
except NEOStorageConflictError:
conflict_serial = app.getConflictSerial()
tid = app.getTID()
if conflict_serial <= tid:
# Try to resolve conflict only if conflicting serial is older
# than the current transaction ID
new_data = self.tryToResolveConflict(oid,
conflict_serial,
serial, data)
if new_data is not None:
# Try again after conflict resolution
self.store(oid, conflict_serial,
new_data, version, transaction)
return ConflictResolution.ResolvedSerial
raise POSException.ConflictError(oid=oid,
serials=(tid,
serial),data=data)
def getSerial(self, oid): def getSerial(self, oid):
try: try:
...@@ -123,11 +104,8 @@ class Storage(BaseStorage.BaseStorage, ...@@ -123,11 +104,8 @@ class Storage(BaseStorage.BaseStorage,
def undo(self, transaction_id, txn): def undo(self, transaction_id, txn):
if self._is_read_only: if self._is_read_only:
raise POSException.ReadOnlyError() raise POSException.ReadOnlyError()
try: return self.app.undo(transaction_id=transaction_id, txn=txn,
return self.app.undo(transaction_id = transaction_id, tryToResolveConflict=self.tryToResolveConflict)
txn = txn, wrapper = self)
except NEOStorageConflictError:
raise POSException.ConflictError
def undoLog(self, first, last, filter): def undoLog(self, first, last, filter):
......
...@@ -23,6 +23,7 @@ from random import shuffle ...@@ -23,6 +23,7 @@ from random import shuffle
from time import sleep from time import sleep
from ZODB.POSException import UndoError, StorageTransactionError, ConflictError from ZODB.POSException import UndoError, StorageTransactionError, ConflictError
from ZODB.ConflictResolution import ResolvedSerial
from neo import setupLog from neo import setupLog
setupLog('CLIENT', verbose=True) setupLog('CLIENT', verbose=True)
...@@ -36,7 +37,7 @@ from neo.locking import Lock ...@@ -36,7 +37,7 @@ from neo.locking import Lock
from neo.connection import MTClientConnection from neo.connection import MTClientConnection
from neo.node import NodeManager from neo.node import NodeManager
from neo.connector import getConnectorHandler from neo.connector import getConnectorHandler
from neo.client.exception import NEOStorageError, NEOStorageConflictError from neo.client.exception import NEOStorageError
from neo.client.exception import NEOStorageNotFoundError, ConnectionClosed from neo.client.exception import NEOStorageNotFoundError, ConnectionClosed
from neo.exception import NeoException from neo.exception import NeoException
from neo.client.handlers import storage, master from neo.client.handlers import storage, master
...@@ -80,6 +81,9 @@ class ThreadContext(object): ...@@ -80,6 +81,9 @@ class ThreadContext(object):
'tid': None, 'tid': None,
'txn': None, 'txn': None,
'data_dict': {}, 'data_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'object_stored': 0, 'object_stored': 0,
'txn_voted': False, 'txn_voted': False,
'txn_finished': False, 'txn_finished': False,
...@@ -88,9 +92,7 @@ class ThreadContext(object): ...@@ -88,9 +92,7 @@ class ThreadContext(object):
'history': None, 'history': None,
'node_tids': {}, 'node_tids': {},
'node_ready': False, 'node_ready': False,
'conflict_serial': 0,
'asked_object': 0, 'asked_object': 0,
'object_stored_counter': 0,
} }
...@@ -534,7 +536,8 @@ class Application(object): ...@@ -534,7 +536,8 @@ class Application(object):
self.local_var.txn = transaction self.local_var.txn = transaction
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction,
tryToResolveConflict):
"""Store object.""" """Store object."""
if transaction is not self.local_var.txn: if transaction is not self.local_var.txn:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
...@@ -551,49 +554,100 @@ class Application(object): ...@@ -551,49 +554,100 @@ class Application(object):
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
p = Packets.AskStoreObject(oid, serial, 1, p = Packets.AskStoreObject(oid, serial, 1,
checksum, compressed_data, self.local_var.tid) checksum, compressed_data, self.local_var.tid)
# Store object in tmp cache
self.local_var.data_dict[oid] = data
# Store data on each node # Store data on each node
self.local_var.object_stored_counter = 0 self.local_var.object_stored_counter_dict[oid] = 0
self.local_var.object_serial_dict[oid] = (serial, version)
local_queue = self.local_var.queue
for cell in cell_list: for cell in cell_list:
conn = self.cp.getConnForCell(cell) conn = self.cp.getConnForCell(cell)
if conn is None: if conn is None:
continue continue
self.local_var.object_stored = 0
try: try:
self._askStorage(conn, p) try:
conn.ask(local_queue, p)
finally:
conn.unlock()
except ConnectionClosed: except ConnectionClosed:
continue continue
# Check we don't get any conflict self._waitAnyMessage(False)
if self.local_var.object_stored[0] == -1: return None
if self.local_var.data_dict.has_key(oid):
# One storage already accept the object, is it normal ??
# remove from dict and raise ConflictError, don't care of
# previous node which already store data as it would be
# resent again if conflict is resolved or txn will be
# aborted
del self.local_var.data_dict[oid]
self.local_var.conflict_serial = self.local_var.object_stored[1]
raise NEOStorageConflictError
# increase counter so that we know if a node has stored the object
# or not
self.local_var.object_stored_counter += 1
if self.local_var.object_stored_counter == 0:
# no storage nodes were available
raise NEOStorageError('tpc_store failed')
# Store object in tmp cache def _handleConflicts(self, tryToResolveConflict):
self.local_var.data_dict[oid] = data result = []
append = result.append
local_var = self.local_var
# Check for conflicts
data_dict = local_var.data_dict
object_serial_dict = local_var.object_serial_dict
for oid, conflict_serial in local_var.conflict_serial_dict.items():
serial, version = object_serial_dict[oid]
data = data_dict[oid]
tid = local_var.tid
resolved = False
if conflict_serial <= tid:
new_data = tryToResolveConflict(oid, conflict_serial, serial,
data)
if new_data is not None:
# Forget this conflict
del local_var.conflict_serial_dict[oid]
# Try to store again
self.store(oid, conflict_serial, new_data, version,
local_var.txn, tryToResolveConflict)
append(oid)
resolved = True
if not resolved:
# XXX: Is it really required to remove from data_dict ?
del data_dict[oid]
raise ConflictError(oid=oid,
serials=(tid, serial), data=data)
return result
return self.local_var.tid def waitStoreResponses(self, tryToResolveConflict):
result = []
append = result.append
resolved_oid_set = set()
update = resolved_oid_set.update
local_var = self.local_var
queue = self.local_var.queue
tid = local_var.tid
_waitAnyMessage = self._waitAnyMessage
_handleConflicts = self._handleConflicts
pending = self.dispatcher.pending
while True:
# Wait for all requests to be answered (or their connection to be
# dected as closed)
while pending(queue):
_waitAnyMessage()
conflicts = _handleConflicts(tryToResolveConflict)
if conflicts:
update(conflicts)
else:
# No more conflict resolutions to do, no more pending store
# requests
break
# Check for never-stored objects, and update result for all others
for oid, store_count in \
local_var.object_stored_counter_dict.iteritems():
if store_count == 0:
raise NEOStorageError('tpc_store failed')
elif oid in resolved_oid_set:
append((oid, ResolvedSerial))
else:
append((oid, tid))
return result
def tpc_vote(self, transaction): def tpc_vote(self, transaction, tryToResolveConflict):
"""Store current transaction.""" """Store current transaction."""
local_var = self.local_var local_var = self.local_var
if transaction is not local_var.txn: if transaction is not local_var.txn:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
result = self.waitStoreResponses(tryToResolveConflict)
tid = local_var.tid tid = local_var.tid
# Store data on each node # Store data on each node
voted_counter = 0 voted_counter = 0
...@@ -626,6 +680,8 @@ class Application(object): ...@@ -626,6 +680,8 @@ class Application(object):
# tpc_finish. # tpc_finish.
self._getMasterConnection() self._getMasterConnection()
return result
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Abort current transaction.""" """Abort current transaction."""
if transaction is not self.local_var.txn: if transaction is not self.local_var.txn:
...@@ -690,7 +746,7 @@ class Application(object): ...@@ -690,7 +746,7 @@ class Application(object):
finally: finally:
self._load_lock_release() self._load_lock_release()
def undo(self, transaction_id, txn, wrapper): def undo(self, transaction_id, txn, tryToResolveConflict):
if txn is not self.local_var.txn: if txn is not self.local_var.txn:
raise StorageTransactionError(self, transaction_id) raise StorageTransactionError(self, transaction_id)
...@@ -739,19 +795,9 @@ class Application(object): ...@@ -739,19 +795,9 @@ class Application(object):
# Third do transaction with old data # Third do transaction with old data
oid_list = data_dict.keys() oid_list = data_dict.keys()
for oid in oid_list: for oid in oid_list:
data = data_dict[oid] self.store(oid, transaction_id, data_dict[oid], None, txn,
try: tryToResolveConflict)
self.store(oid, transaction_id, data, None, txn) self.waitStoreResponses(tryToResolveConflict)
except NEOStorageConflictError, serial:
if serial <= self.local_var.tid:
new_data = wrapper.tryToResolveConflict(oid,
self.local_var.tid, serial, data)
if new_data is not None:
self.store(oid, self.local_var.tid, new_data, None, txn)
continue
raise ConflictError(oid = oid, serials = (self.local_var.tid,
serial),
data = data)
return self.local_var.tid, oid_list return self.local_var.tid, oid_list
def __undoLog(self, first, last, filter=None, block=0, with_oids=False): def __undoLog(self, first, last, filter=None, block=0, with_oids=False):
...@@ -930,9 +976,6 @@ class Application(object): ...@@ -930,9 +976,6 @@ class Application(object):
def getTID(self): def getTID(self):
return self.local_var.tid return self.local_var.tid
def getConflictSerial(self):
return self.local_var.conflict_serial
def setTransactionFinished(self): def setTransactionFinished(self):
self.local_var.txn_finished = True self.local_var.txn_finished = True
......
...@@ -23,8 +23,5 @@ class ConnectionClosed(Exception): ...@@ -23,8 +23,5 @@ class ConnectionClosed(Exception):
class NEOStorageError(POSException.StorageError): class NEOStorageError(POSException.StorageError):
pass pass
class NEOStorageConflictError(NEOStorageError):
pass
class NEOStorageNotFoundError(NEOStorageError): class NEOStorageNotFoundError(NEOStorageError):
pass pass
...@@ -63,10 +63,18 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -63,10 +63,18 @@ class StorageAnswersHandler(AnswerBaseHandler):
compression, checksum, data) compression, checksum, data)
def answerStoreObject(self, conn, conflicting, oid, serial): def answerStoreObject(self, conn, conflicting, oid, serial):
local_var = self.app.local_var
object_stored_counter_dict = local_var.object_stored_counter_dict
if conflicting: if conflicting:
self.app.local_var.object_stored = -1, serial assert object_stored_counter_dict[oid] == 0, \
object_stored_counter_dict[oid]
previous_conflict_serial = local_var.conflict_serial_dict.get(oid,
None)
assert previous_conflict_serial in (None, serial), \
(previous_conflict_serial, serial)
local_var.conflict_serial_dict[oid] = serial
else: else:
self.app.local_var.object_stored = oid, serial object_stored_counter_dict[oid] += 1
def answerStoreTransaction(self, conn, tid): def answerStoreTransaction(self, conn, tid):
if tid != self.app.getTID(): if tid != self.app.getTID():
......
...@@ -20,8 +20,7 @@ from mock import Mock, ReturnValues ...@@ -20,8 +20,7 @@ from mock import Mock, ReturnValues
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.client.app import Application from neo.client.app import Application
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError, \ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
NEOStorageConflictError
from neo import protocol from neo import protocol
from neo.protocol import Packets, INVALID_TID, INVALID_SERIAL from neo.protocol import Packets, INVALID_TID, INVALID_SERIAL
from neo.util import makeChecksum from neo.util import makeChecksum
...@@ -49,6 +48,12 @@ def _waitMessage(self, conn, msg_id, handler=None): ...@@ -49,6 +48,12 @@ def _waitMessage(self, conn, msg_id, handler=None):
else: else:
handler.dispatch(conn, conn.fakeReceived()) handler.dispatch(conn, conn.fakeReceived())
def resolving_tryToResolveConflict(oid, conflict_serial, serial, data):
return data
def failing_tryToResolveConflict(oid, conflict_serial, serial, data):
return None
class ClientApplicationTests(NeoTestBase): class ClientApplicationTests(NeoTestBase):
def setUp(self): def setUp(self):
...@@ -130,7 +135,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -130,7 +135,7 @@ class ClientApplicationTests(NeoTestBase):
cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', }) cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', })
app.pt = Mock({ 'getCellListForTID': (cell, cell, ) }) app.pt = Mock({ 'getCellListForTID': (cell, cell, ) })
app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn), }) app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn), })
app.tpc_vote(txn) app.tpc_vote(txn, resolving_tryToResolveConflict)
def askFinishTransaction(self, app): def askFinishTransaction(self, app):
txn = app.local_var.txn txn = app.local_var.txn
...@@ -399,14 +404,16 @@ class ClientApplicationTests(NeoTestBase): ...@@ -399,14 +404,16 @@ class ClientApplicationTests(NeoTestBase):
# invalid transaction > StorageTransactionError # invalid transaction > StorageTransactionError
app.local_var.txn = old_txn = object() app.local_var.txn = old_txn = object()
self.assertTrue(app.local_var.txn is not txn) self.assertTrue(app.local_var.txn is not txn)
self.assertRaises(StorageTransactionError, app.store, oid, tid, '', None, txn) self.assertRaises(StorageTransactionError, app.store, oid, tid, '',
None, txn, resolving_tryToResolveConflict)
self.assertEquals(app.local_var.txn, old_txn) self.assertEquals(app.local_var.txn, old_txn)
# check partition_id and an empty cell list -> NEOStorageError # check partition_id and an empty cell list -> NEOStorageError
app.local_var.txn = txn app.local_var.txn = txn
app.local_var.tid = tid app.local_var.tid = tid
app.pt = Mock({ 'getCellListForOID': (), }) app.pt = Mock({ 'getCellListForOID': (), })
app.num_partitions = 2 app.num_partitions = 2
self.assertRaises(NEOStorageError, app.store, oid, tid, '', None, txn) self.assertRaises(NEOStorageError, app.store, oid, tid, '', None,
txn, resolving_tryToResolveConflict)
calls = app.pt.mockGetNamedCalls('getCellListForOID') calls = app.pt.mockGetNamedCalls('getCellListForOID')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
self.assertEquals(calls[0].getParam(0), oid) # oid=11 self.assertEquals(calls[0].getParam(0), oid) # oid=11
...@@ -421,9 +428,10 @@ class ClientApplicationTests(NeoTestBase): ...@@ -421,9 +428,10 @@ class ClientApplicationTests(NeoTestBase):
app.local_var.tid = tid app.local_var.tid = tid
packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid)
packet.setId(0) packet.setId(0)
storage_address = ('127.0.0.1', 10020)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': packet, 'getAddress': storage_address,
}) })
cell = Mock({ cell = Mock({
'getAddress': 'FakeServer', 'getAddress': 'FakeServer',
...@@ -431,15 +439,21 @@ class ClientApplicationTests(NeoTestBase): ...@@ -431,15 +439,21 @@ class ClientApplicationTests(NeoTestBase):
}) })
app.pt = Mock({ 'getCellListForOID': (cell, cell, )}) app.pt = Mock({ 'getCellListForOID': (cell, cell, )})
app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn)}) app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn)})
app.dispatcher = Mock({}) class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
app.local_var.object_stored = (oid, tid) app.local_var.object_stored = (oid, tid)
app.local_var.data_dict[oid] = 'BEFORE' app.local_var.data_dict[oid] = 'BEFORE'
self.assertRaises(NEOStorageConflictError, app.store, oid, tid, '', None, txn) app.store(oid, tid, '', None, txn, failing_tryToResolveConflict)
app.local_var.queue.put((conn, packet))
self.assertRaises(ConflictError, app.waitStoreResponses,
failing_tryToResolveConflict)
self.assertTrue(oid not in app.local_var.data_dict) self.assertTrue(oid not in app.local_var.data_dict)
self.assertEquals(app.getConflictSerial(), tid) self.assertEquals(app.local_var.conflict_serial_dict[oid], tid)
self.assertEquals(app.local_var.object_stored, (-1, tid)) self.assertEquals(app.local_var.object_stored_counter_dict[oid], 0)
self.checkAskStoreObject(conn) self.checkAskStoreObject(conn)
self.checkDispatcherRegisterCalled(app, conn)
def test_store3(self): def test_store3(self):
app = self.getApp() app = self.getApp()
...@@ -451,9 +465,10 @@ class ClientApplicationTests(NeoTestBase): ...@@ -451,9 +465,10 @@ class ClientApplicationTests(NeoTestBase):
app.local_var.tid = tid app.local_var.tid = tid
packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
packet.setId(0) packet.setId(0)
storage_address = ('127.0.0.1', 10020)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': packet, 'getAddress': storage_address,
}) })
app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn, ) }) app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn, ) })
cell = Mock({ cell = Mock({
...@@ -461,15 +476,18 @@ class ClientApplicationTests(NeoTestBase): ...@@ -461,15 +476,18 @@ class ClientApplicationTests(NeoTestBase):
'getState': 'FakeState', 'getState': 'FakeState',
}) })
app.pt = Mock({ 'getCellListForOID': (cell, cell, ) }) app.pt = Mock({ 'getCellListForOID': (cell, cell, ) })
app.dispatcher = Mock({}) class Dispatcher(object):
app.conflict_serial = None # reset by hand def pending(self, queue):
app.local_var.object_stored = () return not queue.empty()
app.store(oid, tid, 'DATA', None, txn) app.dispatcher = Dispatcher()
self.assertEquals(app.local_var.object_stored, (oid, tid)) app.nm.createStorage(address=storage_address)
self.assertEquals(app.local_var.data_dict.get(oid, None), 'DATA') app.store(oid, tid, 'DATA', None, txn, resolving_tryToResolveConflict)
self.assertNotEquals(app.conflict_serial, tid)
self.checkAskStoreObject(conn) self.checkAskStoreObject(conn)
self.checkDispatcherRegisterCalled(app, conn) app.local_var.queue.put((conn, packet))
app.waitStoreResponses(resolving_tryToResolveConflict)
self.assertEquals(app.local_var.object_stored_counter_dict[oid], 1)
self.assertEquals(app.local_var.data_dict.get(oid, None), 'DATA')
self.assertFalse(oid in app.local_var.conflict_serial_dict)
def test_tpc_vote1(self): def test_tpc_vote1(self):
app = self.getApp() app = self.getApp()
...@@ -478,7 +496,8 @@ class ClientApplicationTests(NeoTestBase): ...@@ -478,7 +496,8 @@ class ClientApplicationTests(NeoTestBase):
# invalid transaction > StorageTransactionError # invalid transaction > StorageTransactionError
app.local_var.txn = old_txn = object() app.local_var.txn = old_txn = object()
self.assertTrue(app.local_var.txn is not txn) self.assertTrue(app.local_var.txn is not txn)
self.assertRaises(StorageTransactionError, app.tpc_vote, txn) self.assertRaises(StorageTransactionError, app.tpc_vote, txn,
resolving_tryToResolveConflict)
self.assertEquals(app.local_var.txn, old_txn) self.assertEquals(app.local_var.txn, old_txn)
def test_tpc_vote2(self): def test_tpc_vote2(self):
...@@ -504,7 +523,8 @@ class ClientApplicationTests(NeoTestBase): ...@@ -504,7 +523,8 @@ class ClientApplicationTests(NeoTestBase):
app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn), }) app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn), })
app.dispatcher = Mock() app.dispatcher = Mock()
app.tpc_begin(txn, tid) app.tpc_begin(txn, tid)
self.assertRaises(NEOStorageError, app.tpc_vote, txn) self.assertRaises(NEOStorageError, app.tpc_vote, txn,
resolving_tryToResolveConflict)
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
packet = calls[0].getParam(1) packet = calls[0].getParam(1)
...@@ -531,7 +551,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -531,7 +551,7 @@ class ClientApplicationTests(NeoTestBase):
app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn), }) app.cp = Mock({ 'getConnForCell': ReturnValues(None, conn), })
app.dispatcher = Mock() app.dispatcher = Mock()
app.tpc_begin(txn, tid) app.tpc_begin(txn, tid)
app.tpc_vote(txn) app.tpc_vote(txn, resolving_tryToResolveConflict)
self.checkAskStoreTransaction(conn) self.checkAskStoreTransaction(conn)
self.checkDispatcherRegisterCalled(app, conn) self.checkDispatcherRegisterCalled(app, conn)
...@@ -668,18 +688,21 @@ class ClientApplicationTests(NeoTestBase): ...@@ -668,18 +688,21 @@ class ClientApplicationTests(NeoTestBase):
app = self.getApp() app = self.getApp()
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
wrapper = Mock() marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data):
marker.append(1)
app.local_var.txn = old_txn = object() app.local_var.txn = old_txn = object()
app.master_conn = Mock() app.master_conn = Mock()
self.assertFalse(app.local_var.txn is txn) self.assertFalse(app.local_var.txn is txn)
conn = Mock() conn = Mock()
cell = Mock() cell = Mock()
self.assertRaises(StorageTransactionError, app.undo, tid, txn, wrapper) self.assertRaises(StorageTransactionError, app.undo, tid, txn,
tryToResolveConflict)
# no packet sent # no packet sent
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
# nothing done # nothing done
self.assertEquals(len(wrapper.mockGetNamedCalls('tryToResolveConflict')), 0) self.assertEquals(marker, [])
self.assertEquals(app.local_var.txn, old_txn) self.assertEquals(app.local_var.txn, old_txn)
def test_undo2(self): def test_undo2(self):
...@@ -724,13 +747,19 @@ class ClientApplicationTests(NeoTestBase): ...@@ -724,13 +747,19 @@ class ClientApplicationTests(NeoTestBase):
u4p2 = Packets.AnswerObject(oid2, tid3, tid3, 0, makeChecksum('O2V2'), 'O2V2') u4p2 = Packets.AnswerObject(oid2, tid3, tid3, 0, makeChecksum('O2V2'), 'O2V2')
u4p3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid2) u4p3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid2)
# test logic # test logic
packets = (u1p1, u1p2, u2p1, u2p2, u3p1, u3p2, u3p3, u3p1, u4p2, u4p3) packets = (u1p1, u1p2, u2p1, u2p2, u3p1, u3p2, u3p3, u4p1, u4p2, u4p3)
for i, p in enumerate(packets): for i, p in enumerate(packets):
p.setId(p) p.setId(p)
storage_address = ('127.0.0.1', 10010)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': ReturnValues(*packets), 'fakeReceived': ReturnValues(
'getAddress': ('127.0.0.1', 10010), u1p1, u1p2,
u2p1, u2p2,
u4p1, u4p2,
u3p1, u3p2,
),
'getAddress': storage_address,
}) })
cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', }) cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', })
app.pt = Mock({ app.pt = Mock({
...@@ -738,14 +767,27 @@ class ClientApplicationTests(NeoTestBase): ...@@ -738,14 +767,27 @@ class ClientApplicationTests(NeoTestBase):
'getCellListForOID': (cell, ), 'getCellListForOID': (cell, ),
}) })
app.cp = Mock({ 'getConnForCell': conn}) app.cp = Mock({ 'getConnForCell': conn})
wrapper = Mock({'tryToResolveConflict': None}) marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data):
marker.append(1)
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
txn4 = self.beginTransaction(app, tid=tid4) txn4 = self.beginTransaction(app, tid=tid4)
# all start here # all start here
self.assertRaises(UndoError, app.undo, tid1, txn4, wrapper) self.assertRaises(UndoError, app.undo, tid1, txn4,
self.assertRaises(UndoError, app.undo, tid2, txn4, wrapper) tryToResolveConflict)
self.assertRaises(ConflictError, app.undo, tid3, txn4, wrapper) self.assertRaises(UndoError, app.undo, tid2, txn4,
self.assertEquals(len(wrapper.mockGetNamedCalls('tryToResolveConflict')), 1) tryToResolveConflict)
self.assertEquals(app.undo(tid3, txn4, wrapper), (tid4, [oid2, ])) app.local_var.queue.put((conn, u4p3))
self.assertEquals(app.undo(tid3, txn4, tryToResolveConflict),
(tid4, [oid2, ]))
app.local_var.queue.put((conn, u3p3))
self.assertRaises(ConflictError, app.undo, tid3, txn4,
tryToResolveConflict)
self.assertEquals(marker, [1])
self.askFinishTransaction(app) self.askFinishTransaction(app)
def test_undoLog(self): def test_undoLog(self):
......
...@@ -86,13 +86,18 @@ class StorageAnswerHandlerTests(NeoTestBase): ...@@ -86,13 +86,18 @@ class StorageAnswerHandlerTests(NeoTestBase):
oid = self.getOID(0) oid = self.getOID(0)
tid = self.getNextTID() tid = self.getNextTID()
# conflict # conflict
self.app.local_var.object_stored = None local_var = self.app.local_var
local_var.object_stored_counter_dict = {oid: 0}
local_var.conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 1, oid, tid) self.handler.answerStoreObject(conn, 1, oid, tid)
self.assertEqual(self.app.local_var.object_stored, (-1, tid)) self.assertEqual(local_var.conflict_serial_dict[oid], tid)
self.assertFalse(local_var.object_stored_counter_dict[oid], 0)
# no conflict # no conflict
self.app.local_var.object_stored = None local_var.object_stored_counter_dict = {oid: 0}
local_var.conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 0, oid, tid) self.handler.answerStoreObject(conn, 0, oid, tid)
self.assertEqual(self.app.local_var.object_stored, (oid, tid)) self.assertFalse(oid in local_var.conflict_serial_dict)
self.assertEqual(local_var.object_stored_counter_dict[oid], 1)
def test_answerStoreTransaction(self): def test_answerStoreTransaction(self):
conn = self.getConnection() conn = self.getConnection()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment