Commit 61cbc091 authored by Jeremy Hylton's avatar Jeremy Hylton

Backport atomic invalidations code from Zope3.

The DB's invalidate() method takes a set of oids corresponding to all
the changes from a data manager for one transaction.  All the objects
are invalidated at once.

Add a few tests in testZODB of the new code.  The tests just cover
corner cases, because I can't think of a sensible way to test the
atomicity.  When it has failed in the past, it's been caused by
nearly-impossible to reproduce data races.

This fix needs to be backported to Zope 2.6, but only after assessing
how significant an impact the API change will have.
parent 38dddc5d
...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them. ...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them.
static char cPickleCache_doc_string[] = static char cPickleCache_doc_string[] =
"Defines the PickleCache used by ZODB Connection objects.\n" "Defines the PickleCache used by ZODB Connection objects.\n"
"\n" "\n"
"$Id: cPickleCache.c,v 1.80 2003/04/02 16:50:49 jeremy Exp $\n"; "$Id: cPickleCache.c,v 1.81 2003/04/08 15:55:44 jeremy Exp $\n";
#define ASSIGN(V,E) {PyObject *__e; __e=(E); Py_XDECREF(V); (V)=__e;} #define ASSIGN(V,E) {PyObject *__e; __e=(E); Py_XDECREF(V); (V)=__e;}
#define UNLESS(E) if(!(E)) #define UNLESS(E) if(!(E))
...@@ -352,6 +352,7 @@ cc_invalidate(ccobject *self, PyObject *args) ...@@ -352,6 +352,7 @@ cc_invalidate(ccobject *self, PyObject *args)
_invalidate(self, key); _invalidate(self, key);
Py_DECREF(key); Py_DECREF(key);
} }
/* XXX Do we really want to modify the input? */
PySequence_DelSlice(inv, 0, l); PySequence_DelSlice(inv, 0, l);
} }
} }
......
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.87 2003/03/04 21:00:23 jeremy Exp $""" $Id: Connection.py,v 1.88 2003/04/08 15:55:44 jeremy Exp $"""
from cPickleCache import PickleCache from cPickleCache import PickleCache
from POSException import ConflictError, ReadConflictError from POSException import ConflictError, ReadConflictError, TransactionError
from ExtensionClass import Base from ExtensionClass import Base
import ExportImport, TmpStore import ExportImport, TmpStore
from zLOG import LOG, ERROR, BLATHER, WARNING from zLOG import LOG, ERROR, BLATHER, WARNING
...@@ -27,6 +26,7 @@ from Transaction import Transaction, get_transaction ...@@ -27,6 +26,7 @@ from Transaction import Transaction, get_transaction
from cPickle import Unpickler, Pickler from cPickle import Unpickler, Pickler
from cStringIO import StringIO from cStringIO import StringIO
import sys import sys
import threading
from time import time from time import time
from types import StringType, ClassType from types import StringType, ClassType
...@@ -73,14 +73,29 @@ class Connection(ExportImport.ExportImport): ...@@ -73,14 +73,29 @@ class Connection(ExportImport.ExportImport):
# XXX Why do we want version caches to behave this way? # XXX Why do we want version caches to behave this way?
self._cache.cache_drain_resistance = 100 self._cache.cache_drain_resistance = 100
self._incrgc=self.cacheGC=cache.incrgc self._incrgc = self.cacheGC = cache.incrgc
self._invalidated=d={} self._committed = []
self._invalid=d.has_key
self._committed=[]
self._code_timestamp = global_code_timestamp self._code_timestamp = global_code_timestamp
self._load_count = 0 # Number of objects unghosted self._load_count = 0 # Number of objects unghosted
self._store_count = 0 # Number of objects stored self._store_count = 0 # Number of objects stored
# _invalidated queues invalidate messages delivered from the DB
# _inv_lock prevents one thread from modifying the set while
# another is processing invalidations. All the invalidations
# from a single transaction should be applied atomically, so
# the lock must be held when reading _invalidated.
# XXX It sucks that we have to hold the lock to read
# _invalidated. Normally, _invalidated is written by call
# dict.update, which will execute atomically by virtue of the
# GIL. But some storage might generate oids where hash or
# compare invokes Python code. In that case, the GIL can't
# save us.
self._inv_lock = threading.Lock()
self._invalidated = d = {}
self._invalid = d.has_key
self._conflicts = {}
def getTransaction(self): def getTransaction(self):
t = self._transaction t = self._transaction
if t is None: if t is None:
...@@ -91,8 +106,6 @@ class Connection(ExportImport.ExportImport): ...@@ -91,8 +106,6 @@ class Connection(ExportImport.ExportImport):
def setLocalTransaction(self): def setLocalTransaction(self):
"""Use a transaction bound to the connection rather than the thread""" """Use a transaction bound to the connection rather than the thread"""
if self._transaction is None: if self._transaction is None:
# XXX The connection may already be registered with a
# transaction. I guess we should abort that transaction.
self._transaction = Transaction() self._transaction = Transaction()
return self._transaction return self._transaction
...@@ -150,7 +163,7 @@ class Connection(ExportImport.ExportImport): ...@@ -150,7 +163,7 @@ class Connection(ExportImport.ExportImport):
not args and not hasattr(klass,'__getinitargs__')): not args and not hasattr(klass,'__getinitargs__')):
object=klass.__basicnew__() object=klass.__basicnew__()
else: else:
object=apply(klass,args) object = klass(*args)
if klass is not ExtensionKlass: if klass is not ExtensionKlass:
object.__dict__.clear() object.__dict__.clear()
...@@ -221,7 +234,7 @@ class Connection(ExportImport.ExportImport): ...@@ -221,7 +234,7 @@ class Connection(ExportImport.ExportImport):
# New code is in place. Start a new cache. # New code is in place. Start a new cache.
self._resetCache() self._resetCache()
else: else:
self._cache.invalidate(self._invalidated) self._flush_invalidations()
self._opened=time() self._opened=time()
return self return self
...@@ -242,7 +255,7 @@ class Connection(ExportImport.ExportImport): ...@@ -242,7 +255,7 @@ class Connection(ExportImport.ExportImport):
This just deactivates the thing. This just deactivates the thing.
""" """
if object is self: if object is self:
self._cache.invalidate(self._invalidated) self._flush_invalidations()
else: else:
assert object._p_oid is not None assert object._p_oid is not None
self._cache.invalidate(object._p_oid) self._cache.invalidate(object._p_oid)
...@@ -263,7 +276,6 @@ class Connection(ExportImport.ExportImport): ...@@ -263,7 +276,6 @@ class Connection(ExportImport.ExportImport):
def close(self): def close(self):
self._incrgc() # This is a good time to do some GC self._incrgc() # This is a good time to do some GC
db=self._db
# Call the close callbacks. # Call the close callbacks.
if self.__onCloseCallbacks is not None: if self.__onCloseCallbacks is not None:
...@@ -274,10 +286,10 @@ class Connection(ExportImport.ExportImport): ...@@ -274,10 +286,10 @@ class Connection(ExportImport.ExportImport):
LOG('ZODB',ERROR, 'Close callback failed for %s' % f, LOG('ZODB',ERROR, 'Close callback failed for %s' % f,
error=sys.exc_info()) error=sys.exc_info())
self.__onCloseCallbacks = None self.__onCloseCallbacks = None
self._db=self._storage=self._tmp=self.new_oid=self._opened=None self._storage = self._tmp = self.new_oid = self._opened = None
self._debug_info=() self._debug_info = ()
# Return the connection to the pool. # Return the connection to the pool.
db._closeConnection(self) self._db._closeConnection(self)
__onCommitActions = None __onCommitActions = None
...@@ -292,10 +304,13 @@ class Connection(ExportImport.ExportImport): ...@@ -292,10 +304,13 @@ class Connection(ExportImport.ExportImport):
# We registered ourself. Execute a commit action, if any. # We registered ourself. Execute a commit action, if any.
if self.__onCommitActions is not None: if self.__onCommitActions is not None:
method_name, args, kw = self.__onCommitActions.pop(0) method_name, args, kw = self.__onCommitActions.pop(0)
apply(getattr(self, method_name), (transaction,) + args, kw) getattr(self, method_name)(transaction, *args, **kw)
return return
oid = object._p_oid oid = object._p_oid
if self._conflicts.has_key(oid):
raise ReadConflictError(oid)
invalid = self._invalid invalid = self._invalid
if oid is None or object._p_jar is not self: if oid is None or object._p_jar is not self:
# new object # new object
...@@ -305,9 +320,11 @@ class Connection(ExportImport.ExportImport): ...@@ -305,9 +320,11 @@ class Connection(ExportImport.ExportImport):
self._creating.append(oid) self._creating.append(oid)
elif object._p_changed: elif object._p_changed:
if invalid(oid) and not hasattr(object, '_p_resolveConflict'): if invalid(oid):
raise ConflictError(object=object) resolve = getattr(object, "_p_resolveConflict", None)
self._invalidating.append(oid) if resolve is None:
raise ConflictError(object=object)
self._modified.append(oid)
else: else:
# Nothing to do # Nothing to do
...@@ -369,7 +386,7 @@ class Connection(ExportImport.ExportImport): ...@@ -369,7 +386,7 @@ class Connection(ExportImport.ExportImport):
#XXX We should never get here #XXX We should never get here
if invalid(oid) and not hasattr(object, '_p_resolveConflict'): if invalid(oid) and not hasattr(object, '_p_resolveConflict'):
raise ConflictError(object=object) raise ConflictError(object=object)
self._invalidating.append(oid) self._modified.append(oid)
klass = object.__class__ klass = object.__class__
...@@ -433,9 +450,9 @@ class Connection(ExportImport.ExportImport): ...@@ -433,9 +450,9 @@ class Connection(ExportImport.ExportImport):
oids=src._index.keys() oids=src._index.keys()
# Copy invalidating and creating info from temporary storage: # Copy invalidating and creating info from temporary storage:
invalidating=self._invalidating modified = self._modified
invalidating[len(invalidating):]=oids modified[len(modified):] = oids
creating=self._creating creating = self._creating
creating[len(creating):]=src._creating creating[len(creating):]=src._creating
for oid in oids: for oid in oids:
...@@ -479,15 +496,28 @@ class Connection(ExportImport.ExportImport): ...@@ -479,15 +496,28 @@ class Connection(ExportImport.ExportImport):
def isReadOnly(self): def isReadOnly(self):
return self._storage.isReadOnly() return self._storage.isReadOnly()
def invalidate(self, oid): def invalidate(self, oids):
"""Invalidate a particular oid """Invalidate a set of oids.
This marks the oid as invalid, but doesn't actually invalidate This marks the oid as invalid, but doesn't actually invalidate
it. The object data will be actually invalidated at certain it. The object data will be actually invalidated at certain
transaction boundaries. transaction boundaries.
""" """
assert oid is not None self._inv_lock.acquire()
self._invalidated[oid] = 1 try:
self._invalidated.update(oids)
finally:
self._inv_lock.release()
def _flush_invalidations(self):
self._inv_lock.acquire()
try:
self._cache.invalidate(self._invalidated)
self._invalidated.clear()
finally:
self._inv_lock.release()
# Now is a good time to collect some garbage
self._cache.incrgc()
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
try: return self._db.modifiedInVersion(oid) try: return self._db.modifiedInVersion(oid)
...@@ -508,8 +538,8 @@ class Connection(ExportImport.ExportImport): ...@@ -508,8 +538,8 @@ class Connection(ExportImport.ExportImport):
def root(self): def root(self):
return self['\0\0\0\0\0\0\0\0'] return self['\0\0\0\0\0\0\0\0']
def setstate(self, object): def setstate(self, obj):
oid = object._p_oid oid = obj._p_oid
if self._storage is None: if self._storage is None:
msg = ("Shouldn't load state for %s " msg = ("Shouldn't load state for %s "
...@@ -518,54 +548,20 @@ class Connection(ExportImport.ExportImport): ...@@ -518,54 +548,20 @@ class Connection(ExportImport.ExportImport):
raise RuntimeError(msg) raise RuntimeError(msg)
try: try:
# Avoid reading data from a transaction that committed
# after the current transaction started, as that might
# lead to mixing of cached data from earlier transactions
# and new inconsistent data.
#
# Wait for check until after data is loaded from storage
# to avoid time-of-check to time-of-use race.
p, serial = self._storage.load(oid, self._version) p, serial = self._storage.load(oid, self._version)
self._load_count = self._load_count + 1 self._load_count = self._load_count + 1
invalid = self._is_invalidated(obj)
# XXX this is quite conservative! self._set_ghost_state(obj, p)
# We need, however, to avoid reading data from a transaction obj._p_serial = serial
# that committed after the current "session" started, as
# that might lead to mixing of cached data from earlier
# transactions and new inconsistent data.
#
# Note that we (carefully) wait until after we call the
# storage to make sure that we don't miss an invaildation
# notifications between the time we check and the time we
# read.
# XXX Need unit tests for _p_independent.
if self._invalid(oid):
if not hasattr(object.__class__, '_p_independent'):
self.getTransaction().register(self)
raise ReadConflictError(object=object)
invalid = 1
else:
invalid = 0
file = StringIO(p)
unpickler = Unpickler(file)
unpickler.persistent_load = self._persistent_load
unpickler.load()
state = unpickler.load()
if hasattr(object, '__setstate__'):
object.__setstate__(state)
else:
d = object.__dict__
for k, v in state.items():
d[k] = v
object._p_serial = serial
if invalid: if invalid:
if object._p_independent(): self._handle_independent(obj)
try:
del self._invalidated[oid]
except KeyError:
pass
else:
self.getTransaction().register(self)
raise ConflictError(object=object)
except ConflictError: except ConflictError:
raise raise
except: except:
...@@ -573,6 +569,59 @@ class Connection(ExportImport.ExportImport): ...@@ -573,6 +569,59 @@ class Connection(ExportImport.ExportImport):
error=sys.exc_info()) error=sys.exc_info())
raise raise
def _is_invalidated(self, obj):
# Helper method for setstate() covers three cases:
# returns false if obj is valid
# returns true if obj was invalidation, but is independent
# otherwise, raises ConflictError for invalidated objects
self._inv_lock.acquire()
try:
if self._invalidated.has_key(obj._p_oid):
# Defer _p_independent() call until state is loaded.
ind = getattr(obj, "_p_independent", None)
if ind is not None:
# Defer _p_independent() call until state is loaded.
return 1
else:
self.getTransaction().register(self)
self._conflicts[obj._p_oid] = 1
raise ReadConflictError(object=obj)
else:
return 0
finally:
self._inv_lock.release()
def _set_ghost_state(self, obj, p):
file = StringIO(p)
unpickler = Unpickler(file)
unpickler.persistent_load = self._persistent_load
unpickler.load()
state = unpickler.load()
setstate = getattr(obj, "__setstate__", None)
if setstate is None:
obj.update(state)
else:
setstate(state)
def _handle_independent(self, obj):
# Helper method for setstate() handles possibly independent objects
# Call _p_independent(), if it returns True, setstate() wins.
# Otherwise, raise a ConflictError.
if obj._p_independent():
self._inv_lock.acquire()
try:
try:
del self._invalidated[obj._p_oid]
except KeyError:
pass
finally:
self._inv_lock.release()
else:
self.getTransaction().register(obj)
raise ReadConflictError(object=obj)
def oldstate(self, object, serial): def oldstate(self, object, serial):
oid=object._p_oid oid=object._p_oid
p = self._storage.loadSerial(oid, serial) p = self._storage.loadSerial(oid, serial)
...@@ -601,7 +650,7 @@ class Connection(ExportImport.ExportImport): ...@@ -601,7 +650,7 @@ class Connection(ExportImport.ExportImport):
% getattr(object,'__name__','(?)')) % getattr(object,'__name__','(?)'))
return return
copy=apply(klass,args) copy = klass(*args)
object.__dict__.clear() object.__dict__.clear()
object.__dict__.update(copy.__dict__) object.__dict__.update(copy.__dict__)
...@@ -617,12 +666,13 @@ class Connection(ExportImport.ExportImport): ...@@ -617,12 +666,13 @@ class Connection(ExportImport.ExportImport):
if self.__onCommitActions is not None: if self.__onCommitActions is not None:
del self.__onCommitActions del self.__onCommitActions
self._storage.tpc_abort(transaction) self._storage.tpc_abort(transaction)
self._cache.invalidate(self._invalidated) self._cache.invalidate(self._modified)
self._cache.invalidate(self._invalidating) self._flush_invalidations()
self._conflicts.clear()
self._invalidate_creating() self._invalidate_creating()
def tpc_begin(self, transaction, sub=None): def tpc_begin(self, transaction, sub=None):
self._invalidating = [] self._modified = []
self._creating = [] self._creating = []
if sub: if sub:
# Sub-transaction! # Sub-transaction!
...@@ -688,10 +738,10 @@ class Connection(ExportImport.ExportImport): ...@@ -688,10 +738,10 @@ class Connection(ExportImport.ExportImport):
def tpc_finish(self, transaction): def tpc_finish(self, transaction):
# It's important that the storage call the function we pass # It's important that the storage call the function we pass
# (self._invalidate_invalidating) while it still has it's # while it still has it's lock. We don't want another thread
# lock. We don't want another thread to be able to read any # to be able to read any updated data until we've had a chance
# updated data until we've had a chance to send an # to send an invalidation message to all of the other
# invalidation message to all of the other connections! # connections!
if self._tmp is not None: if self._tmp is not None:
# Commiting a subtransaction! # Commiting a subtransaction!
...@@ -700,25 +750,21 @@ class Connection(ExportImport.ExportImport): ...@@ -700,25 +750,21 @@ class Connection(ExportImport.ExportImport):
self._storage._creating[:0]=self._creating self._storage._creating[:0]=self._creating
del self._creating[:] del self._creating[:]
else: else:
self._db.begin_invalidation() def callback():
self._storage.tpc_finish(transaction, d = {}
self._invalidate_invalidating) for oid in self._modified:
d[oid] = 1
self._db.invalidate(d, self)
self._storage.tpc_finish(transaction, callback)
self._cache.invalidate(self._invalidated) self._conflicts.clear()
self._incrgc() # This is a good time to do some GC self._flush_invalidations()
def _invalidate_invalidating(self):
for oid in self._invalidating:
assert oid is not None
self._db.invalidate(oid, self)
self._db.finish_invalidation()
def sync(self): def sync(self):
self.getTransaction().abort() self.getTransaction().abort()
sync=getattr(self._storage, 'sync', 0) sync=getattr(self._storage, 'sync', 0)
if sync != 0: sync() if sync != 0: sync()
self._cache.invalidate(self._invalidated) self._flush_invalidations()
self._incrgc() # This is a good time to do some GC
def getDebugInfo(self): def getDebugInfo(self):
return self._debug_info return self._debug_info
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
############################################################################## ##############################################################################
"""Database objects """Database objects
$Id: DB.py,v 1.47 2003/01/17 17:23:14 shane Exp $""" $Id: DB.py,v 1.48 2003/04/08 15:55:44 jeremy Exp $"""
__version__='$Revision: 1.47 $'[11:-2] __version__='$Revision: 1.48 $'[11:-2]
import cPickle, cStringIO, sys, POSException, UndoLogCompatible import cPickle, cStringIO, sys, POSException, UndoLogCompatible
from Connection import Connection from Connection import Connection
...@@ -26,6 +26,12 @@ from zLOG import LOG, ERROR ...@@ -26,6 +26,12 @@ from zLOG import LOG, ERROR
from types import StringType from types import StringType
def list2dict(L):
d = {}
for elt in L:
d[elt] = 1
return d
class DB(UndoLogCompatible.UndoLogCompatible): class DB(UndoLogCompatible.UndoLogCompatible):
"""The Object Database """The Object Database
...@@ -282,17 +288,7 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -282,17 +288,7 @@ class DB(UndoLogCompatible.UndoLogCompatible):
def importFile(self, file): def importFile(self, file):
raise 'Not yet implemented' raise 'Not yet implemented'
def begin_invalidation(self): def invalidate(self, oids, connection=None, version='',
# Must be called before first call to invalidate and before
# the storage lock is held.
self._a()
def finish_invalidation(self):
# Must be called after begin_invalidation() and after final
# invalidate() call.
self._r()
def invalidate(self, oid, connection=None, version='',
rc=sys.getrefcount): rc=sys.getrefcount):
"""Invalidate references to a given oid. """Invalidate references to a given oid.
...@@ -304,9 +300,11 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -304,9 +300,11 @@ class DB(UndoLogCompatible.UndoLogCompatible):
if connection is not None: if connection is not None:
version=connection._version version=connection._version
# Update modified in version cache # Update modified in version cache
h=hash(oid)%131 # XXX must make this work with list or dict to backport to 2.6
o=self._miv_cache.get(h, None) for oid in oids:
if o is not None and o[0]==oid: del self._miv_cache[h] h=hash(oid)%131
o=self._miv_cache.get(h, None)
if o is not None and o[0]==oid: del self._miv_cache[h]
# Notify connections # Notify connections
for pool, allocated in self._pools[1]: for pool, allocated in self._pools[1]:
...@@ -315,7 +313,7 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -315,7 +313,7 @@ class DB(UndoLogCompatible.UndoLogCompatible):
(not version or cc._version==version)): (not version or cc._version==version)):
if rc(cc) <= 3: if rc(cc) <= 3:
cc.close() cc.close()
cc.invalidate(oid) cc.invalidate(oids)
temps=self._temps temps=self._temps
if temps: if temps:
...@@ -324,7 +322,7 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -324,7 +322,7 @@ class DB(UndoLogCompatible.UndoLogCompatible):
if rc(cc) > 3: if rc(cc) > 3:
if (cc is not connection and if (cc is not connection and
(not version or cc._version==version)): (not version or cc._version==version)):
cc.invalidate(oid) cc.invalidate(oids)
t.append(cc) t.append(cc)
else: cc.close() else: cc.close()
self._temps=t self._temps=t
...@@ -561,8 +559,10 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -561,8 +559,10 @@ class DB(UndoLogCompatible.UndoLogCompatible):
transaction.register(TransactionalUndo(self, id)) transaction.register(TransactionalUndo(self, id))
else: else:
# fall back to old undo # fall back to old undo
d = {}
for oid in storage.undo(id): for oid in storage.undo(id):
self.invalidate(oid) d[oid] = 1
self.invalidate(d)
def versionEmpty(self, version): def versionEmpty(self, version):
return self._storage.versionEmpty(version) return self._storage.versionEmpty(version)
...@@ -589,14 +589,14 @@ class CommitVersion: ...@@ -589,14 +589,14 @@ class CommitVersion:
def abort(self, reallyme, t): pass def abort(self, reallyme, t): pass
def commit(self, reallyme, t): def commit(self, reallyme, t):
db=self._db
dest=self._dest dest=self._dest
oids=db._storage.commitVersion(self._version, dest, t) oids = self._db._storage.commitVersion(self._version, dest, t)
for oid in oids: db.invalidate(oid, version=dest) oids = list2dict(oids)
self._db.invalidate(oids, version=dest)
if dest: if dest:
# the code above just invalidated the dest version. # the code above just invalidated the dest version.
# now we need to invalidate the source! # now we need to invalidate the source!
for oid in oids: db.invalidate(oid, version=self._version) self._db.invalidate(oids, version=self._version)
class AbortVersion(CommitVersion): class AbortVersion(CommitVersion):
"""An object that will see to version abortion """An object that will see to version abortion
...@@ -605,11 +605,9 @@ class AbortVersion(CommitVersion): ...@@ -605,11 +605,9 @@ class AbortVersion(CommitVersion):
""" """
def commit(self, reallyme, t): def commit(self, reallyme, t):
db=self._db
version=self._version version=self._version
oids = db._storage.abortVersion(version, t) oids = self._db._storage.abortVersion(version, t)
for oid in oids: self._db.invalidate(list2dict(oids), version=version)
db.invalidate(oid, version=version)
class TransactionalUndo(CommitVersion): class TransactionalUndo(CommitVersion):
...@@ -623,7 +621,5 @@ class TransactionalUndo(CommitVersion): ...@@ -623,7 +621,5 @@ class TransactionalUndo(CommitVersion):
# similarity of rhythm that I think it's justified. # similarity of rhythm that I think it's justified.
def commit(self, reallyme, t): def commit(self, reallyme, t):
db=self._db oids = self._db._storage.transactionalUndo(self._version, t)
oids=db._storage.transactionalUndo(self._version, t) self._db.invalidate(list2dict(oids))
for oid in oids:
db.invalidate(oid)
...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them. ...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them.
static char cPickleCache_doc_string[] = static char cPickleCache_doc_string[] =
"Defines the PickleCache used by ZODB Connection objects.\n" "Defines the PickleCache used by ZODB Connection objects.\n"
"\n" "\n"
"$Id: cPickleCache.c,v 1.80 2003/04/02 16:50:49 jeremy Exp $\n"; "$Id: cPickleCache.c,v 1.81 2003/04/08 15:55:44 jeremy Exp $\n";
#define ASSIGN(V,E) {PyObject *__e; __e=(E); Py_XDECREF(V); (V)=__e;} #define ASSIGN(V,E) {PyObject *__e; __e=(E); Py_XDECREF(V); (V)=__e;}
#define UNLESS(E) if(!(E)) #define UNLESS(E) if(!(E))
...@@ -352,6 +352,7 @@ cc_invalidate(ccobject *self, PyObject *args) ...@@ -352,6 +352,7 @@ cc_invalidate(ccobject *self, PyObject *args)
_invalidate(self, key); _invalidate(self, key);
Py_DECREF(key); Py_DECREF(key);
} }
/* XXX Do we really want to modify the input? */
PySequence_DelSlice(inv, 0, l); PySequence_DelSlice(inv, 0, l);
} }
} }
......
...@@ -11,16 +11,52 @@ ...@@ -11,16 +11,52 @@
# FOR A PARTICULAR PURPOSE. # FOR A PARTICULAR PURPOSE.
# #
############################################################################## ##############################################################################
import sys, os import unittest
import ZODB import ZODB
import ZODB.FileStorage import ZODB.FileStorage
from ZODB.PersistentMapping import PersistentMapping from ZODB.PersistentMapping import PersistentMapping
from ZODB.POSException import ReadConflictError
from ZODB.tests.StorageTestBase import removefs from ZODB.tests.StorageTestBase import removefs
import unittest from Persistence import Persistent
class P(Persistent):
pass
class Independent(Persistent):
def _p_independent(self):
return True
class DecoyIndependent(Persistent):
def _p_independent(self):
return False
class ZODBTests(unittest.TestCase):
def setUp(self):
self._storage = ZODB.FileStorage.FileStorage(
'ZODBTests.fs', create=1)
self._db = ZODB.DB(self._storage)
def populate(self):
get_transaction().begin()
conn = self._db.open()
root = conn.root()
root['test'] = pm = PersistentMapping()
for n in range(100):
pm[n] = PersistentMapping({0: 100 - n})
get_transaction().note('created test data')
get_transaction().commit()
conn.close()
class ExportImportTests: def tearDown(self):
def checkDuplicate(self, abort_it=0, dup_name='test_duplicate'): self._storage.close()
removefs("ZODBTests.fs")
def checkExportImport(self, abort_it=0, dup_name='test_duplicate'):
self.populate()
get_transaction().begin() get_transaction().begin()
get_transaction().note('duplication') get_transaction().note('duplication')
# Duplicate the 'test' object. # Duplicate the 'test' object.
...@@ -83,29 +119,8 @@ class ExportImportTests: ...@@ -83,29 +119,8 @@ class ExportImportTests:
finally: finally:
conn.close() conn.close()
def checkDuplicateAborted(self): def checkExportImportAborted(self):
self.checkDuplicate(abort_it=1, dup_name='test_duplicate_aborted') self.checkExportImport(abort_it=1, dup_name='test_duplicate_aborted')
class ZODBTests(unittest.TestCase, ExportImportTests):
def setUp(self):
self._storage = ZODB.FileStorage.FileStorage(
'ZODBTests.fs', create=1)
self._db = ZODB.DB(self._storage)
get_transaction().begin()
conn = self._db.open()
root = conn.root()
root['test'] = pm = PersistentMapping()
for n in range(100):
pm[n] = PersistentMapping({0: 100 - n})
get_transaction().note('created test data')
get_transaction().commit()
conn.close()
def tearDown(self):
self._storage.close()
removefs("ZODBTests.fs")
def checkVersionOnly(self): def checkVersionOnly(self):
# Make sure the changes to make empty transactions a no-op # Make sure the changes to make empty transactions a no-op
...@@ -124,6 +139,7 @@ class ZODBTests(unittest.TestCase, ExportImportTests): ...@@ -124,6 +139,7 @@ class ZODBTests(unittest.TestCase, ExportImportTests):
def checkResetCache(self): def checkResetCache(self):
# The cache size after a reset should be 0 and the GC attributes # The cache size after a reset should be 0 and the GC attributes
# ought to be linked to it rather than the old cache. # ought to be linked to it rather than the old cache.
self.populate()
conn = self._db.open() conn = self._db.open()
try: try:
conn.root() conn.root()
...@@ -173,10 +189,99 @@ class ZODBTests(unittest.TestCase, ExportImportTests): ...@@ -173,10 +189,99 @@ class ZODBTests(unittest.TestCase, ExportImportTests):
conn1.close() conn1.close()
conn2.close() conn2.close()
def checkReadConflict(self):
self.obj = P()
self.readConflict()
def test_suite(): def readConflict(self, shouldFail=True):
return unittest.makeSuite(ZODBTests, 'check') # Two transactions run concurrently. Each reads some object,
# then one commits and the other tries to read an object
# modified by the first. This read should fail with a conflict
# error because the object state read is not necessarily
# consistent with the objects read earlier in the transaction.
conn = self._db.open()
conn.setLocalTransaction()
r1 = conn.root()
r1["p"] = self.obj
self.obj.child1 = P()
conn.getTransaction().commit()
if __name__=='__main__': # start a new transaction with a new connection
unittest.main(defaultTest='test_suite') cn2 = self._db.open()
# start a new transaction with the other connection
cn2.setLocalTransaction()
r2 = cn2.root()
self.assertEqual(r1._p_serial, r2._p_serial)
self.obj.child2 = P()
conn.getTransaction().commit()
# resume the transaction using cn2
obj = r2["p"]
# An attempt to access obj should fail, because r2 was read
# earlier in the transaction and obj was modified by the othe
# transaction.
if shouldFail:
self.assertRaises(ReadConflictError, lambda: obj.child1)
else:
# make sure that accessing the object succeeds
obj.child1
cn2.getTransaction().abort()
def testReadConflictIgnored(self):
# Test that an application that catches a read conflict and
# continues can not commit the transaction later.
root = self._db.open().root()
root["real_data"] = real_data = PersistentDict()
root["index"] = index = PersistentDict()
real_data["a"] = PersistentDict({"indexed_value": False})
real_data["b"] = PersistentDict({"indexed_value": True})
index[True] = PersistentDict({"b": 1})
index[False] = PersistentDict({"a": 1})
get_transaction().commit()
# load some objects from one connection
cn2 = self._db.open()
cn2.setLocalTransaction()
r2 = cn2.root()
real_data2 = r2["real_data"]
index2 = r2["index"]
real_data["b"]["indexed_value"] = False
del index[True]["b"]
index[False]["b"] = 1
cn2.getTransaction().commit()
del real_data2["a"]
try:
del index2[False]["a"]
except ReadConflictError:
# This is the crux of the text. Ignore the error.
pass
else:
self.fail("No conflict occurred")
# real_data2 still ready to commit
self.assert_(real_data2._p_changed)
# index2 values not ready to commit
self.assert_(not index2._p_changed)
self.assert_(not index2[False]._p_changed)
self.assert_(not index2[True]._p_changed)
self.assertRaises(ConflictError, get_transaction().commit)
get_transaction().abort()
def checkIndependent(self):
self.obj = Independent()
self.readConflict(shouldFail=False)
def checkNotIndependent(self):
self.obj = DecoyIndependent()
self.readConflict()
def test_suite():
return unittest.makeSuite(ZODBTests, 'check')
...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them. ...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them.
static char cPickleCache_doc_string[] = static char cPickleCache_doc_string[] =
"Defines the PickleCache used by ZODB Connection objects.\n" "Defines the PickleCache used by ZODB Connection objects.\n"
"\n" "\n"
"$Id: cPickleCache.c,v 1.80 2003/04/02 16:50:49 jeremy Exp $\n"; "$Id: cPickleCache.c,v 1.81 2003/04/08 15:55:44 jeremy Exp $\n";
#define ASSIGN(V,E) {PyObject *__e; __e=(E); Py_XDECREF(V); (V)=__e;} #define ASSIGN(V,E) {PyObject *__e; __e=(E); Py_XDECREF(V); (V)=__e;}
#define UNLESS(E) if(!(E)) #define UNLESS(E) if(!(E))
...@@ -352,6 +352,7 @@ cc_invalidate(ccobject *self, PyObject *args) ...@@ -352,6 +352,7 @@ cc_invalidate(ccobject *self, PyObject *args)
_invalidate(self, key); _invalidate(self, key);
Py_DECREF(key); Py_DECREF(key);
} }
/* XXX Do we really want to modify the input? */
PySequence_DelSlice(inv, 0, l); PySequence_DelSlice(inv, 0, l);
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment