Commit c8873c33 authored by Jim Fulton's avatar Jim Fulton

Merge remote-tracking branch 'origin/asyncio' into py2

Conflicts:
	setup.py
	src/ZEO/asyncio/base.py
parents 6fbe47eb 29bf7bc0
...@@ -41,6 +41,9 @@ tests_require = ['zope.testing', 'manuel', 'random2', 'mock'] ...@@ -41,6 +41,9 @@ tests_require = ['zope.testing', 'manuel', 'random2', 'mock']
if sys.version_info[:2] < (3, ): if sys.version_info[:2] < (3, ):
install_requires.extend(('futures', 'trollius')) install_requires.extend(('futures', 'trollius'))
elif sys.version_info >= (3, 5):
install_requires.append('uvloop')
classifiers = """\ classifiers = """\
Intended Audience :: Developers Intended Audience :: Developers
License :: OSI Approved :: Zope Public License License :: OSI Approved :: Zope Public License
......
...@@ -34,6 +34,7 @@ import BTrees.OOBTree ...@@ -34,6 +34,7 @@ import BTrees.OOBTree
import zc.lockfile import zc.lockfile
import ZODB import ZODB
import ZODB.BaseStorage import ZODB.BaseStorage
import ZODB.ConflictResolution
import ZODB.interfaces import ZODB.interfaces
import zope.interface import zope.interface
import six import six
...@@ -75,7 +76,7 @@ def get_timestamp(prev_ts=None): ...@@ -75,7 +76,7 @@ def get_timestamp(prev_ts=None):
MB = 1024**2 MB = 1024**2
@zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage) @zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(object): class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage. """A storage class that is a network client to a remote storage.
This is a faithful implementation of the Storage API. This is a faithful implementation of the Storage API.
...@@ -331,6 +332,7 @@ class ClientStorage(object): ...@@ -331,6 +332,7 @@ class ClientStorage(object):
The storage isn't really ready to use until after this call. The storage isn't really ready to use until after this call.
""" """
super(ClientStorage, self).registerDB(db)
self._db = db self._db = db
def is_connected(self, test=False): def is_connected(self, test=False):
...@@ -722,13 +724,33 @@ class ClientStorage(object): ...@@ -722,13 +724,33 @@ class ClientStorage(object):
""" """
tbuf = self._check_trans(txn, 'tpc_vote') tbuf = self._check_trans(txn, 'tpc_vote')
try: try:
for oid in self._call('vote', id(txn)) or ():
tbuf.serial(oid, ResolvedSerial) conflicts = True
vote_attempts = 0
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
conflicts = False
for oid in self._call('vote', id(txn)) or ():
if isinstance(oid, dict):
# Conflict, let's try to resolve it
conflicts = True
conflict = oid
oid = conflict['oid']
committed, read = conflict['serials']
data = self.tryToResolveConflict(
oid, committed, read, conflict['data'])
self._async('storea', oid, committed, data, id(txn))
tbuf.resolve(oid, data)
else:
tbuf.serial(oid, ResolvedSerial)
vote_attempts += 1
except POSException.StorageTransactionError: except POSException.StorageTransactionError:
# Hm, we got disconnected and reconnected bwtween # Hm, we got disconnected and reconnected bwtween
# _check_trans and voting. Let's chack the transaction again: # _check_trans and voting. Let's chack the transaction again:
self._check_trans(txn, 'tpc_vote') self._check_trans(txn, 'tpc_vote')
raise raise
except POSException.ConflictError as err: except POSException.ConflictError as err:
oid = getattr(err, 'oid', None) oid = getattr(err, 'oid', None)
if oid is not None: if oid is not None:
...@@ -745,8 +767,8 @@ class ClientStorage(object): ...@@ -745,8 +767,8 @@ class ClientStorage(object):
if tbuf.exception: if tbuf.exception:
raise tbuf.exception raise tbuf.exception
if tbuf.resolved: if tbuf.server_resolved or tbuf.client_resolved:
return list(tbuf.resolved) return list(tbuf.server_resolved) + list(tbuf.client_resolved)
else: else:
return None return None
......
...@@ -88,6 +88,7 @@ class ZEOStorage: ...@@ -88,6 +88,7 @@ class ZEOStorage:
def __init__(self, server, read_only=0): def __init__(self, server, read_only=0):
self.server = server self.server = server
self.client_conflict_resolution = server.client_conflict_resolution
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.read_only = read_only self.read_only = read_only
# The authentication protocol may define extra methods. # The authentication protocol may define extra methods.
...@@ -333,6 +334,7 @@ class ZEOStorage: ...@@ -333,6 +334,7 @@ class ZEOStorage:
t._extension = ext t._extension = ext
self.serials = [] self.serials = []
self.conflicts = {}
self.invalidated = [] self.invalidated = []
self.txnlog = CommitLog() self.txnlog = CommitLog()
self.blob_log = [] self.blob_log = []
...@@ -412,6 +414,7 @@ class ZEOStorage: ...@@ -412,6 +414,7 @@ class ZEOStorage:
self.locked, delay = self.server.lock_storage(self, delay) self.locked, delay = self.server.lock_storage(self, delay)
if self.locked: if self.locked:
result = None
try: try:
self.log( self.log(
"Preparing to commit transaction: %d objects, %d bytes" "Preparing to commit transaction: %d objects, %d bytes"
...@@ -432,13 +435,29 @@ class ZEOStorage: ...@@ -432,13 +435,29 @@ class ZEOStorage:
oid, oldserial, data, blobfilename = self.blob_log.pop() oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename) self._store(oid, oldserial, data, blobfilename)
serials = self.storage.tpc_vote(self.transaction)
if serials:
if not isinstance(serials[0], bytes):
serials = (oid for (oid, serial) in serials
if serial == ResolvedSerial)
self.serials.extend(serials) if not self.conflicts:
try:
serials = self.storage.tpc_vote(self.transaction)
except ConflictError as err:
if (self.client_conflict_resolution and
err.oid and err.serials and err.data
):
self.conflicts[err.oid] = dict(
oid=err.oid, serials=err.serials, data=err.data)
else:
raise
else:
if serials:
self.serials.extend(serials)
result = self.serials
if self.conflicts:
result = list(self.conflicts.values())
self.storage.tpc_abort(self.transaction)
self.server.unlock_storage(self)
self.locked = False
self.server.stop_waiting(self)
except Exception as err: except Exception as err:
self.storage.tpc_abort(self.transaction) self.storage.tpc_abort(self.transaction)
...@@ -456,9 +475,9 @@ class ZEOStorage: ...@@ -456,9 +475,9 @@ class ZEOStorage:
raise raise
else: else:
if delay is not None: if delay is not None:
delay.reply(self.serials) delay.reply(result)
else: else:
return self.serials return result
else: else:
return delay return delay
...@@ -558,28 +577,24 @@ class ZEOStorage: ...@@ -558,28 +577,24 @@ class ZEOStorage:
oid, serial, self.transaction) oid, serial, self.transaction)
def _store(self, oid, serial, data, blobfile=None): def _store(self, oid, serial, data, blobfile=None):
if blobfile is None: try:
newserial = self.storage.store( if blobfile is None:
oid, serial, data, '', self.transaction) self.storage.store(oid, serial, data, '', self.transaction)
else:
self.storage.storeBlob(
oid, serial, data, blobfile, '', self.transaction)
except ConflictError as err:
if self.client_conflict_resolution and err.serials:
self.conflicts[oid] = dict(
oid=oid, serials=err.serials, data=data)
else:
raise
else: else:
newserial = self.storage.storeBlob( if oid in self.conflicts:
oid, serial, data, blobfile, '', self.transaction) del self.conflicts[oid]
if serial != b"\0\0\0\0\0\0\0\0":
self.invalidated.append(oid)
if newserial:
if isinstance(newserial, bytes): if serial != b"\0\0\0\0\0\0\0\0":
newserial = [(oid, newserial)] self.invalidated.append(oid)
for oid, s in newserial:
if s == ResolvedSerial:
self.stats.conflicts_resolved += 1
self.log("conflict resolved oid=%s"
% oid_repr(oid), BLATHER)
self.serials.append(oid)
def _restore(self, oid, serial, data, prev_txn): def _restore(self, oid, serial, data, prev_txn):
self.storage.restore(oid, serial, data, '', prev_txn, self.storage.restore(oid, serial, data, '', prev_txn,
...@@ -696,6 +711,7 @@ class StorageServer: ...@@ -696,6 +711,7 @@ class StorageServer:
invalidation_age=None, invalidation_age=None,
transaction_timeout=None, transaction_timeout=None,
ssl=None, ssl=None,
client_conflict_resolution=False,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -766,15 +782,23 @@ class StorageServer: ...@@ -766,15 +782,23 @@ class StorageServer:
for name, storage in storages.items(): for name, storage in storages.items():
self._setup_invq(name, storage) self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name)) storage.registerDB(StorageServerDB(self, name))
if client_conflict_resolution:
# XXX this may go away later, when storages grow
# configuration for this.
storage.tryToResolveConflict = never_resolve_conflict
self.invalidation_age = invalidation_age self.invalidation_age = invalidation_age
self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]} self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.acceptor = Acceptor(self, addr, ssl) self.client_conflict_resolution = client_conflict_resolution
if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr if addr is not None:
else: self.acceptor = Acceptor(self, addr, ssl)
self.addr = addr if isinstance(addr, tuple) and addr[0]:
self.loop = self.acceptor.loop self.addr = self.acceptor.addr
ZODB.event.notify(Serving(self, address=self.acceptor.addr)) else:
self.addr = addr
self.loop = self.acceptor.loop
ZODB.event.notify(Serving(self, address=self.acceptor.addr))
self.stats = {} self.stats = {}
self.timeouts = {} self.timeouts = {}
for name in self.storages.keys(): for name in self.storages.keys():
...@@ -1307,3 +1331,8 @@ class Serving(ServerEvent): ...@@ -1307,3 +1331,8 @@ class Serving(ServerEvent):
class Closed(ServerEvent): class Closed(ServerEvent):
pass pass
def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle,
committedData=b''):
raise ConflictError(oid=oid, serials=(committedSerial, oldSerial),
data=newpickle)
...@@ -46,7 +46,8 @@ class TransactionBuffer: ...@@ -46,7 +46,8 @@ class TransactionBuffer:
# stored are builtin types -- strings or None. # stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1) self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1 self.pickler.fast = 1
self.resolved = set() # {oid} self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.exception = None self.exception = None
def close(self): def close(self):
...@@ -59,11 +60,17 @@ class TransactionBuffer: ...@@ -59,11 +60,17 @@ class TransactionBuffer:
# Estimate per-record cache size # Estimate per-record cache size
self.size = self.size + (data and len(data) or 0) + 31 self.size = self.size + (data and len(data) or 0) + 31
def resolve(self, oid, data):
"""Record client-resolved data
"""
self.store(oid, data)
self.client_resolved[oid] = self.count - 1
def serial(self, oid, serial): def serial(self, oid, serial):
if isinstance(serial, Exception): if isinstance(serial, Exception):
self.exception = serial # This transaction will never be committed self.exception = serial # This transaction will never be committed
elif serial == ResolvedSerial: elif serial == ResolvedSerial:
self.resolved.add(oid) self.server_resolved.add(oid)
def storeBlob(self, oid, blobfilename): def storeBlob(self, oid, blobfilename):
self.blobs.append((oid, blobfilename)) self.blobs.append((oid, blobfilename))
...@@ -71,7 +78,8 @@ class TransactionBuffer: ...@@ -71,7 +78,8 @@ class TransactionBuffer:
def __iter__(self): def __iter__(self):
self.file.seek(0) self.file.seek(0)
unpickler = Unpickler(self.file) unpickler = Unpickler(self.file)
resolved = self.resolved server_resolved = self.server_resolved
client_resolved = self.client_resolved
# Gaaaa, this is awkward. There can be entries in serials that # Gaaaa, this is awkward. There can be entries in serials that
# aren't in the buffer, because undo. Entries can be repeated # aren't in the buffer, because undo. Entries can be repeated
...@@ -81,10 +89,11 @@ class TransactionBuffer: ...@@ -81,10 +89,11 @@ class TransactionBuffer:
seen = set() seen = set()
for i in range(self.count): for i in range(self.count):
oid, data = unpickler.load() oid, data = unpickler.load()
seen.add(oid) if client_resolved.get(oid, i) == i:
yield oid, data, oid in resolved seen.add(oid)
yield oid, data, oid in server_resolved
# We may have leftover oids because undo # We may have leftover oids because undo
for oid in resolved: for oid in server_resolved:
if oid not in seen: if oid not in seen:
yield oid, None, True yield oid, None, True
# import sys
if sys.version_info >= (3, 5):
import asyncio
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -6,12 +6,16 @@ else: ...@@ -6,12 +6,16 @@ else:
import trollius as asyncio import trollius as asyncio
import logging import logging
import socket
from struct import unpack from struct import unpack
import sys
from .marshal import encoder from .marshal import encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6
class Protocol(asyncio.Protocol): class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface """asyncio low-level ZEO base interface
""" """
...@@ -47,7 +51,15 @@ class Protocol(asyncio.Protocol): ...@@ -47,7 +51,15 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) logger.info("Connected %s", self)
if sys.version_info < (3, 6):
sock = transport.get_extra_info('socket')
if sock is not None and sock.family in INET_FAMILIES:
# See https://bugs.python.org/issue27456 :(
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
self.transport = transport self.transport = transport
paused = self.paused paused = self.paused
output = self.output output = self.output
append = output.append append = output.append
......
...@@ -98,7 +98,7 @@ class Transport(object): ...@@ -98,7 +98,7 @@ class Transport(object):
capacity = 1 << 64 capacity = 1 << 64
paused = False paused = False
extra = dict(peername='1.2.3.4', sockname=('127.0.0.1', 4200)) extra = dict(peername='1.2.3.4', sockname=('127.0.0.1', 4200), socket=None)
def __init__(self, protocol): def __init__(self, protocol):
self.data = [] self.data = []
......
...@@ -98,6 +98,9 @@ class ZEOOptionsMixin: ...@@ -98,6 +98,9 @@ class ZEOOptionsMixin:
self.add("address", "zeo.address.address", self.add("address", "zeo.address.address",
required="no server address specified; use -a or -C") required="no server address specified; use -a or -C")
self.add("read_only", "zeo.read_only", default=0) self.add("read_only", "zeo.read_only", default=0)
self.add("client_conflict_resolution",
"zeo.client_conflict_resolution",
default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size", self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100) default=100)
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
...@@ -339,6 +342,7 @@ def create_server(storages, options): ...@@ -339,6 +342,7 @@ def create_server(storages, options):
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only = options.read_only,
client_conflict_resolution=options.client_conflict_resolution,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
......
...@@ -107,6 +107,14 @@ ...@@ -107,6 +107,14 @@
<metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault> <metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault>
</key> </key>
<key name="client-conflict-resolution" datatype="boolean"
required="no" default="false">
<description>
Flag indicating whether the server should return conflict
errors to the client, for resolution there.
</description>
</key>
</sectiontype> </sectiontype>
</component> </component>
...@@ -30,6 +30,8 @@ class DummyDB: ...@@ -30,6 +30,8 @@ class DummyDB:
def invalidate(self, *args, **kwargs): def invalidate(self, *args, **kwargs):
pass pass
transform_record_data = untransform_record_data = lambda self, data: data
class WorkerThread(TestThread): class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for # run the entire test in a thread so that the blocking call for
......
...@@ -59,6 +59,9 @@ class DummyDB: ...@@ -59,6 +59,9 @@ class DummyDB:
def invalidateCache(self): def invalidateCache(self):
pass pass
transform_record_data = untransform_record_data = lambda self, data: data
class CommonSetupTearDown(StorageTestBase): class CommonSetupTearDown(StorageTestBase):
"""Common boilerplate""" """Common boilerplate"""
......
...@@ -324,8 +324,8 @@ class InvalidationTests: ...@@ -324,8 +324,8 @@ class InvalidationTests:
def checkConcurrentUpdates2Storages_emulated(self): def checkConcurrentUpdates2Storages_emulated(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
storage2 = self.openClientStorage()
db1 = DB(storage1) db1 = DB(storage1)
storage2 = self.openClientStorage()
db2 = DB(storage2) db2 = DB(storage2)
cn = db1.open() cn = db1.open()
...@@ -349,8 +349,8 @@ class InvalidationTests: ...@@ -349,8 +349,8 @@ class InvalidationTests:
def checkConcurrentUpdates2Storages(self): def checkConcurrentUpdates2Storages(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
storage2 = self.openClientStorage()
db1 = DB(storage1) db1 = DB(storage1)
storage2 = self.openClientStorage()
db2 = DB(storage2) db2 = DB(storage2)
stop = threading.Event() stop = threading.Event()
......
...@@ -33,7 +33,7 @@ logger = logging.getLogger('ZEO.tests.forker') ...@@ -33,7 +33,7 @@ logger = logging.getLogger('ZEO.tests.forker')
class ZEOConfig: class ZEOConfig:
"""Class to generate ZEO configuration file. """ """Class to generate ZEO configuration file. """
def __init__(self, addr): def __init__(self, addr, **options):
if isinstance(addr, str): if isinstance(addr, str):
self.logpath = addr+'.log' self.logpath = addr+'.log'
else: else:
...@@ -42,6 +42,7 @@ class ZEOConfig: ...@@ -42,6 +42,7 @@ class ZEOConfig:
self.address = addr self.address = addr
self.read_only = None self.read_only = None
self.loglevel = 'INFO' self.loglevel = 'INFO'
self.__dict__.update(options)
def dump(self, f): def dump(self, f):
print("<zeo>", file=f) print("<zeo>", file=f)
...@@ -52,7 +53,7 @@ class ZEOConfig: ...@@ -52,7 +53,7 @@ class ZEOConfig:
for name in ( for name in (
'invalidation_queue_size', 'invalidation_age', 'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'transaction_timeout', 'pid_filename',
'ssl_certificate', 'ssl_key', 'ssl_certificate', 'ssl_key', 'client_conflict_resolution',
): ):
v = getattr(self, name, None) v = getattr(self, name, None)
if v: if v:
...@@ -159,7 +160,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None): ...@@ -159,7 +160,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
# The runner thread didn't stop. If it was a process, # The runner thread didn't stop. If it was a process,
# give it some time to exit # give it some time to exit
if hasattr(thread, 'pid') and thread.pid: if hasattr(thread, 'pid') and thread.pid:
os.waitpid(thread.pid) os.waitpid(thread.pid, 0)
else: else:
# Gaaaa, force gc in hopes of maybe getting the unclosed # Gaaaa, force gc in hopes of maybe getting the unclosed
# sockets to get GCed # sockets to get GCed
......
...@@ -52,6 +52,8 @@ class FakeServer: ...@@ -52,6 +52,8 @@ class FakeServer:
def register_connection(*args): def register_connection(*args):
return None, None return None, None
client_conflict_resolution = False
class FakeConnection: class FakeConnection:
protocol_version = b'Z4' protocol_version = b'Z4'
addr = 'test' addr = 'test'
......
...@@ -143,23 +143,9 @@ class MiscZEOTests: ...@@ -143,23 +143,9 @@ class MiscZEOTests:
self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction()) self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction())
storage3.close() storage3.close()
class GenericTests( class GenericTestBase(
# Base class for all ZODB tests # Base class for all ZODB tests
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase):
# ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage,
PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage,
MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
"""Combine tests from various origins in one class."""
shared_blob_dir = False shared_blob_dir = False
blob_cache_dir = None blob_cache_dir = None
...@@ -200,14 +186,23 @@ class GenericTests( ...@@ -200,14 +186,23 @@ class GenericTests(
stop() stop()
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def runTest(self): class GenericTests(
try: GenericTestBase,
super(GenericTests, self).runTest()
except: # ZODB test mixin classes (in the same order as imported)
self._failed = True BasicStorage.BasicStorage,
raise PackableStorage.PackableStorage,
else: Synchronization.SynchronizedStorage,
self._failed = False MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
"""Combine tests from various origins in one class.
"""
def open(self, read_only=0): def open(self, read_only=0):
# Needed to support ReadOnlyStorage tests. Ought to be a # Needed to support ReadOnlyStorage tests. Ought to be a
...@@ -394,7 +389,16 @@ class FileStorageClientHexTests(FileStorageHexTests): ...@@ -394,7 +389,16 @@ class FileStorageClientHexTests(FileStorageHexTests):
def _wrap_client(self, client): def _wrap_client(self, client):
return ZODB.tests.hexstorage.HexStorage(client) return ZODB.tests.hexstorage.HexStorage(client)
class ClientConflictResolutionTests(
GenericTestBase,
ConflictResolution.ConflictResolvingStorage,
):
def getConfig(self):
return '<mappingstorage>\n</mappingstorage>\n'
def getZEOConfig(self):
return forker.ZEOConfig(('', 0), client_conflict_resolution=True)
class MappingStorageTests(GenericTests): class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage.""" """ZEO backed by a Mapping storage."""
...@@ -492,6 +496,8 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): ...@@ -492,6 +496,8 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
self._invalidatedCache += 1 self._invalidatedCache += 1
def invalidate(*a, **k): def invalidate(*a, **k):
pass pass
transform_record_data = untransform_record_data = \
lambda self, data: data
db = DummyDB() db = DummyDB()
storage.registerDB(db) storage.registerDB(db)
...@@ -928,14 +934,14 @@ def tpc_finish_error(): ...@@ -928,14 +934,14 @@ def tpc_finish_error():
buffer, sadly, using implementation details: buffer, sadly, using implementation details:
>>> tbuf = t.data(client) >>> tbuf = t.data(client)
>>> tbuf.resolved = None >>> tbuf.client_resolved = None
tpc_finish will fail: tpc_finish will fail:
>>> client.tpc_finish(t) # doctest: +ELLIPSIS >>> client.tpc_finish(t) # doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
... ...
TypeError: ... AttributeError: ...
>>> client.tpc_abort(t) >>> client.tpc_abort(t)
>>> t.abort() >>> t.abort()
...@@ -1587,6 +1593,7 @@ def test_suite(): ...@@ -1587,6 +1593,7 @@ def test_suite():
"ClientDisconnected"), "ClientDisconnected"),
)), )),
)) ))
zeo.addTest(unittest.makeSuite(ClientConflictResolutionTests, 'check'))
zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc') zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc')
suite.addTest(zeo) suite.addTest(zeo)
......
import unittest
import zope.testing.setupstack
from BTrees.Length import Length
from ZODB import serialize
from ZODB.DemoStorage import DemoStorage
from ZODB.utils import p64, z64, maxtid
from ZODB.broken import find_global
import ZEO
from .utils import StorageServer
class Var(object):
def __eq__(self, other):
self.value = other
return True
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
def test_server_side(self):
# First, verify default conflict resolution.
server = StorageServer(self, DemoStorage())
zs = server.zs
reader = serialize.ObjectReader(
factory=lambda conn, *args: find_global(*args))
writer = serialize.ObjectWriter()
ob = Length(0)
ob._p_oid = z64
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
# Now, a cnflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns the object id, indicating that a conflict was resolved.
self.assertEqual(zs.vote(3), [ob._p_oid])
tid3 = server.unpack_result(zs.tpc_finish(3))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 2)
# Now, we'll create a server that expects the client to
# resolve conflicts:
server = StorageServer(
self, DemoStorage(), client_conflict_resolution=True)
zs = server.zs
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
# Now, a conflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns an object, indicating that a conflict was not resolved.
self.assertEqual(
zs.vote(3),
[dict(oid=ob._p_oid,
serials=(tid2, tid1),
data=writer.serialize(ob),
)],
)
# Now, it's up to the client to resolve the conflict. It can
# do this by making another store call. In this call, we use
# tid2 as the starting tid:
ob.change(1)
zs.storea(ob._p_oid, tid2, writer.serialize(ob), 3)
self.assertEqual(zs.vote(3), [])
tid3 = server.unpack_result(zs.tpc_finish(3))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 3)
def test_client_side(self):
# First, traditional:
addr, stop = ZEO.server('data.fs')
db = ZEO.DB(addr)
with db.transaction() as conn:
conn.root.l = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
db.close(); stop()
# Now, do conflict resolution on the client.
addr2, stop = ZEO.server(
storage_conf='<mappingstorage>\n</mappingstorage>\n',
zeo_conf=dict(client_conflict_resolution=True),
)
db = ZEO.DB(addr2)
with db.transaction() as conn:
conn.root.l = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
self.assertEqual(conn2.root.l.value, 1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
db.close(); stop()
def test_suite():
return unittest.makeSuite(ClientSideConflictResolutionTests)
"""Testing helpers
"""
import ZEO.StorageServer
from ..asyncio.server import best_protocol_version
class ServerProtocol:
method = ('register', )
def __init__(self, zs,
protocol_version=best_protocol_version,
addr='test-address'):
self.calls = []
self.addr = addr
self.zs = zs
self.protocol_version = protocol_version
zs.notify_connected(self)
closed = False
def close(self):
if not self.closed:
self.closed = True
self.zs.notify_disconnected()
def call_soon_threadsafe(self, func, *args):
func(*args)
def async(self, *args):
self.calls.append(args)
class StorageServer:
"""Create a client interface to a StorageServer.
This is for testing StorageServer. It interacts with the storgr
server through its network interface, but without creating a
network connection.
"""
def __init__(self, test, storage,
protocol_version=best_protocol_version,
**kw):
self.test = test
self.storage_server = ZEO.StorageServer.StorageServer(
None, {'1': storage}, **kw)
self.zs = self.storage_server.create_client_handler()
self.protocol = ServerProtocol(self.zs,
protocol_version=protocol_version)
self.zs.register('1', kw.get('read_only', False))
def assert_calls(self, test, *argss):
if argss:
for args in argss:
test.assertEqual(self.protocol.calls.pop(0), args)
else:
test.assertEqual(self.protocol.calls, ())
def unpack_result(self, result):
"""For methods that return Result objects, unwrap the results
"""
result, callback = result.args
callback()
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment