Commit 326ff391 authored by Jeremy Hylton's avatar Jeremy Hylton

Revise Connection.

Make _added_during_commit a regular instance variable.  Don't use
try/finally to reset it; just clear it at the start of a transaction.
XXX There was a test that needed to be removed, but it seemed to be
just a shallow test that try/finally was used.  Can't see any feature
that depends on specific of error handling: The txn is going to abort.

Remove unused _opened instance variable.
Split commit() into two smaller parts.
Get rid of extra manipulation of _creating.
Don't look for _p_serial of None; z64 is now required.
Undo local variable aliases in subtransaction methods.

Also, trivial change to pickle cache API -- get() works like dict get().
parent 982055e6
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.138 2004/03/12 06:37:23 jeremy Exp $""" $Id: Connection.py,v 1.139 2004/03/13 07:48:11 jeremy Exp $"""
import logging import logging
import sys import sys
...@@ -125,16 +125,15 @@ class Connection(ExportImport, object): ...@@ -125,16 +125,15 @@ class Connection(ExportImport, object):
their state and register changes. The methods are setstate(), their state and register changes. The methods are setstate(),
register(), setklassstate(). register(), setklassstate().
$Id: Connection.py,v 1.138 2004/03/12 06:37:23 jeremy Exp $ $Id: Connection.py,v 1.139 2004/03/13 07:48:11 jeremy Exp $
""" """
_tmp = None _tmp = None
_code_timestamp = 0 _code_timestamp = 0
_transaction = None _transaction = None
_added_during_commit = None
def __init__(self, version='', cache_size=400, def __init__(self, version='', cache_size=400,
cache_deactivate_after=60, mvcc=True): cache_deactivate_after=None, mvcc=True):
"""Create a new Connection. """Create a new Connection.
A Connection instance should by instantiated by the DB A Connection instance should by instantiated by the DB
...@@ -143,7 +142,6 @@ class Connection(ExportImport, object): ...@@ -143,7 +142,6 @@ class Connection(ExportImport, object):
self._log = logging.getLogger("zodb.conn") self._log = logging.getLogger("zodb.conn")
self._storage = None self._storage = None
self._opened = None
self._debug_info = () self._debug_info = ()
self._version = version self._version = version
...@@ -158,6 +156,7 @@ class Connection(ExportImport, object): ...@@ -158,6 +156,7 @@ class Connection(ExportImport, object):
self._cache.cache_drain_resistance = 100 self._cache.cache_drain_resistance = 100
self._committed = [] self._committed = []
self._added = {} self._added = {}
self._added_during_commit = None
self._reset_counter = global_reset_counter self._reset_counter = global_reset_counter
self._load_count = 0 # Number of objects unghosted self._load_count = 0 # Number of objects unghosted
self._store_count = 0 # Number of objects stored self._store_count = 0 # Number of objects stored
...@@ -315,12 +314,10 @@ class Connection(ExportImport, object): ...@@ -315,12 +314,10 @@ class Connection(ExportImport, object):
raise InvalidObjectReference(obj, obj._p_jar) raise InvalidObjectReference(obj, obj._p_jar)
def sortKey(self): def sortKey(self):
# XXX will raise an exception if the DB hasn't been set
storage_key = self._sortKey()
# If two connections use the same storage, give them a # If two connections use the same storage, give them a
# consistent order using id(). This is unique for the # consistent order using id(). This is unique for the
# lifetime of a connection, which is good enough. # lifetime of a connection, which is good enough.
return "%s:%s" % (storage_key, id(self)) return "%s:%s" % (self._sortKey(), id(self))
def _setDB(self, odb): def _setDB(self, odb):
"""Register odb, the DB that this Connection uses. """Register odb, the DB that this Connection uses.
...@@ -348,7 +345,6 @@ class Connection(ExportImport, object): ...@@ -348,7 +345,6 @@ class Connection(ExportImport, object):
self._flush_invalidations() self._flush_invalidations()
self._reader = ConnectionObjectReader(self, self._cache, self._reader = ConnectionObjectReader(self, self._cache,
self._db.classFactory) self._db.classFactory)
self._opened = time()
def _resetCache(self): def _resetCache(self):
"""Creates a new cache, discarding the old. """Creates a new cache, discarding the old.
...@@ -452,148 +448,123 @@ class Connection(ExportImport, object): ...@@ -452,148 +448,123 @@ class Connection(ExportImport, object):
self._log.error("Close callback failed for %s", f, self._log.error("Close callback failed for %s", f,
sys.exc_info()) sys.exc_info())
self.__onCloseCallbacks = None self.__onCloseCallbacks = None
self._storage = self._tmp = self.new_oid = self._opened = None self._storage = self._tmp = self.new_oid = None
self._debug_info = () self._debug_info = ()
# Return the connection to the pool. # Return the connection to the pool.
if self._db is not None: if self._db is not None:
self._db._closeConnection(self) self._db._closeConnection(self)
self._db = None self._db = None
def commit(self, object, transaction): def commit(self, obj, transaction):
if object is self: if obj is self:
# We registered ourself. Execute a commit action, if any. # We registered ourself. Execute a commit action, if any.
if self._import: if self._import:
self._importDuringCommit(transaction, *self._import) self._importDuringCommit(transaction, *self._import)
self._import = None self._import = None
return return
oid = object._p_oid oid = obj._p_oid
if self._conflicts.has_key(oid): if oid in self._conflicts:
self.getTransaction().register(object) self.getTransaction().register(obj)
raise ReadConflictError(object=object) raise ReadConflictError(object=obj)
invalid = self._invalid
# XXX In the case of a new object or an object added using add(), if oid is None or obj._p_jar is not self:
# the oid is appended to _creating.
# However, this ought to be unnecessary because the _p_serial
# of the object will be z64 or None, so it will be appended
# to _creating about 30 lines down. The removal from _added
# ought likewise to be unnecessary.
if oid is None or object._p_jar is not self:
# new object # new object
oid = self.new_oid() oid = self.new_oid()
object._p_jar = self obj._p_jar = self
object._p_oid = oid obj._p_oid = oid
self._creating.append(oid) # maybe don't need this assert obj._p_serial == z64
elif oid in self._added: elif oid in self._added:
# maybe don't need these assert obj._p_serial == z64
self._creating.append(oid) elif obj._p_changed:
del self._added[oid] if oid in self._invalidated:
elif object._p_changed: resolve = getattr(obj, "_p_resolveConflict", None)
if invalid(oid):
resolve = getattr(object, "_p_resolveConflict", None)
if resolve is None: if resolve is None:
raise ConflictError(object=object) raise ConflictError(object=obj)
self._modified.append(oid) self._modified.append(oid)
else: else:
# Nothing to do # Nothing to do
return return
w = ObjectWriter(object) self._store_objects(ObjectWriter(obj), transaction)
def _store_objects(self, writer, transaction):
self._added_during_commit = [] self._added_during_commit = []
try: for obj in itertools.chain(writer, self._added_during_commit):
for obj in itertools.chain(w, self._added_during_commit): oid = obj._p_oid
oid = obj._p_oid serial = getattr(obj, "_p_serial", z64)
serial = getattr(obj, '_p_serial', z64)
if serial == z64:
# XXX which one? z64 or None? Why do I have to check both? # new object
if serial == z64 or serial is None: self._creating.append(oid)
# new object # If this object was added, it is now in _creating, so can
self._creating.append(oid) # be removed from _added.
# If this object was added, it is now in _creating, so can self._added.pop(oid, None)
# be removed from _added. else:
self._added.pop(oid, None) if (oid in self._invalidated
and not hasattr(obj, '_p_resolveConflict')):
raise ConflictError(object=obj)
self._modified.append(oid)
p = writer.serialize(obj) # This calls __getstate__ of obj
s = self._storage.store(oid, serial, p, self._version, transaction)
self._store_count += 1
# Put the object in the cache before handling the
# response, just in case the response contains the
# serial number for a newly created object
try:
self._cache[oid] = obj
except:
# Dang, I bet its wrapped:
if hasattr(obj, 'aq_base'):
self._cache[oid] = obj.aq_base
else: else:
if (invalid(oid) raise
and not hasattr(object, '_p_resolveConflict')):
raise ConflictError(object=obj) self._handle_serial(s, oid)
self._modified.append(oid) self._added_during_commit = None
p = w.serialize(obj) # This calls __getstate__ of obj
s = self._storage.store(oid, serial, p, self._version,
transaction)
self._store_count = self._store_count + 1
# Put the object in the cache before handling the
# response, just in case the response contains the
# serial number for a newly created object
try:
self._cache[oid] = obj
except:
# Dang, I bet its wrapped:
if hasattr(obj, 'aq_base'):
self._cache[oid] = obj.aq_base
else:
raise
self._handle_serial(s, oid)
finally:
del self._added_during_commit
def commit_sub(self, t): def commit_sub(self, t):
"""Commit all work done in all subtransactions for this transaction""" """Commit all work done in all subtransactions for this transaction"""
tmp=self._tmp if self._tmp is None:
if tmp is None: return return
src=self._storage src = self._storage
self._storage = self._tmp
self._log.debug("Commiting subtransaction of size %s", self._tmp = None
src.getSize())
self._log.debug("Commiting subtransaction of size %s", src.getSize())
self._storage=tmp oids = src._index.keys()
self._tmp=None self._storage.tpc_begin(t)
tmp.tpc_begin(t)
load=src.load
store=tmp.store
dest=self._version
oids=src._index.keys()
# Copy invalidating and creating info from temporary storage: # Copy invalidating and creating info from temporary storage:
modified = self._modified self._modified[len(self._modified):] = oids
modified[len(modified):] = oids self._creating[len(self._creating):] = src._creating
creating = self._creating
creating[len(creating):]=src._creating
for oid in oids: for oid in oids:
data, serial = load(oid, src) data, serial = src.load(oid, src)
s=store(oid, serial, data, dest, t) s = self._storage.store(oid, serial, data, self._version, t)
self._handle_serial(s, oid, change=0) self._handle_serial(s, oid, change=False)
def abort_sub(self, t): def abort_sub(self, t):
"""Abort work done in all subtransactions for this transaction""" """Abort work done in all subtransactions for this transaction"""
tmp=self._tmp if self._tmp is None:
if tmp is None: return return
src=self._storage src = self._storage
self._tmp=None self._storage = self._tmp
self._storage=tmp self._tmp = None
self._cache.invalidate(src._index.keys()) self._cache.invalidate(src._index.keys())
self._invalidate_creating(src._creating) self._invalidate_creating(src._creating)
def _invalidate_creating(self, creating=None): def _invalidate_creating(self, creating=None):
"""Dissown any objects newly saved in an uncommitted transaction. """Dissown any objects newly saved in an uncommitted transaction."""
"""
if creating is None: if creating is None:
creating=self._creating creating = self._creating
self._creating=[] self._creating = []
cache=self._cache
cache_get=cache.get
for oid in creating: for oid in creating:
o=cache_get(oid, None) o = self._cache.get(oid)
if o is not None: if o is not None:
del cache[oid] del self._cache[oid]
del o._p_jar del o._p_jar
del o._p_oid del o._p_oid
...@@ -844,6 +815,9 @@ class Connection(ExportImport, object): ...@@ -844,6 +815,9 @@ class Connection(ExportImport, object):
def tpc_begin(self, transaction, sub=None): def tpc_begin(self, transaction, sub=None):
self._modified = [] self._modified = []
# _creating is a list of oids of new objects, which is used to
# remove them from the cache if a transaction aborts.
self._creating = [] self._creating = []
if sub: if sub:
# Sub-transaction! # Sub-transaction!
......
...@@ -250,7 +250,7 @@ class CacheErrors(unittest.TestCase): ...@@ -250,7 +250,7 @@ class CacheErrors(unittest.TestCase):
self.cache = PickleCache(self.jar) self.cache = PickleCache(self.jar)
def checkGetBogusKey(self): def checkGetBogusKey(self):
self.assertRaises(KeyError, self.cache.get, p64(0)) self.assertEqual(self.cache.get(p64(0)), None)
try: try:
self.cache[12] self.cache[12]
except KeyError: except KeyError:
......
...@@ -124,16 +124,6 @@ class ConnectionDotAdd(unittest.TestCase): ...@@ -124,16 +124,6 @@ class ConnectionDotAdd(unittest.TestCase):
"subobject was not stored") "subobject was not stored")
self.assert_(self.datamgr._added_during_commit is None) self.assert_(self.datamgr._added_during_commit is None)
def checkErrorDuringCommit(self):
# We need to check that _added_during_commit still gets set to None
# when there is an error during commit()/
obj = ErrorOnGetstateObject()
self.datamgr.tpc_begin(self.transaction)
self.assertRaises(ErrorOnGetstateException,
self.datamgr.commit, obj, self.transaction)
self.assert_(self.datamgr._added_during_commit is None)
def checkUnusedAddWorks(self): def checkUnusedAddWorks(self):
# When an object is added, but not committed, it shouldn't be stored, # When an object is added, but not committed, it shouldn't be stored,
# but also it should be an error. # but also it should be an error.
......
...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them. ...@@ -90,7 +90,7 @@ process must skip such objects, rather than deactivating them.
static char cPickleCache_doc_string[] = static char cPickleCache_doc_string[] =
"Defines the PickleCache used by ZODB Connection objects.\n" "Defines the PickleCache used by ZODB Connection objects.\n"
"\n" "\n"
"$Id: cPickleCache.c,v 1.91 2004/03/02 22:13:54 jeremy Exp $\n"; "$Id: cPickleCache.c,v 1.92 2004/03/13 07:48:12 jeremy Exp $\n";
#define DONT_USE_CPERSISTENCECAPI #define DONT_USE_CPERSISTENCECAPI
#include "cPersistence.h" #include "cPersistence.h"
...@@ -408,12 +408,10 @@ cc_get(ccobject *self, PyObject *args) ...@@ -408,12 +408,10 @@ cc_get(ccobject *self, PyObject *args)
r = PyDict_GetItem(self->data, key); r = PyDict_GetItem(self->data, key);
if (!r) { if (!r) {
if (d) { if (d)
r = d; r = d;
} else { else
PyErr_SetObject(PyExc_KeyError, key); r = Py_None;
return NULL;
}
} }
Py_INCREF(r); Py_INCREF(r);
return r; return r;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment