Commit 52c5d862 authored by Julien Muchembled's avatar Julien Muchembled

storage: fix severe performance issue by committing backend only at key moments

parent 4741e38e
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_TID from neo.lib.protocol import ZERO_TID
class CreationUndone(Exception): class CreationUndone(Exception):
...@@ -28,7 +27,6 @@ class DatabaseManager(object): ...@@ -28,7 +27,6 @@ class DatabaseManager(object):
""" """
Initialize the object. Initialize the object.
""" """
self._under_transaction = False
self._wait = wait self._wait = wait
self._parse(database) self._parse(database)
...@@ -50,34 +48,9 @@ class DatabaseManager(object): ...@@ -50,34 +48,9 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def __enter__(self):
"""
Begin a transaction
"""
if self._under_transaction:
raise DatabaseFailure('A transaction has already begun')
r = self.begin()
self._under_transaction = True
return r
def __exit__(self, exc_type, exc_value, tb):
if not self._under_transaction:
raise DatabaseFailure('The transaction has not begun')
self._under_transaction = False
if exc_type is None:
self.commit()
else:
self.rollback()
def begin(self):
pass
def commit(self): def commit(self):
pass pass
def rollback(self):
pass
def _getPartition(self, oid_or_tid): def _getPartition(self, oid_or_tid):
return oid_or_tid % self.getNumPartitions() return oid_or_tid % self.getNumPartitions()
...@@ -91,11 +64,8 @@ class DatabaseManager(object): ...@@ -91,11 +64,8 @@ class DatabaseManager(object):
""" """
Set a configuration value Set a configuration value
""" """
if self._under_transaction: self._setConfiguration(key, value)
self._setConfiguration(key, value) self.commit()
else:
with self:
self._setConfiguration(key, value)
def _setConfiguration(self, key, value): def _setConfiguration(self, key, value):
raise NotImplementedError raise NotImplementedError
...@@ -344,8 +314,8 @@ class DatabaseManager(object): ...@@ -344,8 +314,8 @@ class DatabaseManager(object):
else: else:
del refcount[data_id] del refcount[data_id]
if prune: if prune:
with self: self._pruneData(data_id_list)
self._pruneData(data_id_list) self.commit()
__getDataTID = set() __getDataTID = set()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
...@@ -465,11 +435,11 @@ class DatabaseManager(object): ...@@ -465,11 +435,11 @@ class DatabaseManager(object):
def truncate(self, tid): def truncate(self, tid):
assert tid not in (None, ZERO_TID), tid assert tid not in (None, ZERO_TID), tid
with self: assert self.getBackupTID()
assert self.getBackupTID() self.setBackupTID(tid)
self.setBackupTID(tid) for partition in xrange(self.getNumPartitions()):
for partition in xrange(self.getNumPartitions()): self._deleteRange(partition, tid)
self._deleteRange(partition, tid) self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
......
...@@ -93,22 +93,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -93,22 +93,9 @@ class MySQLDatabaseManager(DatabaseManager):
self.conn.query("SET SESSION group_concat_max_len = -1") self.conn.query("SET SESSION group_concat_max_len = -1")
self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
def begin(self): def commit(self):
q = self.query logging.debug('committing...')
q("BEGIN") self.conn.commit()
return q
if LOG_QUERIES:
def commit(self):
logging.debug('committing...')
self.conn.commit()
def rollback(self):
logging.debug('aborting...')
self.conn.rollback()
else:
commit = property(lambda self: self.conn.commit)
rollback = property(lambda self: self.conn.rollback)
def query(self, query): def query(self, query):
"""Query data from a database.""" """Query data from a database."""
...@@ -271,44 +258,38 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -271,44 +258,38 @@ class MySQLDatabaseManager(DatabaseManager):
def _getLastIDs(self, all=True): def _getLastIDs(self, all=True):
p64 = util.p64 p64 = util.p64
with self as q: q = self.query
trans = dict((partition, p64(tid)) trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition")) " FROM trans GROUP BY partition"))
obj = dict((partition, p64(tid)) obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM obj GROUP BY partition")) " FROM obj GROUP BY partition"))
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t")[0][0] " GROUP BY partition) as t")[0][0]
if all: if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0] tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None: if tid is not None:
trans[None] = p64(tid) trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0] tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0]
if tid is not None: if tid is not None:
obj[None] = p64(tid) obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None): if toid is not None and (oid < toid or oid is None):
oid = toid oid = toid
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
tid_set = set() p64 = util.p64
with self as q: return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
r = q("""SELECT tid FROM ttrans""") " UNION SELECT tid FROM tobj")]
tid_set.update((util.p64(t[0]) for t in r))
r = q("""SELECT tid FROM tobj""")
tid_set.update((util.p64(t[0]) for t in r))
return list(tid_set)
def objectPresent(self, oid, tid, all = True): def objectPresent(self, oid, tid, all = True):
oid = util.u64(oid) oid = util.u64(oid)
tid = util.u64(tid) tid = util.u64(tid)
partition = self._getPartition(oid) q = self.query
with self as q: return q("SELECT 1 FROM obj WHERE partition=%d AND oid=%d AND tid=%d"
return q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND " % (self._getPartition(oid), oid, tid)) or all and \
"tid=%d" % (partition, oid, tid)) or all and \ q("SELECT 1 FROM tobj WHERE tid=%d AND oid=%d" % (tid, oid))
q("SELECT oid FROM tobj WHERE tid=%d AND oid=%d"
% (tid, oid))
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
...@@ -339,21 +320,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -339,21 +320,21 @@ class MySQLDatabaseManager(DatabaseManager):
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
offset_list = [] offset_list = []
with self as q: q = self.query
if reset: if reset:
q("""TRUNCATE pt""") q("TRUNCATE pt")
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
# TODO: this logic should move out of database manager # TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query # add 'dropCells(cell_list)' to API and use one query
if state == CellStates.DISCARDED: if state == CellStates.DISCARDED:
q("""DELETE FROM pt WHERE rid = %d AND uuid = %d""" q("DELETE FROM pt WHERE rid = %d AND uuid = %d"
% (offset, uuid)) % (offset, uuid))
else: else:
offset_list.append(offset) offset_list.append(offset)
q("""INSERT INTO pt VALUES (%d, %d, %d) q("INSERT INTO pt VALUES (%d, %d, %d)"
ON DUPLICATE KEY UPDATE state = %d""" \ " ON DUPLICATE KEY UPDATE state = %d"
% (offset, uuid, state, state)) % (offset, uuid, state, state))
self.setPTID(ptid) self.setPTID(ptid)
if self._use_partition: if self._use_partition:
for offset in offset_list: for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION ( add = """ALTER TABLE %%s ADD PARTITION (
...@@ -372,18 +353,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -372,18 +353,18 @@ class MySQLDatabaseManager(DatabaseManager):
self.doSetPartitionTable(ptid, cell_list, True) self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
with self as q: q = self.query
# XXX: these queries are inefficient (execution time increase with # XXX: these queries are inefficient (execution time increase with
# row count, although we use indexes) when there are rows to # row count, although we use indexes) when there are rows to
# delete. It should be done as an idle task, by chunks. # delete. It should be done as an idle task, by chunks.
for partition in offset_list: for partition in offset_list:
where = " WHERE partition=%d" % partition where = " WHERE partition=%d" % partition
data_id_list = [x for x, in data_id_list = [x for x, in
q("SELECT DISTINCT data_id FROM obj" + where) if x] q("SELECT DISTINCT data_id FROM obj" + where) if x]
if not self._use_partition: if not self._use_partition:
q("DELETE FROM obj" + where) q("DELETE FROM obj" + where)
q("DELETE FROM trans" + where) q("DELETE FROM trans" + where)
self._pruneData(data_id_list) self._pruneData(data_id_list)
if self._use_partition: if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \ drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list) ','.join(' p%u' % i for i in offset_list)
...@@ -395,47 +376,46 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -395,47 +376,46 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
def dropUnfinishedData(self): def dropUnfinishedData(self):
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("""TRUNCATE tobj""") q("TRUNCATE tobj")
q("""TRUNCATE ttrans""") q("TRUNCATE ttrans")
self.unlockData(data_id_list, True) self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True): def storeTransaction(self, tid, object_list, transaction, temporary = True):
e = self.escape e = self.escape
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
if temporary: if temporary:
obj_table = 'tobj' obj_table = 'tobj'
trans_table = 'ttrans' trans_table = 'ttrans'
else: else:
obj_table = 'obj' obj_table = 'obj'
trans_table = 'trans' trans_table = 'trans'
q = self.query
with self as q: for oid, data_id, value_serial in object_list:
for oid, data_id, value_serial in object_list: oid = u64(oid)
oid = u64(oid) partition = self._getPartition(oid)
partition = self._getPartition(oid) if value_serial:
if value_serial: value_serial = u64(value_serial)
value_serial = u64(value_serial) (data_id,), = q("SELECT data_id FROM obj"
(data_id,), = q("SELECT data_id FROM obj" " WHERE partition=%d AND oid=%d AND tid=%d"
" WHERE partition=%d AND oid=%d AND tid=%d" % (partition, oid, value_serial))
% (partition, oid, value_serial)) if temporary:
if temporary: self.storeData(data_id)
self.storeData(data_id) else:
else: value_serial = 'NULL'
value_serial = 'NULL' q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" % (obj_table,
q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" % (obj_table, partition, oid, tid, data_id or 'NULL', value_serial))
partition, oid, tid, data_id or 'NULL', value_serial)) if transaction:
oid_list, user, desc, ext, packed, ttid = transaction
if transaction: partition = self._getPartition(tid)
oid_list, user, desc, ext, packed, ttid = transaction assert packed in (0, 1)
partition = self._getPartition(tid) q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % (
assert packed in (0, 1) trans_table, partition, tid, packed, e(''.join(oid_list)),
q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % ( e(user), e(desc), e(ext), u64(ttid)))
trans_table, partition, tid, packed, e(''.join(oid_list)), if temporary:
e(user), e(desc), e(ext), u64(ttid))) self.commit()
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
...@@ -448,20 +428,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -448,20 +428,17 @@ class MySQLDatabaseManager(DatabaseManager):
def _storeData(self, checksum, data, compression): def _storeData(self, checksum, data, compression):
e = self.escape e = self.escape
checksum = e(checksum) checksum = e(checksum)
with self as q: try:
try: self.query("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" %
q("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" % (checksum, compression, e(data)))
(checksum, compression, e(data))) except IntegrityError, (code, _):
except IntegrityError, (code, _): if code == DUP_ENTRY:
if code != DUP_ENTRY: (r, c, d), = self.query("SELECT id, compression, value"
raise " FROM data WHERE hash='%s'" % checksum)
(r, c, d), = q("SELECT id, compression, value" if c == compression and d == data:
" FROM data WHERE hash='%s'" % checksum) return r
if c != compression or d != data: raise
raise return self.conn.insert_id()
else:
r = self.conn.insert_id()
return r
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
sql = ('SELECT tid, data_id, value_tid FROM obj' sql = ('SELECT tid, data_id, value_tid FROM obj'
...@@ -486,37 +463,37 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -486,37 +463,37 @@ class MySQLDatabaseManager(DatabaseManager):
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
with self as q: sql = " FROM tobj WHERE tid=%d" % tid
sql = " FROM tobj WHERE tid=%d" % tid data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] q("INSERT INTO obj SELECT *" + sql)
q("INSERT INTO obj SELECT *" + sql) q("DELETE FROM tobj WHERE tid=%d" % tid)
q("DELETE FROM tobj WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.unlockData(data_id_list) self.unlockData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
sql = " FROM tobj WHERE tid=%d" % tid sql = " FROM tobj WHERE tid=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.unlockData(data_id_list) self.unlockData(data_id_list)
q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" %
(getPartition(tid), tid))
# delete from obj using indexes
data_id_set = set()
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=%d AND oid=%d AND tid=%d" \
% (getPartition(oid), oid, tid)
data_id_set.update(*q("SELECT data_id" + sql))
q("DELETE" + sql) q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid) data_id_set.discard(None)
q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" % self._pruneData(data_id_set)
(getPartition(tid), tid))
# delete from obj using indexes
data_id_set = set()
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=%d AND oid=%d AND tid=%d" \
% (getPartition(oid), oid, tid)
data_id_set.update(*q("SELECT data_id" + sql))
q("DELETE" + sql)
data_id_set.discard(None)
self._pruneData(data_id_set)
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
...@@ -525,10 +502,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -525,10 +502,10 @@ class MySQLDatabaseManager(DatabaseManager):
% (self._getPartition(oid), oid) % (self._getPartition(oid), oid)
if serial: if serial:
sql += ' AND tid=%d' % u64(serial) sql += ' AND tid=%d' % u64(serial)
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql) q("DELETE" + sql)
self._pruneData(data_id_list) self._pruneData(data_id_list)
def _deleteRange(self, partition, min_tid=None, max_tid=None): def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=%d" % partition sql = " WHERE partition=%d" % partition
...@@ -545,13 +522,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -545,13 +522,13 @@ class MySQLDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
tid = util.u64(tid) tid = util.u64(tid)
with self as q: q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition = %d AND tid = %d"
% (self._getPartition(tid), tid))
if not r and all:
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition = %d AND tid = %d" " FROM ttrans WHERE tid = %d" % tid)
% (self._getPartition(tid), tid))
if not r and all:
r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM ttrans WHERE tid = %d" % tid)
if r: if r:
oids, user, desc, ext, packed, ttid = r[0] oids, user, desc, ext, packed, ttid = r[0]
oid_list = splitOIDField(tid, oids) oid_list = splitOIDField(tid, oids)
...@@ -665,32 +642,33 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -665,32 +642,33 @@ class MySQLDatabaseManager(DatabaseManager):
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, ' for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
'MAX(tid) FROM obj WHERE tid <= %d GROUP BY oid' " FROM obj WHERE tid <= %d GROUP BY oid"
% tid): % tid):
partition = getPartition(oid) partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition = %d" if q("SELECT 1 FROM obj WHERE partition = %d"
" AND oid = %d AND tid = %d AND data_id IS NULL" " AND oid = %d AND tid = %d AND data_id IS NULL"
% (partition, oid, max_serial)): % (partition, oid, max_serial)):
max_serial += 1 max_serial += 1
elif not count: elif not count:
continue continue
# There are things to delete for this object # There are things to delete for this object
data_id_set = set() data_id_set = set()
sql = ' FROM obj WHERE partition=%d AND oid=%d' \ sql = ' FROM obj WHERE partition=%d AND oid=%d' \
' AND tid<%d' % (partition, oid, max_serial) ' AND tid<%d' % (partition, oid, max_serial)
for serial, data_id in q('SELECT tid, data_id' + sql): for serial, data_id in q('SELECT tid, data_id' + sql):
data_id_set.add(data_id) data_id_set.add(data_id)
new_serial = updatePackFuture(oid, serial, max_serial) new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial: if new_serial:
new_serial = p64(new_serial) new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial), updateObjectDataForPack(p64(oid), p64(serial),
new_serial, data_id) new_serial, data_id)
q('DELETE' + sql) q('DELETE' + sql)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
self.commit()
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
......
...@@ -76,23 +76,13 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -76,23 +76,13 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, isolation_level=None, self.conn = sqlite3.connect(self.db, check_same_thread=False)
check_same_thread=False)
def begin(self): def commit(self):
q = self.query logging.debug('committing...')
retry_if_locked(q, "BEGIN IMMEDIATE") retry_if_locked(self.conn.commit)
return q
if LOG_QUERIES: if LOG_QUERIES:
def commit(self):
logging.debug('committing...')
retry_if_locked(self.conn.commit)
def rollback(self):
logging.debug('aborting...')
self.conn.rollback()
def query(self, query): def query(self, query):
printable_char_list = [] printable_char_list = []
for c in query.split('\n', 1)[0][:70]: for c in query.split('\n', 1)[0][:70]:
...@@ -102,10 +92,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -102,10 +92,7 @@ class SQLiteDatabaseManager(DatabaseManager):
logging.debug('querying %s...', ''.join(printable_char_list)) logging.debug('querying %s...', ''.join(printable_char_list))
return self.conn.execute(query) return self.conn.execute(query)
else: else:
rollback = property(lambda self: self.conn.rollback)
query = property(lambda self: self.conn.execute) query = property(lambda self: self.conn.execute)
def commit(self):
retry_if_locked(self.conn.commit)
def setup(self, reset = 0): def setup(self, reset = 0):
self._config.clear() self._config.clear()
...@@ -226,44 +213,39 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -226,44 +213,39 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getLastIDs(self, all=True): def _getLastIDs(self, all=True):
p64 = util.p64 p64 = util.p64
with self as q: q = self.query
trans = dict((partition, p64(tid)) trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition")) " FROM trans GROUP BY partition"))
obj = dict((partition, p64(tid)) obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM obj GROUP BY partition")) " FROM obj GROUP BY partition"))
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t").next()[0] " GROUP BY partition) as t").next()[0]
if all: if all:
tid = q("SELECT MAX(tid) FROM ttrans").next()[0] tid = q("SELECT MAX(tid) FROM ttrans").next()[0]
if tid is not None: if tid is not None:
trans[None] = p64(tid) trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next() tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next()
if tid is not None: if tid is not None:
obj[None] = p64(tid) obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None): if toid is not None and (oid < toid or oid is None):
oid = toid oid = toid
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
p64 = util.p64 p64 = util.p64
tid_set = set() return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
with self as q: " UNION SELECT tid FROM tobj")]
tid_set.update((p64(t[0]) for t in q("SELECT tid FROM ttrans")))
tid_set.update((p64(t[0]) for t in q("SELECT tid FROM tobj")))
return list(tid_set)
def objectPresent(self, oid, tid, all=True): def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid) oid = util.u64(oid)
tid = util.u64(tid) tid = util.u64(tid)
with self as q: q = self.query
r = q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?", return q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?",
(self._getPartition(oid), oid, tid)).fetchone() (self._getPartition(oid), oid, tid)).fetchone() or all and \
if not r and all: q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
r = q("SELECT 1 FROM tobj WHERE tid=? AND oid=?", (tid, oid)).fetchone()
(tid, oid)).fetchone()
return bool(r)
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
...@@ -292,21 +274,21 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -292,21 +274,21 @@ class SQLiteDatabaseManager(DatabaseManager):
return serial, r and r[0], compression, checksum, data, value_serial return serial, r and r[0], compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
with self as q: q = self.query
if reset: if reset:
q("DELETE FROM pt") q("DELETE FROM pt")
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
# TODO: this logic should move out of database manager # TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query # add 'dropCells(cell_list)' to API and use one query
# WKRD: Why does SQLite need a statement journal file # WKRD: Why does SQLite need a statement journal file
# whereas we try to replace only 1 value ? # whereas we try to replace only 1 value ?
# We don't want to remove the 'NOT NULL' constraint # We don't want to remove the 'NOT NULL' constraint
# so we must simulate a "REPLACE OR FAIL". # so we must simulate a "REPLACE OR FAIL".
q("DELETE FROM pt WHERE rid=? AND uuid=?", (offset, uuid)) q("DELETE FROM pt WHERE rid=? AND uuid=?", (offset, uuid))
if state != CellStates.DISCARDED: if state != CellStates.DISCARDED:
q("INSERT OR FAIL INTO pt VALUES (?,?,?)", q("INSERT OR FAIL INTO pt VALUES (?,?,?)",
(offset, uuid, int(state))) (offset, uuid, int(state)))
self.setPTID(ptid) self.setPTID(ptid)
def changePartitionTable(self, ptid, cell_list): def changePartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, False) self.doSetPartitionTable(ptid, cell_list, False)
...@@ -316,20 +298,20 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -316,20 +298,20 @@ class SQLiteDatabaseManager(DatabaseManager):
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
where = " WHERE partition=?" where = " WHERE partition=?"
with self as q: q = self.query
for partition in offset_list: for partition in offset_list:
data_id_list = [x for x, in args = partition,
q("SELECT DISTINCT data_id FROM obj" + where, data_id_list = [x for x, in
(partition,)) if x] q("SELECT DISTINCT data_id FROM obj" + where, args) if x]
q("DELETE FROM obj" + where, (partition,)) q("DELETE FROM obj" + where, args)
q("DELETE FROM trans" + where, (partition,)) q("DELETE FROM trans" + where, args)
self._pruneData(data_id_list) self._pruneData(data_id_list)
def dropUnfinishedData(self): def dropUnfinishedData(self):
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("DELETE FROM tobj") q("DELETE FROM tobj")
q("DELETE FROM ttrans") q("DELETE FROM ttrans")
self.unlockData(data_id_list, True) self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary=True): def storeTransaction(self, tid, object_list, transaction, temporary=True):
...@@ -337,37 +319,38 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -337,37 +319,38 @@ class SQLiteDatabaseManager(DatabaseManager):
tid = u64(tid) tid = u64(tid)
T = 't' if temporary else '' T = 't' if temporary else ''
obj_sql = "INSERT OR FAIL INTO %sobj VALUES (?,?,?,?,?)" % T obj_sql = "INSERT OR FAIL INTO %sobj VALUES (?,?,?,?,?)" % T
with self as q: q = self.query
for oid, data_id, value_serial in object_list: for oid, data_id, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
if value_serial: if value_serial:
value_serial = u64(value_serial) value_serial = u64(value_serial)
(data_id,), = q("SELECT data_id FROM obj" (data_id,), = q("SELECT data_id FROM obj"
" WHERE partition=? AND oid=? AND tid=?", " WHERE partition=? AND oid=? AND tid=?",
(partition, oid, value_serial)) (partition, oid, value_serial))
if temporary: if temporary:
self.storeData(data_id) self.storeData(data_id)
try: try:
q(obj_sql, (partition, oid, tid, data_id, value_serial)) q(obj_sql, (partition, oid, tid, data_id, value_serial))
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
# This may happen if a previous replication of 'obj' was # This may happen if a previous replication of 'obj' was
# interrupted. # interrupted.
if not T: if not T:
r, = q("SELECT data_id, value_tid FROM obj" r, = q("SELECT data_id, value_tid FROM obj"
" WHERE partition=? AND oid=? AND tid=?", " WHERE partition=? AND oid=? AND tid=?",
(partition, oid, tid)) (partition, oid, tid))
if r == (data_id, value_serial): if r == (data_id, value_serial):
continue continue
raise raise
if transaction:
if transaction: oid_list, user, desc, ext, packed, ttid = transaction
oid_list, user, desc, ext, packed, ttid = transaction partition = self._getPartition(tid)
partition = self._getPartition(tid) assert packed in (0, 1)
assert packed in (0, 1) q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T,
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T, (partition, tid, packed, buffer(''.join(oid_list)),
(partition, tid, packed, buffer(''.join(oid_list)), buffer(user), buffer(desc), buffer(ext), u64(ttid)))
buffer(user), buffer(desc), buffer(ext), u64(ttid))) if temporary:
self.commit()
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
...@@ -381,17 +364,16 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -381,17 +364,16 @@ class SQLiteDatabaseManager(DatabaseManager):
def _storeData(self, checksum, data, compression): def _storeData(self, checksum, data, compression):
H = buffer(checksum) H = buffer(checksum)
with self as q: try:
try: return self.query("INSERT INTO data VALUES (NULL,?,?,?)",
return q("INSERT INTO data VALUES (NULL,?,?,?)", (H, compression, buffer(data))).lastrowid
(H, compression, buffer(data))).lastrowid except sqlite3.IntegrityError, e:
except sqlite3.IntegrityError, e: if e.args[0] == 'column hash is not unique':
if e.args[0] == 'column hash is not unique': (r, c, d), = self.query("SELECT id, compression, value"
(r, c, d), = q("SELECT id, compression, value" " FROM data WHERE hash=?", (H,))
" FROM data WHERE hash=?", (H,)) if c == compression and str(d) == data:
if c == compression and str(d) == data: return r
return r raise
raise
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getPartition(oid) partition = self._getPartition(oid)
...@@ -415,38 +397,38 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -415,38 +397,38 @@ class SQLiteDatabaseManager(DatabaseManager):
def finishTransaction(self, tid): def finishTransaction(self, tid):
args = util.u64(tid), args = util.u64(tid),
with self as q: q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args) q("INSERT OR FAIL INTO obj SELECT *" + sql, args)
q("DELETE FROM tobj WHERE tid=?", args) q("DELETE FROM tobj WHERE tid=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args)
args) q("DELETE FROM ttrans WHERE tid=?", args)
q("DELETE FROM ttrans WHERE tid=?", args)
self.unlockData(data_id_list) self.unlockData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.unlockData(data_id_list) self.unlockData(data_id_list)
q("DELETE" + sql, (tid,)) q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,)) q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?", q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid)) (getPartition(tid), tid))
# delete from obj using indexes # delete from obj using indexes
data_id_set = set() data_id_set = set()
for oid in oid_list: for oid in oid_list:
oid = u64(oid) oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND tid=?" sql = " FROM obj WHERE partition=? AND oid=? AND tid=?"
args = getPartition(oid), oid, tid args = getPartition(oid), oid, tid
data_id_set.update(*q("SELECT data_id" + sql, args)) data_id_set.update(*q("SELECT data_id" + sql, args))
q("DELETE" + sql, args) q("DELETE" + sql, args)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
oid = util.u64(oid) oid = util.u64(oid)
...@@ -455,11 +437,11 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -455,11 +437,11 @@ class SQLiteDatabaseManager(DatabaseManager):
if serial: if serial:
sql += " AND tid=?" sql += " AND tid=?"
args.append(util.u64(serial)) args.append(util.u64(serial))
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args) data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args)
if x] if x]
q("DELETE" + sql, args) q("DELETE" + sql, args)
self._pruneData(data_id_list) self._pruneData(data_id_list)
def _deleteRange(self, partition, min_tid=None, max_tid=None): def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=?" sql = " WHERE partition=?"
...@@ -480,13 +462,13 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -480,13 +462,13 @@ class SQLiteDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all=False): def getTransaction(self, tid, all=False):
tid = util.u64(tid) tid = util.u64(tid)
with self as q: q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)).fetchone()
if not r and all:
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition=? AND tid=?", " FROM ttrans WHERE tid=?", (tid,)).fetchone()
(self._getPartition(tid), tid)).fetchone()
if not r and all:
r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM ttrans WHERE tid=?", (tid,)).fetchone()
if r: if r:
oids, user, description, ext, packed, ttid = r oids, user, description, ext, packed, ttid = r
return splitOIDField(tid, oids), str(user), \ return splitOIDField(tid, oids), str(user), \
...@@ -515,19 +497,18 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -515,19 +497,18 @@ class SQLiteDatabaseManager(DatabaseManager):
pack_tid = self._getPackTID() pack_tid = self._getPackTID()
result = [] result = []
append = result.append append = result.append
with self as q: for serial, length, value_serial in self.query("""\
for serial, length, value_serial in q("""\ SELECT tid, LENGTH(value), value_tid
SELECT tid, LENGTH(value), value_tid FROM obj LEFT JOIN data ON obj.data_id = data.id
FROM obj LEFT JOIN data ON obj.data_id = data.id WHERE partition=? AND oid=? AND tid>=?
WHERE partition=? AND oid=? AND tid>=? ORDER BY tid DESC LIMIT ?,?""",
ORDER BY tid DESC LIMIT ?,?""", (self._getPartition(oid), oid, pack_tid, offset, length)):
(self._getPartition(oid), oid, pack_tid, offset, length)): if length is None:
if length is None: try:
try: length = self._getObjectLength(oid, value_serial)
length = self._getObjectLength(oid, value_serial) except CreationUndone:
except CreationUndone: length = 0
length = 0 append((p64(serial), length))
append((p64(serial), length))
return result or None return result or None
def getReplicationObjectList(self, min_tid, max_tid, length, partition, def getReplicationObjectList(self, min_tid, max_tid, length, partition,
...@@ -587,32 +568,33 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -587,32 +568,33 @@ class SQLiteDatabaseManager(DatabaseManager):
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid," for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
" MAX(tid) FROM obj WHERE tid<=? GROUP BY oid", " FROM obj WHERE tid<=? GROUP BY oid",
(tid,)): (tid,)):
partition = getPartition(oid) partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition=?" if q("SELECT 1 FROM obj WHERE partition=?"
" AND oid=? AND tid=? AND data_id IS NULL", " AND oid=? AND tid=? AND data_id IS NULL",
(partition, oid, max_serial)).fetchone(): (partition, oid, max_serial)).fetchone():
max_serial += 1 max_serial += 1
elif not count: elif not count:
continue continue
# There are things to delete for this object # There are things to delete for this object
data_id_set = set() data_id_set = set()
sql = " FROM obj WHERE partition=? AND oid=? AND tid<?" sql = " FROM obj WHERE partition=? AND oid=? AND tid<?"
args = partition, oid, max_serial args = partition, oid, max_serial
for serial, data_id in q("SELECT tid, data_id" + sql, args): for serial, data_id in q("SELECT tid, data_id" + sql, args):
data_id_set.add(data_id) data_id_set.add(data_id)
new_serial = updatePackFuture(oid, serial, max_serial) new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial: if new_serial:
new_serial = p64(new_serial) new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial), updateObjectDataForPack(p64(oid), p64(serial),
new_serial, data_id) new_serial, data_id)
q("DELETE" + sql, args) q("DELETE" + sql, args)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
self.commit()
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tids, max_tid = self.query("""\ count, tids, max_tid = self.query("""\
......
...@@ -98,6 +98,7 @@ class StorageOperationHandler(EventHandler): ...@@ -98,6 +98,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid_list in object_dict.iteritems(): for serial, oid_list in object_dict.iteritems():
for oid in oid_list: for oid in oid_list:
deleteObject(oid, serial) deleteObject(oid, serial)
self.app.dm.commit()
assert not pack_tid, "TODO" assert not pack_tid, "TODO"
if next_tid: if next_tid:
self.app.replicator.fetchObjects(next_tid, next_oid) self.app.replicator.fetchObjects(next_tid, next_oid)
......
...@@ -20,7 +20,6 @@ from mock import Mock ...@@ -20,7 +20,6 @@ from mock import Mock
from neo.lib.util import add64, dump, p64, u64 from neo.lib.util import add64, dump, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.exception import DatabaseFailure
class StorageDBTests(NeoUnitTestBase): class StorageDBTests(NeoUnitTestBase):
...@@ -93,29 +92,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -93,29 +92,6 @@ class StorageDBTests(NeoUnitTestBase):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1)) self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1))
def test_transaction(self):
db = self.getDB()
x = []
class DB(db.__class__):
begin = lambda self: x.append('begin')
commit = lambda self: x.append('commit')
rollback = lambda self: x.append('rollback')
db.__class__ = DB
with db:
self.assertEqual(x.pop(), 'begin')
self.assertEqual(x.pop(), 'commit')
try:
with db:
self.assertEqual(x.pop(), 'begin')
with db:
self.fail()
self.fail()
except DatabaseFailure:
pass
self.assertEqual(x.pop(), 'rollback')
self.assertRaises(DatabaseFailure, db.__exit__, None, None, None)
self.assertFalse(x)
def test_getPartitionTable(self): def test_getPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = self.getPTID(1)
......
...@@ -300,11 +300,11 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -300,11 +300,11 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
pass pass
def switchTables(self): def switchTables(self):
with self.dm as q: q = self.dm.query
for table in ('trans', 'obj'): for table in 'trans', 'obj':
q('ALTER TABLE %s RENAME TO tmp' % table) q('ALTER TABLE %s RENAME TO tmp' % table)
q('ALTER TABLE t%s RENAME TO %s' % (table, table)) q('ALTER TABLE t%s RENAME TO %s' % (table, table))
q('ALTER TABLE tmp RENAME TO t%s' % table) q('ALTER TABLE tmp RENAME TO t%s' % table)
def getDataLockInfo(self): def getDataLockInfo(self):
dm = self.dm dm = self.dm
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment