Commit d5c469be authored by Julien Muchembled's avatar Julien Muchembled

Fix protocol and DB schema so that storages can handle transactions of any size

- Change protocol to use SHA1 for all checksums:
  - Use SHA1 instead of CRC32 for data checksums.
  - Use SHA1 instead of MD5 for replication.

- Change DatabaseManager API so that backends can store raw data separately from
  object metadata:
  - When processing AskStoreObject, call the backend to store the data
    immediately, instead of keeping it in RAM or in the temporary object table.
    Data is then referenced only by its checksum.
    Without such change, the storage could fail to store the transaction due to
    lack of RAM, or it could make tpc_finish step very slow.
  - Backends have to store data in a separate space, and remove entries as soon
    as they get unreferenced. So they must have an index of checksums in object
    metadata space. A new '_uncommitted_data' backend attribute keeps references
    of uncommitted data.
  - New methods: _pruneData, _storeData, storeData, unlockData
  - MySQL: change vertical partitioning of 'obj' by having data in a separate
    'data' table instead of using a shortened 'obj_short' table.
  - BTree: data is moved from '_obj' to a new '_data' btree.

- Undo is optimized so that backpointers are not required anymore to fetch data:
  - The checksum of an object is None only when creation is undone.
  - Removed DatabaseManager methods: _getObjectData, _getDataTIDFromData
  - DatabaseManager: move some code from _getDataTID to findUndoTID so that
    _getDataTID only has what's specific to backend.

- Removed because already covered by ZODB tests:
  - neo.tests.storage.testStorageDBTests.StorageDBTests.test__getDataTID
  - neo.tests.storage.testStorageDBTests.StorageDBTests.test__getDataTIDFromData
parent d90c5b83
......@@ -4,6 +4,8 @@ Change History
0.10 (unreleased)
-----------------
- Storage was unable or slow to process large-sized transactions.
This required to change protocol and MySQL tables format.
- NEO learned to store empty values (although it's useless when managed by
a ZODB Connection).
......
......@@ -28,7 +28,8 @@ from ZODB.ConflictResolution import ResolvedSerial
from persistent.TimeStamp import TimeStamp
import neo.lib
from neo.lib.protocol import NodeTypes, Packets, INVALID_PARTITION, ZERO_TID
from neo.lib.protocol import NodeTypes, Packets, \
INVALID_PARTITION, ZERO_HASH, ZERO_TID
from neo.lib.event import EventManager
from neo.lib.util import makeChecksum as real_makeChecksum, dump
from neo.lib.locking import Lock
......@@ -444,7 +445,7 @@ class Application(object):
except ConnectionClosed:
continue
if data or checksum:
if data or checksum != ZERO_HASH:
if checksum != makeChecksum(data):
neo.lib.logging.error('wrong checksum from %s for oid %s',
conn, dump(oid))
......@@ -509,7 +510,7 @@ class Application(object):
# an older object revision).
compressed_data = ''
compression = 0
checksum = 0
checksum = ZERO_HASH
else:
assert data_serial is None
compression = self.compress
......
......@@ -66,9 +66,6 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerObject(self, conn, oid, start_serial, end_serial,
compression, checksum, data, data_serial):
if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data))
......
......@@ -112,6 +112,7 @@ INVALID_TID = '\xff' * 8
INVALID_OID = '\xff' * 8
INVALID_PARTITION = 0xffffffff
INVALID_ADDRESS_TYPE = socket.AF_UNSPEC
ZERO_HASH = '\0' * 20
ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
......@@ -527,6 +528,17 @@ class PProtocol(PStructItem):
raise ProtocolError('protocol version mismatch')
return (major, minor)
class PChecksum(PItem):
"""
A hash (SHA1)
"""
def _encode(self, writer, checksum):
assert len(checksum) == 20, (len(checksum), checksum)
writer(checksum)
def _decode(self, reader):
return reader(20)
class PUUID(PItem):
"""
An UUID (node identifier)
......@@ -561,7 +573,6 @@ class PTID(PItem):
# same definition, for now
POID = PTID
PChecksum = PUUID # (md5 is same length as uuid)
# common definitions
......@@ -908,7 +919,7 @@ class StoreObject(Packet):
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PNumber('checksum'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
PTID('tid'),
......@@ -964,7 +975,7 @@ class GetObject(Packet):
PTID('serial_start'),
PTID('serial_end'),
PBoolean('compression'),
PNumber('checksum'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
......
......@@ -18,7 +18,7 @@
import re
import socket
from zlib import adler32
from hashlib import sha1
from Queue import deque
from struct import pack, unpack
......@@ -62,8 +62,8 @@ def bin(s):
def makeChecksum(s):
"""Return a 4-byte integer checksum against a string."""
return adler32(s) & 0xffffffff
"""Return a 20-byte checksum against a string."""
return sha1(s).digest()
def resolve(hostname):
......
......@@ -22,23 +22,19 @@ Not persistent ! (no data retained after process exit)
from BTrees.OOBTree import OOBTree as _OOBTree
import neo.lib
from hashlib import md5
from hashlib import sha1
from neo.storage.database import DatabaseManager
from neo.storage.database.manager import CreationUndone
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID
from neo.lib import util
# The only purpose of this value (and code using it) is to avoid creating
# arbitrarily-long lists of values when cleaning up dictionaries.
KEY_BATCH_SIZE = 1000
# Keep dropped trees in memory to avoid instanciating when not needed.
TREE_POOL = []
# How many empty BTree istance to keep in ram
MAX_TREE_POOL_SIZE = 100
def batchDelete(tree, tester_callback, iter_kw=None, recycle_subtrees=False):
def batchDelete(tree, tester_callback=None, deleter_callback=None, **kw):
"""
Iter over given BTree and delete found entries.
tree BTree
......@@ -46,49 +42,21 @@ def batchDelete(tree, tester_callback, iter_kw=None, recycle_subtrees=False):
tester_callback function(key, value) -> boolean
Called with each key, value pair found in tree.
If return value is true, delete entry. Otherwise, skip to next key.
iter_kw dict
deleter_callback function(tree, key_list) -> None (None)
Custom function to delete items
**kw
Keyword arguments for tree.items .
Warning: altered in this function.
recycle_subtrees boolean (False)
If true, deleted values will be put in TREE_POOL for future reuse.
They must be BTrees.
If False, values are not touched.
"""
if iter_kw is None:
iter_kw = {}
if recycle_subtrees:
deleter_callback = _btreeDeleterCallback
if tester_callback is None:
key_list = list(safeIter(tree.iterkeys, **kw))
else:
deleter_callback = _deleterCallback
items = tree.items
while True:
to_delete = []
append = to_delete.append
for key, value in safeIter(items, **iter_kw):
if tester_callback(key, value):
append(key)
if len(to_delete) >= KEY_BATCH_SIZE:
iter_kw['min'] = key
iter_kw['excludemin'] = True
break
if to_delete:
deleter_callback(tree, to_delete)
else:
break
def _deleterCallback(tree, key_list):
for key in key_list:
del tree[key]
if hasattr(_OOBTree, 'pop'):
def _btreeDeleterCallback(tree, key_list):
for key in key_list:
prune(tree.pop(key))
else:
def _btreeDeleterCallback(tree, key_list):
key_list = [key for key, value in safeIter(tree.iteritems, **kw)
if tester_callback(key, value)]
if deleter_callback is None:
for key in key_list:
prune(tree[key])
del tree[key]
else:
deleter_callback(tree, key_list)
def OOBTree():
try:
......@@ -153,24 +121,20 @@ def safeIter(func, *args, **kw):
class BTreeDatabaseManager(DatabaseManager):
_obj = None
_trans = None
_tobj = None
_ttrans = None
_pt = None
_config = None
def __init__(self, database):
super(BTreeDatabaseManager, self).__init__()
self.setup(reset=1)
def setup(self, reset=0):
if reset:
self._data = OOBTree()
self._obj = OOBTree()
self._trans = OOBTree()
self.dropUnfinishedData()
self._tobj = OOBTree()
self._ttrans = OOBTree()
self._pt = {}
self._config = {}
self._uncommitted_data = {}
def _begin(self):
pass
......@@ -249,29 +213,6 @@ class BTreeDatabaseManager(DatabaseManager):
result = False
return result
def _getObjectData(self, oid, value_serial, tid):
if value_serial is None:
raise CreationUndone
if value_serial >= tid:
raise ValueError, "Incorrect value reference found for " \
"oid %d at tid %d: reference = %d" % (oid, value_serial, tid)
try:
tserial = self._obj[oid]
except KeyError:
raise IndexError(oid)
try:
compression, checksum, value, next_value_serial = tserial[
value_serial]
except KeyError:
raise IndexError(value_serial)
if value is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
value_serial, compression, checksum, value = self._getObjectData(
oid, next_value_serial, value_serial)
return value_serial, compression, checksum, value
def _getObject(self, oid, tid=None, before_tid=None):
tserial = self._obj.get(oid)
if tserial is not None:
......@@ -282,14 +223,20 @@ class BTreeDatabaseManager(DatabaseManager):
else:
tid = tserial.maxKey(before_tid - 1)
except ValueError:
return
result = tserial.get(tid)
if result:
try:
next_serial = tserial.minKey(tid + 1)
except ValueError:
next_serial = None
return (tid, next_serial) + result
return False
try:
checksum, value_serial = tserial[tid]
except KeyError:
return False
try:
next_serial = tserial.minKey(tid + 1)
except ValueError:
next_serial = None
if checksum is None:
compression = data = None
else:
compression, data, _ = self._data[checksum]
return tid, next_serial, compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset):
pt = self._pt
......@@ -311,16 +258,48 @@ class BTreeDatabaseManager(DatabaseManager):
def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True)
def _oidDeleterCallback(self, oid):
data = self._data
uncommitted_data = self._uncommitted_data
def deleter_callback(tree, key_list):
for tid in key_list:
checksum = tree[tid][0] # BBB: recent ZODB provides pop()
del tree[tid] #
if checksum:
index = data[checksum][2]
index.remove((oid, tid))
if not index and checksum not in uncommitted_data:
del data[checksum]
return deleter_callback
def _objDeleterCallback(self, tree, key_list):
data = self._data
checksum_list = []
checksum_set = set()
for oid in key_list:
tserial = tree[oid]; del tree[oid] # BBB: recent ZODB provides pop()
for tid, (checksum, _) in tserial.items():
if checksum:
index = data[checksum][2]
try:
index.remove((oid, tid))
except KeyError: # _tobj
checksum_list.append(checksum)
checksum_set.add(checksum)
prune(tserial)
self.unlockData(checksum_list)
self._pruneData(checksum_set)
def dropPartitions(self, num_partitions, offset_list):
offset_list = frozenset(offset_list)
def same_partition(key, _):
return key % num_partitions in offset_list
batchDelete(self._obj, same_partition, recycle_subtrees=True)
batchDelete(self._obj, same_partition, self._objDeleterCallback)
batchDelete(self._trans, same_partition)
def dropUnfinishedData(self):
self._tobj = OOBTree()
self._ttrans = OOBTree()
batchDelete(self._tobj, deleter_callback=self._objDeleterCallback)
self._ttrans.clear()
def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64
......@@ -331,45 +310,39 @@ class BTreeDatabaseManager(DatabaseManager):
else:
obj = self._obj
trans = self._trans
for oid, compression, checksum, data, value_serial in object_list:
data = self._data
for oid, checksum, value_serial in object_list:
oid = u64(oid)
if data is None:
compression = checksum = data
else:
# TODO: unit-test this raise
if value_serial is not None:
raise ValueError, 'Either data or value_serial ' \
'must be None (oid %d, tid %d)' % (oid, tid)
if value_serial:
value_serial = u64(value_serial)
checksum = self._obj[oid][value_serial][0]
if temporary:
self.storeData(checksum)
if checksum:
if not temporary:
data[checksum][2].add((oid, tid))
try:
tserial = obj[oid]
except KeyError:
tserial = obj[oid] = OOBTree()
if value_serial is not None:
value_serial = u64(value_serial)
tserial[tid] = (compression, checksum, data, value_serial)
tserial[tid] = checksum, value_serial
if transaction is not None:
oid_list, user, desc, ext, packed = transaction
trans[tid] = (tuple(oid_list), user, desc, ext, packed)
def _getDataTIDFromData(self, oid, result):
tid, _, _, _, data, value_serial = result
if data is None:
try:
data_serial = self._getObjectData(oid, value_serial, tid)[0]
except CreationUndone:
data_serial = None
else:
data_serial = tid
return tid, data_serial
def _pruneData(self, checksum_list):
data = self._data
for checksum in set(checksum_list).difference(self._uncommitted_data):
if not data[checksum][2]:
del data[checksum]
def _getDataTID(self, oid, tid=None, before_tid=None):
result = self._getObject(oid, tid=tid, before_tid=before_tid)
if result is None:
result = (None, None)
else:
result = self._getDataTIDFromData(oid, result)
return result
def _storeData(self, checksum, data, compression):
try:
if self._data[checksum][:2] != (compression, data):
raise AssertionError("hash collision")
except KeyError:
self._data[checksum] = compression, data, set()
def finishTransaction(self, tid):
tid = util.u64(tid)
......@@ -384,8 +357,9 @@ class BTreeDatabaseManager(DatabaseManager):
self._trans[tid] = data
def _popTransactionFromTObj(self, tid, to_obj):
checksum_list = []
if to_obj:
recycle_subtrees = False
deleter_callback = None
obj = self._obj
def callback(oid, data):
try:
......@@ -393,8 +367,12 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError:
tserial = obj[oid] = OOBTree()
tserial[tid] = data
checksum = data[0]
if checksum:
self._data[checksum][2].add((oid, tid))
checksum_list.append(checksum)
else:
recycle_subtrees = True
deleter_callback = self._objDeleterCallback
callback = lambda oid, data: None
def tester_callback(oid, tserial):
try:
......@@ -405,8 +383,8 @@ class BTreeDatabaseManager(DatabaseManager):
del tserial[tid]
callback(oid, data)
return not tserial
batchDelete(self._tobj, tester_callback,
recycle_subtrees=recycle_subtrees)
batchDelete(self._tobj, tester_callback, deleter_callback)
self.unlockData(checksum_list)
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
......@@ -427,7 +405,7 @@ class BTreeDatabaseManager(DatabaseManager):
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(self._trans, same_partition,
iter_kw={'min': util.u64(tid), 'max': util.u64(max_tid)})
min=util.u64(tid), max=util.u64(max_tid))
def deleteObject(self, oid, serial=None):
u64 = util.u64
......@@ -438,16 +416,11 @@ class BTreeDatabaseManager(DatabaseManager):
try:
tserial = obj[oid]
except KeyError:
pass
else:
if serial is not None:
try:
del tserial[serial]
except KeyError:
pass
if serial is None or not tserial:
prune(obj[oid])
del obj[oid]
return
batchDelete(tserial, deleter_callback=self._oidDeleterCallback(oid),
min=serial, max=serial)
if not tserial:
del obj[oid]
def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
max_tid):
......@@ -462,13 +435,14 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError:
pass
else:
batchDelete(tserial, lambda _, __: True,
iter_kw={'min': serial, 'max': max_tid})
batchDelete(tserial, min=serial, max=max_tid,
deleter_callback=self._oidDeleterCallback(oid))
if not tserial:
del tserial[oid]
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(obj, same_partition,
iter_kw={'min': oid, 'excludemin': True, 'max': max_tid},
recycle_subtrees=True)
batchDelete(obj, same_partition, self._objDeleterCallback,
min=oid, excludemin=True, max=max_tid)
def getTransaction(self, tid, all=False):
tid = util.u64(tid)
......@@ -504,15 +478,13 @@ class BTreeDatabaseManager(DatabaseManager):
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
_, _, value, value_serial = self._obj[oid][value_serial]
if value is None:
checksum, value_serial = self._obj[oid][value_serial]
if checksum is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
length = self._getObjectLength(oid, value_serial)
else:
length = len(value)
return length
return self._getObjectLength(oid, value_serial)
return len(self._data[checksum][1])
def getObjectHistory(self, oid, offset=0, length=1):
# FIXME: This method doesn't take client's current ransaction id as
......@@ -532,17 +504,18 @@ class BTreeDatabaseManager(DatabaseManager):
while offset > 0:
tserial_iter.next()
offset -= 1
for serial, (_, _, value, value_serial) in tserial_iter:
data = self._data
for serial, (checksum, value_serial) in tserial_iter:
if length == 0 or serial < pack_tid:
break
length -= 1
if value is None:
if checksum is None:
try:
data_length = self._getObjectLength(oid, value_serial)
except CreationUndone:
data_length = 0
else:
data_length = len(value)
data_length = len(data[checksum][1])
append((p64(serial), data_length))
if not result:
result = None
......@@ -613,39 +586,28 @@ class BTreeDatabaseManager(DatabaseManager):
append(p64(tid))
return result
def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack):
p64 = util.p64
def _updatePackFuture(self, oid, orig_serial, max_serial):
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
value_serial = None
new_serial = None
obj = self._obj
for tree in (obj, self._tobj):
try:
tserial = tree[oid]
except KeyError:
continue
for serial, record in tserial.items(
for serial, (checksum, value_serial) in tserial.iteritems(
min=max_serial):
if record[3] == orig_serial:
if value_serial is None:
value_serial = serial
tserial[serial] = tserial[orig_serial]
else:
record = list(record)
record[3] = value_serial
tserial[serial] = tuple(record)
def getObjectData():
assert value_serial is None
return obj[oid][orig_serial][:3]
if value_serial:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(orig_serial), value_serial,
getObjectData)
if value_serial == orig_serial:
tserial[serial] = checksum, new_serial
if not new_serial:
new_serial = serial
return new_serial
def pack(self, tid, updateObjectDataForPack):
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
self._setPackTID(tid)
......@@ -656,17 +618,21 @@ class BTreeDatabaseManager(DatabaseManager):
# No entry before pack TID, nothing to pack on this object.
pass
else:
if tserial[max_serial][1] is None:
if tserial[max_serial][0] is None:
# Last version before/at pack TID is a creation undo, drop
# it too.
max_serial += 1
def serial_callback(serial, _):
updatePackFuture(oid, serial, max_serial,
updateObjectDataForPack)
def serial_callback(serial, value):
new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, value[0])
batchDelete(tserial, serial_callback,
iter_kw={'max': max_serial, 'excludemax': True})
self._oidDeleterCallback(oid),
max=max_serial, excludemax=True)
return not tserial
batchDelete(self._obj, obj_callback, recycle_subtrees=True)
batchDelete(self._obj, obj_callback, self._objDeleterCallback)
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
if length:
......@@ -679,9 +645,9 @@ class BTreeDatabaseManager(DatabaseManager):
break
if tid_list:
return (len(tid_list),
md5(','.join(map(str, tid_list))).digest(),
sha1(','.join(map(str, tid_list))).digest(),
util.p64(tid_list[-1]))
return 0, None, ZERO_TID
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition):
......@@ -712,8 +678,8 @@ class BTreeDatabaseManager(DatabaseManager):
if oid_list:
p64 = util.p64
return (len(oid_list),
md5(','.join(map(str, oid_list))).digest(),
sha1(','.join(map(str, oid_list))).digest(),
p64(oid_list[-1]),
md5(','.join(map(str, serial_list))).digest(),
sha1(','.join(map(str, serial_list))).digest(),
p64(serial_list[-1]))
return 0, None, ZERO_OID, None, ZERO_TID
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
......@@ -15,6 +15,7 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib
from neo.lib import util
from neo.lib.exception import DatabaseFailure
......@@ -24,6 +25,8 @@ class CreationUndone(Exception):
class DatabaseManager(object):
"""This class only describes an interface for database managers."""
def __init__(self):
"""
Initialize the object.
......@@ -59,8 +62,17 @@ class DatabaseManager(object):
self._under_transaction = False
def setup(self, reset = 0):
"""Set up a database. If reset is true, existing data must be
discarded."""
"""Set up a database
It must recover self._uncommitted_data from temporary object table.
_uncommitted_data is a dict containing refcounts to data of
write-locked objects, except in case of undo, where the refcount is
increased later, when the object is read-locked.
Keys are checksums and values are number of references.
If reset is true, existing data must be discarded and
self._uncommitted_data must be an empty dict.
"""
raise NotImplementedError
def _begin(self):
......@@ -213,7 +225,7 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def getObject(self, oid, tid=None, before_tid=None, resolve_data=True):
def getObject(self, oid, tid=None, before_tid=None):
"""
oid (packed)
Identifier of object to retrieve.
......@@ -222,9 +234,6 @@ class DatabaseManager(object):
before_tid (packed, None)
Serial to retrieve is the highest existing one strictly below this
value.
resolve_data (bool, True)
If actual object data is desired, or raw record content.
This is different in case retrieved line undoes a transaction.
Return value:
None: Given oid doesn't exist in database.
......@@ -237,7 +246,6 @@ class DatabaseManager(object):
- data (binary string, None)
- data_serial (packed, None)
"""
# TODO: resolve_data must be unit-tested
u64 = util.u64
p64 = util.p64
oid = u64(oid)
......@@ -246,32 +254,20 @@ class DatabaseManager(object):
if before_tid is not None:
before_tid = u64(before_tid)
result = self._getObject(oid, tid, before_tid)
if result is None:
# See if object exists at all
result = self._getObject(oid)
if result is not None:
# Object exists
result = False
else:
if result:
serial, next_serial, compression, checksum, data, data_serial = \
result
assert before_tid is None or next_serial is None or \
before_tid <= next_serial
if data is None and resolve_data:
try:
_, compression, checksum, data = self._getObjectData(oid,
data_serial, serial)
except CreationUndone:
pass
data_serial = None
if serial is not None:
serial = p64(serial)
if next_serial is not None:
next_serial = p64(next_serial)
if data_serial is not None:
data_serial = p64(data_serial)
result = serial, next_serial, compression, checksum, data, data_serial
return result
return serial, next_serial, compression, checksum, data, data_serial
# See if object exists at all
return self._getObject(oid) and False
def changePartitionTable(self, ptid, cell_list):
"""Change a part of a partition table. The list of cells is
......@@ -298,12 +294,68 @@ class DatabaseManager(object):
"""Store a transaction temporarily, if temporary is true. Note
that this transaction is not finished yet. The list of objects
contains tuples, each of which consists of an object ID,
a compression specification, a checksum and object data.
a checksum and object serial.
The transaction is either None or a tuple of the list of OIDs,
user information, a description, extension information and transaction
pack state (True for packed)."""
raise NotImplementedError
def _pruneData(self, checksum_list):
"""To be overriden by the backend to delete any unreferenced data
'unreferenced' means:
- not in self._uncommitted_data
- and not referenced by a fully-committed object (storage should have
an index or a refcound of all data checksums of all objects)
"""
raise NotImplementedError
def _storeData(self, checksum, data, compression):
"""To be overriden by the backend to store object raw data
If same data was already stored, the storage only has to check there's
no hash collision.
"""
raise NotImplementedError
def storeData(self, checksum, data=None, compression=None):
"""Store object raw data
'checksum' must be the result of neo.lib.util.makeChecksum(data)
'compression' indicates if 'data' is compressed.
A volatile reference is set to this data until 'unlockData' is called
with this checksum.
If called with only a checksum, it only increment the volatile
reference to the data matching the checksum.
"""
refcount = self._uncommitted_data
refcount[checksum] = 1 + refcount.get(checksum, 0)
if data is not None:
self._storeData(checksum, data, compression)
def unlockData(self, checksum_list, prune=False):
"""Release 1 volatile reference to given list of checksums
If 'prune' is true, any data that is not referenced anymore (either by
a volatile reference or by a fully-committed object) is deleted.
"""
refcount = self._uncommitted_data
for checksum in checksum_list:
count = refcount[checksum] - 1
if count:
refcount[checksum] = count
else:
del refcount[checksum]
if prune:
self.begin()
try:
self._pruneData(checksum_list)
except:
self.rollback()
raise
self.commit()
__getDataTID = set()
def _getDataTID(self, oid, tid=None, before_tid=None):
"""
Return a 2-tuple:
......@@ -321,7 +373,17 @@ class DatabaseManager(object):
Otherwise, it's an undo transaction which did not involve conflict
resolution.
"""
raise NotImplementedError
if self.__class__ not in self.__getDataTID:
self.__getDataTID.add(self.__class__)
neo.lib.logging.warning("Fallback to generic/slow implementation"
" of _getDataTID. It should be overriden by backend storage.")
r = self._getObject(oid, tid, before_tid)
if r:
serial, _, _, checksum, _, value_serial = r
if value_serial is None and checksum:
return serial, serial
return serial, value_serial
return None, None
def findUndoTID(self, oid, tid, ltid, undone_tid, transaction_object):
"""
......@@ -360,21 +422,31 @@ class DatabaseManager(object):
if ltid:
ltid = u64(ltid)
undone_tid = u64(undone_tid)
_getDataTID = self._getDataTID
if transaction_object is not None:
_, _, _, _, tvalue_serial = transaction_object
current_tid = current_data_tid = u64(tvalue_serial)
def getDataTID(tid=None, before_tid=None):
tid, value_serial = self._getDataTID(oid, tid, before_tid)
if value_serial not in (None, tid):
if value_serial >= tid:
raise ValueError("Incorrect value reference found for"
" oid %d at tid %d: reference = %d"
% (oid, value_serial, tid))
if value_serial != getDataTID(value_serial)[1]:
neo.lib.logging.warning("Multiple levels of indirection"
" when getting data serial for oid %d at tid %d."
" This causes suboptimal performance." % (oid, tid))
return tid, value_serial
if transaction_object:
current_tid = current_data_tid = u64(transaction_object[2])
else:
current_tid, current_data_tid = _getDataTID(oid, before_tid=ltid)
current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None:
return (None, None, False)
found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid)
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
assert found_undone_tid is not None, (oid, undone_tid)
is_current = undone_data_tid in (current_data_tid, tid)
# Load object data as it was before given transaction.
# It can be None, in which case it means we are undoing object
# creation.
_, data_tid = _getDataTID(oid, before_tid=undone_tid)
_, data_tid = getDataTID(before_tid=undone_tid)
if data_tid is not None:
data_tid = p64(data_tid)
return p64(current_tid), data_tid, is_current
......@@ -471,8 +543,8 @@ class DatabaseManager(object):
Returns a 3-tuple:
- number of records actually found
- a XOR computed from record's TID
0 if no record found
- a SHA1 computed from record's TID
ZERO_HASH if no record found
- biggest TID found (ie, TID of last record read)
ZERO_TID if not record found
"""
......@@ -493,12 +565,12 @@ class DatabaseManager(object):
Returns a 5-tuple:
- number of records actually found
- a XOR computed from record's OID
0 if no record found
- a SHA1 computed from record's OID
ZERO_HASH if no record found
- biggest OID found (ie, OID of last record read)
ZERO_OID if no record found
- a XOR computed from record's serial
0 if no record found
- a SHA1 computed from record's serial
ZERO_HASH if no record found
- biggest serial found for biggest OID found (ie, serial of last
record read)
ZERO_TID if no record found
......
......@@ -17,18 +17,19 @@
from binascii import a2b_hex
import MySQLdb
from MySQLdb import OperationalError
from MySQLdb import IntegrityError, OperationalError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DUP_ENTRY
import neo.lib
from array import array
from hashlib import md5
from hashlib import sha1
import re
import string
from neo.storage.database import DatabaseManager
from neo.storage.database.manager import CreationUndone
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib import util
LOG_QUERIES = False
......@@ -46,6 +47,9 @@ def splitOIDField(tid, oids):
class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL."""
# WARNING: some parts are not concurrent safe (ex: storeData)
# (there must be only 1 writable connection per DB)
# Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because
# 'select count(*) from obj' sometimes returns incorrect values
# (tested with testOudatedCellsOnDownStorage).
......@@ -136,8 +140,7 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query
if reset:
q('DROP TABLE IF EXISTS config, pt, trans, obj, obj_short, '
'ttrans, tobj')
q('DROP TABLE IF EXISTS config, pt, trans, obj, data, ttrans, tobj')
# The table "config" stores configuration parameters which affect the
# persistent data.
......@@ -174,22 +177,18 @@ class MySQLDatabaseManager(DatabaseManager):
partition SMALLINT UNSIGNED NOT NULL,
oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NULL,
checksum INT UNSIGNED NULL,
value LONGBLOB NULL,
hash BINARY(20) NULL,
value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (partition, oid, serial)
PRIMARY KEY (partition, oid, serial),
KEY (hash(4))
) ENGINE = InnoDB""" + p)
# The table "obj_short" contains columns which are accessed in queries
# which don't need to access object data. This is needed because InnoDB
# loads a whole row even when it only needs columns in primary key.
q('CREATE TABLE IF NOT EXISTS obj_short ('
'partition SMALLINT UNSIGNED NOT NULL,'
'oid BIGINT UNSIGNED NOT NULL,'
'serial BIGINT UNSIGNED NOT NULL,'
'PRIMARY KEY (partition, oid, serial)'
') ENGINE = InnoDB' + p)
#
q("""CREATE TABLE IF NOT EXISTS data (
hash BINARY(20) NOT NULL PRIMARY KEY,
compression TINYINT UNSIGNED NULL,
value LONGBLOB NULL
) ENGINE = InnoDB""")
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
......@@ -207,21 +206,13 @@ class MySQLDatabaseManager(DatabaseManager):
partition SMALLINT UNSIGNED NOT NULL,
oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NULL,
checksum INT UNSIGNED NULL,
value LONGBLOB NULL,
value_serial BIGINT UNSIGNED NULL
hash BINARY(20) NULL,
value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (serial, oid)
) ENGINE = InnoDB""")
def objQuery(self, query):
"""
Execute given query for both obj and obj_short tables.
query: query string, must contain "%(table)s" where obj table name is
needed.
"""
q = self.query
for table in ('obj', 'obj_short'):
q(query % {'table': table})
self._uncommitted_data = dict(q("SELECT hash, count(*)"
" FROM tobj WHERE hash IS NOT NULL GROUP BY hash") or ())
def getConfiguration(self, key):
if key in self._config:
......@@ -309,45 +300,22 @@ class MySQLDatabaseManager(DatabaseManager):
tid = util.u64(tid)
partition = self._getPartition(oid)
self.begin()
r = q("SELECT oid FROM obj_short WHERE partition=%d AND oid=%d AND "
r = q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND "
"serial=%d" % (partition, oid, tid))
if not r and all:
r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \
% (oid, tid))
r = q("SELECT oid FROM tobj WHERE serial=%d AND oid=%d"
% (tid, oid))
self.commit()
if r:
return True
return False
def _getObjectData(self, oid, value_serial, tid):
if value_serial is None:
raise CreationUndone
if value_serial >= tid:
raise ValueError, "Incorrect value reference found for " \
"oid %d at tid %d: reference = %d" % (oid, value_serial, tid)
r = self.query("""SELECT compression, checksum, value, """ \
"""value_serial FROM obj WHERE partition = %(partition)d """
"""AND oid = %(oid)d AND serial = %(serial)d""" % {
'partition': self._getPartition(oid),
'oid': oid,
'serial': value_serial,
})
compression, checksum, value, next_value_serial = r[0]
if value is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
value_serial, compression, checksum, value = self._getObjectData(
oid, next_value_serial, value_serial)
return value_serial, compression, checksum, value
def _getObject(self, oid, tid=None, before_tid=None):
q = self.query
partition = self._getPartition(oid)
sql = """SELECT serial, compression, checksum, value, value_serial
FROM obj
WHERE partition = %d
AND oid = %d""" % (partition, oid)
sql = ('SELECT serial, compression, obj.hash, value, value_serial'
' FROM obj LEFT JOIN data ON (obj.hash = data.hash)'
' WHERE partition = %d AND oid = %d') % (partition, oid)
if tid is not None:
sql += ' AND serial = %d' % tid
elif before_tid is not None:
......@@ -361,7 +329,7 @@ class MySQLDatabaseManager(DatabaseManager):
serial, compression, checksum, data, value_serial = r[0]
except IndexError:
return None
r = q("""SELECT serial FROM obj_short
r = q("""SELECT serial FROM obj
WHERE partition = %d AND oid = %d AND serial > %d
ORDER BY serial LIMIT 1""" % (partition, oid, serial))
try:
......@@ -399,7 +367,7 @@ class MySQLDatabaseManager(DatabaseManager):
for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION (
PARTITION p%u VALUES IN (%u))""" % (offset, offset)
for table in 'trans', 'obj', 'obj_short':
for table in 'trans', 'obj':
try:
self.conn.query(add % table)
except OperationalError, (code, _):
......@@ -414,42 +382,45 @@ class MySQLDatabaseManager(DatabaseManager):
def dropPartitions(self, num_partitions, offset_list):
q = self.query
if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list)
for table in 'trans', 'obj', 'obj_short':
try:
self.conn.query(drop % table)
except OperationalError, (code, _):
if code != 1508: # already dropped
raise
return
e = self.escape
offset_list = ', '.join((str(i) for i in offset_list))
self.begin()
try:
# XXX: these queries are inefficient (execution time increase with
# row count, although we use indexes) when there are rows to
# delete. It should be done as an idle task, by chunks.
self.objQuery('DELETE FROM %%(table)s WHERE partition IN (%s)' %
(offset_list, ))
q("""DELETE FROM trans WHERE partition IN (%s)""" %
(offset_list, ))
for partition in offset_list:
where = " WHERE partition=%d" % partition
checksum_list = [x for x, in
q("SELECT DISTINCT hash FROM obj" + where) if x]
if not self._use_partition:
q("DELETE FROM obj" + where)
q("DELETE FROM trans" + where)
self._pruneData(checksum_list)
except:
self.rollback()
raise
self.commit()
if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list)
for table in 'trans', 'obj':
try:
self.conn.query(drop % table)
except OperationalError, (code, _):
if code != 1508: # already dropped
raise
def dropUnfinishedData(self):
q = self.query
self.begin()
try:
checksum_list = [x for x, in q("SELECT hash FROM tobj") if x]
q("""TRUNCATE tobj""")
q("""TRUNCATE ttrans""")
except:
self.rollback()
raise
self.commit()
self.unlockData(checksum_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True):
q = self.query
......@@ -466,30 +437,24 @@ class MySQLDatabaseManager(DatabaseManager):
self.begin()
try:
for oid, compression, checksum, data, value_serial in object_list:
for oid, checksum, value_serial in object_list:
oid = u64(oid)
if data is None:
compression = checksum = data = 'NULL'
partition = self._getPartition(oid)
if value_serial:
value_serial = u64(value_serial)
(checksum,), = q("SELECT hash FROM obj"
" WHERE partition=%d AND oid=%d AND serial=%d"
% (partition, oid, value_serial))
if temporary:
self.storeData(checksum)
else:
# TODO: unit-test this raise
if value_serial is not None:
raise ValueError, 'Either data or value_serial ' \
'must be None (oid %d, tid %d)' % (oid, tid)
compression = '%d' % (compression, )
checksum = '%d' % (checksum, )
data = "'%s'" % (e(data), )
if value_serial is None:
value_serial = 'NULL'
if checksum:
checksum = "'%s'" % e(checksum)
else:
value_serial = '%d' % (u64(value_serial), )
partition = self._getPartition(oid)
q("""REPLACE INTO %s VALUES (%d, %d, %d, %s, %s, %s, %s)""" \
% (obj_table, partition, oid, tid, compression, checksum,
data, value_serial))
if obj_table == 'obj':
# Update obj_short too
q('REPLACE INTO obj_short VALUES (%d, %d, %d)' % (
partition, oid, tid))
checksum = 'NULL'
q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" %
(obj_table, partition, oid, tid, checksum, value_serial))
if transaction is not None:
oid_list, user, desc, ext, packed = transaction
......@@ -507,66 +472,95 @@ class MySQLDatabaseManager(DatabaseManager):
raise
self.commit()
def _getDataTIDFromData(self, oid, result):
tid, next_serial, compression, checksum, data, value_serial = result
if data is None:
def _pruneData(self, checksum_list):
checksum_list = set(checksum_list).difference(self._uncommitted_data)
if checksum_list:
self.query("DELETE data FROM data"
" LEFT JOIN obj ON (data.hash = obj.hash)"
" WHERE data.hash IN ('%s') AND obj.hash IS NULL"
% "','".join(map(self.escape, checksum_list)))
def _storeData(self, checksum, data, compression):
e = self.escape
checksum = e(checksum)
self.begin()
try:
try:
data_serial = self._getObjectData(oid, value_serial, tid)[0]
except CreationUndone:
data_serial = None
else:
data_serial = tid
return tid, data_serial
self.query("INSERT INTO data VALUES ('%s', %d, '%s')" %
(checksum, compression, e(data)))
except IntegrityError, (code, _):
if code != DUP_ENTRY:
raise
r, = self.query("SELECT compression, value FROM data"
" WHERE hash='%s'" % checksum)
if r != (compression, data):
raise
except:
self.rollback()
raise
self.commit()
def _getDataTID(self, oid, tid=None, before_tid=None):
result = self._getObject(oid, tid=tid, before_tid=before_tid)
if result is None:
result = (None, None)
sql = ('SELECT serial, hash, value_serial FROM obj'
' WHERE partition = %d AND oid = %d'
) % (self._getPartition(oid), oid)
if tid is not None:
sql += ' AND serial = %d' % tid
elif before_tid is not None:
sql += ' AND serial < %d ORDER BY serial DESC LIMIT 1' % before_tid
else:
result = self._getDataTIDFromData(oid, result)
return result
# XXX I want to express "HAVING serial = MAX(serial)", but
# MySQL does not use an index for a HAVING clause!
sql += ' ORDER BY serial DESC LIMIT 1'
r = self.query(sql)
if r:
(serial, checksum, value_serial), = r
if value_serial is None and checksum:
return serial, serial
return serial, value_serial
return None, None
def finishTransaction(self, tid):
q = self.query
tid = util.u64(tid)
self.begin()
try:
q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \
% tid)
q('INSERT INTO obj_short SELECT partition, oid, serial FROM tobj'
' WHERE tobj.serial = %d' % (tid, ))
q("""DELETE FROM tobj WHERE serial = %d""" % tid)
q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d"""
% tid)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
sql = " FROM tobj WHERE serial=%d" % tid
checksum_list = [x for x, in q("SELECT hash" + sql) if x]
q("INSERT INTO obj SELECT *" + sql)
q("DELETE FROM tobj WHERE serial=%d" % tid)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
except:
self.rollback()
raise
self.commit()
self.unlockData(checksum_list)
def deleteTransaction(self, tid, oid_list=()):
q = self.query
objQuery = self.objQuery
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
self.begin()
try:
q("""DELETE FROM tobj WHERE serial = %d""" % tid)
sql = " FROM tobj WHERE serial=%d" % tid
checksum_list = [x for x, in q("SELECT hash" + sql) if x]
self.unlockData(checksum_list)
q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" %
(getPartition(tid), tid))
# delete from obj using indexes
checksum_set = set()
for oid in oid_list:
oid = u64(oid)
partition = getPartition(oid)
objQuery('DELETE FROM %%(table)s WHERE '
'partition=%(partition)d '
'AND oid = %(oid)d AND serial = %(serial)d' % {
'partition': partition,
'oid': oid,
'serial': tid,
})
sql = " FROM obj WHERE partition=%d AND oid=%d AND serial=%d" \
% (getPartition(oid), oid, tid)
checksum_set.update(*q("SELECT hash" + sql))
q("DELETE" + sql)
checksum_set.discard(None)
self._pruneData(checksum_set)
except:
self.rollback()
raise
......@@ -587,20 +581,18 @@ class MySQLDatabaseManager(DatabaseManager):
self.commit()
def deleteObject(self, oid, serial=None):
q = self.query
u64 = util.u64
oid = u64(oid)
query_param_dict = {
'partition': self._getPartition(oid),
'oid': oid,
}
query_fmt = 'DELETE FROM %%(table)s WHERE ' \
'partition = %(partition)d AND oid = %(oid)d'
if serial is not None:
query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d'
sql = " FROM obj WHERE partition=%d AND oid=%d" \
% (self._getPartition(oid), oid)
if serial:
sql += ' AND serial=%d' % u64(serial)
self.begin()
try:
self.objQuery(query_fmt % query_param_dict)
checksum_list = [x for x, in q("SELECT DISTINCT hash" + sql) if x]
q("DELETE" + sql)
self._pruneData(checksum_list)
except:
self.rollback()
raise
......@@ -608,17 +600,17 @@ class MySQLDatabaseManager(DatabaseManager):
def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
max_tid):
q = self.query
u64 = util.u64
oid = u64(oid)
sql = (" FROM obj WHERE partition=%d AND serial <= %d"
" AND (oid > %d OR (oid = %d AND serial >= %d))" %
(partition, u64(max_tid), oid, oid, u64(serial)))
self.begin()
try:
self.objQuery('DELETE FROM %%(table)s WHERE partition=%(partition)d'
' AND serial <= %(max_tid)d AND ('
'oid > %(oid)d OR (oid = %(oid)d AND serial >= %(serial)d))' % {
'partition': partition,
'max_tid': u64(max_tid),
'oid': u64(oid),
'serial': u64(serial),
})
checksum_list = [x for x, in q("SELECT DISTINCT hash" + sql) if x]
q("DELETE" + sql)
self._pruneData(checksum_list)
except:
self.rollback()
raise
......@@ -645,8 +637,9 @@ class MySQLDatabaseManager(DatabaseManager):
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
r = self.query("""SELECT LENGTH(value), value_serial FROM obj """ \
"""WHERE partition = %d AND oid = %d AND serial = %d""" %
r = self.query("""SELECT LENGTH(value), value_serial
FROM obj LEFT JOIN data ON (obj.hash = data.hash)
WHERE partition = %d AND oid = %d AND serial = %d""" %
(self._getPartition(oid), oid, value_serial))
length, value_serial = r[0]
if length is None:
......@@ -660,11 +653,11 @@ class MySQLDatabaseManager(DatabaseManager):
# FIXME: This method doesn't take client's current ransaction id as
# parameter, which means it can return transactions in the future of
# client's transaction.
q = self.query
oid = util.u64(oid)
p64 = util.p64
pack_tid = self._getPackTID()
r = q("""SELECT serial, LENGTH(value), value_serial FROM obj
r = self.query("""SELECT serial, LENGTH(value), value_serial
FROM obj LEFT JOIN data ON (obj.hash = data.hash)
WHERE partition = %d AND oid = %d AND serial >= %d
ORDER BY serial DESC LIMIT %d, %d""" \
% (self._getPartition(oid), oid, pack_tid, offset, length))
......@@ -689,7 +682,7 @@ class MySQLDatabaseManager(DatabaseManager):
min_oid = u64(min_oid)
min_serial = u64(min_serial)
max_serial = u64(max_serial)
r = q('SELECT oid, serial FROM obj_short '
r = q('SELECT oid, serial FROM obj '
'WHERE partition = %(partition)s '
'AND serial <= %(max_serial)d '
'AND ((oid = %(min_oid)d AND serial >= %(min_serial)d) '
......@@ -735,71 +728,37 @@ class MySQLDatabaseManager(DatabaseManager):
})
return [p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack):
def _updatePackFuture(self, oid, orig_serial, max_serial):
q = self.query
p64 = util.p64
getPartition = self._getPartition
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
value_serial = None
for table in ('obj', 'tobj'):
for (serial, ) in q('SELECT serial FROM %(table)s WHERE '
'partition = %(partition)d AND oid = %(oid)d '
'AND serial >= %(max_serial)d AND '
'value_serial = %(orig_serial)d ORDER BY serial ASC' % {
'table': table,
'partition': getPartition(oid),
'oid': oid,
'orig_serial': orig_serial,
'max_serial': max_serial,
}):
kw = {
'partition': self._getPartition(oid),
'oid': oid,
'orig_serial': orig_serial,
'max_serial': max_serial,
'new_serial': 'NULL',
}
for kw['table'] in 'obj', 'tobj':
for kw['serial'], in q('SELECT serial FROM %(table)s'
' WHERE partition=%(partition)d AND oid=%(oid)d'
' AND serial>=%(max_serial)d AND value_serial=%(orig_serial)d'
' ORDER BY serial ASC' % kw):
q('UPDATE %(table)s SET value_serial=%(new_serial)s'
' WHERE partition=%(partition)d AND oid=%(oid)d'
' AND serial=%(serial)d' % kw)
if value_serial is None:
# First found, copy data to it and mark its serial for
# future reference.
value_serial = serial
q('REPLACE INTO %(table)s (partition, oid, serial, compression, '
'checksum, value, value_serial) SELECT partition, oid, '
'%(serial)d, compression, checksum, value, NULL FROM '
'obj WHERE partition = %(partition)d AND oid = %(oid)d '
'AND serial = %(orig_serial)d' \
% {
'table': table,
'partition': getPartition(oid),
'oid': oid,
'serial': serial,
'orig_serial': orig_serial,
})
else:
q('REPLACE INTO %(table)s (partition, oid, serial, value_serial) '
'VALUES (%(partition)d, %(oid)d, %(serial)d, '
'%(value_serial)d)' % {
'table': table,
'partition': getPartition(oid),
'oid': oid,
'serial': serial,
'value_serial': value_serial,
})
def getObjectData():
assert value_serial is None
return q('SELECT compression, checksum, value FROM obj WHERE '
'partition = %(partition)d AND oid = %(oid)d '
'AND serial = %(orig_serial)d' % {
'partition': getPartition(oid),
'oid': oid,
'orig_serial': orig_serial,
})[0]
if value_serial:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(orig_serial), value_serial,
getObjectData)
# First found, mark its serial for future reference.
kw['new_serial'] = value_serial = kw['serial']
return value_serial
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
q = self.query
objQuery = self.objQuery
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
getPartition = self._getPartition
......@@ -807,35 +766,29 @@ class MySQLDatabaseManager(DatabaseManager):
try:
self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
'MAX(serial) FROM obj_short WHERE serial <= %(tid)d '
'GROUP BY oid' % {'tid': tid}):
if q('SELECT 1 FROM obj WHERE partition ='
'%(partition)s AND oid = %(oid)d AND '
'serial = %(max_serial)d AND checksum IS NULL' % {
'oid': oid,
'partition': getPartition(oid),
'max_serial': max_serial,
}):
count += 1
'MAX(serial) FROM obj WHERE serial <= %d GROUP BY oid'
% tid):
partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition = %d"
" AND oid = %d AND serial = %d AND hash IS NULL"
% (partition, oid, max_serial)):
max_serial += 1
if count:
# There are things to delete for this object
for (serial, ) in q('SELECT serial FROM obj_short WHERE '
'partition=%(partition)d AND oid=%(oid)d AND '
'serial < %(max_serial)d' % {
'oid': oid,
'partition': getPartition(oid),
'max_serial': max_serial,
}):
updatePackFuture(oid, serial, max_serial,
updateObjectDataForPack)
objQuery('DELETE FROM %%(table)s WHERE '
'partition=%(partition)d '
'AND oid=%(oid)d AND serial=%(serial)d' % {
'partition': getPartition(oid),
'oid': oid,
'serial': serial
})
elif not count:
continue
# There are things to delete for this object
checksum_set = set()
sql = ' FROM obj WHERE partition=%d AND oid=%d' \
' AND serial<%d' % (partition, oid, max_serial)
for serial, checksum in q('SELECT serial, hash' + sql):
checksum_set.add(checksum)
new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, checksum)
q('DELETE' + sql)
checksum_set.discard(None)
self._pruneData(checksum_set)
except:
self.rollback()
raise
......@@ -843,7 +796,7 @@ class MySQLDatabaseManager(DatabaseManager):
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
count, tid_checksum, max_tid = self.query(
"""SELECT COUNT(*), MD5(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
"""SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
FROM (SELECT tid FROM trans
WHERE partition = %(partition)s
AND tid >= %(min_tid)d
......@@ -854,12 +807,9 @@ class MySQLDatabaseManager(DatabaseManager):
'max_tid': util.u64(max_tid),
'length': length,
})[0]
if count == 0:
max_tid = ZERO_TID
else:
tid_checksum = a2b_hex(tid_checksum)
max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid
if count:
return count, a2b_hex(tid_checksum), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition):
......@@ -870,7 +820,7 @@ class MySQLDatabaseManager(DatabaseManager):
# last grouped value, instead of the greatest one.
r = self.query(
"""SELECT oid, serial
FROM obj_short
FROM obj
WHERE partition = %(partition)s
AND serial <= %(max_tid)d
AND (oid > %(min_oid)d OR
......@@ -885,8 +835,8 @@ class MySQLDatabaseManager(DatabaseManager):
if r:
p64 = util.p64
return (len(r),
md5(','.join(str(x[0]) for x in r)).digest(),
sha1(','.join(str(x[0]) for x in r)).digest(),
p64(r[-1][0]),
md5(','.join(str(x[1]) for x in r)).digest(),
sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1]))
return 0, None, ZERO_OID, None, ZERO_TID
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
......@@ -21,7 +21,7 @@ from neo.lib.handler import EventHandler
from neo.lib import protocol
from neo.lib.util import dump
from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors
from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors, ZERO_HASH
class BaseMasterHandler(EventHandler):
......@@ -97,7 +97,7 @@ class BaseClientAndStorageOperationHandler(EventHandler):
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial))
if checksum is None:
checksum = 0
checksum = ZERO_HASH
data = ''
p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial)
......
......@@ -18,7 +18,7 @@
import neo.lib
from neo.lib import protocol
from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors
from neo.lib.protocol import Packets, LockState, Errors, ZERO_HASH
from neo.storage.handlers import BaseClientAndStorageOperationHandler
from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.exception import AlreadyPendingError
......@@ -88,7 +88,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
compression, checksum, data, data_serial, ttid, unlock):
# register the transaction
self.app.tm.register(conn.getUUID(), ttid)
if data or checksum:
if data or checksum != ZERO_HASH:
# TODO: return an appropriate error packet
assert makeChecksum(data) == checksum
assert data_serial is None
......
......@@ -20,7 +20,7 @@ from functools import wraps
import neo.lib
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, ZERO_TID, ZERO_OID
from neo.lib.protocol import Packets, ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64
# TODO: benchmark how different values behave
......@@ -173,12 +173,14 @@ class ReplicationHandler(EventHandler):
@checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial):
app = self.app
dm = self.app.dm
if data or checksum != ZERO_HASH:
dm.storeData(checksum, data, compression)
else:
checksum = None
# Directly store the transaction.
obj = (oid, compression, checksum, data, data_serial)
app.dm.storeTransaction(serial_start, [obj], None, False)
del obj
del data
obj = oid, checksum, data_serial
dm.storeTransaction(serial_start, [obj], None, False)
def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid,
length=RANGE_LENGTH):
......
......@@ -21,7 +21,10 @@ from neo.lib.protocol import Packets
class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid, resolve_data=False)
result = self.app.dm.getObject(oid, serial, tid)
if result and result[5]:
return result[:2] + (None, None, None) + result[4:]
return result
def askLastIDs(self, conn):
app = self.app
......
......@@ -98,22 +98,21 @@ class Transaction(object):
# assert self._transaction is not None
self._transaction = (oid_list, user, desc, ext, packed)
def addObject(self, oid, compression, checksum, data, value_serial):
def addObject(self, oid, checksum, value_serial):
"""
Add an object to the transaction
"""
assert oid not in self._checked_set, dump(oid)
self._object_dict[oid] = (oid, compression, checksum, data,
value_serial)
self._object_dict[oid] = oid, checksum, value_serial
def delObject(self, oid):
try:
del self._object_dict[oid]
return self._object_dict.pop(oid)[1]
except KeyError:
self._checked_set.remove(oid)
def getObject(self, oid):
return self._object_dict.get(oid)
return self._object_dict[oid]
def getObjectList(self):
return self._object_dict.values()
......@@ -163,10 +162,10 @@ class TransactionManager(object):
Return object data for given running transaction.
Return None if not found.
"""
result = self._transaction_dict.get(ttid)
if result is not None:
result = result.getObject(oid)
return result
try:
return self._transaction_dict[ttid].getObject(oid)
except KeyError:
return None
def reset(self):
"""
......@@ -242,7 +241,9 @@ class TransactionManager(object):
# drop the lock it held on this object, and drop object data for
# consistency.
del self._store_lock_dict[oid]
self._transaction_dict[ttid].delObject(oid)
checksum = self._transaction_dict[ttid].delObject(oid)
if checksum:
self._app.dm.pruneData((checksum,))
# Give a chance to pending events to take that lock now.
self._app.executeQueuedEvents()
# Attemp to acquire lock again.
......@@ -252,7 +253,7 @@ class TransactionManager(object):
elif locking_tid == ttid:
# If previous store was an undo, next store must be based on
# undo target.
previous_serial = self._transaction_dict[ttid].getObject(oid)[4]
previous_serial = self._transaction_dict[ttid].getObject(oid)[2]
if previous_serial is None:
# XXX: use some special serial when previous store was not
# an undo ? Maybe it should just not happen.
......@@ -301,8 +302,11 @@ class TransactionManager(object):
self.lockObject(ttid, serial, oid, unlock=unlock)
# store object
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid]
transaction.addObject(oid, compression, checksum, data, value_serial)
if data is None:
checksum = None
else:
self._app.dm.storeData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, checksum, value_serial)
def abort(self, ttid, even_if_locked=False):
"""
......@@ -320,8 +324,13 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid]
has_load_lock = transaction.isLocked()
# if the transaction is locked, ensure we can drop it
if not even_if_locked and has_load_lock:
return
if has_load_lock:
if not even_if_locked:
return
else:
self._app.dm.unlockData([checksum
for oid, checksum, value_serial in transaction.getObjectList()
if checksum], True)
# unlock any object
for oid in transaction.getLockedOIDList():
if has_load_lock:
......@@ -370,19 +379,13 @@ class TransactionManager(object):
for oid, ttid in self._store_lock_dict.items():
neo.lib.logging.info(' %r by %r', dump(oid), dump(ttid))
def updateObjectDataForPack(self, oid, orig_serial, new_serial,
getObjectData):
def updateObjectDataForPack(self, oid, orig_serial, new_serial, checksum):
lock_tid = self.getLockingTID(oid)
if lock_tid is not None:
transaction = self._transaction_dict[lock_tid]
oid, compression, checksum, data, value_serial = \
transaction.getObject(oid)
if value_serial == orig_serial:
if transaction.getObject(oid)[2] == orig_serial:
if new_serial:
value_serial = new_serial
checksum = None
else:
compression, checksum, data = getObjectData()
value_serial = None
transaction.addObject(oid, compression, checksum, data,
value_serial)
self._app.dm.storeData(checksum)
transaction.addObject(oid, checksum, new_serial)
......@@ -88,10 +88,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.handler.answerObject(conn, *the_object)
self._checkHandlerData(the_object[:-1])
# Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.assertRaises(NEOStorageError, self.handler.answerObject, conn,
*the_object)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict):
......
......@@ -23,9 +23,8 @@ from neo.tests import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_PARTITION
from neo.lib.protocol import INVALID_TID, INVALID_OID
from neo.lib.protocol import Packets, LockState
from neo.lib.protocol import INVALID_PARTITION, INVALID_TID, INVALID_OID
from neo.lib.protocol import Packets, LockState, ZERO_HASH
class StorageClientHandlerTests(NeoUnitTestBase):
......@@ -124,7 +123,8 @@ class StorageClientHandlerTests(NeoUnitTestBase):
next_serial = self.getNextTID()
oid = self.getOID(1)
tid = self.getNextTID()
self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)})
H = "0" * 20
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
......@@ -239,7 +239,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, 0,
self.operation.askStoreObject(conn, oid, serial, comp, ZERO_HASH,
'', data_tid, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp,
None, None, data_tid, False)
......
......@@ -128,8 +128,11 @@ class ReplicationTests(NeoUnitTestBase):
transaction = ([ZERO_OID], 'user', 'desc', '', False)
storage.storeTransaction(makeid(tid), [], transaction, False)
# store object history
H = "0" * 20
storage.storeData(H, '', 0)
storage.unlockData((H,))
for tid, oid_list in objects.iteritems():
object_list = [(makeid(oid), False, 0, '', None) for oid in oid_list]
object_list = [(makeid(oid), H, None) for oid in oid_list]
storage.storeTransaction(makeid(tid), object_list, None, False)
return storage
......
......@@ -268,15 +268,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
serial_start = self.getNextTID()
serial_end = self.getNextTID()
compression = 1
checksum = 2
checksum = "0" * 20
data = 'foo'
data_serial = None
ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, compression, checksum, data,
data_serial)], None, False)
calls[0].checkArgs(serial_start, [(oid, checksum, data_serial)],
None, False)
# CheckTIDRange
def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self):
......
......@@ -121,7 +121,10 @@ class StorageDBTests(NeoUnitTestBase):
def getTransaction(self, oid_list):
transaction = (oid_list, 'user', 'desc', 'ext', False)
object_list = [(oid, 1, 0, '', None) for oid in oid_list]
H = "0" * 20
for _ in oid_list:
self.db.storeData(H, '', 1)
object_list = [(oid, H, None) for oid in oid_list]
return (transaction, object_list)
def checkSet(self, list1, list2):
......@@ -180,9 +183,9 @@ class StorageDBTests(NeoUnitTestBase):
oid1, = self.getOIDs(1)
tid1, tid2 = self.getTIDs(2)
FOUND_BUT_NOT_VISIBLE = False
OBJECT_T1_NO_NEXT = (tid1, None, 1, 0, '', None)
OBJECT_T1_NEXT = (tid1, tid2, 1, 0, '', None)
OBJECT_T2 = (tid2, None, 1, 0, '', None)
OBJECT_T1_NO_NEXT = (tid1, None, 1, "0"*20, '', None)
OBJECT_T1_NEXT = (tid1, tid2, 1, "0"*20, '', None)
OBJECT_T2 = (tid2, None, 1, "0"*20, '', None)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
# non-present
......@@ -277,14 +280,14 @@ class StorageDBTests(NeoUnitTestBase):
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, 0, '', None))
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it
self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, 0, '', None))
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self):
......@@ -393,8 +396,8 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1))
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
objs2[1][1:])
self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None))
def test_deleteObjectsAbove(self):
self.setNumPartitions(2)
......@@ -574,138 +577,6 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0)
self.checkSet(result, [tid1])
def test__getObjectData(self):
self.setNumPartitions(4, True)
db = self.db
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
assert tid0 < tid1 < tid2 < tid3
oid1 = self.getOID(1)
oid2 = self.getOID(2)
oid3 = self.getOID(3)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
(oid2, None, None, None, tid0),
(oid3, None, None, None, tid2),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
(oid2, None, None, None, tid1),
(oid3, 0, 0, 'bar', None),
), None, temporary=False)
original_getObjectData = db._getObjectData
def _getObjectData(*args, **kw):
call_counter.append(1)
return original_getObjectData(*args, **kw)
db._getObjectData = _getObjectData
# NOTE: all tests are done as if values were fetched by _getObject, so
# there is already one indirection level.
# oid1 at tid1: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid1), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 1)
# oid2 at tid1: missing data in table, raise IndexError on next
# recursive call
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# oid3 at tid1: data_serial grater than row's tid, raise ValueError
# on next recursive call - even if data does exist at that tid (see
# "oid3 at tid2" case below)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# Same with wrong parameters (tid0 < tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid0))
self.assertEqual(sum(call_counter), 1)
# Same with wrong parameters (tid1 == tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid1))
self.assertEqual(sum(call_counter), 1)
# oid1 at tid2: data is found after ons recursive call
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid2), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 2)
# oid2 at tid2: missing data in table, raise IndexError after two
# recursive calls
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid2),
u64(tid3))
self.assertEqual(sum(call_counter), 3)
# oid3 at tid2: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid3), u64(tid2), u64(tid3)),
(u64(tid2), 0, 0, 'bar'))
self.assertEqual(sum(call_counter), 1)
def test__getDataTIDFromData(self):
self.setNumPartitions(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
), None, temporary=False)
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid1))),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid2))),
(u64(tid2), u64(tid1)))
def test__getDataTID(self):
self.setNumPartitions(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
), None, temporary=False)
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid1)),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid2)),
(u64(tid2), u64(tid1)))
def test_findUndoTID(self):
self.setNumPartitions(4, True)
db = self.db
......@@ -715,9 +586,14 @@ class StorageDBTests(NeoUnitTestBase):
tid4 = self.getNextTID()
tid5 = self.getNextTID()
oid1 = self.getOID(1)
foo = "3" * 20
bar = "4" * 20
db.storeData(foo, 'foo', 0)
db.storeData(bar, 'bar', 0)
db.unlockData((foo, bar))
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
(oid1, foo, None),
), None, temporary=False)
# Undoing oid1 tid1, OK: tid1 is latest
......@@ -730,7 +606,7 @@ class StorageDBTests(NeoUnitTestBase):
# Store a new transaction
db.storeTransaction(
tid2, (
(oid1, 0, 0, 'bar', None),
(oid1, bar, None),
), None, temporary=False)
# Undoing oid1 tid2, OK: tid2 is latest
......@@ -753,13 +629,13 @@ class StorageDBTests(NeoUnitTestBase):
# to tid1
self.assertEqual(
db.findUndoTID(oid1, tid5, tid4, tid1,
(u64(oid1), None, None, None, tid1)),
(u64(oid1), None, tid1)),
(tid1, None, True))
# Store a new transaction
db.storeTransaction(
tid3, (
(oid1, None, None, None, tid1),
(oid1, None, tid1),
), None, temporary=False)
# Undoing oid1 tid1, OK: tid3 is latest with tid1 data
......
......@@ -97,7 +97,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1)
calls[0].checkArgs(oid, serial, tid, resolve_data=False)
calls[0].checkArgs(oid, serial, tid)
self.checkErrorPacket(conn)
def test_24_askObject3(self):
......@@ -105,8 +105,9 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
tid = self.getNextTID()
serial = self.getNextTID()
next_serial = self.getNextTID()
H = "0" * 20
# object found => answer
self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)})
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self.getFakeConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
......@@ -149,7 +150,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
def test_askCheckTIDRange(self):
count = 1
tid_checksum = self.getNewUUID()
tid_checksum = "1" * 20
min_tid = self.getNextTID()
num_partitions = 4
length = 5
......@@ -173,12 +174,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
def test_askCheckSerialRange(self):
count = 1
oid_checksum = self.getNewUUID()
oid_checksum = "2" * 20
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = self.getNewUUID()
serial_checksum = "3" * 20
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
......
......@@ -125,23 +125,6 @@ class StorageMySQSLdbTests(StorageDBTests):
self.assertEqual(self.db.escape('a"b'), 'a\\"b')
self.assertEqual(self.db.escape("a'b"), "a\\'b")
def test_setup(self):
# XXX: this test verifies irrelevant symptoms. It should instead check that
# - setup, store, setup, load -> data still there
# - setup, store, setup(reset=True), load -> data not found
# Then, it should be moved to generic test class.
# create all tables
self.db.conn = Mock()
self.db.setup()
calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 7)
# create all tables but drop them first
self.db.conn = Mock()
self.db.setup(reset=True)
calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 8)
del StorageDBTests
if __name__ == "__main__":
......
......@@ -63,8 +63,8 @@ class TransactionTests(NeoUnitTestBase):
def testObjects(self):
txn = Transaction(self.getNewUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2)
object1 = (oid1, 1, '1', 'O1', None)
object2 = (oid2, 1, '2', 'O2', None)
object1 = oid1, "0" * 20, None
object2 = oid2, "1" * 20, None
self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), [])
txn.addObject(*object1)
......@@ -78,9 +78,9 @@ class TransactionTests(NeoUnitTestBase):
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
txn = Transaction(self.getNewUUID(), self.getNextTID())
object_info = (oid_1, None, None, None, None)
object_info = oid_1, None, None
txn.addObject(*object_info)
self.assertEqual(txn.getObject(oid_2), None)
self.assertRaises(KeyError, txn.getObject, oid_2)
self.assertEqual(txn.getObject(oid_1), object_info)
class TransactionManagerTests(NeoUnitTestBase):
......@@ -102,12 +102,12 @@ class TransactionManagerTests(NeoUnitTestBase):
def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]):
self.manager.storeObject(tid, None,
oid, 1, str(i), '0' + str(i), None)
oid, 1, '%020d' % i, '0' + str(i), None)
def _getObject(self, value):
oid = self.getOID(value)
serial = self.getNextTID()
return (serial, (oid, 1, str(value), 'O' + str(value), None))
return (serial, (oid, 1, '%020d' % value, 'O' + str(value), None))
def _checkTransactionStored(self, *args):
calls = self.app.dm.mockGetNamedCalls('storeTransaction')
......@@ -136,7 +136,10 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid, txn[0])
self._checkTransactionStored(tid, [object1, object2], txn)
self._checkTransactionStored(tid, [
(object1[0], object1[2], object1[4]),
(object2[0], object2[2], object2[4]),
], txn)
self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager)
self._checkTransactionFinished(tid)
......@@ -340,7 +343,7 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]),
obj1)
(obj1[0], obj1[2], obj1[4]))
def test_getLockingTID(self):
uuid = self.getNewUUID()
......@@ -360,26 +363,24 @@ class TransactionManagerTests(NeoUnitTestBase):
locking_serial = self.getNextTID()
other_serial = self.getNextTID()
new_serial = self.getNextTID()
compression = 1
checksum = 42
value = 'foo'
checksum = "2" * 20
self.manager.register(uuid, locking_serial)
def getObjectData():
return (compression, checksum, value)
# Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known, but doesn't point at orig_serial, it is not updated
self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, 512,
self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None)
storeData = self.app.dm.mockGetNamedCalls('storeData')
self.assertEqual(storeData.pop(0).params, ("3" * 20, 'bar', 0))
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
......@@ -389,29 +390,29 @@ class TransactionManagerTests(NeoUnitTestBase):
None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known and points at undone data it gets updated
# ...with data_serial: getObjectData must not be called
self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
None)
checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, None, None, None, new_serial))
oid), (oid, None, new_serial))
self.manager.abort(locking_serial, even_if_locked=True)
# with data
self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None,
getObjectData)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(storeData.pop(0).params, (checksum,))
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, compression, checksum, value, None))
oid), (oid, checksum, None))
self.manager.abort(locking_serial, even_if_locked=True)
self.assertFalse(storeData)
if __name__ == "__main__":
unittest.main()
......@@ -387,7 +387,8 @@ class ProtocolTests(NeoUnitTestBase):
tid = self.getNextTID()
tid2 = self.getNextTID()
unlock = False
p = Packets.AskStoreObject(oid, serial, 1, 55, "to", tid2, tid, unlock)
H = "1" * 20
p = Packets.AskStoreObject(oid, serial, 1, H, "to", tid2, tid, unlock)
poid, pserial, compression, checksum, data, ptid2, ptid, punlock = \
p.decode()
self.assertEqual(oid, poid)
......@@ -395,7 +396,7 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(tid, ptid)
self.assertEqual(tid2, ptid2)
self.assertEqual(compression, 1)
self.assertEqual(checksum, 55)
self.assertEqual(checksum, H)
self.assertEqual(data, "to")
self.assertEqual(unlock, punlock)
......@@ -423,7 +424,8 @@ class ProtocolTests(NeoUnitTestBase):
serial_start = self.getNextTID()
serial_end = self.getNextTID()
data_serial = self.getNextTID()
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to",
H = "2" * 20
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, H, "to",
data_serial)
poid, pserial_start, pserial_end, compression, checksum, data, \
pdata_serial = p.decode()
......@@ -431,7 +433,7 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(serial_start, pserial_start)
self.assertEqual(serial_end, pserial_end)
self.assertEqual(compression, 1)
self.assertEqual(checksum, 55)
self.assertEqual(checksum, H)
self.assertEqual(data, "to")
self.assertEqual(pdata_serial, data_serial)
......@@ -686,7 +688,7 @@ class ProtocolTests(NeoUnitTestBase):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = self.getNewUUID()
tid_checksum = "3" * 20
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
......@@ -717,9 +719,9 @@ class ProtocolTests(NeoUnitTestBase):
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = self.getNewUUID()
oid_checksum = "4" * 20
max_oid = self.getOID(5)
tid_checksum = self.getNewUUID()
tid_checksum = "5" * 20
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
......
......@@ -259,6 +259,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
if adapter == 'BTree':
dm._obj, dm._tobj = dm._tobj, dm._obj
dm._trans, dm._ttrans = dm._ttrans, dm._trans
uncommitted_data = dm._uncommitted_data
for checksum, (_, _, index) in dm._data.iteritems():
uncommitted_data[checksum] = len(index)
index.clear()
elif adapter == 'MySQL':
q = dm.query
dm.begin()
......@@ -266,11 +270,22 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
q('RENAME TABLE %s to tmp' % table)
q('RENAME TABLE t%s to %s' % (table, table))
q('RENAME TABLE tmp to t%s' % table)
q('TRUNCATE obj_short')
dm.commit()
else:
assert False
def getDataLockInfo(self):
adapter = self._init_args[1]['getAdapter']
dm = self.dm
if adapter == 'BTree':
checksum_list = dm._data
elif adapter == 'MySQL':
checksum_list = [x for x, in dm.query("SELECT hash FROM data")]
else:
assert False
assert set(dm._uncommitted_data).issubset(checksum_list)
return dict((x, dm._uncommitted_data.get(x, 0)) for x in checksum_list)
class ClientApplication(Node, neo.client.app.Application):
@SerializedEventManager.decorate
......
......@@ -26,6 +26,7 @@ from neo.lib.connection import MTClientConnection
from neo.lib.protocol import NodeStates, Packets, ZERO_TID
from neo.tests.threaded import NEOCluster, NEOThreadedTest, \
Patch, ConnectionFilter
from neo.lib.util import makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class PCounter(Persistent):
......@@ -43,13 +44,19 @@ class Test(NEOThreadedTest):
try:
cluster.start()
storage = cluster.getZODBStorage()
for data in 'foo', '':
data_info = {}
for data in 'foo', '', 'foo':
checksum = makeChecksum(data)
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn)
data_info[checksum] = 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
serial = storage.tpc_finish(txn)
data_info[checksum] = 0
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
self.assertEqual((data, serial), storage.load(oid, ''))
storage._cache.clear()
self.assertEqual((data, serial), storage.load(oid, ''))
......@@ -57,6 +64,51 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testStorageDataLock(self):
cluster = NEOCluster()
try:
cluster.start()
storage = cluster.getZODBStorage()
data_info = {}
data = 'foo'
checksum = makeChecksum(data)
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn)
tid = storage.tpc_finish(txn)
data_info[checksum] = 0
storage.sync()
txn = [transaction.Transaction() for x in xrange(3)]
for t in txn:
storage.tpc_begin(t)
storage.store(tid and oid or storage.new_oid(),
tid, data, '', t)
tid = None
for t in txn:
storage.tpc_vote(t)
data_info[checksum] = 3
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[1])
storage.sync()
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
tid1 = storage.tpc_finish(txn[2])
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[0])
storage.sync()
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally:
cluster.stop()
def testDelayedUnlockInformation(self):
except_list = []
def delayUnlockInformation(conn, packet):
......@@ -273,16 +325,21 @@ class Test(NEOThreadedTest):
t, c = cluster.getTransaction()
c.root()[0] = 'ok'
t.commit()
data_info = cluster.storage.getDataLockInfo()
self.assertEqual(data_info.values(), [0, 0])
# (obj|trans) become t(obj|trans)
cluster.storage.switchTables()
finally:
cluster.stop()
cluster.reset()
# XXX: (obj|trans) become t(obj|trans)
cluster.storage.switchTables()
self.assertEqual(dict.fromkeys(data_info, 1),
cluster.storage.getDataLockInfo())
try:
cluster.start(fast_startup=fast_startup)
t, c = cluster.getTransaction()
# transaction should be verified and commited
self.assertEqual(c.root()[0], 'ok')
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally:
cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment