Commit accd8c02 authored by Jim Fulton's avatar Jim Fulton

server side

parent 02bb6611
...@@ -89,6 +89,8 @@ class ZEOStorage: ...@@ -89,6 +89,8 @@ class ZEOStorage:
def __init__(self, server, read_only=0): def __init__(self, server, read_only=0):
self.server = server self.server = server
self.client_side_conflict_resolution = (
server.client_side_conflict_resolution)
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.read_only = read_only self.read_only = read_only
# The authentication protocol may define extra methods. # The authentication protocol may define extra methods.
...@@ -334,6 +336,7 @@ class ZEOStorage: ...@@ -334,6 +336,7 @@ class ZEOStorage:
t._extension = ext t._extension = ext
self.serials = [] self.serials = []
self.conflicts = {}
self.invalidated = [] self.invalidated = []
self.txnlog = CommitLog() self.txnlog = CommitLog()
self.blob_log = [] self.blob_log = []
...@@ -413,6 +416,7 @@ class ZEOStorage: ...@@ -413,6 +416,7 @@ class ZEOStorage:
self.locked, delay = self.server.lock_storage(self, delay) self.locked, delay = self.server.lock_storage(self, delay)
if self.locked: if self.locked:
result = None
try: try:
self.log( self.log(
"Preparing to commit transaction: %d objects, %d bytes" "Preparing to commit transaction: %d objects, %d bytes"
...@@ -433,13 +437,29 @@ class ZEOStorage: ...@@ -433,13 +437,29 @@ class ZEOStorage:
oid, oldserial, data, blobfilename = self.blob_log.pop() oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename) self._store(oid, oldserial, data, blobfilename)
if not self.conflicts:
try:
serials = self.storage.tpc_vote(self.transaction) serials = self.storage.tpc_vote(self.transaction)
except ConflictError as err:
if (self.client_side_conflict_resolution and
err.oid and err.serials and err.data
):
self.conflicts[err.oid] = dict(
oid=err.oid, serials=err.serials, data=err.data)
else:
raise
else:
if serials: if serials:
if not isinstance(serials[0], bytes):
serials = (oid for (oid, serial) in serials
if serial == ResolvedSerial)
self.serials.extend(serials) self.serials.extend(serials)
result = self.serials
if self.conflicts:
result = list(self.conflicts.values())
self.storage.tpc_abort(self.transaction)
self.server.unlock_storage(self)
self.locked = False
self.server.stop_waiting(self)
except Exception as err: except Exception as err:
self.storage.tpc_abort(self.transaction) self.storage.tpc_abort(self.transaction)
...@@ -457,9 +477,9 @@ class ZEOStorage: ...@@ -457,9 +477,9 @@ class ZEOStorage:
raise raise
else: else:
if delay is not None: if delay is not None:
delay.reply(self.serials) delay.reply(result)
else: else:
return self.serials return result
else: else:
return delay return delay
...@@ -559,29 +579,26 @@ class ZEOStorage: ...@@ -559,29 +579,26 @@ class ZEOStorage:
oid, serial, self.transaction) oid, serial, self.transaction)
def _store(self, oid, serial, data, blobfile=None): def _store(self, oid, serial, data, blobfile=None):
try:
if blobfile is None: if blobfile is None:
newserial = self.storage.store( self.storage.store(oid, serial, data, '', self.transaction)
oid, serial, data, '', self.transaction)
else: else:
newserial = self.storage.storeBlob( self.storage.storeBlob(
oid, serial, data, blobfile, '', self.transaction) oid, serial, data, blobfile, '', self.transaction)
except ConflictError as err:
if self.client_side_conflict_resolution and err.serials:
self.conflicts[oid] = dict(
oid=oid, serials=err.serials, data=data)
else:
raise
else:
if oid in self.conflicts:
del self.conflicts[oid]
self.serials.append(oid)
if serial != b"\0\0\0\0\0\0\0\0": if serial != b"\0\0\0\0\0\0\0\0":
self.invalidated.append(oid) self.invalidated.append(oid)
if newserial:
if isinstance(newserial, bytes):
newserial = [(oid, newserial)]
for oid, s in newserial:
if s == ResolvedSerial:
self.stats.conflicts_resolved += 1
self.log("conflict resolved oid=%s"
% oid_repr(oid), BLATHER)
self.serials.append(oid)
def _restore(self, oid, serial, data, prev_txn): def _restore(self, oid, serial, data, prev_txn):
self.storage.restore(oid, serial, data, '', prev_txn, self.storage.restore(oid, serial, data, '', prev_txn,
self.transaction) self.transaction)
...@@ -697,6 +714,7 @@ class StorageServer: ...@@ -697,6 +714,7 @@ class StorageServer:
invalidation_age=None, invalidation_age=None,
transaction_timeout=None, transaction_timeout=None,
ssl=None, ssl=None,
client_side_conflict_resolution=False,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -767,8 +785,15 @@ class StorageServer: ...@@ -767,8 +785,15 @@ class StorageServer:
for name, storage in storages.items(): for name, storage in storages.items():
self._setup_invq(name, storage) self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name)) storage.registerDB(StorageServerDB(self, name))
if client_side_conflict_resolution:
# XXX this may go away later, when storages grow
# configuration for this.
storage.tryToResolveConflict = never_resolve_conflict
self.invalidation_age = invalidation_age self.invalidation_age = invalidation_age
self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]} self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.client_side_conflict_resolution = client_side_conflict_resolution
if addr is not None:
self.acceptor = Acceptor(self, addr, ssl) self.acceptor = Acceptor(self, addr, ssl)
if isinstance(addr, tuple) and addr[0]: if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr self.addr = self.acceptor.addr
...@@ -776,6 +801,7 @@ class StorageServer: ...@@ -776,6 +801,7 @@ class StorageServer:
self.addr = addr self.addr = addr
self.loop = self.acceptor.loop self.loop = self.acceptor.loop
ZODB.event.notify(Serving(self, address=self.acceptor.addr)) ZODB.event.notify(Serving(self, address=self.acceptor.addr))
self.stats = {} self.stats = {}
self.timeouts = {} self.timeouts = {}
for name in self.storages.keys(): for name in self.storages.keys():
...@@ -1313,3 +1339,8 @@ default_cert_authenticate = 'SIGNED' ...@@ -1313,3 +1339,8 @@ default_cert_authenticate = 'SIGNED'
def ssl_config(section): def ssl_config(section):
from .sslconfig import ssl_config from .sslconfig import ssl_config
return ssl_config(section, True) return ssl_config(section, True)
def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle,
committedData=b''):
raise ConflictError(oid=oid, serials=(committedSerial, oldSerial),
data=newpickle)
import unittest
import zope.testing.setupstack
from BTrees.Length import Length
from ZODB import serialize
from ZODB.DemoStorage import DemoStorage
from ZODB.utils import p64, z64, maxtid
from ZODB.broken import find_global
from .utils import StorageServer
class Var(object):
def __eq__(self, other):
self.value = other
return True
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
def test_server_side(self):
# First, verify default conflict resolution.
server = StorageServer(self, DemoStorage())
zs = server.zs
reader = serialize.ObjectReader(
factory=lambda conn, *args: find_global(*args))
writer = serialize.ObjectWriter()
ob = Length(0)
ob._p_oid = z64
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
# Now, a cnflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns the object id, indicating that a conflict was resolved.
self.assertEqual(zs.vote(3), [ob._p_oid])
tid3 = server.unpack_result(zs.tpc_finish(3))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 2)
# Now, we'll create a server that expects the client to
# resolve conflicts:
server = StorageServer(
self, DemoStorage(), client_side_conflict_resolution=True)
zs = server.zs
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
# Now, a conflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns an object, indicating that a conflict was not resolved.
self.assertEqual(
zs.vote(3),
[dict(oid=ob._p_oid,
serials=(tid2, tid1),
data=writer.serialize(ob),
)],
)
# Now, it's up to the client to resolve the conflict. It can
# do this by making another store call. In this call, we use
# tid2 as the starting tid:
ob.change(1)
zs.storea(ob._p_oid, tid2, writer.serialize(ob), 3)
self.assertEqual(zs.vote(3), [ob._p_oid])
tid3 = server.unpack_result(zs.tpc_finish(3))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 3)
def test_suite():
return unittest.makeSuite(ClientSideConflictResolutionTests)
"""Testing helpers
"""
import ZEO.StorageServer
from ..asyncio.server import best_protocol_version
class ServerProtocol:
method = ('register', )
def __init__(self, zs,
protocol_version=best_protocol_version,
addr='test-address'):
self.calls = []
self.addr = addr
self.zs = zs
self.protocol_version = protocol_version
zs.notify_connected(self)
closed = False
def close(self):
if not self.closed:
self.closed = True
self.zs.notify_disconnected()
def call_soon_threadsafe(self, func, *args):
func(*args)
def async(self, *args):
self.calls.append(args)
class StorageServer:
"""Create a client interface to a StorageServer.
This is for testing StorageServer. It interacts with the storgr
server through its network interface, but without creating a
network connection.
"""
def __init__(self, test, storage,
protocol_version=best_protocol_version,
**kw):
self.test = test
self.storage_server = ZEO.StorageServer.StorageServer(
None, {'1': storage}, **kw)
self.zs = self.storage_server.create_client_handler()
self.protocol = ServerProtocol(self.zs,
protocol_version=protocol_version)
self.zs.register('1', kw.get('read_only', False))
def assert_calls(self, test, *argss):
if argss:
for args in argss:
test.assertEqual(self.protocol.calls.pop(0), args)
else:
test.assertEqual(self.protocol.calls, ())
def unpack_result(self, result):
"""For methods that return Result objects, unwrap the results
"""
result, callback = result.args
callback()
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment