Commit c61072ab authored by Julien Muchembled's avatar Julien Muchembled

wip

parent c081a86f
...@@ -101,6 +101,9 @@ def datetimeFromTID(tid): ...@@ -101,6 +101,9 @@ def datetimeFromTID(tid):
seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW) seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW)
return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32)))) return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32))))
def timeFromTID(tid, _epoch=datetime.utcfromtimestamp(0)):
return (datetimeFromTID(tid) - _epoch).total_seconds()
def addTID(ptid, offset): def addTID(ptid, offset):
""" """
Offset given packed TID. Offset given packed TID.
......
...@@ -55,7 +55,6 @@ UNIT_TEST_MODULES = [ ...@@ -55,7 +55,6 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testClientHandler', 'neo.tests.storage.testClientHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testTransactions',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
......
...@@ -18,7 +18,6 @@ import os, errno, socket, sys, threading, weakref ...@@ -18,7 +18,6 @@ import os, errno, socket, sys, threading, weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from functools import partial, wraps
from time import time from time import time
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import NonReadableCell from neo.lib.exception import NonReadableCell
...@@ -117,9 +116,7 @@ class DeleteWorker(object): ...@@ -117,9 +116,7 @@ class DeleteWorker(object):
round(drop_time - before[1], 3), round(drop_time - before[1], 3),
round(drop_time, 3)) round(drop_time, 3))
elif self._pack_set: elif self._pack_set:
dm_pack = partial(dm._pack, pack_id, approved, partial, oids, tid = self._pack_info
weak_app().tm.updateObjectDataForPack)
pack_id, approved, partial_, oids, tid = self._pack_info
assert approved, self._pack_info assert approved, self._pack_info
tid = util.u64(tid) tid = util.u64(tid)
before = delete_count, delete_time = self._pack_stats before = delete_count, delete_time = self._pack_stats
...@@ -146,13 +143,13 @@ class DeleteWorker(object): ...@@ -146,13 +143,13 @@ class DeleteWorker(object):
limit = max(100, limit = max(100,
int(.1 * delete_count / delete_time) int(.1 * delete_count / delete_time)
) if delete_time else 1000 ) if delete_time else 1000
if partial_: if partial:
i = oid_index + limit i = oid_index + limit
deleted = dm_pack(offset, oids[oid_index:i], deleted = dm._pack(offset, oids[oid_index:i],
tid)[1] tid)[1]
oid_index = i oid_index = i
else: else:
oid, deleted = dm_pack(offset, oid, tid, limit) oid, deleted = dm._pack(offset, oid, tid, limit)
if log: if log:
log = False log = False
logging.info("pack %s: partition %s...", logging.info("pack %s: partition %s...",
...@@ -160,7 +157,7 @@ class DeleteWorker(object): ...@@ -160,7 +157,7 @@ class DeleteWorker(object):
delete_count += deleted delete_count += deleted
delete_time += time() - start delete_time += time() - start
self._pack_stats = delete_count, delete_time self._pack_stats = delete_count, delete_time
if (oid_index < len(oids) if partial_ else if (oid_index < len(oids) if partial else
oid is None): oid is None):
parts.remove(offset) parts.remove(offset)
packed += 1 packed += 1
...@@ -1025,7 +1022,7 @@ class DatabaseManager(object): ...@@ -1025,7 +1022,7 @@ class DatabaseManager(object):
r = self._getObject(oid, tid, before_tid) r = self._getObject(oid, tid, before_tid)
return (r[0], r[-1]) if r else (None, None) return (r[0], r[-1]) if r else (None, None)
def findUndoTID(self, oid, ltid, undone_tid, current_tid): def findUndoTID(self, oid, ltid, undo_tid, current_tid):
""" """
oid oid
Object OID Object OID
...@@ -1034,7 +1031,7 @@ class DatabaseManager(object): ...@@ -1034,7 +1031,7 @@ class DatabaseManager(object):
ltid ltid
Upper (excluded) bound of transactions visible to transaction doing Upper (excluded) bound of transactions visible to transaction doing
the undo. the undo.
undone_tid undo_tid
Transaction to undo Transaction to undo
current_tid current_tid
Serial of object data from memory, if it was modified by running Serial of object data from memory, if it was modified by running
...@@ -1046,8 +1043,7 @@ class DatabaseManager(object): ...@@ -1046,8 +1043,7 @@ class DatabaseManager(object):
see. This is used later to detect current conflicts (eg, another see. This is used later to detect current conflicts (eg, another
client modifying the same object in parallel) client modifying the same object in parallel)
data_tid (int) data_tid (int)
TID containing (without indirection) the data prior to undone TID containing the data prior to undone transaction.
transaction.
None if object doesn't exist prior to transaction being undone None if object doesn't exist prior to transaction being undone
(its creation is being undone). (its creation is being undone).
is_current (bool) is_current (bool)
...@@ -1056,38 +1052,26 @@ class DatabaseManager(object): ...@@ -1056,38 +1052,26 @@ class DatabaseManager(object):
""" """
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
undone_tid = u64(undone_tid) undo_tid = u64(undo_tid)
def getDataTID(tid=None, before_tid=None): if self._getDataTID(oid, undo_tid)[0] is None:
tid, data_tid = self._getDataTID(oid, tid, before_tid)
current_tid = tid
while data_tid:
if data_tid < tid:
tid, data_tid = self._getDataTID(oid, data_tid)
if tid is not None:
continue
logging.error("Incorrect data serial for oid %s at tid %s",
oid, current_tid)
return current_tid, current_tid
return current_tid, tid
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
if found_undone_tid is None:
return return
undone_tid = self._getDataTID(oid, before_tid=undo_tid)[0]
if current_tid: if current_tid:
current_data_tid = u64(current_tid) tid = data_tid = u64(current_tid)
else: else:
if ltid: if ltid:
ltid = u64(ltid) ltid = u64(ltid)
current_tid, current_data_tid = getDataTID(before_tid=ltid) tid, data_tid = self._getDataTID(oid, before_tid=ltid)
if current_tid is None: if tid is None:
return None, None, False return None, None, False
current_tid = util.p64(current_tid) current_tid = util.p64(tid)
# Load object data as it was before given transaction. if undo_tid < tid:
# It can be None, in which case it means we are undoing object tid = data_tid
# creation. while undo_tid < tid:
_, data_tid = getDataTID(before_tid=undone_tid) tid = self._getDataTID(oid, tid)[1]
if data_tid is not None: return (current_tid,
data_tid = util.p64(data_tid) None if undone_tid is None else util.p64(undone_tid),
return current_tid, data_tid, undone_data_tid == current_data_tid undo_tid == tid)
@abstract @abstract
def storePackOrder(self, tid, approved, partial, oid_list, pack_tid): def storePackOrder(self, tid, approved, partial, oid_list, pack_tid):
...@@ -1235,7 +1219,7 @@ class DatabaseManager(object): ...@@ -1235,7 +1219,7 @@ class DatabaseManager(object):
passed to filter out non-applicable TIDs.""" passed to filter out non-applicable TIDs."""
@abstract @abstract
def _pack(self, updateObjectDataForPack, offset, oid, tid, limit=None): def _pack(self, offset, oid, tid, limit=None):
"""""" """"""
@abstract @abstract
......
...@@ -681,7 +681,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -681,7 +681,10 @@ class MySQLDatabaseManager(DatabaseManager):
" WHERE `partition`=%s AND oid=%s AND tid=%s" " WHERE `partition`=%s AND oid=%s AND tid=%s"
% (p, oid, util.u64(data_tid))): % (p, oid, util.u64(data_tid))):
return r return r
assert p not in self._readable_set if p in self._readable_set: # and not checksum:
raise NotImplementedError("Race condition between undo & pack:"
" an Error must be returned to the client"
" and the transaction must be aborted.")
if not checksum: if not checksum:
return # delete return # delete
e = self.escape e = self.escape
...@@ -928,13 +931,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -928,13 +931,11 @@ class MySQLDatabaseManager(DatabaseManager):
'' if length is None else ' LIMIT %s' % length)) '' if length is None else ' LIMIT %s' % length))
return [p64(t[0]) for t in r] return [p64(t[0]) for t in r]
def _pack(self, updateObjectDataForPack, offset, oid, tid, limit=None): def _pack(self, offset, oid, tid, limit=None):
p64 = util.p64
q = self.query q = self.query
if limit: if limit:
q("SET @next_oid=NULL") q("SET @next_oid=NULL")
data_id_set = set() data_id_set = set()
value_dict = defaultdict(list)
for oid, max_serial in q( for oid, max_serial in q(
"SELECT obj.oid, tid + (data_id IS NULL)" "SELECT obj.oid, tid + (data_id IS NULL)"
" FROM (SELECT COUNT(*) AS n, oid, MAX(tid) AS max_tid" " FROM (SELECT COUNT(*) AS n, oid, MAX(tid) AS max_tid"
...@@ -948,34 +949,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -948,34 +949,9 @@ class MySQLDatabaseManager(DatabaseManager):
" IN (%s)" % ','.join(map(str, oid)), " IN (%s)" % ','.join(map(str, oid)),
" LIMIT %s" % limit if limit else "", " LIMIT %s" % limit if limit else "",
offset)): offset)):
value_dict.clear()
for i, serial, value_serial in q(
"SELECT 0, tid, value_tid FROM obj FORCE INDEX(PRIMARY)"
" WHERE `partition`=%s AND oid=%s AND tid>=%s"
" AND value_tid<%s UNION "
"SELECT 1, tid, value_tid FROM tobj"
" WHERE `partition`=%s AND oid=%s AND value_tid<%s"
% (offset, oid, max_serial, max_serial,
offset, oid, max_serial)):
value_dict[value_serial].append((i, serial))
sql = " FROM obj WHERE `partition`=%s AND oid=%s AND tid<%s" \ sql = " FROM obj WHERE `partition`=%s AND oid=%s AND tid<%s" \
% (offset, oid, max_serial) % (offset, oid, max_serial)
for serial, data_id in q("SELECT tid, data_id" + sql): data_id_set.update(*zip(*q("SELECT DISTINCT data_id" + sql)))
data_id_set.add(data_id)
if serial in value_dict:
value_serial = None
for i, t in value_dict[serial]:
q("UPDATE %s SET value_tid=%s"
" WHERE `partition`=%s AND oid=%s AND tid=%s" % (
'tobj' if i else 'obj',
'NULL' if value_serial is None else value_serial,
offset, oid, t))
if value_serial is None and not i:
value_serial = t
if value_serial is not None:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(serial),
value_serial, data_id)
q("DELETE" + sql) q("DELETE" + sql)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
......
...@@ -491,7 +491,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -491,7 +491,10 @@ class SQLiteDatabaseManager(DatabaseManager):
" WHERE partition=? AND oid=? AND tid=?", " WHERE partition=? AND oid=? AND tid=?",
(p, oid, util.u64(data_tid))): (p, oid, util.u64(data_tid))):
return r return r
assert p not in self._readable_set if p in self._readable_set: # and not checksum:
raise NotImplementedError("Race condition between undo & pack:"
" an Error must be returned to the client"
" and the transaction must be aborted.")
if not checksum: if not checksum:
return # delete return # delete
H = buffer(checksum) H = buffer(checksum)
...@@ -710,8 +713,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -710,8 +713,10 @@ class SQLiteDatabaseManager(DatabaseManager):
ORDER BY tid ASC LIMIT ?""", ORDER BY tid ASC LIMIT ?""",
(partition, min_tid, max_tid, length))] (partition, min_tid, max_tid, length))]
def _pack(self, updateObjectDataForPack, offset, oid, tid, limit=None): _pack = " FROM obj WHERE partition=? AND oid=? AND tid<?"
p64 = util.p64 def _pack(self, offset, oid, tid, limit=None,
_select_data_id_sql="SELECT DISTINCT data_id" + _pack,
_delete_obj_sql="DELETE" + _pack):
q = self.query q = self.query
data_id_set = set() data_id_set = set()
value_dict = defaultdict(list) value_dict = defaultdict(list)
...@@ -722,41 +727,14 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -722,41 +727,14 @@ class SQLiteDatabaseManager(DatabaseManager):
" LIMIT %s" % limit if limit else "") " LIMIT %s" % limit if limit else "")
oid = None oid = None
for x, oid, max_tid in q(sql): for x, oid, max_tid in q(sql):
x = q("SELECT tid + (data_id IS NULL) FROM obj" for x in q("SELECT tid + (data_id IS NULL) FROM obj"
" WHERE partition=? AND oid=? AND tid=?" " WHERE partition=? AND oid=? AND tid=?"
" AND (data_id IS NULL OR ?>1)", " AND (data_id IS NULL OR ?>1)",
(offset, oid, max_tid, x)).fetchone() (offset, oid, max_tid, x)):
if x is None: x = (offset, oid) + x
continue data_id_set.update(*zip(*q(_select_data_id_sql, x)))
max_serial, = x q(_delete_obj_sql, x)
value_dict.clear() break
for i, serial, value_serial in q(
"SELECT 0, tid, value_tid FROM obj"
" WHERE partition=? AND oid=? AND tid>=?"
" AND value_tid<? UNION "
"SELECT 1, tid, value_tid FROM tobj"
" WHERE partition=? AND oid=? AND value_tid<?",
(offset, oid, max_serial, max_serial,
offset, oid, max_serial)):
value_dict[value_serial].append((i, serial))
sql = " FROM obj WHERE partition=? AND oid=? AND tid<?"
args = offset, oid, max_serial
for serial, data_id in q("SELECT tid, data_id" + sql, args):
data_id_set.add(data_id)
if serial in value_dict:
value_serial = None
for i, t in value_dict[serial]:
q("UPDATE %s SET value_tid=?"
" WHERE partition=? AND oid=? AND tid=?"
% ('tobj' if i else 'obj'),
(value_serial, offset, oid, t))
if value_serial is None and not i:
value_serial = t
if value_serial is not None:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(serial),
value_serial, data_id)
q("DELETE" + sql, args)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
return limit and (None if oid is None else oid+1), len(data_id_set) return limit and (None if oid is None else oid+1), len(data_id_set)
......
...@@ -571,10 +571,3 @@ class TransactionManager(EventQueue): ...@@ -571,10 +571,3 @@ class TransactionManager(EventQueue):
logging.info(' %s by %s', dump(oid), dump(ttid)) logging.info(' %s by %s', dump(oid), dump(ttid))
self.logQueuedEvents() self.logQueuedEvents()
self.read_queue.logQueuedEvents() self.read_queue.logQueuedEvents()
def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id):
lock_tid = self.getLockingTID(oid)
if lock_tid is not None:
transaction = self._transaction_dict[lock_tid]
if transaction.store_dict[oid][2] == orig_serial:
transaction.store(oid, data_id, new_serial)
#
# Copyright (C) 2010-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.storage.transactions import TransactionManager
class TransactionManagerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.app = Mock()
# no history
self.app.dm = Mock({'getObjectHistory': []})
self.app.pt = Mock({'isAssigned': True, 'getPartitions': 2})
self.app.em = Mock({'setTimeout': None})
self.manager = TransactionManager(self.app)
def register(self, uuid, ttid):
self.manager.register(Mock({'getUUID': uuid}), ttid)
def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID()
oid = p64(1)
orig_serial = self.getNextTID()
uuid = self.getClientUUID()
locking_serial = self.getNextTID()
other_serial = self.getNextTID()
new_serial = self.getNextTID()
data_id = (1 << 48) + 2
self.register(uuid, locking_serial)
# Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known, but doesn't point at orig_serial, it is not updated
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None)
holdData = self.app.dm.mockGetNamedCalls('holdData')
self.assertEqual(holdData.pop(0).params,
("3" * 20, oid, 'bar', 0, None))
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known and points at undone data it gets updated
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, data_id, new_serial))
self.manager.abort(locking_serial, even_if_locked=True)
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, data_id, None))
self.manager.abort(locking_serial, even_if_locked=True)
if __name__ == "__main__":
unittest.main()
...@@ -19,22 +19,24 @@ from collections import defaultdict, deque ...@@ -19,22 +19,24 @@ from collections import defaultdict, deque
from contextlib import contextmanager, nested from contextlib import contextmanager, nested
from time import time from time import time
from persistent import Persistent from persistent import Persistent
from ZODB.POSException import UndoError
from neo.lib.protocol import Packets from neo.lib.protocol import Packets
from neo.lib.util import timeFromTID
from .. import consume, Patch from .. import consume, Patch
from . import ConnectionFilter, NEOThreadedTest, with_cluster from . import ConnectionFilter, NEOThreadedTest, with_cluster
class PCounter(Persistent): class PCounter(Persistent):
value = 0 value = 0
class TestPack(NEOThreadedTest): class PackTests(NEOThreadedTest):
@contextmanager @contextmanager
def assertPackOperationCount(self, cluster_or_storages, *counts): def assertPackOperationCount(self, cluster_or_storages, *counts):
def patch(storage): def patch(storage):
packs = {} packs = {}
def _pack(orig, *args): def _pack(orig, *args):
offset = args[1] offset = args[0]
tid = args[3] tid = args[2]
try: try:
tids = packs[offset] tids = packs[offset]
except KeyError: except KeyError:
...@@ -106,7 +108,7 @@ class TestPack(NEOThreadedTest): ...@@ -106,7 +108,7 @@ class TestPack(NEOThreadedTest):
self.checkReplicas(cluster) self.checkReplicas(cluster)
@with_cluster(replicas=1) @with_cluster(replicas=1)
def testValueSerial(self, cluster): def testValueSerialVsReplication(self, cluster):
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
ob = c.root()[''] = PCounter() ob = c.root()[''] = PCounter()
t.commit() t.commit()
...@@ -128,6 +130,32 @@ class TestPack(NEOThreadedTest): ...@@ -128,6 +130,32 @@ class TestPack(NEOThreadedTest):
self.tic() self.tic()
self.checkReplicas(cluster) self.checkReplicas(cluster)
@with_cluster()
def testValueSerialMultipleUndo(self, cluster):
t, c = cluster.getTransaction()
r = c.root()
ob = r[''] = PCounter()
t.commit()
tids = []
for x in xrange(2):
ob.value += 1
t.commit()
tids.append(ob._p_serial)
r._p_changed = 1
t.commit()
db = c.db()
def undo(tid):
db.undo(tid, t.get())
t.commit()
tids.append(db.lastTransaction())
undo(tids[1])
undo(tids[0])
undo(tids[-1])
cluster.client.pack(timeFromTID(r._p_serial))
self.tic()
db.undo(tids[2], t.get())
t.commit()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment