Commit 93a40774 authored by Shane Hathaway's avatar Shane Hathaway

Merged shane-local-transactions-branch.

This change adds a new method, setLocalTransaction(), to the
Connection class.  ZODB applications can call this method to bind
transactions to connections rather than threads.  This is especially
useful for GUI applications, which often have only one thread but
multiple independent activities within that thread (generally one per
window).  Thanks to Christian Reis for championing this feature.

Applications that take advantage of this feature should not use the
get_transaction() function.  Until now, ZODB itself sometimes assumed
get_transaction() was the only way to get the transaction.  Minor
corrections have been added.  The ZODB test suite, on the other hand,
can continue to use get_transaction(), since it is free to assume that
transactions are bound to threads.
parent 9259274f
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.82 2003/01/15 21:29:26 jeremy Exp $""" $Id: Connection.py,v 1.83 2003/01/17 17:23:14 shane Exp $"""
from cPickleCache import PickleCache from cPickleCache import PickleCache
from POSException import ConflictError, ReadConflictError from POSException import ConflictError, ReadConflictError
...@@ -22,6 +22,7 @@ import ExportImport, TmpStore ...@@ -22,6 +22,7 @@ import ExportImport, TmpStore
from zLOG import LOG, ERROR, BLATHER, WARNING from zLOG import LOG, ERROR, BLATHER, WARNING
from coptimizations import new_persistent_id from coptimizations import new_persistent_id
from ConflictResolution import ResolvedSerial from ConflictResolution import ResolvedSerial
from Transaction import Transaction, get_transaction
from cPickle import Unpickler, Pickler from cPickle import Unpickler, Pickler
from cStringIO import StringIO from cStringIO import StringIO
...@@ -55,6 +56,7 @@ class Connection(ExportImport.ExportImport): ...@@ -55,6 +56,7 @@ class Connection(ExportImport.ExportImport):
_debug_info=() _debug_info=()
_opened=None _opened=None
_code_timestamp = 0 _code_timestamp = 0
_transaction = None
# Experimental. Other connections can register to be closed # Experimental. Other connections can register to be closed
# when we close by putting something here. # when we close by putting something here.
...@@ -80,6 +82,19 @@ class Connection(ExportImport.ExportImport): ...@@ -80,6 +82,19 @@ class Connection(ExportImport.ExportImport):
self._load_count = 0 # Number of objects unghosted self._load_count = 0 # Number of objects unghosted
self._store_count = 0 # Number of objects stored self._store_count = 0 # Number of objects stored
def getTransaction(self):
t = self._transaction
if t is None:
# Fall back to thread-bound transactions
t = get_transaction()
return t
def setLocalTransaction(self):
"""Use a transaction bound to the connection rather than the thread"""
if self._transaction is None:
self._transaction = Transaction()
return self._transaction
def _cache_items(self): def _cache_items(self):
# find all items on the lru list # find all items on the lru list
items = self._cache.lru_items() items = self._cache.lru_items()
...@@ -269,7 +284,7 @@ class Connection(ExportImport.ExportImport): ...@@ -269,7 +284,7 @@ class Connection(ExportImport.ExportImport):
if self.__onCommitActions is None: if self.__onCommitActions is None:
self.__onCommitActions = [] self.__onCommitActions = []
self.__onCommitActions.append((method_name, args, kw)) self.__onCommitActions.append((method_name, args, kw))
get_transaction().register(self) self.getTransaction().register(self)
def commit(self, object, transaction): def commit(self, object, transaction):
if object is self: if object is self:
...@@ -484,7 +499,7 @@ class Connection(ExportImport.ExportImport): ...@@ -484,7 +499,7 @@ class Connection(ExportImport.ExportImport):
assert object._p_jar is self assert object._p_jar is self
# XXX Figure out why this assert causes test failures # XXX Figure out why this assert causes test failures
# assert object._p_oid is not None # assert object._p_oid is not None
get_transaction().register(object) self.getTransaction().register(object)
def root(self): def root(self):
return self['\0\0\0\0\0\0\0\0'] return self['\0\0\0\0\0\0\0\0']
...@@ -516,7 +531,7 @@ class Connection(ExportImport.ExportImport): ...@@ -516,7 +531,7 @@ class Connection(ExportImport.ExportImport):
# XXX Need unit tests for _p_independent. # XXX Need unit tests for _p_independent.
if self._invalid(oid): if self._invalid(oid):
if not hasattr(object.__class__, '_p_independent'): if not hasattr(object.__class__, '_p_independent'):
get_transaction().register(self) self.getTransaction().register(self)
raise ReadConflictError(object=object) raise ReadConflictError(object=object)
invalid = 1 invalid = 1
else: else:
...@@ -544,7 +559,7 @@ class Connection(ExportImport.ExportImport): ...@@ -544,7 +559,7 @@ class Connection(ExportImport.ExportImport):
except KeyError: except KeyError:
pass pass
else: else:
get_transaction().register(self) self.getTransaction().register(self)
raise ConflictError(object=object) raise ConflictError(object=object)
except ConflictError: except ConflictError:
...@@ -695,7 +710,7 @@ class Connection(ExportImport.ExportImport): ...@@ -695,7 +710,7 @@ class Connection(ExportImport.ExportImport):
self._db.finish_invalidation() self._db.finish_invalidation()
def sync(self): def sync(self):
get_transaction().abort() self.getTransaction().abort()
sync=getattr(self._storage, 'sync', 0) sync=getattr(self._storage, 'sync', 0)
if sync != 0: sync() if sync != 0: sync()
self._cache.invalidate(self._invalidated) self._cache.invalidate(self._invalidated)
...@@ -726,7 +741,7 @@ class Connection(ExportImport.ExportImport): ...@@ -726,7 +741,7 @@ class Connection(ExportImport.ExportImport):
new._p_oid=oid new._p_oid=oid
new._p_jar=self new._p_jar=self
new._p_changed=1 new._p_changed=1
get_transaction().register(new) self.getTransaction().register(new)
self._cache[oid]=new self._cache[oid]=new
class tConnection(Connection): class tConnection(Connection):
......
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
############################################################################## ##############################################################################
"""Database objects """Database objects
$Id: DB.py,v 1.46 2002/12/03 17:40:56 jeremy Exp $""" $Id: DB.py,v 1.47 2003/01/17 17:23:14 shane Exp $"""
__version__='$Revision: 1.46 $'[11:-2] __version__='$Revision: 1.47 $'[11:-2]
import cPickle, cStringIO, sys, POSException, UndoLogCompatible import cPickle, cStringIO, sys, POSException, UndoLogCompatible
from Connection import Connection from Connection import Connection
from bpthread import allocate_lock from bpthread import allocate_lock
from Transaction import Transaction from Transaction import Transaction, get_transaction
from referencesf import referencesf from referencesf import referencesf
from time import time, ctime from time import time, ctime
from zLOG import LOG, ERROR from zLOG import LOG, ERROR
...@@ -153,8 +153,10 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -153,8 +153,10 @@ class DB(UndoLogCompatible.UndoLogCompatible):
self._temps=t self._temps=t
finally: self._r() finally: self._r()
def abortVersion(self, version): def abortVersion(self, version, transaction=None):
AbortVersion(self, version) if transaction is None:
transaction = get_transaction()
transaction.register(AbortVersion(self, version))
def cacheDetail(self): def cacheDetail(self):
"""Return information on objects in the various caches """Return information on objects in the various caches
...@@ -248,8 +250,10 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -248,8 +250,10 @@ class DB(UndoLogCompatible.UndoLogCompatible):
def close(self): def close(self):
self._storage.close() self._storage.close()
def commitVersion(self, source, destination=''): def commitVersion(self, source, destination='', transaction=None):
CommitVersion(self, source, destination) if transaction is None:
transaction = get_transaction()
transaction.register(CommitVersion(self, source, destination))
def exportFile(self, oid, file=None): def exportFile(self, oid, file=None):
raise 'Not yet implemented' raise 'Not yet implemented'
...@@ -542,7 +546,7 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -542,7 +546,7 @@ class DB(UndoLogCompatible.UndoLogCompatible):
def cacheStatistics(self): return () # :( def cacheStatistics(self): return () # :(
def undo(self, id): def undo(self, id, transaction=None):
storage=self._storage storage=self._storage
try: supportsTransactionalUndo = storage.supportsTransactionalUndo try: supportsTransactionalUndo = storage.supportsTransactionalUndo
except AttributeError: except AttributeError:
...@@ -552,7 +556,9 @@ class DB(UndoLogCompatible.UndoLogCompatible): ...@@ -552,7 +556,9 @@ class DB(UndoLogCompatible.UndoLogCompatible):
if supportsTransactionalUndo: if supportsTransactionalUndo:
# new style undo # new style undo
TransactionalUndo(self, id) if transaction is None:
transaction = get_transaction()
transaction.register(TransactionalUndo(self, id))
else: else:
# fall back to old undo # fall back to old undo
for oid in storage.undo(id): for oid in storage.undo(id):
...@@ -576,7 +582,6 @@ class CommitVersion: ...@@ -576,7 +582,6 @@ class CommitVersion:
self.tpc_vote=s.tpc_vote self.tpc_vote=s.tpc_vote
self.tpc_finish=s.tpc_finish self.tpc_finish=s.tpc_finish
self._sortKey=s.sortKey self._sortKey=s.sortKey
get_transaction().register(self)
def sortKey(self): def sortKey(self):
return "%s:%s" % (self._sortKey(), id(self)) return "%s:%s" % (self._sortKey(), id(self))
...@@ -613,9 +618,9 @@ class TransactionalUndo(CommitVersion): ...@@ -613,9 +618,9 @@ class TransactionalUndo(CommitVersion):
in cooperation with a transaction manager. in cooperation with a transaction manager.
""" """
# I'm lazy. I'm reusing __init__ and abort and reusing the # I (Jim) am lazy. I'm reusing __init__ and abort and reusing the
# version attr for the transavtion id. There's such a strong # version attr for the transaction id. There's such a strong
# similarity of rythm, that I think it's justified. # similarity of rhythm that I think it's justified.
def commit(self, reallyme, t): def commit(self, reallyme, t):
db=self._db db=self._db
......
...@@ -76,7 +76,7 @@ class ExportImport: ...@@ -76,7 +76,7 @@ class ExportImport:
return customImporters[magic](self, file, clue) return customImporters[magic](self, file, clue)
raise POSException.ExportError, 'Invalid export header' raise POSException.ExportError, 'Invalid export header'
t = get_transaction() t = self.getTransaction()
if clue: t.note(clue) if clue: t.note(clue)
return_oid_list = [] return_oid_list = []
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
This module provides a wrapper that causes a database connection to be created This module provides a wrapper that causes a database connection to be created
and used when bobo publishes a bobo_application object. and used when bobo publishes a bobo_application object.
""" """
__version__='$Revision: 1.11 $'[11:-2] __version__='$Revision: 1.12 $'[11:-2]
StringType=type('') StringType=type('')
connection_open_hooks = [] connection_open_hooks = []
...@@ -31,7 +31,7 @@ class ZApplicationWrapper: ...@@ -31,7 +31,7 @@ class ZApplicationWrapper:
root=conn.root() root=conn.root()
if not root.has_key(name): if not root.has_key(name):
root[name]=klass() root[name]=klass()
get_transaction().commit() conn.getTransaction().commit()
conn.close() conn.close()
self._klass=klass self._klass=klass
......
...@@ -111,9 +111,13 @@ class ZODBTests(unittest.TestCase, ExportImportTests): ...@@ -111,9 +111,13 @@ class ZODBTests(unittest.TestCase, ExportImportTests):
# Make sure the changes to make empty transactions a no-op # Make sure the changes to make empty transactions a no-op
# still allow things like abortVersion(). This should work # still allow things like abortVersion(). This should work
# because abortVersion() calls tpc_begin() itself. # because abortVersion() calls tpc_begin() itself.
r = self._db.open("version").root() conn = self._db.open("version")
try:
r = conn.root()
r[1] = 1 r[1] = 1
get_transaction().commit() get_transaction().commit()
finally:
conn.close()
self._db.abortVersion("version") self._db.abortVersion("version")
get_transaction().commit() get_transaction().commit()
...@@ -131,24 +135,48 @@ class ZODBTests(unittest.TestCase, ExportImportTests): ...@@ -131,24 +135,48 @@ class ZODBTests(unittest.TestCase, ExportImportTests):
finally: finally:
conn.close() conn.close()
def checkLocalTransactions(self):
# Test of transactions that apply to only the connection,
# not the thread.
conn1 = self._db.open()
conn2 = self._db.open()
try:
conn1.setLocalTransaction()
conn2.setLocalTransaction()
r1 = conn1.root()
r2 = conn2.root()
if r1.has_key('item'):
del r1['item']
conn1.getTransaction().commit()
r1.get('item')
r2.get('item')
r1['item'] = 1
conn1.getTransaction().commit()
self.assertEqual(r1['item'], 1)
# r2 has not seen a transaction boundary,
# so it should be unchanged.
self.assertEqual(r2.get('item'), None)
conn2.sync()
# Now r2 is updated.
self.assertEqual(r2['item'], 1)
def test_suite(): # Now, for good measure, send an update in the other direction.
return unittest.makeSuite(ZODBTests, 'check') r2['item'] = 2
conn2.getTransaction().commit()
def main(): self.assertEqual(r1['item'], 1)
alltests=test_suite() self.assertEqual(r2['item'], 2)
runner = unittest.TextTestRunner() conn1.sync()
runner.run(alltests) conn2.sync()
self.assertEqual(r1['item'], 2)
self.assertEqual(r2['item'], 2)
finally:
conn1.close()
conn2.close()
def debug():
test_suite().debug()
def pdebug(): def test_suite():
import pdb return unittest.makeSuite(ZODBTests, 'check')
pdb.run('debug()')
if __name__=='__main__': if __name__=='__main__':
if len(sys.argv) > 1: unittest.main(defaultTest='test_suite')
globals()[sys.argv[1]]()
else:
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment