Commit 2db4de86 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Make partition part of the SQL index.

Better performances are expected because of the removal of all MOD() operators
that would do a full scan to find the rows maching a given partition. Now a
query like '... where partition = x limit 10' should match only a subtree of
the table and not scan lots of rows if there is none matching this partition.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2306 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 2b7a0882
...@@ -51,6 +51,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -51,6 +51,9 @@ class MySQLDatabaseManager(DatabaseManager):
self.conn = None self.conn = None
self._connect() self._connect()
def getPartition(self, oid_or_tid):
return oid_or_tid % self.getNumPartitions()
def _parse(self, database): def _parse(self, database):
""" Get the database credentials (username, password, database) """ """ Get the database credentials (username, password, database) """
# expected pattern : [user[:password]@]database # expected pattern : [user[:password]@]database
...@@ -148,27 +151,31 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -148,27 +151,31 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "trans" stores information on committed transactions. # The table "trans" stores information on committed transactions.
q("""CREATE TABLE IF NOT EXISTS trans ( q("""CREATE TABLE IF NOT EXISTS trans (
tid BIGINT UNSIGNED NOT NULL PRIMARY KEY, partition SMALLINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED NOT NULL,
packed BOOLEAN NOT NULL, packed BOOLEAN NOT NULL,
oids MEDIUMBLOB NOT NULL, oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
description BLOB NOT NULL, description BLOB NOT NULL,
ext BLOB NOT NULL ext BLOB NOT NULL,
PRIMARY KEY (partition, tid)
) ENGINE = InnoDB""") ) ENGINE = InnoDB""")
# The table "obj" stores committed object data. # The table "obj" stores committed object data.
q("""CREATE TABLE IF NOT EXISTS obj ( q("""CREATE TABLE IF NOT EXISTS obj (
partition SMALLINT UNSIGNED NOT NULL,
oid BIGINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NULL, compression TINYINT UNSIGNED NULL,
checksum INT UNSIGNED NULL, checksum INT UNSIGNED NULL,
value LONGBLOB NULL, value LONGBLOB NULL,
value_serial BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (oid, serial) PRIMARY KEY (partition, oid, serial)
) ENGINE = InnoDB""") ) ENGINE = InnoDB""")
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
partition SMALLINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED NOT NULL, tid BIGINT UNSIGNED NOT NULL,
packed BOOLEAN NOT NULL, packed BOOLEAN NOT NULL,
oids MEDIUMBLOB NOT NULL, oids MEDIUMBLOB NOT NULL,
...@@ -179,6 +186,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -179,6 +186,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "tobj" stores uncommitted object data. # The table "tobj" stores uncommitted object data.
q("""CREATE TABLE IF NOT EXISTS tobj ( q("""CREATE TABLE IF NOT EXISTS tobj (
partition SMALLINT UNSIGNED NOT NULL,
oid BIGINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NULL, compression TINYINT UNSIGNED NULL,
...@@ -265,9 +273,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -265,9 +273,10 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
oid = util.u64(oid) oid = util.u64(oid)
tid = util.u64(tid) tid = util.u64(tid)
partition = self.getPartition(oid)
self.begin() self.begin()
r = q("""SELECT oid FROM obj WHERE oid = %d AND serial = %d""" \ r = q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND serial=%d"
% (oid, tid)) % (partition, oid, tid))
if not r and all: if not r and all:
r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \ r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \
% (oid, tid)) % (oid, tid))
...@@ -296,11 +305,12 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -296,11 +305,12 @@ class MySQLDatabaseManager(DatabaseManager):
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
partition = self.getPartition(oid)
if tid is not None: if tid is not None:
r = q("""SELECT serial, compression, checksum, value, value_serial r = q("""SELECT serial, compression, checksum, value, value_serial
FROM obj FROM obj
WHERE oid = %d AND serial = %d""" \ WHERE partition = %d AND oid = %d AND serial = %d""" \
% (oid, tid)) % (partition, oid, tid))
try: try:
serial, compression, checksum, data, value_serial = r[0] serial, compression, checksum, data, value_serial = r[0]
next_serial = None next_serial = None
...@@ -309,17 +319,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -309,17 +319,19 @@ class MySQLDatabaseManager(DatabaseManager):
elif before_tid is not None: elif before_tid is not None:
r = q("""SELECT serial, compression, checksum, value, value_serial r = q("""SELECT serial, compression, checksum, value, value_serial
FROM obj FROM obj
WHERE oid = %d AND serial < %d WHERE partition = %d
AND oid = %d AND serial < %d
ORDER BY serial DESC LIMIT 1""" \ ORDER BY serial DESC LIMIT 1""" \
% (oid, before_tid)) % (partition, oid, before_tid))
try: try:
serial, compression, checksum, data, value_serial = r[0] serial, compression, checksum, data, value_serial = r[0]
except IndexError: except IndexError:
return None return None
r = q("""SELECT serial FROM obj r = q("""SELECT serial FROM obj
WHERE oid = %d AND serial >= %d WHERE partition = %d
AND oid = %d AND serial >= %d
ORDER BY serial LIMIT 1""" \ ORDER BY serial LIMIT 1""" \
% (oid, before_tid)) % (partition, oid, before_tid))
try: try:
next_serial = r[0][0] next_serial = r[0][0]
except IndexError: except IndexError:
...@@ -329,8 +341,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -329,8 +341,9 @@ class MySQLDatabaseManager(DatabaseManager):
# MySQL does not use an index for a HAVING clause! # MySQL does not use an index for a HAVING clause!
r = q("""SELECT serial, compression, checksum, value, value_serial r = q("""SELECT serial, compression, checksum, value, value_serial
FROM obj FROM obj
WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \ WHERE partition = %d AND oid = %d
% oid) ORDER BY serial DESC LIMIT 1""" \
% (partition, oid))
try: try:
serial, compression, checksum, data, value_serial = r[0] serial, compression, checksum, data, value_serial = r[0]
next_serial = None next_serial = None
...@@ -375,10 +388,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -375,10 +388,10 @@ class MySQLDatabaseManager(DatabaseManager):
offset_list = ', '.join((str(i) for i in offset_list)) offset_list = ', '.join((str(i) for i in offset_list))
self.begin() self.begin()
try: try:
q("""DELETE FROM obj WHERE MOD(oid, %d) IN (%s)""" % q("""DELETE FROM obj WHERE partition IN (%s)""" %
(num_partitions, offset_list)) (offset_list, ))
q("""DELETE FROM trans WHERE MOD(tid, %d) IN (%s)""" % q("""DELETE FROM trans WHERE partition IN (%s)""" %
(num_partitions, offset_list)) (offset_list, ))
except: except:
self.rollback() self.rollback()
raise raise
...@@ -426,9 +439,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -426,9 +439,11 @@ class MySQLDatabaseManager(DatabaseManager):
value_serial = 'NULL' value_serial = 'NULL'
else: else:
value_serial = '%d' % (u64(value_serial), ) value_serial = '%d' % (u64(value_serial), )
q("""REPLACE INTO %s VALUES (%d, %d, %s, %s, %s, %s)""" \ partition = self.getPartition(oid)
% (obj_table, oid, tid, compression, checksum, data, q("""REPLACE INTO %s VALUES (%d, %d, %d, %s, %s, %s, %s)""" \
value_serial)) % (obj_table, partition, oid, tid, compression, checksum,
data, value_serial))
if transaction is not None: if transaction is not None:
oid_list, user, desc, ext, packed = transaction oid_list, user, desc, ext, packed = transaction
packed = packed and 1 or 0 packed = packed and 1 or 0
...@@ -436,8 +451,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -436,8 +451,10 @@ class MySQLDatabaseManager(DatabaseManager):
user = e(user) user = e(user)
desc = e(desc) desc = e(desc)
ext = e(ext) ext = e(ext)
q("""REPLACE INTO %s VALUES (%d, %i, '%s', '%s', '%s', '%s')""" \ partition = self.getPartition(tid)
% (trans_table, tid, packed, oids, user, desc, ext)) q("REPLACE INTO %s VALUES (%d, %d, %i, '%s', '%s', '%s', '%s')"
% (trans_table, partition, tid, packed, oids, user, desc,
ext))
except: except:
self.rollback() self.rollback()
raise raise
...@@ -488,7 +505,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -488,7 +505,8 @@ class MySQLDatabaseManager(DatabaseManager):
if all: if all:
# Note that this can be very slow. # Note that this can be very slow.
q("""DELETE FROM obj WHERE serial = %d""" % tid) q("""DELETE FROM obj WHERE serial = %d""" % tid)
q("""DELETE FROM trans WHERE tid = %d""" % tid) q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" %
(self.getPartition(tid), tid))
except: except:
self.rollback() self.rollback()
raise raise
...@@ -496,10 +514,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -496,10 +514,13 @@ class MySQLDatabaseManager(DatabaseManager):
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
oid = u64(oid)
query_param_dict = { query_param_dict = {
'oid': u64(oid), 'partition': self.getPartition(oid),
'oid': oid,
} }
query_fmt = 'DELETE FROM obj WHERE oid = %(oid)d' query_fmt = """DELETE FROM obj WHERE partition = %(partition)d
AND oid = %(oid)d"""
if serial is not None: if serial is not None:
query_param_dict['serial'] = u64(serial) query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d' query_fmt = query_fmt + ' AND serial = %(serial)d'
...@@ -516,8 +537,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -516,8 +537,8 @@ class MySQLDatabaseManager(DatabaseManager):
tid = util.u64(tid) tid = util.u64(tid)
self.begin() self.begin()
r = q("""SELECT oids, user, description, ext, packed FROM trans r = q("""SELECT oids, user, description, ext, packed FROM trans
WHERE tid = %d""" \ WHERE partition = %d AND tid = %d""" \
% tid) % (self.getPartition(tid), tid))
if not r and all: if not r and all:
r = q("""SELECT oids, user, description, ext, packed FROM ttrans r = q("""SELECT oids, user, description, ext, packed FROM ttrans
WHERE tid = %d""" \ WHERE tid = %d""" \
...@@ -533,7 +554,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -533,7 +554,8 @@ class MySQLDatabaseManager(DatabaseManager):
if value_serial is None: if value_serial is None:
raise CreationUndone raise CreationUndone
r = self.query("""SELECT LENGTH(value), value_serial FROM obj """ \ r = self.query("""SELECT LENGTH(value), value_serial FROM obj """ \
"""WHERE oid = %d AND serial = %d""" % (oid, value_serial)) """WHERE partition = %d AND oid = %d AND serial = %d""" %
(self.getPartition(oid), oid, value_serial))
length, value_serial = r[0] length, value_serial = r[0]
if length is None: if length is None:
logging.info("Multiple levels of indirection when " \ logging.info("Multiple levels of indirection when " \
...@@ -551,9 +573,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -551,9 +573,9 @@ class MySQLDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
pack_tid = self._getPackTID() pack_tid = self._getPackTID()
r = q("""SELECT serial, LENGTH(value), value_serial FROM obj r = q("""SELECT serial, LENGTH(value), value_serial FROM obj
WHERE oid = %d AND serial >= %d WHERE partition = %d AND oid = %d AND serial >= %d
ORDER BY serial DESC LIMIT %d, %d""" \ ORDER BY serial DESC LIMIT %d, %d""" \
% (oid, pack_tid, offset, length)) % (self.getPartition(oid), oid, pack_tid, offset, length))
if r: if r:
result = [] result = []
append = result.append append = result.append
...@@ -576,16 +598,16 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -576,16 +598,16 @@ class MySQLDatabaseManager(DatabaseManager):
min_serial = u64(min_serial) min_serial = u64(min_serial)
max_serial = u64(max_serial) max_serial = u64(max_serial)
r = q('SELECT oid, serial FROM obj ' r = q('SELECT oid, serial FROM obj '
'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR ' 'WHERE (partition=%(partition)s AND (oid = %(min_oid)d '
'AND serial >= %(min_serial)d) OR '
'oid > %(min_oid)d) AND ' 'oid > %(min_oid)d) AND '
'MOD(oid, %(num_partitions)d) = %(partition)s AND ' 'partition = %(partition)d AND '
'serial <= %(max_serial)d ' 'serial <= %(max_serial)d '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % { 'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid, 'min_oid': min_oid,
'min_serial': min_serial, 'min_serial': min_serial,
'max_serial': max_serial, 'max_serial': max_serial,
'length': length, 'length': length,
'num_partitions': num_partitions,
'partition': partition, 'partition': partition,
}) })
result = {} result = {}
...@@ -599,11 +621,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -599,11 +621,9 @@ class MySQLDatabaseManager(DatabaseManager):
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query q = self.query
r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s) r = q("""SELECT tid FROM trans WHERE partition in (%s)
ORDER BY tid DESC LIMIT %d,%d""" \ ORDER BY tid DESC LIMIT %d,%d""" \
% (num_partitions, % (','.join([str(p) for p in partition_list]), offset, length))
','.join([str(p) for p in partition_list]),
offset, length))
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
...@@ -613,11 +633,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -613,11 +633,10 @@ class MySQLDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
min_tid = u64(min_tid) min_tid = u64(min_tid)
max_tid = u64(max_tid) max_tid = u64(max_tid)
r = q("""SELECT tid FROM trans WHERE r = q("""SELECT tid FROM trans
MOD(tid, %(num_partitions)d) = %(partition)d WHERE partition = %(partition)d
AND tid >= %(min_tid)d AND tid <= %(max_tid)d AND tid >= %(min_tid)d AND tid <= %(max_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % { ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partition': partition, 'partition': partition,
'min_tid': min_tid, 'min_tid': min_tid,
'max_tid': max_tid, 'max_tid': max_tid,
...@@ -636,9 +655,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -636,9 +655,11 @@ class MySQLDatabaseManager(DatabaseManager):
value_serial = None value_serial = None
for table in ('obj', 'tobj'): for table in ('obj', 'tobj'):
for (serial, ) in q('SELECT serial FROM %(table)s WHERE ' for (serial, ) in q('SELECT serial FROM %(table)s WHERE '
'oid = %(oid)d AND serial >= %(max_serial)d AND ' 'partition = %(partition)d AND oid = %(oid)d '
'AND serial >= %(max_serial)d AND '
'value_serial = %(orig_serial)d ORDER BY serial ASC' % { 'value_serial = %(orig_serial)d ORDER BY serial ASC' % {
'table': table, 'table': table,
'partition': self.getPartition(oid),
'oid': oid, 'oid': oid,
'orig_serial': orig_serial, 'orig_serial': orig_serial,
'max_serial': max_serial, 'max_serial': max_serial,
...@@ -647,20 +668,24 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -647,20 +668,24 @@ class MySQLDatabaseManager(DatabaseManager):
# First found, copy data to it and mark its serial for # First found, copy data to it and mark its serial for
# future reference. # future reference.
value_serial = serial value_serial = serial
q('REPLACE INTO %(table)s (oid, serial, compression, ' q('REPLACE INTO %(table)s (partition, oid, serial, compression, '
'checksum, value, value_serial) SELECT oid, ' 'checksum, value, value_serial) SELECT partition, oid, '
'%(serial)d, compression, checksum, value, NULL FROM ' '%(serial)d, compression, checksum, value, NULL FROM '
'obj WHERE oid = %(oid)d AND serial = %(orig_serial)d' \ 'obj WHERE partition = %(partition)d AND oid = %(oid)d '
'AND serial = %(orig_serial)d' \
% { % {
'table': table, 'table': table,
'partition': self.getPartition(oid),
'oid': oid, 'oid': oid,
'serial': serial, 'serial': serial,
'orig_serial': orig_serial, 'orig_serial': orig_serial,
}) })
else: else:
q('REPLACE INTO %(table)s (oid, serial, value_serial) ' q('REPLACE INTO %(table)s (partition, oid, serial, value_serial) '
'VALUES (%(oid)d, %(serial)d, %(value_serial)d)' % { 'VALUES (%(partition)d, %(oid)d, %(serial)d, '
'%(value_serial)d)' % {
'table': table, 'table': table,
'partition': self.getPartition(oid),
'oid': oid, 'oid': oid,
'serial': serial, 'serial': serial,
'value_serial': value_serial, 'value_serial': value_serial,
...@@ -668,7 +693,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -668,7 +693,9 @@ class MySQLDatabaseManager(DatabaseManager):
def getObjectData(): def getObjectData():
assert value_serial is None assert value_serial is None
return q('SELECT compression, checksum, value FROM obj WHERE ' return q('SELECT compression, checksum, value FROM obj WHERE '
'oid = %(oid)d AND serial = %(orig_serial)d' % { 'partition = %(partition)d AND oid = %(oid)d '
'AND serial = %(orig_serial)d' % {
'partition': self.getPartition(oid),
'oid': oid, 'oid': oid,
'orig_serial': orig_serial, 'orig_serial': orig_serial,
})[0] })[0]
...@@ -688,9 +715,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -688,9 +715,11 @@ class MySQLDatabaseManager(DatabaseManager):
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, ' for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
'MAX(serial) FROM obj WHERE serial <= %(tid)d ' 'MAX(serial) FROM obj WHERE serial <= %(tid)d '
'GROUP BY oid' % {'tid': tid}): 'GROUP BY oid' % {'tid': tid}):
if q('SELECT LENGTH(value) FROM obj WHERE oid = %(oid)d AND ' if q('SELECT LENGTH(value) FROM obj WHERE partition ='
'%(partition)s AND oid = %(oid)d AND '
'serial = %(max_serial)d' % { 'serial = %(max_serial)d' % {
'oid': oid, 'oid': oid,
'partition': self.getPartition(oid),
'max_serial': max_serial, 'max_serial': max_serial,
})[0][0] == 0: })[0][0] == 0:
count += 1 count += 1
...@@ -698,14 +727,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -698,14 +727,17 @@ class MySQLDatabaseManager(DatabaseManager):
if count: if count:
# There are things to delete for this object # There are things to delete for this object
for (serial, ) in q('SELECT serial FROM obj WHERE ' for (serial, ) in q('SELECT serial FROM obj WHERE '
'oid=%(oid)d AND serial < %(max_serial)d' % { 'partition=%(partition)d AND oid=%(oid)d AND '
'serial < %(max_serial)d' % {
'oid': oid, 'oid': oid,
'partition': self.getPartition(oid),
'max_serial': max_serial, 'max_serial': max_serial,
}): }):
updatePackFuture(oid, serial, max_serial, updatePackFuture(oid, serial, max_serial,
updateObjectDataForPack) updateObjectDataForPack)
q('DELETE FROM obj WHERE oid=%(oid)d AND ' q('DELETE FROM obj WHERE partition=%(partition)d '
'serial=%(serial)d' % { 'AND oid=%(oid)d AND serial=%(serial)d' % {
'partition': self.getPartition(oid),
'oid': oid, 'oid': oid,
'serial': serial 'serial': serial
}) })
...@@ -719,11 +751,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -719,11 +751,10 @@ class MySQLDatabaseManager(DatabaseManager):
count, tid_checksum, max_tid = self.query('SELECT COUNT(*), ' count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
'BIT_XOR(tid), MAX(tid) FROM (' 'BIT_XOR(tid), MAX(tid) FROM ('
'SELECT tid FROM trans ' 'SELECT tid FROM trans '
'WHERE MOD(tid, %(num_partitions)d) = %(partition)s ' 'WHERE partition = %(partition)s '
'AND tid >= %(min_tid)d ' 'AND tid >= %(min_tid)d '
'ORDER BY tid ASC LIMIT %(length)d' 'ORDER BY tid ASC LIMIT %(length)d'
') AS foo' % { ') AS foo' % {
'num_partitions': num_partitions,
'partition': partition, 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': util.u64(min_tid),
'length': length, 'length': length,
...@@ -741,14 +772,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -741,14 +772,13 @@ class MySQLDatabaseManager(DatabaseManager):
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
r = self.query('SELECT oid, serial FROM obj WHERE ' r = self.query('SELECT oid, serial FROM obj WHERE '
'partition = %(partition)s AND '
'(oid > %(min_oid)d OR ' '(oid > %(min_oid)d OR '
'(oid = %(min_oid)d AND serial >= %(min_serial)d)) ' '(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
'AND MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % { 'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': u64(min_oid), 'min_oid': u64(min_oid),
'min_serial': u64(min_serial), 'min_serial': u64(min_serial),
'length': length, 'length': length,
'num_partitions': num_partitions,
'partition': partition, 'partition': partition,
}) })
count = len(r) count = len(r)
......
...@@ -74,6 +74,7 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -74,6 +74,7 @@ class StorageClientHandlerTests(NeoTestBase):
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = self._getConnection() conn = self._getConnection()
self.app.dm = Mock({'getNumPartitions': 1})
self.operation.askTransactionInformation(conn, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
......
...@@ -57,6 +57,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -57,6 +57,7 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getNumPartitions': 1})
self.operation.askTransactionInformation(conn, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
......
...@@ -35,6 +35,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -35,6 +35,7 @@ class StorageMySQSLdbTests(NeoTestBase):
database = '%s@%s' % (NEO_SQL_USER, NEO_SQL_DATABASE) database = '%s@%s' % (NEO_SQL_USER, NEO_SQL_DATABASE)
self.db = MySQLDatabaseManager(database) self.db = MySQLDatabaseManager(database)
self.db.setup() self.db.setup()
self.db.setNumPartitions(1)
def tearDown(self): def tearDown(self):
self.db.close() self.db.close()
...@@ -159,6 +160,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -159,6 +160,7 @@ class StorageMySQSLdbTests(NeoTestBase):
self.checkConfigEntry(self.db.getUUID, self.db.setUUID, 'TEST_VALUE') self.checkConfigEntry(self.db.getUUID, self.db.setUUID, 'TEST_VALUE')
def test_NumPartitions(self): def test_NumPartitions(self):
self.db.setup(reset=True)
self.checkConfigEntry(self.db.getNumPartitions, self.checkConfigEntry(self.db.getNumPartitions,
self.db.setNumPartitions, 10) self.db.setNumPartitions, 10)
...@@ -613,6 +615,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -613,6 +615,7 @@ class StorageMySQSLdbTests(NeoTestBase):
def test__getObjectData(self): def test__getObjectData(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
self.db.setNumPartitions(4)
tid0 = self.getNextTID() tid0 = self.getNextTID()
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
...@@ -699,6 +702,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -699,6 +702,7 @@ class StorageMySQSLdbTests(NeoTestBase):
def test__getDataTIDFromData(self): def test__getDataTIDFromData(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
self.db.setNumPartitions(4)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
oid1 = self.getOID(1) oid1 = self.getOID(1)
...@@ -723,6 +727,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -723,6 +727,7 @@ class StorageMySQSLdbTests(NeoTestBase):
def test__getDataTID(self): def test__getDataTID(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
self.db.setNumPartitions(4)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
oid1 = self.getOID(1) oid1 = self.getOID(1)
...@@ -745,6 +750,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -745,6 +750,7 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_findUndoTID(self): def test_findUndoTID(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
self.db.setNumPartitions(4)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment