Commit 3c0e51bd authored by Jim Fulton's avatar Jim Fulton

Refactored storage server to support multiple client threads.

Changed ZEO undo protocol. (Undo is disabled with older clients.)
Now use one-way undoa.  Undone oids are now returned by (tpc_)vote for
ZEO. Undo no-longer gets commit lock.
parent 2e618202
...@@ -14,6 +14,10 @@ New Features ...@@ -14,6 +14,10 @@ New Features
database's undo method multiple times in the same transaction now database's undo method multiple times in the same transaction now
raises an exception. raises an exception.
- The ZEO protocol for undo has changed. The only user-visible
consequence of this is that when ZODB 3.10 ZEO servers won't support
undo for older clients.
- The storage API (IStorage) has been tightened. Now, storages should - The storage API (IStorage) has been tightened. Now, storages should
raise a StorageTransactionError when invalid transactions are passed raise a StorageTransactionError when invalid transactions are passed
to tpc_begin, tpc_vote, or tpc_finish. to tpc_begin, tpc_vote, or tpc_finish.
......
...@@ -1198,14 +1198,19 @@ class ClientStorage(object): ...@@ -1198,14 +1198,19 @@ class ClientStorage(object):
if self._cache is None: if self._cache is None:
return return
for oid, data in self._tbuf: for oid, _ in self._seriald.iteritems():
self._cache.invalidate(oid, tid, False) self._cache.invalidate(oid, tid, False)
for oid, data in self._tbuf:
# If data is None, we just invalidate. # If data is None, we just invalidate.
if data is not None: if data is not None:
s = self._seriald[oid] s = self._seriald[oid]
if s != ResolvedSerial: if s != ResolvedSerial:
assert s == tid, (s, tid) assert s == tid, (s, tid)
self._cache.store(oid, s, None, data) self._cache.store(oid, s, None, data)
else:
# object deletion
self._cache.invalidate(oid, tid, False)
if self.fshelper is not None: if self.fshelper is not None:
blobs = self._tbuf.blobs blobs = self._tbuf.blobs
...@@ -1241,10 +1246,7 @@ class ClientStorage(object): ...@@ -1241,10 +1246,7 @@ class ClientStorage(object):
""" """
self._check_trans(txn) self._check_trans(txn)
tid, oids = self._server.undo(trans_id, id(txn)) self._server.undoa(trans_id, id(txn))
for oid in oids:
self._tbuf.invalidate(oid)
return tid, oids
def undoInfo(self, first=0, last=-20, specification=None): def undoInfo(self, first=0, last=-20, specification=None):
"""Storage API: return undo information.""" """Storage API: return undo information."""
......
...@@ -272,8 +272,8 @@ class StorageServer: ...@@ -272,8 +272,8 @@ class StorageServer:
def new_oid(self): def new_oid(self):
return self.rpc.call('new_oid') return self.rpc.call('new_oid')
def undo(self, trans_id, trans): def undoa(self, trans_id, trans):
return self.rpc.call('undo', trans_id, trans) self.rpc.callAsync('undoa', trans_id, trans)
def undoLog(self, first, last): def undoLog(self, first, last):
return self.rpc.call('undoLog', first, last) return self.rpc.call('undoLog', first, last)
......
...@@ -20,6 +20,8 @@ TODO: Need some basic access control-- a declaration of the methods ...@@ -20,6 +20,8 @@ TODO: Need some basic access control-- a declaration of the methods
exported for invocation by the server. exported for invocation by the server.
""" """
from __future__ import with_statement
import asyncore import asyncore
import cPickle import cPickle
import logging import logging
...@@ -32,6 +34,7 @@ import itertools ...@@ -32,6 +34,7 @@ import itertools
import transaction import transaction
import ZODB.blob
import ZODB.serialize import ZODB.serialize
import ZODB.TimeStamp import ZODB.TimeStamp
import ZEO.zrpc.error import ZEO.zrpc.error
...@@ -40,7 +43,7 @@ import zope.interface ...@@ -40,7 +43,7 @@ import zope.interface
from ZEO.CommitLog import CommitLog from ZEO.CommitLog import CommitLog
from ZEO.monitor import StorageStats, StatsServer from ZEO.monitor import StorageStats, StatsServer
from ZEO.zrpc.server import Dispatcher from ZEO.zrpc.server import Dispatcher
from ZEO.zrpc.connection import ManagedServerConnection, Delay, MTDelay from ZEO.zrpc.connection import ManagedServerConnection, Delay, MTDelay, Result
from ZEO.zrpc.trigger import trigger from ZEO.zrpc.trigger import trigger
from ZEO.Exceptions import AuthError from ZEO.Exceptions import AuthError
...@@ -48,7 +51,7 @@ from ZODB.ConflictResolution import ResolvedSerial ...@@ -48,7 +51,7 @@ from ZODB.ConflictResolution import ResolvedSerial
from ZODB.POSException import StorageError, StorageTransactionError from ZODB.POSException import StorageError, StorageTransactionError
from ZODB.POSException import TransactionError, ReadOnlyError, ConflictError from ZODB.POSException import TransactionError, ReadOnlyError, ConflictError
from ZODB.serialize import referencesf from ZODB.serialize import referencesf
from ZODB.utils import u64, p64, oid_repr, mktemp from ZODB.utils import oid_repr, p64, u64, z64
from ZODB.loglevels import BLATHER from ZODB.loglevels import BLATHER
...@@ -87,7 +90,6 @@ class ZEOStorage: ...@@ -87,7 +90,6 @@ class ZEOStorage:
def __init__(self, server, read_only=0, auth_realm=None): def __init__(self, server, read_only=0, auth_realm=None):
self.server = server self.server = server
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.timeout = None
self.stats = None self.stats = None
self.connection = None self.connection = None
self.client = None self.client = None
...@@ -95,14 +97,13 @@ class ZEOStorage: ...@@ -95,14 +97,13 @@ class ZEOStorage:
self.storage_id = "uninitialized" self.storage_id = "uninitialized"
self.transaction = None self.transaction = None
self.read_only = read_only self.read_only = read_only
self.locked = 0 self.locked = False # Don't have storage lock
self.verifying = 0 self.verifying = 0
self.store_failed = 0 self.store_failed = 0
self.log_label = _label self.log_label = _label
self.authenticated = 0 self.authenticated = 0
self.auth_realm = auth_realm self.auth_realm = auth_realm
self.blob_tempfile = None self.blob_tempfile = None
self.blob_log = []
# The authentication protocol may define extra methods. # The authentication protocol may define extra methods.
self._extensions = {} self._extensions = {}
for func in self.extensions: for func in self.extensions:
...@@ -139,24 +140,13 @@ class ZEOStorage: ...@@ -139,24 +140,13 @@ class ZEOStorage:
self.log_label = _label + "/" + label self.log_label = _label + "/" + label
def notifyDisconnected(self): def notifyDisconnected(self):
self.connection = None
# When this storage closes, we must ensure that it aborts # When this storage closes, we must ensure that it aborts
# any pending transaction. # any pending transaction.
if self.transaction is not None: if self.transaction is not None:
self.log("disconnected during transaction %s" % self.transaction) self.log("disconnected during transaction %s" % self.transaction)
if not self.locked:
# Delete (d, zeo_storage) from the _waiting list, if found.
waiting = self.storage._waiting
for i in range(len(waiting)):
d, z = waiting[i]
if z is self:
del waiting[i]
self.log("Closed connection removed from waiting list."
" Clients waiting: %d." % len(waiting))
break
if self.transaction:
self.tpc_abort(self.transaction.id) self.tpc_abort(self.transaction.id)
else: else:
self.log("disconnected") self.log("disconnected")
...@@ -176,6 +166,7 @@ class ZEOStorage: ...@@ -176,6 +166,7 @@ class ZEOStorage:
def setup_delegation(self): def setup_delegation(self):
"""Delegate several methods to the storage """Delegate several methods to the storage
""" """
# Called from register
storage = self.storage storage = self.storage
...@@ -183,9 +174,6 @@ class ZEOStorage: ...@@ -183,9 +174,6 @@ class ZEOStorage:
if not info['supportsUndo']: if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: () self.undoLog = self.undoInfo = lambda *a,**k: ()
def undo(*a, **k):
raise NotImplementedError
self.undo = undo
self.getTid = storage.getTid self.getTid = storage.getTid
self.load = storage.load self.load = storage.load
...@@ -268,6 +256,7 @@ class ZEOStorage: ...@@ -268,6 +256,7 @@ class ZEOStorage:
if self.storage is not None: if self.storage is not None:
self.log("duplicate register() call") self.log("duplicate register() call")
raise ValueError("duplicate register() call") raise ValueError("duplicate register() call")
storage = self.server.storages.get(storage_id) storage = self.server.storages.get(storage_id)
if storage is None: if storage is None:
self.log("unknown storage_id: %s" % storage_id) self.log("unknown storage_id: %s" % storage_id)
...@@ -280,18 +269,14 @@ class ZEOStorage: ...@@ -280,18 +269,14 @@ class ZEOStorage:
self.storage_id = storage_id self.storage_id = storage_id
self.storage = storage self.storage = storage
self.setup_delegation() self.setup_delegation()
self.timeout, self.stats = self.server.register_connection(storage_id, self.stats = self.server.register_connection(storage_id, self)
self)
def get_info(self): def get_info(self):
storage = self.storage storage = self.storage
try:
supportsUndo = storage.supportsUndo supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)()
except AttributeError: and self.connection.peer_protocol_version >= 'Z310')
supportsUndo = False
else:
supportsUndo = supportsUndo()
# Communicate the backend storage interfaces to the client # Communicate the backend storage interfaces to the client
storage_provides = zope.interface.providedBy(storage) storage_provides = zope.interface.providedBy(storage)
...@@ -419,6 +404,7 @@ class ZEOStorage: ...@@ -419,6 +404,7 @@ class ZEOStorage:
self.serials = [] self.serials = []
self.invalidated = [] self.invalidated = []
self.txnlog = CommitLog() self.txnlog = CommitLog()
self.blob_log = []
self.tid = tid self.tid = tid
self.status = status self.status = status
self.store_failed = 0 self.store_failed = 0
...@@ -437,19 +423,23 @@ class ZEOStorage: ...@@ -437,19 +423,23 @@ class ZEOStorage:
def tpc_finish(self, id): def tpc_finish(self, id):
if not self._check_tid(id): if not self._check_tid(id):
return return
assert self.locked assert self.locked, "finished called wo lock"
self.stats.commits += 1 self.stats.commits += 1
self.storage.tpc_finish(self.transaction) self.storage.tpc_finish(self.transaction, self._invalidate)
# Note that the tid is still current because we still hold the
# commit lock. We'll relinquish it in _clear_transaction.
tid = self.storage.lastTransaction() tid = self.storage.lastTransaction()
# Return the tid, for cache invalidation optimization
return Result(tid, self._clear_transaction)
def _invalidate(self, tid):
if self.invalidated: if self.invalidated:
self.server.invalidate(self, self.storage_id, tid, self.server.invalidate(self, self.storage_id, tid,
self.invalidated, self.get_size_info()) self.invalidated, self.get_size_info())
self._clear_transaction()
# Return the tid, for cache invalidation optimization
return tid
def tpc_abort(self, id): def tpc_abort(self, tid):
if not self._check_tid(id): if not self._check_tid(tid):
return return
self.stats.aborts += 1 self.stats.aborts += 1
if self.locked: if self.locked:
...@@ -458,111 +448,68 @@ class ZEOStorage: ...@@ -458,111 +448,68 @@ class ZEOStorage:
def _clear_transaction(self): def _clear_transaction(self):
# Common code at end of tpc_finish() and tpc_abort() # Common code at end of tpc_finish() and tpc_abort()
self.stats.active_txns -= 1 if self.locked:
self.server.unlock_storage(self)
self.locked = 0
self.transaction = None self.transaction = None
self.stats.active_txns -= 1
if self.txnlog is not None:
self.txnlog.close() self.txnlog.close()
self.txnlog = None
for oid, oldserial, data, blobfilename in self.blob_log:
ZODB.blob.remove_committed(blobfilename)
del self.blob_log
def vote(self, tid):
self._check_tid(tid, exc=StorageTransactionError)
return self._try_to_vote()
def _try_to_vote(self, delay=None):
if self.connection is None:
return # We're disconnected
self.locked = self.server.lock_storage(self)
if self.locked: if self.locked:
self.locked = 0
self.timeout.end(self)
self.stats.lock_time = None
self.log("Transaction released storage lock", BLATHER)
# Restart any client waiting for the storage lock.
while self.storage._waiting:
delay, zeo_storage = self.storage._waiting.pop(0)
try: try:
zeo_storage._restart(delay) self._vote()
except: except Exception:
self.log("Unexpected error handling waiting transaction", if delay is not None:
level=logging.WARNING, exc_info=True) delay.error()
zeo_storage.connection.close()
continue
if self.storage._waiting:
n = len(self.storage._waiting)
self.log("Blocked transaction restarted. "
"Clients waiting: %d" % n)
else:
self.log("Blocked transaction restarted.")
break
# The following two methods return values, so they must acquire
# the storage lock and begin the transaction before returning.
# It's a bit vile that undo can cause us to get the lock before vote.
def undo(self, trans_id, id):
self._check_tid(id, exc=StorageTransactionError)
if self.locked:
return self._undo(trans_id)
else: else:
return self._wait(lambda: self._undo(trans_id)) raise
def vote(self, id):
self._check_tid(id, exc=StorageTransactionError)
if self.locked:
return self._vote()
else: else:
return self._wait(lambda: self._vote()) if delay is not None:
delay.reply(None)
# When a delayed transaction is restarted, the dance is
# complicated. The restart occurs when one ZEOStorage instance
# finishes as a transaction and finds another instance is in the
# _waiting list.
# It might be better to have a mechanism to explicitly send
# the finishing transaction's reply before restarting the waiting
# transaction. If the restart takes a long time, the previous
# client will be blocked until it finishes.
def _wait(self, thunk):
# Wait for the storage lock to be acquired.
self._thunk = thunk
if self.tpc_transaction():
d = Delay()
self.storage._waiting.append((d, self))
self.log("Transaction blocked waiting for storage. "
"Clients waiting: %d." % len(self.storage._waiting))
return d
else: else:
self.log("Transaction acquired storage lock.", BLATHER) if delay == None:
return self._restart() self.log("(%r) queue lock: transactions waiting: %s"
% (self.storage_id, self.server.waiting(self)+1))
delay = Delay()
self.server.unlock_callback(self, delay)
return delay
def _unlock_callback(self, delay):
connection = self.connection
if connection is not None:
connection.call_from_thread(self._try_to_vote, delay)
def _vote(self):
def _restart(self, delay=None):
# Restart when the storage lock is available.
if self.txnlog.stores == 1: if self.txnlog.stores == 1:
template = "Preparing to commit transaction: %d object, %d bytes" template = "Preparing to commit transaction: %d object, %d bytes"
else: else:
template = "Preparing to commit transaction: %d objects, %d bytes" template = "Preparing to commit transaction: %d objects, %d bytes"
self.log(template % (self.txnlog.stores, self.txnlog.size()), self.log(template % (self.txnlog.stores, self.txnlog.size()),
level=BLATHER) level=BLATHER)
self.locked = 1
self.timeout.begin(self)
self.stats.lock_time = time.time()
if (self.tid is not None) or (self.status != ' '): if (self.tid is not None) or (self.status != ' '):
self.storage.tpc_begin(self.transaction, self.tid, self.status) self.storage.tpc_begin(self.transaction, self.tid, self.status)
else: else:
self.storage.tpc_begin(self.transaction) self.storage.tpc_begin(self.transaction)
try: try:
loads, loader = self.txnlog.get_loader() for op, args in self.txnlog:
for i in range(loads): if not getattr(self, op)(*args):
store = loader.load()
store_type = store[0]
store_args = store[1:]
if store_type == 'd':
do_store = self._delete
elif store_type == 's':
do_store = self._store
elif store_type == 'r':
do_store = self._restore
else:
raise ValueError('Invalid store type: %r' % store_type)
if not do_store(*store_args):
break break
# Blob support # Blob support
...@@ -575,11 +522,16 @@ class ZEOStorage: ...@@ -575,11 +522,16 @@ class ZEOStorage:
self._clear_transaction() self._clear_transaction()
raise raise
resp = self._thunk()
if delay is not None: if not self.store_failed:
delay.reply(resp) # Only call tpc_vote of no store call failed, otherwise
else: # the serialnos() call will deliver an exception that will be
return resp # handled by the client in its tpc_vote() method.
serials = self.storage.tpc_vote(self.transaction)
if serials:
self.serials.extend(serials)
self.client.serialnos(self.serials)
# The public methods of the ZEO client API do not do the real work. # The public methods of the ZEO client API do not do the real work.
# They defer work until after the storage lock has been acquired. # They defer work until after the storage lock has been acquired.
...@@ -610,14 +562,18 @@ class ZEOStorage: ...@@ -610,14 +562,18 @@ class ZEOStorage:
os.write(self.blob_tempfile[0], chunk) os.write(self.blob_tempfile[0], chunk)
def storeBlobEnd(self, oid, serial, data, id): def storeBlobEnd(self, oid, serial, data, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
fd, tempname = self.blob_tempfile fd, tempname = self.blob_tempfile
self.blob_tempfile = None self.blob_tempfile = None
os.close(fd) os.close(fd)
self.blob_log.append((oid, serial, data, tempname)) self.blob_log.append((oid, serial, data, tempname))
def storeBlobShared(self, oid, serial, data, filename, id): def storeBlobShared(self, oid, serial, data, filename, id):
# Reconstruct the full path from the filename in the OID directory self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
# Reconstruct the full path from the filename in the OID directory
if (os.path.sep in filename if (os.path.sep in filename
or not (filename.endswith('.tmp') or not (filename.endswith('.tmp')
or filename[:-1].endswith('.tmp') or filename[:-1].endswith('.tmp')
...@@ -635,6 +591,13 @@ class ZEOStorage: ...@@ -635,6 +591,13 @@ class ZEOStorage:
def sendBlob(self, oid, serial): def sendBlob(self, oid, serial):
self.client.storeBlob(oid, serial, self.storage.loadBlob(oid, serial)) self.client.storeBlob(oid, serial, self.storage.loadBlob(oid, serial))
def undo(*a, **k):
raise NotImplementedError
def undoa(self, trans_id, tid):
self._check_tid(tid, exc=StorageTransactionError)
self.txnlog.undo(trans_id)
def _delete(self, oid, serial): def _delete(self, oid, serial):
err = None err = None
try: try:
...@@ -721,6 +684,27 @@ class ZEOStorage: ...@@ -721,6 +684,27 @@ class ZEOStorage:
return err is None return err is None
def _undo(self, trans_id):
err = None
try:
tid, oids = self.storage.undo(trans_id, self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception, err:
self.store_failed = 1
if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client
self.log("store error: %s, %s" % sys.exc_info()[:2],
logging.ERROR, exc_info=True)
err = self._marshal_error(err)
# The exception is reported back as newserial for this oid
self.serials.append((z64, err))
else:
self.invalidated.extend(oids)
self.serials.extend((oid, ResolvedSerial) for oid in oids)
return err is None
def _marshal_error(self, error): def _marshal_error(self, error):
# Try to pickle the exception. If it can't be pickled, # Try to pickle the exception. If it can't be pickled,
# the RPC response would fail, so use something that can be pickled. # the RPC response would fail, so use something that can be pickled.
...@@ -734,23 +718,6 @@ class ZEOStorage: ...@@ -734,23 +718,6 @@ class ZEOStorage:
error = StorageServerError(msg) error = StorageServerError(msg)
return error return error
def _vote(self):
if not self.store_failed:
# Only call tpc_vote of no store call failed, otherwise
# the serialnos() call will deliver an exception that will be
# handled by the client in its tpc_vote() method.
serials = self.storage.tpc_vote(self.transaction)
if serials:
self.serials.extend(serials)
self.client.serialnos(self.serials)
return
def _undo(self, trans_id):
tid, oids = self.storage.undo(trans_id, self.transaction)
self.invalidated.extend(oids)
return tid, oids
# IStorageIteration support # IStorageIteration support
def iterator_start(self, start, stop): def iterator_start(self, start, stop):
...@@ -929,8 +896,12 @@ class StorageServer: ...@@ -929,8 +896,12 @@ class StorageServer:
for name, storage in storages.items()]) for name, storage in storages.items()])
log("%s created %s with storages: %s" % log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg)) (self.__class__.__name__, read_only and "RO" or "RW", msg))
for s in storages.values():
s._waiting = []
self._lock = threading.Lock()
self._commit_locks = {}
self._unlock_callbacks = dict((name, []) for name in storages)
self.read_only = read_only self.read_only = read_only
self.auth_protocol = auth_protocol self.auth_protocol = auth_protocol
self.auth_database = auth_database self.auth_database = auth_database
...@@ -1044,7 +1015,7 @@ class StorageServer: ...@@ -1044,7 +1015,7 @@ class StorageServer:
Returns the timeout and stats objects for the appropriate storage. Returns the timeout and stats objects for the appropriate storage.
""" """
self.connections[storage_id].append(conn) self.connections[storage_id].append(conn)
return self.timeouts[storage_id], self.stats[storage_id] return self.stats[storage_id]
def _invalidateCache(self, storage_id): def _invalidateCache(self, storage_id):
"""We need to invalidate any caches we have. """We need to invalidate any caches we have.
...@@ -1195,8 +1166,6 @@ class StorageServer: ...@@ -1195,8 +1166,6 @@ class StorageServer:
self.dispatcher.close() self.dispatcher.close()
if self.monitor is not None: if self.monitor is not None:
self.monitor.close() self.monitor.close()
for storage in self.storages.values():
storage.close()
# Force the asyncore mainloop to exit by hackery, i.e. close # Force the asyncore mainloop to exit by hackery, i.e. close
# every socket in the map. loop() will return when the map is # every socket in the map. loop() will return when the map is
# empty. # empty.
...@@ -1206,6 +1175,8 @@ class StorageServer: ...@@ -1206,6 +1175,8 @@ class StorageServer:
except: except:
pass pass
asyncore.socket_map.clear() asyncore.socket_map.clear()
for storage in self.storages.values():
storage.close()
def close_conn(self, conn): def close_conn(self, conn):
"""Internal: remove the given connection from self.connections. """Internal: remove the given connection from self.connections.
...@@ -1216,6 +1187,45 @@ class StorageServer: ...@@ -1216,6 +1187,45 @@ class StorageServer:
if conn.obj in cl: if conn.obj in cl:
cl.remove(conn.obj) cl.remove(conn.obj)
def lock_storage(self, zeostore):
storage_id = zeostore.storage_id
with self._lock:
if storage_id in self._commit_locks:
return False
self._commit_locks[storage_id] = zeostore
self.timeouts[storage_id].begin(zeostore)
self.stats[storage_id].lock_time = time.time()
return True
def unlock_storage(self, zeostore):
storage_id = zeostore.storage_id
with self._lock:
assert self._commit_locks[storage_id] is zeostore
del self._commit_locks[storage_id]
self.timeouts[storage_id].end(zeostore)
self.stats[storage_id].lock_time = None
callbacks = self._unlock_callbacks[storage_id][:]
del self._unlock_callbacks[storage_id][:]
if callbacks:
zeostore.log("(%r) unlock: transactions waiting: %s"
% (storage_id, len(callbacks)-1))
for zeostore, delay in callbacks:
try:
zeostore._unlock_callback(delay)
except (SystemExit, KeyboardInterrupt):
raise
except Exception:
logger.exception("Calling unlock callback")
def unlock_callback(self, zeostore, delay):
storage_id = zeostore.storage_id
with self._lock:
self._unlock_callbacks[storage_id].append((zeostore, delay))
def waiting(self, zeostore):
return len(self._unlock_callbacks[zeostore.storage_id])
class StubTimeoutThread: class StubTimeoutThread:
...@@ -1238,7 +1248,6 @@ class TimeoutThread(threading.Thread): ...@@ -1238,7 +1248,6 @@ class TimeoutThread(threading.Thread):
self._client = None self._client = None
self._deadline = None self._deadline = None
self._cond = threading.Condition() # Protects _client and _deadline self._cond = threading.Condition() # Protects _client and _deadline
self._trigger = trigger()
def begin(self, client): def begin(self, client):
# Called from the restart code the "main" thread, whenever the # Called from the restart code the "main" thread, whenever the
...@@ -1281,7 +1290,7 @@ class TimeoutThread(threading.Thread): ...@@ -1281,7 +1290,7 @@ class TimeoutThread(threading.Thread):
if howlong <= 0: if howlong <= 0:
client.log("Transaction timeout after %s seconds" % client.log("Transaction timeout after %s seconds" %
self._timeout) self._timeout)
self._trigger.pull_trigger(lambda: client.connection.close()) client.connection.trigger.pull_trigger(client.connection.close)
else: else:
time.sleep(howlong) time.sleep(howlong)
......
...@@ -37,14 +37,11 @@ class TransUndoStorageWithCache: ...@@ -37,14 +37,11 @@ class TransUndoStorageWithCache:
# Now start an undo transaction # Now start an undo transaction
t = Transaction() t = Transaction()
t.note('undo1') t.note('undo1')
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
tid, oids = self._storage.undo(tid, t)
# Make sure this doesn't load invalid data into the cache # Make sure this doesn't load invalid data into the cache
self._storage.load(oid, '') self._storage.load(oid, '')
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
assert len(oids) == 1 assert len(oids) == 1
......
...@@ -181,64 +181,3 @@ class CommitLockVoteTests(CommitLockTests): ...@@ -181,64 +181,3 @@ class CommitLockVoteTests(CommitLockTests):
self._finish_threads() self._finish_threads()
self._cleanup() self._cleanup()
class CommitLockUndoTests(CommitLockTests):
def _get_trans_id(self):
self._dostore()
L = self._storage.undoInfo()
return L[0]['id']
def _begin_undo(self, trans_id, txn):
rpc = self._storage._server.rpc
return rpc._deferred_call('undo', trans_id, id(txn))
def _finish_undo(self, msgid):
return self._storage._server.rpc._deferred_wait(msgid)
def checkCommitLockUndoFinish(self):
trans_id = self._get_trans_id()
oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id, txn)
self._begin_threads()
self._finish_undo(msgid)
self._storage.tpc_vote(txn)
self._storage.tpc_finish(txn)
self._storage.load(oid, '')
self._finish_threads()
self._dostore()
self._cleanup()
def checkCommitLockUndoAbort(self):
trans_id = self._get_trans_id()
oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id, txn)
self._begin_threads()
self._finish_undo(msgid)
self._storage.tpc_vote(txn)
self._storage.tpc_abort(txn)
self._finish_threads()
self._dostore()
self._cleanup()
def checkCommitLockUndoClose(self):
trans_id = self._get_trans_id()
oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id, txn)
self._begin_threads()
self._finish_undo(msgid)
self._storage.tpc_vote(txn)
self._storage.close()
self._finish_threads()
self._cleanup()
...@@ -318,9 +318,9 @@ class InvalidationTests: ...@@ -318,9 +318,9 @@ class InvalidationTests:
# tearDown then immediately, but if other threads are still # tearDown then immediately, but if other threads are still
# running that can lead to a cascade of spurious exceptions. # running that can lead to a cascade of spurious exceptions.
for t in threads: for t in threads:
t.join(10) t.join(30)
for t in threads: for t in threads:
t.cleanup() t.cleanup(10)
def checkConcurrentUpdates2Storages_emulated(self): def checkConcurrentUpdates2Storages_emulated(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
...@@ -378,6 +378,34 @@ class InvalidationTests: ...@@ -378,6 +378,34 @@ class InvalidationTests:
db1.close() db1.close()
db2.close() db2.close()
def checkConcurrentUpdates19Storages(self):
n = 19
dbs = [DB(self.openClientStorage()) for i in range(n)]
self._storage = dbs[0].storage
stop = threading.Event()
cn = dbs[0].open()
tree = cn.root()["tree"] = OOBTree()
transaction.commit()
cn.close()
# Run threads that update the BTree
cd = {}
threads = [self.StressThread(dbs[i], stop, i, cd, i, n)
for i in range(n)]
self.go(stop, cd, *threads)
while len(set(db.lastTransaction() for db in dbs)) > 1:
_ = [db._storage.sync() for db in dbs]
cn = dbs[0].open()
tree = cn.root()["tree"]
self._check_tree(cn, tree)
self._check_threads(tree, *threads)
cn.close()
_ = [db.close() for db in dbs]
def checkConcurrentUpdates1Storage(self): def checkConcurrentUpdates1Storage(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
db1 = DB(storage1) db1 = DB(storage1)
......
...@@ -58,3 +58,7 @@ class Connection: ...@@ -58,3 +58,7 @@ class Connection:
print self.name, 'callAsync', meth, repr(args) print self.name, 'callAsync', meth, repr(args)
callAsyncNoPoll = callAsync callAsyncNoPoll = callAsync
def call_from_thread(self, *args):
if args:
args[0](*args[1:])
...@@ -25,7 +25,6 @@ from ZODB.tests import StorageTestBase, BasicStorage, \ ...@@ -25,7 +25,6 @@ from ZODB.tests import StorageTestBase, BasicStorage, \
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle from ZODB.tests.StorageTestBase import zodb_unpickle
import asyncore
import doctest import doctest
import logging import logging
import os import os
...@@ -244,7 +243,6 @@ class GenericTests( ...@@ -244,7 +243,6 @@ class GenericTests(
class FullGenericTests( class FullGenericTests(
GenericTests, GenericTests,
Cache.TransUndoStorageWithCache, Cache.TransUndoStorageWithCache,
CommitLockTests.CommitLockUndoTests,
ConflictResolution.ConflictResolvingStorage, ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage, ConflictResolution.ConflictResolvingTransUndoStorage,
PackableStorage.PackableUndoStorage, PackableStorage.PackableUndoStorage,
...@@ -727,6 +725,10 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests): ...@@ -727,6 +725,10 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests):
blob_cache_dir = 'blobs' blob_cache_dir = 'blobs'
shared_blob_dir = True shared_blob_dir = True
class FauxConn:
addr = 'x'
peer_protocol_version = ZEO.zrpc.connection.Connection.current_protocol
class StorageServerClientWrapper: class StorageServerClientWrapper:
def __init__(self): def __init__(self):
...@@ -743,8 +745,8 @@ class StorageServerWrapper: ...@@ -743,8 +745,8 @@ class StorageServerWrapper:
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
self.storage_id = storage_id self.storage_id = storage_id
self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only) self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only)
self.server.notifyConnected(FauxConn())
self.server.register(storage_id, False) self.server.register(storage_id, False)
self.server._thunk = lambda : None
self.server.client = StorageServerClientWrapper() self.server.client = StorageServerClientWrapper()
def sortKey(self): def sortKey(self):
...@@ -766,8 +768,7 @@ class StorageServerWrapper: ...@@ -766,8 +768,7 @@ class StorageServerWrapper:
self.server.tpc_begin(id(transaction), '', '', {}, None, ' ') self.server.tpc_begin(id(transaction), '', '', {}, None, ' ')
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
self.server._restart() assert self.server.vote(id(transaction)) is None
self.server.vote(id(transaction))
result = self.server.client.serials[:] result = self.server.client.serials[:]
del self.server.client.serials[:] del self.server.client.serials[:]
return result return result
...@@ -775,8 +776,11 @@ class StorageServerWrapper: ...@@ -775,8 +776,11 @@ class StorageServerWrapper:
def store(self, oid, serial, data, version_ignored, transaction): def store(self, oid, serial, data, version_ignored, transaction):
self.server.storea(oid, serial, data, id(transaction)) self.server.storea(oid, serial, data, id(transaction))
def send_reply(self, *args): # Masquerade as conn
pass
def tpc_finish(self, transaction, func = lambda: None): def tpc_finish(self, transaction, func = lambda: None):
self.server.tpc_finish(id(transaction)) self.server.tpc_finish(id(transaction)).set_sender(0, self)
def multiple_storages_invalidation_queue_is_not_insane(): def multiple_storages_invalidation_queue_is_not_insane():
...@@ -849,6 +853,7 @@ Now we'll open a storage server on the data, simulating a restart: ...@@ -849,6 +853,7 @@ Now we'll open a storage server on the data, simulating a restart:
>>> fs = FileStorage('t.fs') >>> fs = FileStorage('t.fs')
>>> sv = StorageServer(('', get_port()), dict(fs=fs)) >>> sv = StorageServer(('', get_port()), dict(fs=fs))
>>> s = ZEOStorage(sv, sv.read_only) >>> s = ZEOStorage(sv, sv.read_only)
>>> s.notifyConnected(FauxConn())
>>> s.register('fs', False) >>> s.register('fs', False)
If we ask for the last transaction, we should get the last transaction If we ask for the last transaction, we should get the last transaction
...@@ -941,7 +946,7 @@ def tpc_finish_error(): ...@@ -941,7 +946,7 @@ def tpc_finish_error():
... def close(self): ... def close(self):
... print 'connection closed' ... print 'connection closed'
... trigger = property(lambda self: self) ... trigger = property(lambda self: self)
... pull_trigger = lambda self, func: func() ... pull_trigger = lambda self, func, *args: func(*args)
>>> class ConnectionManager: >>> class ConnectionManager:
... def __init__(self, addr, client, tmin, tmax): ... def __init__(self, addr, client, tmin, tmax):
...@@ -1251,6 +1256,8 @@ Invalidations could cause errors when closing client storages, ...@@ -1251,6 +1256,8 @@ Invalidations could cause errors when closing client storages,
>>> thread.join(1) >>> thread.join(1)
""" """
if sys.version_info >= (2, 6): if sys.version_info >= (2, 6):
import multiprocessing import multiprocessing
...@@ -1259,28 +1266,32 @@ if sys.version_info >= (2, 6): ...@@ -1259,28 +1266,32 @@ if sys.version_info >= (2, 6):
q.put((name, conn.root.x)) q.put((name, conn.root.x))
conn.close() conn.close()
def work_with_multiprocessing(): class MultiprocessingTests(unittest.TestCase):
"""Client storage should work with multi-processing.
>>> import StringIO def test_work_with_multiprocessing(self):
>>> sys.stdin = StringIO.StringIO() "Client storage should work with multi-processing."
>>> addr, _ = start_server()
>>> conn = ZEO.connection(addr)
>>> conn.root.x = 1
>>> transaction.commit()
>>> q = multiprocessing.Queue()
>>> processes = [multiprocessing.Process(
... target=work_with_multiprocessing_process,
... args=(i, addr, q))
... for i in range(3)]
>>> _ = [p.start() for p in processes]
>>> sorted(q.get(timeout=60) for p in processes)
[(0, 1), (1, 1), (2, 1)]
>>> _ = [p.join(30) for p in processes]
>>> conn.close()
"""
self.globs = {}
forker.setUp(self)
addr, adminaddr = self.globs['start_server']()
conn = ZEO.connection(addr)
conn.root.x = 1
transaction.commit()
q = multiprocessing.Queue()
processes = [multiprocessing.Process(
target=work_with_multiprocessing_process,
args=(i, addr, q))
for i in range(3)]
_ = [p.start() for p in processes]
self.assertEqual(sorted(q.get(timeout=300) for p in processes),
[(0, 1), (1, 1), (2, 1)])
_ = [p.join(30) for p in processes]
conn.close()
zope.testing.setupstack.tearDown(self)
else:
class MultiprocessingTests(unittest.TestCase):
pass
slow_test_classes = [ slow_test_classes = [
BlobAdaptedFileStorageTests, BlobWritableCacheTests, BlobAdaptedFileStorageTests, BlobWritableCacheTests,
...@@ -1353,6 +1364,7 @@ def test_suite(): ...@@ -1353,6 +1364,7 @@ def test_suite():
# unit test layer # unit test layer
zeo = unittest.TestSuite() zeo = unittest.TestSuite()
zeo.addTest(unittest.makeSuite(ZODB.tests.util.AAAA_Test_Runner_Hack)) zeo.addTest(unittest.makeSuite(ZODB.tests.util.AAAA_Test_Runner_Hack))
zeo.addTest(unittest.makeSuite(MultiprocessingTests))
zeo.addTest(doctest.DocTestSuite( zeo.addTest(doctest.DocTestSuite(
setUp=forker.setUp, tearDown=zope.testing.setupstack.tearDown)) setUp=forker.setUp, tearDown=zope.testing.setupstack.tearDown))
zeo.addTest(doctest.DocTestSuite(ZEO.tests.IterationTests, zeo.addTest(doctest.DocTestSuite(ZEO.tests.IterationTests,
......
...@@ -93,9 +93,9 @@ client will be restarted. It will get a conflict error, that is ...@@ -93,9 +93,9 @@ client will be restarted. It will get a conflict error, that is
handled correctly: handled correctly:
>>> zs1.tpc_abort('0') # doctest: +ELLIPSIS >>> zs1.tpc_abort('0') # doctest: +ELLIPSIS
(511/test-addr) ('1') unlock: transactions waiting: 0
2 callAsync serialnos ... 2 callAsync serialnos ...
reply 1 None reply 1 None
(511/test-addr) Blocked transaction restarted.
>>> fs.tpc_transaction() is not None >>> fs.tpc_transaction() is not None
True True
......
...@@ -55,6 +55,16 @@ class Delay: ...@@ -55,6 +55,16 @@ class Delay:
log("Error raised in delayed method", logging.ERROR, exc_info=True) log("Error raised in delayed method", logging.ERROR, exc_info=True)
self.conn.return_error(self.msgid, *exc_info[:2]) self.conn.return_error(self.msgid, *exc_info[:2])
class Result(Delay):
def __init__(self, *args):
self.args = args
def set_sender(self, msgid, conn):
reply, callback = self.args
conn.send_reply(msgid, reply, False)
callback()
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
...@@ -218,18 +228,25 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -218,18 +228,25 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# restorea, iterator_start, iterator_next, # restorea, iterator_start, iterator_next,
# iterator_record_start, iterator_record_next, # iterator_record_start, iterator_record_next,
# iterator_gc # iterator_gc
#
# Z310 -- named after the ZODB release 3.10
# New server methods:
# undoa
# Doesn't support undo for older clients.
# Undone oid info returned by vote.
# Protocol variables: # Protocol variables:
# Our preferred protocol. # Our preferred protocol.
current_protocol = "Z309" current_protocol = "Z310"
# If we're a client, an exhaustive list of the server protocols we # If we're a client, an exhaustive list of the server protocols we
# can accept. # can accept.
servers_we_can_talk_to = ["Z308", current_protocol] servers_we_can_talk_to = ["Z308", "Z309", current_protocol]
# If we're a server, an exhaustive list of the client protocols we # If we're a server, an exhaustive list of the client protocols we
# can accept. # can accept.
clients_we_can_talk_to = ["Z200", "Z201", "Z303", "Z308", current_protocol] clients_we_can_talk_to = [
"Z200", "Z201", "Z303", "Z308", "Z309", current_protocol]
# This is pretty excruciating. Details: # This is pretty excruciating. Details:
# #
......
...@@ -666,32 +666,11 @@ class Connection(ExportImport, object): ...@@ -666,32 +666,11 @@ class Connection(ExportImport, object):
self._cache.update_object_size_estimation(oid, len(p)) self._cache.update_object_size_estimation(oid, len(p))
obj._p_estimated_size = len(p) obj._p_estimated_size = len(p)
self._handle_serial(s, oid) self._handle_serial(oid, s)
def _handle_serial(self, store_return, oid=None, change=1): def _handle_serial(self, oid, serial, change=True):
"""Handle the returns from store() and tpc_vote() calls.""" if not serial:
# These calls can return different types depending on whether
# ZEO is used. ZEO uses asynchronous returns that may be
# returned in batches by the ClientStorage. ZEO1 can also
# return an exception object and expect that the Connection
# will raise the exception.
# When conflict resolution occurs, the object state held by
# the connection does not match what is written to the
# database. Invalidate the object here to guarantee that
# the new state is read the next time the object is used.
if not store_return:
return return
if isinstance(store_return, str):
assert oid is not None
self._handle_one_serial(oid, store_return, change)
else:
for oid, serial in store_return:
self._handle_one_serial(oid, serial, change)
def _handle_one_serial(self, oid, serial, change):
if not isinstance(serial, str): if not isinstance(serial, str):
raise serial raise serial
obj = self._cache.get(oid, None) obj = self._cache.get(oid, None)
...@@ -757,7 +736,9 @@ class Connection(ExportImport, object): ...@@ -757,7 +736,9 @@ class Connection(ExportImport, object):
except AttributeError: except AttributeError:
return return
s = vote(transaction) s = vote(transaction)
self._handle_serial(s) if s:
for oid, serial in s:
self._handle_serial(oid, serial)
def tpc_finish(self, transaction): def tpc_finish(self, transaction):
"""Indicate confirmation that the transaction is done.""" """Indicate confirmation that the transaction is done."""
...@@ -1171,7 +1152,7 @@ class Connection(ExportImport, object): ...@@ -1171,7 +1152,7 @@ class Connection(ExportImport, object):
s = self._storage.store(oid, serial, data, s = self._storage.store(oid, serial, data,
'', transaction) '', transaction)
self._handle_serial(s, oid, change=False) self._handle_serial(oid, s, change=False)
src.close() src.close()
def _abort_savepoint(self): def _abort_savepoint(self):
......
...@@ -158,6 +158,7 @@ class ConflictResolvingTransUndoStorage: ...@@ -158,6 +158,7 @@ class ConflictResolvingTransUndoStorage:
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
self._storage.undo(tid, t) self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
def checkUndoUnresolvable(self): def checkUndoUnresolvable(self):
...@@ -177,7 +178,5 @@ class ConflictResolvingTransUndoStorage: ...@@ -177,7 +178,5 @@ class ConflictResolvingTransUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self.assertRaises(UndoError, self._begin_undos_vote, t, tid)
self.assertRaises(UndoError, self._storage.undo,
tid, t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
...@@ -122,7 +122,7 @@ class RevisionStorage: ...@@ -122,7 +122,7 @@ class RevisionStorage:
tid = info[0]["id"] tid = info[0]["id"]
# Always undo the most recent txn, so the value will # Always undo the most recent txn, so the value will
# alternate between 3 and 4. # alternate between 3 and 4.
self._undo(tid, [oid], note="undo %d" % i) self._undo(tid, note="undo %d" % i)
revs.append(self._storage.load(oid, "")) revs.append(self._storage.load(oid, ""))
prev_tid = None prev_tid = None
......
...@@ -209,10 +209,12 @@ class StorageTestBase(ZODB.tests.util.TestCase): ...@@ -209,10 +209,12 @@ class StorageTestBase(ZODB.tests.util.TestCase):
t = transaction.Transaction() t = transaction.Transaction()
t.note(note or "undo") t.note(note or "undo")
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
tid, oids = self._storage.undo(tid, t) undo_result = self._storage.undo(tid, t)
self._storage.tpc_vote(t) vote_result = self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
if expected_oids is not None: if expected_oids is not None:
oids = undo_result and undo_result[1] or []
oids.extend(oid for (oid, _) in vote_result or ())
self.assertEqual(len(oids), len(expected_oids), repr(oids)) self.assertEqual(len(oids), len(expected_oids), repr(oids))
for oid in expected_oids: for oid in expected_oids:
self.assert_(oid in oids) self.assert_(oid in oids)
......
...@@ -101,12 +101,20 @@ class TransactionalUndoStorage: ...@@ -101,12 +101,20 @@ class TransactionalUndoStorage:
for rec in txn: for rec in txn:
pass pass
def _begin_undos_vote(self, t, *tids):
self._storage.tpc_begin(t)
oids = []
for tid in tids:
undo_result = self._storage.undo(tid, t)
if undo_result:
oids.extend(undo_result[1])
oids.extend(oid for (oid, _) in self._storage.tpc_vote(t) or ())
return oids
def undo(self, tid, note): def undo(self, tid, note):
t = Transaction() t = Transaction()
t.note(note) t.note(note)
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
oids = self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
return oids return oids
...@@ -152,9 +160,7 @@ class TransactionalUndoStorage: ...@@ -152,9 +160,7 @@ class TransactionalUndoStorage:
tid = info[0]['id'] tid = info[0]['id']
t = Transaction() t = Transaction()
t.note('undo1') t.note('undo1')
self._storage.tpc_begin(t) self._begin_undos_vote(t, tid)
self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
# Check that calling getTid on an uncreated object raises a KeyError # Check that calling getTid on an uncreated object raises a KeyError
# The current version of FileStorage fails this test # The current version of FileStorage fails this test
...@@ -281,14 +287,10 @@ class TransactionalUndoStorage: ...@@ -281,14 +287,10 @@ class TransactionalUndoStorage:
tid = info[0]['id'] tid = info[0]['id']
tid1 = info[1]['id'] tid1 = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid, tid1)
tid, oids = self._storage.undo(tid, t)
tid, oids1 = self._storage.undo(tid1, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
# We get the finalization stuff called an extra time: # We get the finalization stuff called an extra time:
eq(len(oids), 2) eq(len(oids), 4)
eq(len(oids1), 2)
unless(oid1 in oids) unless(oid1 in oids)
unless(oid2 in oids) unless(oid2 in oids)
data, revid1 = self._storage.load(oid1, '') data, revid1 = self._storage.load(oid1, '')
...@@ -355,9 +357,7 @@ class TransactionalUndoStorage: ...@@ -355,9 +357,7 @@ class TransactionalUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
tid, oids = self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
eq(len(oids), 1) eq(len(oids), 1)
self.failUnless(oid1 in oids) self.failUnless(oid1 in oids)
...@@ -368,7 +368,6 @@ class TransactionalUndoStorage: ...@@ -368,7 +368,6 @@ class TransactionalUndoStorage:
eq(zodb_unpickle(data), MinPO(54)) eq(zodb_unpickle(data), MinPO(54))
self._iterate() self._iterate()
def checkNotUndoable(self): def checkNotUndoable(self):
eq = self.assertEqual eq = self.assertEqual
# Set things up so we've got a transaction that can't be undone # Set things up so we've got a transaction that can't be undone
...@@ -380,10 +379,7 @@ class TransactionalUndoStorage: ...@@ -380,10 +379,7 @@ class TransactionalUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self.assertRaises(POSException.UndoError, self._begin_undos_vote, t, tid)
self.assertRaises(POSException.UndoError,
self._storage.undo,
tid, t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
# Now have more fun: object1 and object2 are in the same transaction, # Now have more fun: object1 and object2 are in the same transaction,
# which we'll try to undo to, but one of them has since modified in # which we'll try to undo to, but one of them has since modified in
...@@ -419,10 +415,7 @@ class TransactionalUndoStorage: ...@@ -419,10 +415,7 @@ class TransactionalUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self.assertRaises(POSException.UndoError, self._begin_undos_vote, t, tid)
self.assertRaises(POSException.UndoError,
self._storage.undo,
tid, t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
self._iterate() self._iterate()
...@@ -462,9 +455,7 @@ class TransactionalUndoStorage: ...@@ -462,9 +455,7 @@ class TransactionalUndoStorage:
self.assertEqual(len(info2), 2) self.assertEqual(len(info2), 2)
# And now attempt to undo the last transaction # And now attempt to undo the last transaction
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
tid, oids = self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
self.assertEqual(len(oids), 1) self.assertEqual(len(oids), 1)
self.assertEqual(oids[0], oid) self.assertEqual(oids[0], oid)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment