Commit 6df232f3 authored by Jim Fulton's avatar Jim Fulton

Fixed a bug in savepoint rollback. It's not enough to rollback

just the savepoint being rolled back because later savepoints
might involved data managers that hadn't joined when the savepoint
being rolled back was created.

Now, when a data manager joins and we have savepoints, we create a
data manager savepoint for the new data manager and add the
datamanager savepoint to all previous transaction savepoints.  Note
that this data manager savepoint can be a special savepoint that just
calls abort on the data manager when it is rolled back.
parent a870c5db
......@@ -264,9 +264,7 @@ class Transaction(object):
self._prior_operation_failed() # doesn't return, it raises
try:
savepoint = Savepoint(self, optimistic)
for resource in self._resources:
savepoint.join(resource)
savepoint = Savepoint(self, optimistic, *self._resources)
except:
self._cleanup(self._resources)
self._saveCommitishError() # reraises!
......@@ -598,24 +596,41 @@ class Savepoint:
"""
interface.implements(interfaces.ISavepoint)
def __init__(self, transaction, optimistic):
def __init__(self, transaction, optimistic, *resources):
self.transaction = transaction
self._savepoints = []
self._savepoints = savepoints = []
self.valid = True
self.next = self.previous = None
self.optimistic = optimistic
for datamanager in resources:
try:
savepoint = datamanager.savepoint
except AttributeError:
if not self.optimistic:
raise TypeError("Savepoints unsupported", datamanager)
savepoint = NoRollbackSavepoint(datamanager)
else:
savepoint = savepoint()
savepoints.append(savepoint)
def join(self, datamanager):
try:
savepoint = datamanager.savepoint
except AttributeError:
if not self.optimistic:
raise TypeError("Savepoints unsupported", datamanager)
savepoint = NoRollbackSavepoint(datamanager)
else:
savepoint = savepoint()
self._savepoints.append(savepoint)
# A data manager has joined a transaction *after* a savepoint
# was created. A couple of things are different in this case:
# 1. We need to add it's savepoint to all previous savepoints.
# so that if they are rolled back, we roll this was back too.
# 2. We don't actualy need to ask it for a savepoint. Because
# is just joining, then we can abort it if there is an error,
# so we use an AbortSavepoint.
savepoint = AbortSavepoint(datamanager, self.transaction)
while self is not None:
self._savepoints.append(savepoint)
self = self.previous
def rollback(self):
if not self.valid:
......@@ -638,6 +653,15 @@ class Savepoint:
if self.previous is not None:
self.previous._invalidate_previous()
class AbortSavepoint:
def __init__(self, datamanager, transaction):
self.datamanager = datamanager
self.transaction = transaction
def rollback(self):
self.datamanager.abort(self.transaction)
class NoRollbackSavepoint:
def __init__(self, datamanager):
......
......@@ -201,7 +201,7 @@ support savepoints:
>>> transaction.abort()
However, a flag can be passed to the transaction savepoint method to
indicate that databases without savepoint support should be tolderated
indicate that databases without savepoint support should be tolerated
until a savepoint is roled back. This allows transactions to proceed
is there are no reasons to roll back:
......@@ -212,8 +212,8 @@ is there are no reasons to roll back:
>>> dm_no_sp['name']
'sue'
>>> savepoint = transaction.savepoint(1)
>>> dm_no_sp['name'] = 'sam'
>>> savepoint = transaction.savepoint(1)
>>> savepoint.rollback()
Traceback (most recent call last):
...
......
......@@ -19,9 +19,49 @@ import unittest
from zope.testing import doctest
def testRollbackRollsbackDataManagersThatJoinedLater():
"""
A savepoint needs to not just rollback it's savepoints, but needs to
rollback savepoints for data managers that joined savepoints after the
savepoint:
>>> import transaction.tests.savepointsample
>>> dm = transaction.tests.savepointsample.SampleSavepointDataManager()
>>> dm['name'] = 'bob'
>>> sp1 = transaction.savepoint()
>>> dm['job'] = 'geek'
>>> sp2 = transaction.savepoint()
>>> dm['salary'] = 'fun'
>>> dm2 = transaction.tests.savepointsample.SampleSavepointDataManager()
>>> dm2['name'] = 'sally'
>>> 'name' in dm
True
>>> 'job' in dm
True
>>> 'salary' in dm
True
>>> 'name' in dm2
True
>>> sp1.rollback()
>>> 'name' in dm
True
>>> 'job' in dm
False
>>> 'salary' in dm
False
>>> 'name' in dm2
False
"""
def test_suite():
return unittest.TestSuite((
doctest.DocFileSuite('../savepoint.txt'),
doctest.DocTestSuite(),
))
if __name__ == '__main__':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment