# # Copyright (C) 2006-2009 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import MySQLdb from MySQLdb import OperationalError from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST import logging from array import array import string from struct import pack, unpack from neo.storage.database import DatabaseManager from neo.exception import DatabaseFailure from neo.protocol import DISCARDED_STATE, INVALID_PTID def p64(n): return pack('!Q', n) def u64(s): return unpack('!Q', s)[0] class MySQLDatabaseManager(DatabaseManager): """This class manages a database on MySQL.""" def __init__(self, **kwargs): self.db = kwargs['database'] self.user = kwargs['user'] self.passwd = kwargs.get('password') self.conn = None self.connect() super(MySQLDatabaseManager, self).__init__(**kwargs) def close(self): self.conn.close() def connect(self): kwd = {'db' : self.db, 'user' : self.user} if self.passwd is not None: kwd['passwd'] = self.passwd logging.info('connecting to MySQL on the database %s with user %s', self.db, self.user) self.conn = MySQLdb.connect(**kwd) self.conn.autocommit(False) self.under_transaction = False def begin(self): if self.under_transaction: try: self.commit() except: # Ignore any error for this implicit commit. pass self.query("""BEGIN""") self.under_transaction = True def commit(self): self.conn.commit() self.under_transaction = False def rollback(self): self.conn.rollback() self.under_transaction = False def query(self, query): """Query data from a database.""" conn = self.conn try: printable_char_list = [] for c in query.split('\n', 1)[0][:70]: if c not in string.printable or c in '\t\x0b\x0c\r': c = '\\x%02x' % ord(c) printable_char_list.append(c) query_part = ''.join(printable_char_list) logging.debug('querying %s...', query_part) conn.query(query) r = conn.store_result() if r is not None: new_r = [] for row in r.fetch_row(r.num_rows()): new_row = [] for d in row: if isinstance(d, array): d = d.tostring() new_row.append(d) new_r.append(tuple(new_row)) r = tuple(new_r) except OperationalError, m: if m[0] in (SERVER_GONE_ERROR, SERVER_LOST): logging.info('the MySQL server is gone; reconnecting') self.connect() return self.query(query) raise DatabaseFailure('MySQL error %d: %s' % (m[0], m[1])) return r def escape(self, s): """Escape special characters in a string.""" return self.conn.escape_string(s) def setup(self, reset = 0): q = self.query if reset: q("""DROP TABLE IF EXISTS config, pt, trans, obj, ttrans, tobj""") # The table "config" stores configuration parameters which affect the # persistent data. q("""CREATE TABLE IF NOT EXISTS config ( name VARBINARY(16) NOT NULL PRIMARY KEY, value VARBINARY(255) NOT NULL ) ENGINE = InnoDB""") # The table "pt" stores a partition table. q("""CREATE TABLE IF NOT EXISTS pt ( rid INT UNSIGNED NOT NULL, uuid BINARY(16) NOT NULL, state TINYINT UNSIGNED NOT NULL, PRIMARY KEY (rid, uuid) ) ENGINE = InnoDB""") # The table "trans" stores information on committed transactions. q("""CREATE TABLE IF NOT EXISTS trans ( tid BIGINT UNSIGNED NOT NULL PRIMARY KEY, oids MEDIUMBLOB NOT NULL, user BLOB NOT NULL, description BLOB NOT NULL, ext BLOB NOT NULL ) ENGINE = InnoDB""") # The table "obj" stores committed object data. q("""CREATE TABLE IF NOT EXISTS obj ( oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NOT NULL, checksum INT UNSIGNED NOT NULL, value MEDIUMBLOB NOT NULL, PRIMARY KEY (oid, serial) ) ENGINE = InnoDB""") # The table "ttrans" stores information on uncommitted transactions. q("""CREATE TABLE IF NOT EXISTS ttrans ( tid BIGINT UNSIGNED NOT NULL, oids MEDIUMBLOB NOT NULL, user BLOB NOT NULL, description BLOB NOT NULL, ext BLOB NOT NULL ) ENGINE = InnoDB""") # The table "tobj" stores uncommitted object data. q("""CREATE TABLE IF NOT EXISTS tobj ( oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NOT NULL, checksum INT UNSIGNED NOT NULL, value MEDIUMBLOB NOT NULL ) ENGINE = InnoDB""") def getConfiguration(self, key): q = self.query e = self.escape key = e(str(key)) r = q("""SELECT value FROM config WHERE name = '%s'""" % key) try: return r[0][0] except IndexError: return None def setConfiguration(self, key, value): q = self.query e = self.escape key = e(str(key)) value = e(str(value)) q("""REPLACE INTO config VALUES ('%s', '%s')""" % (key, value)) def getUUID(self): return self.getConfiguration('uuid') def setUUID(self, uuid): self.begin() try: self.setConfiguration('uuid', uuid) except: self.rollback() raise self.commit() def getNumPartitions(self): n = self.getConfiguration('partitions') if n is not None: return int(n) def setNumPartitions(self, num_partitions): self.begin() try: self.setConfiguration('partitions', num_partitions) except: self.rollback() raise self.commit() def getNumReplicas(self): n = self.getConfiguration('replicas') if n is not None: return int(n) def setNumReplicas(self, num_replicas): self.begin() try: self.setConfiguration('replicas', num_replicas) except: self.rollback() raise self.commit() def getName(self): return self.getConfiguration('name') def setName(self, name): self.begin() try: self.setConfiguration('name', name) except: self.rollback() raise self.commit() def getPTID(self): ptid = self.getConfiguration('ptid') if ptid is None: return INVALID_PTID return ptid def setPTID(self, ptid): self.begin() try: self.setConfiguration('ptid', ptid) except: self.rollback() raise self.commit() def getPartitionTable(self): q = self.query return q("""SELECT rid, uuid, state FROM pt""") def getLastOID(self, all = True): q = self.query self.begin() loid = q("""SELECT MAX(oid) FROM obj""")[0][0] if all: tmp_loid = q("""SELECT MAX(oid) FROM tobj""")[0][0] if loid is None or (tmp_loid is not None and loid < tmp_loid): loid = tmp_loid self.commit() if loid is not None: loid = p64(loid) return loid def getLastTID(self, all = True): # XXX this does not consider serials in obj. # I am not sure if this is really harmful. For safety, # check for tobj only at the moment. The reason why obj is # not tested is that it is too slow to get the max serial # from obj when it has a huge number of objects, because # serial is the second part of the primary key, so the index # is not used in this case. If doing it, it is better to # make another index for serial, but I doubt the cost increase # is worth. q = self.query self.begin() ltid = q("""SELECT MAX(tid) FROM trans""")[0][0] if all: tmp_ltid = q("""SELECT MAX(tid) FROM ttrans""")[0][0] if ltid is None or (tmp_ltid is not None and ltid < tmp_ltid): ltid = tmp_ltid tmp_serial = q("""SELECT MAX(serial) FROM tobj""")[0][0] if ltid is None or (tmp_serial is not None and ltid < tmp_serial): ltid = tmp_serial self.commit() if ltid is not None: ltid = p64(ltid) return ltid def getUnfinishedTIDList(self): q = self.query tid_set = set() self.begin() r = q("""SELECT tid FROM ttrans""") tid_set.update((p64(t[0]) for t in r)) r = q("""SELECT serial FROM tobj""") self.commit() tid_set.update((p64(t[0]) for t in r)) return list(tid_set) def objectPresent(self, oid, tid, all = True): q = self.query oid = u64(oid) tid = u64(tid) self.begin() r = q("""SELECT oid FROM obj WHERE oid = %d AND serial = %d""" \ % (oid, tid)) if not r and all: r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \ % (oid, tid)) self.commit() if r: return True return False def getObject(self, oid, tid = None, before_tid = None): q = self.query oid = u64(oid) if tid is not None: tid = u64(tid) r = q("""SELECT serial, compression, checksum, value FROM obj WHERE oid = %d AND serial = %d""" \ % (oid, tid)) try: serial, compression, checksum, data = r[0] next_serial = None except IndexError: return None elif before_tid is not None: before_tid = u64(before_tid) r = q("""SELECT serial, compression, checksum, value FROM obj WHERE oid = %d AND serial < %d ORDER BY serial DESC LIMIT 1""" \ % (oid, before_tid)) try: serial, compression, checksum, data = r[0] r = q("""SELECT serial FROM obj WHERE oid = %d AND serial >= %d ORDER BY serial LIMIT 1""" \ % (oid, before_tid)) try: next_serial = r[0][0] except IndexError: next_serial = None except IndexError: return None else: # XXX I want to express "HAVING serial = MAX(serial)", but # MySQL does not use an index for a HAVING clause! r = q("""SELECT serial, compression, checksum, value FROM obj WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \ % oid) try: serial, compression, checksum, data = r[0] next_serial = None except IndexError: return None if serial is not None: serial = p64(serial) if next_serial is not None: next_serial = p64(next_serial) return serial, next_serial, compression, checksum, data def doSetPartitionTable(self, ptid, cell_list, reset): q = self.query e = self.escape self.begin() try: if reset: q("""TRUNCATE pt""") for offset, uuid, state in cell_list: uuid = e(uuid) if state == DISCARDED_STATE: q("""DELETE FROM pt WHERE rid = %d AND uuid = '%s'""" \ % (offset, uuid)) else: q("""INSERT INTO pt VALUES (%d, '%s', %d) ON DUPLICATE KEY UPDATE state = %d""" \ % (offset, uuid, state, state)) ptid = e(ptid) q("""UPDATE config SET value = '%s' WHERE name = 'ptid'""" % ptid) except: self.rollback() raise self.commit() def changePartitionTable(self, ptid, cell_list): self.doSetPartitionTable(ptid, cell_list, False) def setPartitionTable(self, ptid, cell_list): self.doSetPartitionTable(ptid, cell_list, True) def dropUnfinishedData(self): q = self.query self.begin() try: q("""TRUNCATE tobj""") q("""TRUNCATE ttrans""") except: self.rollback() raise self.commit() def storeTransaction(self, tid, object_list, transaction, temporary = True): q = self.query e = self.escape tid = u64(tid) if temporary: obj_table = 'tobj' trans_table = 'ttrans' else: obj_table = 'obj' trans_table = 'trans' self.begin() try: # XXX it might be more efficient to insert multiple objects # at a time, but it is potentially dangerous, because # a packet to MySQL can exceed the maximum packet size. # However, I do not think this would be a big problem, because # tobj has no index, so inserting one by one should not be # significantly different from inserting many at a time. for oid, compression, checksum, data in object_list: oid = u64(oid) data = e(data) q("""REPLACE INTO %s VALUES (%d, %d, %d, %d, '%s')""" \ % (obj_table, oid, tid, compression, checksum, data)) if transaction is not None: oid_list, user, desc, ext = transaction oids = e(''.join(oid_list)) user = e(user) desc = e(desc) ext = e(ext) q("""REPLACE INTO %s VALUES (%d, '%s', '%s', '%s', '%s')""" \ % (trans_table, tid, oids, user, desc, ext)) except: self.rollback() raise self.commit() def finishTransaction(self, tid): q = self.query tid = u64(tid) self.begin() try: q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \ % tid) q("""DELETE FROM tobj WHERE serial = %d""" % tid) q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d""" \ % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid) except: self.rollback() raise self.commit() def deleteTransaction(self, tid, all = False): q = self.query tid = u64(tid) self.begin() try: q("""DELETE FROM tobj WHERE serial = %d""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid) if all: # Note that this can be very slow. q("""DELETE FROM obj WHERE serial = %d""" % tid) q("""DELETE FROM trans WHERE tid = %d""" % tid) except: self.rollback() raise self.commit() def getTransaction(self, tid, all = False): q = self.query tid = u64(tid) self.begin() r = q("""SELECT oids, user, description, ext FROM trans WHERE tid = %d""" \ % tid) if not r and all: r = q("""SELECT oids, user, description, ext FROM ttrans WHERE tid = %d""" \ % tid) self.commit() if r: oids, user, desc, ext = r[0] if (len(oids) % 8) != 0 or len(oids) == 0: raise DatabaseFailure('invalid oids for tid %x' % tid) oid_list = [] for i in xrange(0, len(oids), 8): oid_list.append(oids[i:i+8]) return oid_list, user, desc, ext return None def getOIDList(self, offset, length, num_partitions, partition_list): q = self.query r = q("""SELECT DISTINCT oid FROM obj WHERE MOD(oid,%d) in (%s) ORDER BY oid DESC LIMIT %d,%d""" \ % (num_partitions, ','.join([str(p) for p in partition_list]), offset, length)) return [p64(t[0]) for t in r] def getObjectHistory(self, oid, offset = 0, length = 1): q = self.query oid = u64(oid) r = q("""SELECT serial, LENGTH(value) FROM obj WHERE oid = %d ORDER BY serial DESC LIMIT %d,%d""" \ % (oid, offset, length)) if r: return [(p64(serial), length) for serial, length in r] return None def getTIDList(self, offset, length, num_partitions, partition_list): q = self.query r = q("""SELECT tid FROM trans WHERE MOD(tid,%d) in (%s) ORDER BY tid DESC LIMIT %d,%d""" \ % (num_partitions, ','.join([str(p) for p in partition_list]), offset, length)) return [p64(t[0]) for t in r] def getTIDListPresent(self, tid_list): q = self.query r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \ % ','.join([str(u64(tid)) for tid in tid_list])) return [p64(t[0]) for t in r] def getSerialListPresent(self, oid, serial_list): q = self.query oid = u64(oid) r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \ % (oid, ','.join([str(u64(serial)) for serial in serial_list]))) return [p64(t[0]) for t in r]