Commit e5653d43 authored by Jim Fulton's avatar Jim Fulton Committed by GitHub

Merge pull request #38 from zopefoundation/client-side-conflict-resolution

Client side conflict resolution
parents bff1a14a 10d7f9ff
...@@ -34,6 +34,7 @@ import BTrees.OOBTree ...@@ -34,6 +34,7 @@ import BTrees.OOBTree
import zc.lockfile import zc.lockfile
import ZODB import ZODB
import ZODB.BaseStorage import ZODB.BaseStorage
import ZODB.ConflictResolution
import ZODB.interfaces import ZODB.interfaces
import zope.interface import zope.interface
import six import six
...@@ -75,7 +76,7 @@ def get_timestamp(prev_ts=None): ...@@ -75,7 +76,7 @@ def get_timestamp(prev_ts=None):
MB = 1024**2 MB = 1024**2
@zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage) @zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(object): class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage. """A storage class that is a network client to a remote storage.
This is a faithful implementation of the Storage API. This is a faithful implementation of the Storage API.
...@@ -331,6 +332,7 @@ class ClientStorage(object): ...@@ -331,6 +332,7 @@ class ClientStorage(object):
The storage isn't really ready to use until after this call. The storage isn't really ready to use until after this call.
""" """
super(ClientStorage, self).registerDB(db)
self._db = db self._db = db
def is_connected(self, test=False): def is_connected(self, test=False):
...@@ -722,13 +724,33 @@ class ClientStorage(object): ...@@ -722,13 +724,33 @@ class ClientStorage(object):
""" """
tbuf = self._check_trans(txn, 'tpc_vote') tbuf = self._check_trans(txn, 'tpc_vote')
try: try:
conflicts = True
vote_attempts = 0
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
conflicts = False
for oid in self._call('vote', id(txn)) or (): for oid in self._call('vote', id(txn)) or ():
if isinstance(oid, dict):
# Conflict, let's try to resolve it
conflicts = True
conflict = oid
oid = conflict['oid']
committed, read = conflict['serials']
data = self.tryToResolveConflict(
oid, committed, read, conflict['data'])
self._async('storea', oid, committed, data, id(txn))
tbuf.resolve(oid, data)
else:
tbuf.serial(oid, ResolvedSerial) tbuf.serial(oid, ResolvedSerial)
vote_attempts += 1
except POSException.StorageTransactionError: except POSException.StorageTransactionError:
# Hm, we got disconnected and reconnected bwtween # Hm, we got disconnected and reconnected bwtween
# _check_trans and voting. Let's chack the transaction again: # _check_trans and voting. Let's chack the transaction again:
self._check_trans(txn, 'tpc_vote') self._check_trans(txn, 'tpc_vote')
raise raise
except POSException.ConflictError as err: except POSException.ConflictError as err:
oid = getattr(err, 'oid', None) oid = getattr(err, 'oid', None)
if oid is not None: if oid is not None:
...@@ -745,8 +767,8 @@ class ClientStorage(object): ...@@ -745,8 +767,8 @@ class ClientStorage(object):
if tbuf.exception: if tbuf.exception:
raise tbuf.exception raise tbuf.exception
if tbuf.resolved: if tbuf.server_resolved or tbuf.client_resolved:
return list(tbuf.resolved) return list(tbuf.server_resolved) + list(tbuf.client_resolved)
else: else:
return None return None
......
...@@ -89,6 +89,7 @@ class ZEOStorage: ...@@ -89,6 +89,7 @@ class ZEOStorage:
def __init__(self, server, read_only=0): def __init__(self, server, read_only=0):
self.server = server self.server = server
self.client_conflict_resolution = server.client_conflict_resolution
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.read_only = read_only self.read_only = read_only
# The authentication protocol may define extra methods. # The authentication protocol may define extra methods.
...@@ -334,6 +335,7 @@ class ZEOStorage: ...@@ -334,6 +335,7 @@ class ZEOStorage:
t._extension = ext t._extension = ext
self.serials = [] self.serials = []
self.conflicts = {}
self.invalidated = [] self.invalidated = []
self.txnlog = CommitLog() self.txnlog = CommitLog()
self.blob_log = [] self.blob_log = []
...@@ -413,6 +415,7 @@ class ZEOStorage: ...@@ -413,6 +415,7 @@ class ZEOStorage:
self.locked, delay = self.server.lock_storage(self, delay) self.locked, delay = self.server.lock_storage(self, delay)
if self.locked: if self.locked:
result = None
try: try:
self.log( self.log(
"Preparing to commit transaction: %d objects, %d bytes" "Preparing to commit transaction: %d objects, %d bytes"
...@@ -433,13 +436,29 @@ class ZEOStorage: ...@@ -433,13 +436,29 @@ class ZEOStorage:
oid, oldserial, data, blobfilename = self.blob_log.pop() oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename) self._store(oid, oldserial, data, blobfilename)
if not self.conflicts:
try:
serials = self.storage.tpc_vote(self.transaction) serials = self.storage.tpc_vote(self.transaction)
except ConflictError as err:
if (self.client_conflict_resolution and
err.oid and err.serials and err.data
):
self.conflicts[err.oid] = dict(
oid=err.oid, serials=err.serials, data=err.data)
else:
raise
else:
if serials: if serials:
if not isinstance(serials[0], bytes):
serials = (oid for (oid, serial) in serials
if serial == ResolvedSerial)
self.serials.extend(serials) self.serials.extend(serials)
result = self.serials
if self.conflicts:
result = list(self.conflicts.values())
self.storage.tpc_abort(self.transaction)
self.server.unlock_storage(self)
self.locked = False
self.server.stop_waiting(self)
except Exception as err: except Exception as err:
self.storage.tpc_abort(self.transaction) self.storage.tpc_abort(self.transaction)
...@@ -457,9 +476,9 @@ class ZEOStorage: ...@@ -457,9 +476,9 @@ class ZEOStorage:
raise raise
else: else:
if delay is not None: if delay is not None:
delay.reply(self.serials) delay.reply(result)
else: else:
return self.serials return result
else: else:
return delay return delay
...@@ -559,29 +578,25 @@ class ZEOStorage: ...@@ -559,29 +578,25 @@ class ZEOStorage:
oid, serial, self.transaction) oid, serial, self.transaction)
def _store(self, oid, serial, data, blobfile=None): def _store(self, oid, serial, data, blobfile=None):
try:
if blobfile is None: if blobfile is None:
newserial = self.storage.store( self.storage.store(oid, serial, data, '', self.transaction)
oid, serial, data, '', self.transaction)
else: else:
newserial = self.storage.storeBlob( self.storage.storeBlob(
oid, serial, data, blobfile, '', self.transaction) oid, serial, data, blobfile, '', self.transaction)
except ConflictError as err:
if self.client_conflict_resolution and err.serials:
self.conflicts[oid] = dict(
oid=oid, serials=err.serials, data=data)
else:
raise
else:
if oid in self.conflicts:
del self.conflicts[oid]
if serial != b"\0\0\0\0\0\0\0\0": if serial != b"\0\0\0\0\0\0\0\0":
self.invalidated.append(oid) self.invalidated.append(oid)
if newserial:
if isinstance(newserial, bytes):
newserial = [(oid, newserial)]
for oid, s in newserial:
if s == ResolvedSerial:
self.stats.conflicts_resolved += 1
self.log("conflict resolved oid=%s"
% oid_repr(oid), BLATHER)
self.serials.append(oid)
def _restore(self, oid, serial, data, prev_txn): def _restore(self, oid, serial, data, prev_txn):
self.storage.restore(oid, serial, data, '', prev_txn, self.storage.restore(oid, serial, data, '', prev_txn,
self.transaction) self.transaction)
...@@ -697,6 +712,7 @@ class StorageServer: ...@@ -697,6 +712,7 @@ class StorageServer:
invalidation_age=None, invalidation_age=None,
transaction_timeout=None, transaction_timeout=None,
ssl=None, ssl=None,
client_conflict_resolution=False,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -767,8 +783,15 @@ class StorageServer: ...@@ -767,8 +783,15 @@ class StorageServer:
for name, storage in storages.items(): for name, storage in storages.items():
self._setup_invq(name, storage) self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name)) storage.registerDB(StorageServerDB(self, name))
if client_conflict_resolution:
# XXX this may go away later, when storages grow
# configuration for this.
storage.tryToResolveConflict = never_resolve_conflict
self.invalidation_age = invalidation_age self.invalidation_age = invalidation_age
self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]} self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.client_conflict_resolution = client_conflict_resolution
if addr is not None:
self.acceptor = Acceptor(self, addr, ssl) self.acceptor = Acceptor(self, addr, ssl)
if isinstance(addr, tuple) and addr[0]: if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr self.addr = self.acceptor.addr
...@@ -776,6 +799,7 @@ class StorageServer: ...@@ -776,6 +799,7 @@ class StorageServer:
self.addr = addr self.addr = addr
self.loop = self.acceptor.loop self.loop = self.acceptor.loop
ZODB.event.notify(Serving(self, address=self.acceptor.addr)) ZODB.event.notify(Serving(self, address=self.acceptor.addr))
self.stats = {} self.stats = {}
self.timeouts = {} self.timeouts = {}
for name in self.storages.keys(): for name in self.storages.keys():
...@@ -1308,3 +1332,8 @@ class Serving(ServerEvent): ...@@ -1308,3 +1332,8 @@ class Serving(ServerEvent):
class Closed(ServerEvent): class Closed(ServerEvent):
pass pass
def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle,
committedData=b''):
raise ConflictError(oid=oid, serials=(committedSerial, oldSerial),
data=newpickle)
...@@ -46,7 +46,8 @@ class TransactionBuffer: ...@@ -46,7 +46,8 @@ class TransactionBuffer:
# stored are builtin types -- strings or None. # stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1) self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1 self.pickler.fast = 1
self.resolved = set() # {oid} self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.exception = None self.exception = None
def close(self): def close(self):
...@@ -59,11 +60,17 @@ class TransactionBuffer: ...@@ -59,11 +60,17 @@ class TransactionBuffer:
# Estimate per-record cache size # Estimate per-record cache size
self.size = self.size + (data and len(data) or 0) + 31 self.size = self.size + (data and len(data) or 0) + 31
def resolve(self, oid, data):
"""Record client-resolved data
"""
self.store(oid, data)
self.client_resolved[oid] = self.count - 1
def serial(self, oid, serial): def serial(self, oid, serial):
if isinstance(serial, Exception): if isinstance(serial, Exception):
self.exception = serial # This transaction will never be committed self.exception = serial # This transaction will never be committed
elif serial == ResolvedSerial: elif serial == ResolvedSerial:
self.resolved.add(oid) self.server_resolved.add(oid)
def storeBlob(self, oid, blobfilename): def storeBlob(self, oid, blobfilename):
self.blobs.append((oid, blobfilename)) self.blobs.append((oid, blobfilename))
...@@ -71,7 +78,8 @@ class TransactionBuffer: ...@@ -71,7 +78,8 @@ class TransactionBuffer:
def __iter__(self): def __iter__(self):
self.file.seek(0) self.file.seek(0)
unpickler = Unpickler(self.file) unpickler = Unpickler(self.file)
resolved = self.resolved server_resolved = self.server_resolved
client_resolved = self.client_resolved
# Gaaaa, this is awkward. There can be entries in serials that # Gaaaa, this is awkward. There can be entries in serials that
# aren't in the buffer, because undo. Entries can be repeated # aren't in the buffer, because undo. Entries can be repeated
...@@ -81,10 +89,11 @@ class TransactionBuffer: ...@@ -81,10 +89,11 @@ class TransactionBuffer:
seen = set() seen = set()
for i in range(self.count): for i in range(self.count):
oid, data = unpickler.load() oid, data = unpickler.load()
if client_resolved.get(oid, i) == i:
seen.add(oid) seen.add(oid)
yield oid, data, oid in resolved yield oid, data, oid in server_resolved
# We may have leftover oids because undo # We may have leftover oids because undo
for oid in resolved: for oid in server_resolved:
if oid not in seen: if oid not in seen:
yield oid, None, True yield oid, None, True
...@@ -98,6 +98,9 @@ class ZEOOptionsMixin: ...@@ -98,6 +98,9 @@ class ZEOOptionsMixin:
self.add("address", "zeo.address.address", self.add("address", "zeo.address.address",
required="no server address specified; use -a or -C") required="no server address specified; use -a or -C")
self.add("read_only", "zeo.read_only", default=0) self.add("read_only", "zeo.read_only", default=0)
self.add("client_conflict_resolution",
"zeo.client_conflict_resolution",
default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size", self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100) default=100)
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
...@@ -339,6 +342,7 @@ def create_server(storages, options): ...@@ -339,6 +342,7 @@ def create_server(storages, options):
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only = options.read_only,
client_conflict_resolution=options.client_conflict_resolution,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
......
...@@ -107,6 +107,14 @@ ...@@ -107,6 +107,14 @@
<metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault> <metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault>
</key> </key>
<key name="client-conflict-resolution" datatype="boolean"
required="no" default="false">
<description>
Flag indicating whether the server should return conflict
errors to the client, for resolution there.
</description>
</key>
</sectiontype> </sectiontype>
</component> </component>
...@@ -30,6 +30,8 @@ class DummyDB: ...@@ -30,6 +30,8 @@ class DummyDB:
def invalidate(self, *args, **kwargs): def invalidate(self, *args, **kwargs):
pass pass
transform_record_data = untransform_record_data = lambda self, data: data
class WorkerThread(TestThread): class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for # run the entire test in a thread so that the blocking call for
......
...@@ -59,6 +59,9 @@ class DummyDB: ...@@ -59,6 +59,9 @@ class DummyDB:
def invalidateCache(self): def invalidateCache(self):
pass pass
transform_record_data = untransform_record_data = lambda self, data: data
class CommonSetupTearDown(StorageTestBase): class CommonSetupTearDown(StorageTestBase):
"""Common boilerplate""" """Common boilerplate"""
......
...@@ -33,7 +33,7 @@ logger = logging.getLogger('ZEO.tests.forker') ...@@ -33,7 +33,7 @@ logger = logging.getLogger('ZEO.tests.forker')
class ZEOConfig: class ZEOConfig:
"""Class to generate ZEO configuration file. """ """Class to generate ZEO configuration file. """
def __init__(self, addr): def __init__(self, addr, **options):
if isinstance(addr, str): if isinstance(addr, str):
self.logpath = addr+'.log' self.logpath = addr+'.log'
else: else:
...@@ -42,6 +42,7 @@ class ZEOConfig: ...@@ -42,6 +42,7 @@ class ZEOConfig:
self.address = addr self.address = addr
self.read_only = None self.read_only = None
self.loglevel = 'INFO' self.loglevel = 'INFO'
self.__dict__.update(options)
def dump(self, f): def dump(self, f):
print("<zeo>", file=f) print("<zeo>", file=f)
...@@ -52,7 +53,7 @@ class ZEOConfig: ...@@ -52,7 +53,7 @@ class ZEOConfig:
for name in ( for name in (
'invalidation_queue_size', 'invalidation_age', 'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'transaction_timeout', 'pid_filename',
'ssl_certificate', 'ssl_key', 'ssl_certificate', 'ssl_key', 'client_conflict_resolution',
): ):
v = getattr(self, name, None) v = getattr(self, name, None)
if v: if v:
...@@ -159,7 +160,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None): ...@@ -159,7 +160,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
# The runner thread didn't stop. If it was a process, # The runner thread didn't stop. If it was a process,
# give it some time to exit # give it some time to exit
if hasattr(thread, 'pid') and thread.pid: if hasattr(thread, 'pid') and thread.pid:
os.waitpid(thread.pid) os.waitpid(thread.pid, 0)
else: else:
# Gaaaa, force gc in hopes of maybe getting the unclosed # Gaaaa, force gc in hopes of maybe getting the unclosed
# sockets to get GCed # sockets to get GCed
......
...@@ -52,6 +52,8 @@ class FakeServer: ...@@ -52,6 +52,8 @@ class FakeServer:
def register_connection(*args): def register_connection(*args):
return None, None return None, None
client_conflict_resolution = False
class FakeConnection: class FakeConnection:
protocol_version = b'Z4' protocol_version = b'Z4'
addr = 'test' addr = 'test'
......
...@@ -143,23 +143,9 @@ class MiscZEOTests: ...@@ -143,23 +143,9 @@ class MiscZEOTests:
self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction()) self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction())
storage3.close() storage3.close()
class GenericTests( class GenericTestBase(
# Base class for all ZODB tests # Base class for all ZODB tests
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase):
# ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage,
PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage,
MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
"""Combine tests from various origins in one class."""
shared_blob_dir = False shared_blob_dir = False
blob_cache_dir = None blob_cache_dir = None
...@@ -200,14 +186,23 @@ class GenericTests( ...@@ -200,14 +186,23 @@ class GenericTests(
stop() stop()
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def runTest(self): class GenericTests(
try: GenericTestBase,
super(GenericTests, self).runTest()
except: # ZODB test mixin classes (in the same order as imported)
self._failed = True BasicStorage.BasicStorage,
raise PackableStorage.PackableStorage,
else: Synchronization.SynchronizedStorage,
self._failed = False MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
"""Combine tests from various origins in one class.
"""
def open(self, read_only=0): def open(self, read_only=0):
# Needed to support ReadOnlyStorage tests. Ought to be a # Needed to support ReadOnlyStorage tests. Ought to be a
...@@ -394,7 +389,16 @@ class FileStorageClientHexTests(FileStorageHexTests): ...@@ -394,7 +389,16 @@ class FileStorageClientHexTests(FileStorageHexTests):
def _wrap_client(self, client): def _wrap_client(self, client):
return ZODB.tests.hexstorage.HexStorage(client) return ZODB.tests.hexstorage.HexStorage(client)
class ClientConflictResolutionTests(
GenericTestBase,
ConflictResolution.ConflictResolvingStorage,
):
def getConfig(self):
return '<mappingstorage>\n</mappingstorage>\n'
def getZEOConfig(self):
return forker.ZEOConfig(('', 0), client_conflict_resolution=True)
class MappingStorageTests(GenericTests): class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage.""" """ZEO backed by a Mapping storage."""
...@@ -492,6 +496,8 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): ...@@ -492,6 +496,8 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
self._invalidatedCache += 1 self._invalidatedCache += 1
def invalidate(*a, **k): def invalidate(*a, **k):
pass pass
transform_record_data = untransform_record_data = \
lambda self, data: data
db = DummyDB() db = DummyDB()
storage.registerDB(db) storage.registerDB(db)
...@@ -936,14 +942,14 @@ def tpc_finish_error(): ...@@ -936,14 +942,14 @@ def tpc_finish_error():
buffer, sadly, using implementation details: buffer, sadly, using implementation details:
>>> tbuf = t.data(client) >>> tbuf = t.data(client)
>>> tbuf.resolved = None >>> tbuf.client_resolved = None
tpc_finish will fail: tpc_finish will fail:
>>> client.tpc_finish(t) # doctest: +ELLIPSIS >>> client.tpc_finish(t) # doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
... ...
TypeError: ... AttributeError: ...
>>> client.tpc_abort(t) >>> client.tpc_abort(t)
>>> t.abort() >>> t.abort()
...@@ -1595,6 +1601,7 @@ def test_suite(): ...@@ -1595,6 +1601,7 @@ def test_suite():
"ClientDisconnected"), "ClientDisconnected"),
)), )),
)) ))
zeo.addTest(unittest.makeSuite(ClientConflictResolutionTests, 'check'))
zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc') zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc')
suite.addTest(zeo) suite.addTest(zeo)
......
import unittest
import zope.testing.setupstack
from BTrees.Length import Length
from ZODB import serialize
from ZODB.DemoStorage import DemoStorage
from ZODB.utils import p64, z64, maxtid
from ZODB.broken import find_global
import ZEO
from .utils import StorageServer
class Var(object):
def __eq__(self, other):
self.value = other
return True
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
def test_server_side(self):
# First, verify default conflict resolution.
server = StorageServer(self, DemoStorage())
zs = server.zs
reader = serialize.ObjectReader(
factory=lambda conn, *args: find_global(*args))
writer = serialize.ObjectWriter()
ob = Length(0)
ob._p_oid = z64
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
# Now, a cnflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns the object id, indicating that a conflict was resolved.
self.assertEqual(zs.vote(3), [ob._p_oid])
tid3 = server.unpack_result(zs.tpc_finish(3))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 2)
# Now, we'll create a server that expects the client to
# resolve conflicts:
server = StorageServer(
self, DemoStorage(), client_conflict_resolution=True)
zs = server.zs
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
# Now, a conflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns an object, indicating that a conflict was not resolved.
self.assertEqual(
zs.vote(3),
[dict(oid=ob._p_oid,
serials=(tid2, tid1),
data=writer.serialize(ob),
)],
)
# Now, it's up to the client to resolve the conflict. It can
# do this by making another store call. In this call, we use
# tid2 as the starting tid:
ob.change(1)
zs.storea(ob._p_oid, tid2, writer.serialize(ob), 3)
self.assertEqual(zs.vote(3), [])
tid3 = server.unpack_result(zs.tpc_finish(3))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 3)
def test_client_side(self):
# First, traditional:
addr, stop = ZEO.server('data.fs')
db = ZEO.DB(addr)
with db.transaction() as conn:
conn.root.l = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
db.close(); stop()
# Now, do conflict resolution on the client.
addr2, stop = ZEO.server(
storage_conf='<mappingstorage>\n</mappingstorage>\n',
zeo_conf=dict(client_conflict_resolution=True),
)
db = ZEO.DB(addr2)
with db.transaction() as conn:
conn.root.l = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
self.assertEqual(conn2.root.l.value, 1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
db.close(); stop()
def test_suite():
return unittest.makeSuite(ClientSideConflictResolutionTests)
"""Testing helpers
"""
import ZEO.StorageServer
from ..asyncio.server import best_protocol_version
class ServerProtocol:
method = ('register', )
def __init__(self, zs,
protocol_version=best_protocol_version,
addr='test-address'):
self.calls = []
self.addr = addr
self.zs = zs
self.protocol_version = protocol_version
zs.notify_connected(self)
closed = False
def close(self):
if not self.closed:
self.closed = True
self.zs.notify_disconnected()
def call_soon_threadsafe(self, func, *args):
func(*args)
def async(self, *args):
self.calls.append(args)
class StorageServer:
"""Create a client interface to a StorageServer.
This is for testing StorageServer. It interacts with the storgr
server through its network interface, but without creating a
network connection.
"""
def __init__(self, test, storage,
protocol_version=best_protocol_version,
**kw):
self.test = test
self.storage_server = ZEO.StorageServer.StorageServer(
None, {'1': storage}, **kw)
self.zs = self.storage_server.create_client_handler()
self.protocol = ServerProtocol(self.zs,
protocol_version=protocol_version)
self.zs.register('1', kw.get('read_only', False))
def assert_calls(self, test, *argss):
if argss:
for args in argss:
test.assertEqual(self.protocol.calls.pop(0), args)
else:
test.assertEqual(self.protocol.calls, ())
def unpack_result(self, result):
"""For methods that return Result objects, unwrap the results
"""
result, callback = result.args
callback()
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment