mysqldb.py 33.3 KB
Newer Older
Aurel's avatar
Aurel committed
1
#
Grégory Wisniewski's avatar
Grégory Wisniewski committed
2
# Copyright (C) 2006-2010  Nexedi SA
3
#
Aurel's avatar
Aurel committed
4 5 6 7
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
8
#
Aurel's avatar
Aurel committed
9 10 11 12 13 14 15
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
Aurel's avatar
Aurel committed
17

Yoshinori Okuji's avatar
Yoshinori Okuji committed
18 19 20
import MySQLdb
from MySQLdb import OperationalError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
21
import neo
22
from array import array
23
import string
Yoshinori Okuji's avatar
Yoshinori Okuji committed
24

25
from neo.storage.database import DatabaseManager
26
from neo.storage.database.manager import CreationUndone
27
from neo.exception import DatabaseFailure
28
from neo.protocol import CellStates, ZERO_OID, ZERO_TID
29
from neo import util
Yoshinori Okuji's avatar
Yoshinori Okuji committed
30

31 32
LOG_QUERIES = False

33 34
def splitOIDField(tid, oids):
    if (len(oids) % 8) != 0 or len(oids) == 0:
35 36
        raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid,
            len(oids)))
37 38 39 40 41 42
    oid_list = []
    append = oid_list.append
    for i in xrange(0, len(oids), 8):
        append(oids[i:i+8])
    return oid_list

43 44
class MySQLDatabaseManager(DatabaseManager):
    """This class manages a database on MySQL."""
Yoshinori Okuji's avatar
Yoshinori Okuji committed
45

46 47 48
    def __init__(self, database):
        super(MySQLDatabaseManager, self).__init__()
        self.user, self.passwd, self.db = self._parse(database)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
49
        self.conn = None
50
        self._config = {}
51
        self._connect()
52 53 54 55 56 57 58 59 60 61 62

    def _parse(self, database):
        """ Get the database credentials (username, password, database) """
        # expected pattern : [user[:password]@]database
        username = None
        password = None
        if '@' in database:
            (username, database) = database.split('@')
            if ':' in username:
                (username, password) = username.split(':')
        return (username, password, database)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
63

64 65 66
    def close(self):
        self.conn.close()

67
    def _connect(self):
Yoshinori Okuji's avatar
Yoshinori Okuji committed
68 69 70
        kwd = {'db' : self.db, 'user' : self.user}
        if self.passwd is not None:
            kwd['passwd'] = self.passwd
71
        neo.logging.info('connecting to MySQL on the database %s with user %s',
Yoshinori Okuji's avatar
Yoshinori Okuji committed
72 73
                     self.db, self.user)
        self.conn = MySQLdb.connect(**kwd)
74 75
        self.conn.autocommit(False)

76
    def _begin(self):
77 78
        self.query("""BEGIN""")

79
    def _commit(self):
80
        if LOG_QUERIES:
81
            neo.logging.debug('committing...')
82 83
        self.conn.commit()

84
    def _rollback(self):
85
        if LOG_QUERIES:
86
            neo.logging.debug('aborting...')
87
        self.conn.rollback()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
88 89 90 91 92

    def query(self, query):
        """Query data from a database."""
        conn = self.conn
        try:
93 94 95 96 97 98 99
            if LOG_QUERIES:
                printable_char_list = []
                for c in query.split('\n', 1)[0][:70]:
                    if c not in string.printable or c in '\t\x0b\x0c\r':
                        c = '\\x%02x' % ord(c)
                    printable_char_list.append(c)
                query_part = ''.join(printable_char_list)
100
                neo.logging.debug('querying %s...', query_part)
101

Yoshinori Okuji's avatar
Yoshinori Okuji committed
102 103 104
            conn.query(query)
            r = conn.store_result()
            if r is not None:
105 106 107 108 109 110 111 112 113 114
                new_r = []
                for row in r.fetch_row(r.num_rows()):
                    new_row = []
                    for d in row:
                        if isinstance(d, array):
                            d = d.tostring()
                        new_row.append(d)
                    new_r.append(tuple(new_row))
                r = tuple(new_r)

Yoshinori Okuji's avatar
Yoshinori Okuji committed
115 116
        except OperationalError, m:
            if m[0] in (SERVER_GONE_ERROR, SERVER_LOST):
117
                neo.logging.info('the MySQL server is gone; reconnecting')
118
                self._connect()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
119
                return self.query(query)
120
            raise DatabaseFailure('MySQL error %d: %s' % (m[0], m[1]))
Yoshinori Okuji's avatar
Yoshinori Okuji committed
121
        return r
122

123 124 125
    def escape(self, s):
        """Escape special characters in a string."""
        return self.conn.escape_string(s)
126 127

    def setup(self, reset = 0):
128
        self._config.clear()
129 130 131
        q = self.query

        if reset:
132 133
            q('DROP TABLE IF EXISTS config, pt, trans, obj, obj_short, '
                'ttrans, tobj')
134 135 136 137 138

        # The table "config" stores configuration parameters which affect the
        # persistent data.
        q("""CREATE TABLE IF NOT EXISTS config (
                 name VARBINARY(16) NOT NULL PRIMARY KEY,
139
                 value VARBINARY(255) NULL
140 141 142 143 144
             ) ENGINE = InnoDB""")

        # The table "pt" stores a partition table.
        q("""CREATE TABLE IF NOT EXISTS pt (
                 rid INT UNSIGNED NOT NULL,
145
                 uuid CHAR(32) NOT NULL,
146 147 148 149 150 151
                 state TINYINT UNSIGNED NOT NULL,
                 PRIMARY KEY (rid, uuid)
             ) ENGINE = InnoDB""")

        # The table "trans" stores information on committed transactions.
        q("""CREATE TABLE IF NOT EXISTS trans (
152 153
                 partition SMALLINT UNSIGNED NOT NULL,
                 tid BIGINT UNSIGNED NOT NULL,
154
                 packed BOOLEAN NOT NULL,
155 156
                 oids MEDIUMBLOB NOT NULL,
                 user BLOB NOT NULL,
Yoshinori Okuji's avatar
Yoshinori Okuji committed
157
                 description BLOB NOT NULL,
158 159
                 ext BLOB NOT NULL,
                 PRIMARY KEY (partition, tid)
160 161 162 163
             ) ENGINE = InnoDB""")

        # The table "obj" stores committed object data.
        q("""CREATE TABLE IF NOT EXISTS obj (
164
                 partition SMALLINT UNSIGNED NOT NULL,
165 166
                 oid BIGINT UNSIGNED NOT NULL,
                 serial BIGINT UNSIGNED NOT NULL,
167 168 169 170
                 compression TINYINT UNSIGNED NULL,
                 checksum INT UNSIGNED NULL,
                 value LONGBLOB NULL,
                 value_serial BIGINT UNSIGNED NULL,
171
                 PRIMARY KEY (partition, oid, serial)
172 173
             ) ENGINE = InnoDB""")

174 175 176 177 178 179 180 181 182 183
        # The table "obj_short" contains columns which are accessed in queries
        # which don't need to access object data. This is needed because InnoDB
        # loads a whole row even when it only needs columns in primary key.
        q('CREATE TABLE IF NOT EXISTS obj_short ('
            'partition SMALLINT UNSIGNED NOT NULL,'
            'oid BIGINT UNSIGNED NOT NULL,'
            'serial BIGINT UNSIGNED NOT NULL,'
            'PRIMARY KEY (partition, oid, serial)'
            ') ENGINE = InnoDB')

184 185
        # The table "ttrans" stores information on uncommitted transactions.
        q("""CREATE TABLE IF NOT EXISTS ttrans (
186
                 partition SMALLINT UNSIGNED NOT NULL,
187
                 tid BIGINT UNSIGNED NOT NULL,
188
                 packed BOOLEAN NOT NULL,
189 190
                 oids MEDIUMBLOB NOT NULL,
                 user BLOB NOT NULL,
Yoshinori Okuji's avatar
Yoshinori Okuji committed
191
                 description BLOB NOT NULL,
192 193 194 195 196
                 ext BLOB NOT NULL
             ) ENGINE = InnoDB""")

        # The table "tobj" stores uncommitted object data.
        q("""CREATE TABLE IF NOT EXISTS tobj (
197
                 partition SMALLINT UNSIGNED NOT NULL,
198 199
                 oid BIGINT UNSIGNED NOT NULL,
                 serial BIGINT UNSIGNED NOT NULL,
200 201 202 203
                 compression TINYINT UNSIGNED NULL,
                 checksum INT UNSIGNED NULL,
                 value LONGBLOB NULL,
                 value_serial BIGINT UNSIGNED NULL
204 205
             ) ENGINE = InnoDB""")

206 207 208 209 210 211 212 213 214 215
    def objQuery(self, query):
        """
        Execute given query for both obj and obj_short tables.
        query: query string, must contain "%(table)s" where obj table name is
        needed.
        """
        q = self.query
        for table in ('obj', 'obj_short'):
          q(query % {'table': table})

216
    def getConfiguration(self, key):
217 218
        if key in self._config:
            return self._config[key]
219 220
        q = self.query
        e = self.escape
221
        sql_key = e(str(key))
222
        try:
223
            r = q("SELECT value FROM config WHERE name = '%s'" % sql_key)[0][0]
224
        except IndexError:
225
            raise KeyError, key
226 227
        self._config[key] = r
        return r
228

229 230 231
    def _setConfiguration(self, key, value):
        q = self.query
        e = self.escape
232
        self._config[key] = value
233
        key = e(str(key))
234 235 236 237 238
        if value is None:
            value = 'NULL'
        else:
            value = "'%s'" % (e(str(value)), )
        q("""REPLACE INTO config VALUES ('%s', %s)""" % (key, value))
239

240 241 242 243 244 245 246 247 248 249
    def _setPackTID(self, tid):
        self._setConfiguration('_pack_tid', tid)

    def _getPackTID(self):
        try:
            result = int(self.getConfiguration('_pack_tid'))
        except KeyError:
            result = -1
        return result

250 251
    def getPartitionTable(self):
        q = self.query
252 253 254 255 256 257
        cell_list = q("""SELECT rid, uuid, state FROM pt""")
        pt = []
        for offset, uuid, state in cell_list:
            uuid = util.bin(uuid)
            pt.append((offset, uuid, state))
        return pt
258 259 260 261 262 263 264 265 266 267 268 269 270

    def getLastTID(self, all = True):
        # XXX this does not consider serials in obj.
        # I am not sure if this is really harmful. For safety,
        # check for tobj only at the moment. The reason why obj is
        # not tested is that it is too slow to get the max serial
        # from obj when it has a huge number of objects, because
        # serial is the second part of the primary key, so the index
        # is not used in this case. If doing it, it is better to
        # make another index for serial, but I doubt the cost increase
        # is worth.
        q = self.query
        self.begin()
271 272
        ltid = q("SELECT MAX(value) FROM (SELECT MAX(tid) AS value FROM trans "
                    "GROUP BY partition) AS foo")[0][0]
273 274 275 276 277 278 279 280
        if all:
            tmp_ltid = q("""SELECT MAX(tid) FROM ttrans""")[0][0]
            if ltid is None or (tmp_ltid is not None and ltid < tmp_ltid):
                ltid = tmp_ltid
            tmp_serial = q("""SELECT MAX(serial) FROM tobj""")[0][0]
            if ltid is None or (tmp_serial is not None and ltid < tmp_serial):
                ltid = tmp_serial
        self.commit()
281
        if ltid is not None:
282
            ltid = util.p64(ltid)
283 284 285 286 287 288 289
        return ltid

    def getUnfinishedTIDList(self):
        q = self.query
        tid_set = set()
        self.begin()
        r = q("""SELECT tid FROM ttrans""")
290
        tid_set.update((util.p64(t[0]) for t in r))
291 292
        r = q("""SELECT serial FROM tobj""")
        self.commit()
293
        tid_set.update((util.p64(t[0]) for t in r))
294 295 296 297
        return list(tid_set)

    def objectPresent(self, oid, tid, all = True):
        q = self.query
298 299
        oid = util.u64(oid)
        tid = util.u64(tid)
300
        partition = self._getPartition(oid)
301
        self.begin()
302 303
        r = q("SELECT oid FROM obj_short WHERE partition=%d AND oid=%d AND "
              "serial=%d" % (partition, oid, tid))
304
        if not r and all:
305
            r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \
306 307 308 309 310 311
                    % (oid, tid))
        self.commit()
        if r:
            return True
        return False

312 313 314 315 316 317 318
    def _getObjectData(self, oid, value_serial, tid):
        if value_serial is None:
            raise CreationUndone
        if value_serial >= tid:
            raise ValueError, "Incorrect value reference found for " \
                "oid %d at tid %d: reference = %d" % (oid, value_serial, tid)
        r = self.query("""SELECT compression, checksum, value, """ \
319 320 321 322 323 324
            """value_serial FROM obj WHERE partition = %(partition)d """
            """AND oid = %(oid)d AND serial = %(serial)d""" % {
                'partition': self._getPartition(oid),
                'oid': oid,
                'serial': value_serial,
            })
325 326
        compression, checksum, value, next_value_serial = r[0]
        if value is None:
327
            neo.logging.info("Multiple levels of indirection when " \
328 329 330 331 332 333 334
                "searching for object data for oid %d at tid %d. This " \
                "causes suboptimal performance." % (oid, value_serial))
            value_serial, compression, checksum, value = self._getObjectData(
                oid, next_value_serial, value_serial)
        return value_serial, compression, checksum, value

    def _getObject(self, oid, tid=None, before_tid=None):
335
        q = self.query
336
        partition = self._getPartition(oid)
337
        if tid is not None:
338 339
            r = q("""SELECT serial, compression, checksum, value, value_serial
                        FROM obj
340 341
                        WHERE partition = %d AND oid = %d AND serial = %d""" \
                    % (partition, oid, tid))
342
            try:
343
                serial, compression, checksum, data, value_serial = r[0]
344 345 346
                next_serial = None
            except IndexError:
                return None
347
        elif before_tid is not None:
348 349
            r = q("""SELECT serial, compression, checksum, value, value_serial
                        FROM obj
350 351
                        WHERE partition = %d
                        AND oid = %d AND serial < %d
352
                        ORDER BY serial DESC LIMIT 1""" \
353
                    % (partition, oid, before_tid))
354
            try:
355
                serial, compression, checksum, data, value_serial = r[0]
356 357
            except IndexError:
                return None
358
            r = q("""SELECT serial FROM obj_short
359 360
                        WHERE partition = %d
                        AND oid = %d AND serial >= %d
361
                        ORDER BY serial LIMIT 1""" \
362
                    % (partition, oid, before_tid))
363 364 365 366
            try:
                next_serial = r[0][0]
            except IndexError:
                next_serial = None
367 368 369
        else:
            # XXX I want to express "HAVING serial = MAX(serial)", but
            # MySQL does not use an index for a HAVING clause!
370 371
            r = q("""SELECT serial, compression, checksum, value, value_serial
                        FROM obj
372 373 374
                        WHERE partition = %d AND oid = %d
                        ORDER BY serial DESC LIMIT 1""" \
                    % (partition, oid))
375
            try:
376
                serial, compression, checksum, data, value_serial = r[0]
377 378 379 380
                next_serial = None
            except IndexError:
                return None

381 382
        return serial, next_serial, compression, checksum, data, value_serial

383 384 385 386 387 388 389 390
    def doSetPartitionTable(self, ptid, cell_list, reset):
        q = self.query
        e = self.escape
        self.begin()
        try:
            if reset:
                q("""TRUNCATE pt""")
            for offset, uuid, state in cell_list:
391
                uuid = e(util.dump(uuid))
392 393
                # TODO: this logic should move out of database manager
                # add 'dropCells(cell_list)' to API and use one query
394
                if state == CellStates.DISCARDED:
395
                    q("""DELETE FROM pt WHERE rid = %d AND uuid = '%s'""" \
396 397 398 399 400
                            % (offset, uuid))
                else:
                    q("""INSERT INTO pt VALUES (%d, '%s', %d)
                            ON DUPLICATE KEY UPDATE state = %d""" \
                                    % (offset, uuid, state, state))
401
            self.setPTID(ptid)
402 403 404 405 406 407
        except:
            self.rollback()
            raise
        self.commit()

    def changePartitionTable(self, ptid, cell_list):
408
        self.doSetPartitionTable(ptid, cell_list, False)
409 410

    def setPartitionTable(self, ptid, cell_list):
411
        self.doSetPartitionTable(ptid, cell_list, True)
412

413
    def dropPartitions(self, num_partitions, offset_list):
414
        q = self.query
415 416
        e = self.escape
        offset_list = ', '.join((str(i) for i in offset_list))
417 418
        self.begin()
        try:
419
            self.objQuery('DELETE FROM %%(table)s WHERE partition IN (%s)' %
420 421 422
                (offset_list, ))
            q("""DELETE FROM trans WHERE partition IN (%s)""" %
                (offset_list, ))
423 424
        except:
            self.rollback()
425
            raise
426 427
        self.commit()

428 429 430 431 432 433 434 435 436 437 438
    def dropUnfinishedData(self):
        q = self.query
        self.begin()
        try:
            q("""TRUNCATE tobj""")
            q("""TRUNCATE ttrans""")
        except:
            self.rollback()
            raise
        self.commit()

439
    def storeTransaction(self, tid, object_list, transaction, temporary = True):
440 441
        q = self.query
        e = self.escape
Vincent Pelletier's avatar
Vincent Pelletier committed
442 443
        u64 = util.u64
        tid = u64(tid)
444 445 446 447 448 449 450 451

        if temporary:
            obj_table = 'tobj'
            trans_table = 'ttrans'
        else:
            obj_table = 'obj'
            trans_table = 'trans'

452 453
        self.begin()
        try:
454
            for oid, compression, checksum, data, value_serial in object_list:
Vincent Pelletier's avatar
Vincent Pelletier committed
455
                oid = u64(oid)
456 457 458 459 460 461 462 463 464 465 466 467 468
                if data is None:
                    compression = checksum = data = 'NULL'
                else:
                    # TODO: unit-test this raise
                    if value_serial is not None:
                        raise ValueError, 'Either data or value_serial ' \
                            'must be None (oid %d, tid %d)' % (oid, tid)
                    compression = '%d' % (compression, )
                    checksum = '%d' % (checksum, )
                    data = "'%s'" % (e(data), )
                if value_serial is None:
                    value_serial = 'NULL'
                else:
469
                    value_serial = '%d' % (u64(value_serial), )
470
                partition = self._getPartition(oid)
471 472 473
                q("""REPLACE INTO %s VALUES (%d, %d, %d, %s, %s, %s, %s)""" \
                    % (obj_table, partition, oid, tid, compression, checksum,
                        data, value_serial))
474 475 476 477
                if obj_table == 'obj':
                    # Update obj_short too
                    q('REPLACE INTO obj_short VALUES (%d, %d, %d)' % (
                        partition, oid, tid))
478

479
            if transaction is not None:
480 481
                oid_list, user, desc, ext, packed = transaction
                packed = packed and 1 or 0
482 483 484 485
                oids = e(''.join(oid_list))
                user = e(user)
                desc = e(desc)
                ext = e(ext)
486
                partition = self._getPartition(tid)
487 488 489
                q("REPLACE INTO %s VALUES (%d, %d, %i, '%s', '%s', '%s', '%s')"
                    % (trans_table, partition, tid, packed, oids, user, desc,
                        ext))
490 491 492 493 494
        except:
            self.rollback()
            raise
        self.commit()

495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    def _getDataTIDFromData(self, oid, result):
        tid, next_serial, compression, checksum, data, value_serial = result
        if data is None:
            try:
                data_serial = self._getObjectData(oid, value_serial, tid)[0]
            except CreationUndone:
                data_serial = None
        else:
            data_serial = tid
        return tid, data_serial

    def _getDataTID(self, oid, tid=None, before_tid=None):
        result = self._getObject(oid, tid=tid, before_tid=before_tid)
        if result is None:
            result = (None, None)
        else:
            result = self._getDataTIDFromData(oid, result)
        return result

514 515
    def finishTransaction(self, tid):
        q = self.query
516
        tid = util.u64(tid)
517 518
        self.begin()
        try:
519
            q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \
520
                    % tid)
521 522
            q('INSERT INTO obj_short SELECT partition, oid, serial FROM tobj'
                ' WHERE tobj.serial = %d' % (tid, ))
523
            q("""DELETE FROM tobj WHERE serial = %d""" % tid)
524
            q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d"""
525
                    % tid)
526
            q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
527 528 529 530 531
        except:
            self.rollback()
            raise
        self.commit()

532
    def deleteTransaction(self, tid, oid_list=()):
533
        q = self.query
534
        objQuery = self.objQuery
535 536
        u64 = util.u64
        tid = u64(tid)
537
        getPartition = self._getPartition
538 539
        self.begin()
        try:
540 541
            q("""DELETE FROM tobj WHERE serial = %d""" % tid)
            q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
542
            q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" %
543
                (getPartition(tid), tid))
544 545 546 547
            # delete from obj using indexes
            for oid in oid_list:
                oid = u64(oid)
                partition = getPartition(oid)
548 549 550
                objQuery('DELETE FROM %%(table)s WHERE '
                    'partition=%(partition)d '
                    'AND oid = %(oid)d AND serial = %(serial)d' % {
551 552 553 554
                    'partition': partition,
                    'oid': oid,
                    'serial': tid,
                })
555 556 557 558 559
        except:
            self.rollback()
            raise
        self.commit()

560 561 562 563 564 565 566 567 568 569 570 571 572
    def deleteTransactionsAbove(self, num_partitions, partition, tid):
        self.begin()
        try:
            self.query('DELETE FROM trans WHERE partition=%(partition)d AND '
              'tid > %(tid)d' % {
                'partition': partition,
                'tid': util.u64(tid),
            })
        except:
            self.rollback()
            raise
        self.commit()

573 574
    def deleteObject(self, oid, serial=None):
        u64 = util.u64
575
        oid = u64(oid)
576
        query_param_dict = {
577
            'partition': self._getPartition(oid),
578
            'oid': oid,
579
        }
580 581
        query_fmt = 'DELETE FROM %%(table)s WHERE ' \
            'partition = %(partition)d AND oid = %(oid)d'
582 583 584 585 586
        if serial is not None:
            query_param_dict['serial'] = u64(serial)
            query_fmt = query_fmt + ' AND serial = %(serial)d'
        self.begin()
        try:
587
            self.objQuery(query_fmt % query_param_dict)
588 589 590 591 592
        except:
            self.rollback()
            raise
        self.commit()

593 594 595 596
    def deleteObjectsAbove(self, num_partitions, partition, oid, serial):
        u64 = util.u64
        self.begin()
        try:
597 598
            self.query('DELETE FROM obj WHERE partition=%(partition)d AND ('
              'oid > %(oid)d OR (oid = %(oid)d AND serial > %(serial)d))' % {
599 600 601 602 603 604 605 606 607
                'partition': partition,
                'oid': u64(oid),
                'serial': u64(serial),
            })
        except:
            self.rollback()
            raise
        self.commit()

608 609
    def getTransaction(self, tid, all = False):
        q = self.query
610
        tid = util.u64(tid)
611
        self.begin()
612
        r = q("""SELECT oids, user, description, ext, packed FROM trans
613
                    WHERE partition = %d AND tid = %d""" \
614
                % (self._getPartition(tid), tid))
615
        if not r and all:
616
            r = q("""SELECT oids, user, description, ext, packed FROM ttrans
617
                        WHERE tid = %d""" \
618 619 620
                    % tid)
        self.commit()
        if r:
Grégory Wisniewski's avatar
Grégory Wisniewski committed
621
            oids, user, desc, ext, packed = r[0]
622
            oid_list = splitOIDField(tid, oids)
623
            return oid_list, user, desc, ext, bool(packed)
624 625
        return None

626 627 628 629
    def _getObjectLength(self, oid, value_serial):
        if value_serial is None:
            raise CreationUndone
        r = self.query("""SELECT LENGTH(value), value_serial FROM obj """ \
630
            """WHERE partition = %d AND oid = %d AND serial = %d""" %
631
            (self._getPartition(oid), oid, value_serial))
632 633
        length, value_serial = r[0]
        if length is None:
634
            neo.logging.info("Multiple levels of indirection when " \
635 636 637 638 639
                "searching for object data for oid %d at tid %d. This " \
                "causes suboptimal performance." % (oid, value_serial))
            length = self._getObjectLength(oid, value_serial)
        return length

640
    def getObjectHistory(self, oid, offset = 0, length = 1):
641 642 643
        # FIXME: This method doesn't take client's current ransaction id as
        # parameter, which means it can return transactions in the future of
        # client's transaction.
644
        q = self.query
645
        oid = util.u64(oid)
646
        p64 = util.p64
647
        pack_tid = self._getPackTID()
648
        r = q("""SELECT serial, LENGTH(value), value_serial FROM obj
649
                    WHERE partition = %d AND oid = %d AND serial >= %d
650
                    ORDER BY serial DESC LIMIT %d, %d""" \
651
                % (self._getPartition(oid), oid, pack_tid, offset, length))
652
        if r:
653 654 655 656 657 658 659 660 661 662
            result = []
            append = result.append
            for serial, length, value_serial in r:
                if length is None:
                    try:
                        length = self._getObjectLength(oid, value_serial)
                    except CreationUndone:
                        length = 0
                append((p64(serial), length))
            return result
663
        return None
664

665 666
    def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
            num_partitions, partition):
667
        q = self.query
668
        u64 = util.u64
669
        p64 = util.p64
670 671
        min_oid = u64(min_oid)
        min_serial = u64(min_serial)
672
        max_serial = u64(max_serial)
673
        r = q('SELECT oid, serial FROM obj_short '
674
                'WHERE partition = %(partition)s '
675 676 677
                'AND serial <= %(max_serial)d '
                'AND ((oid = %(min_oid)d AND serial >= %(min_serial)d) '
                'OR oid > %(min_oid)d) '
678 679 680
                'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
            'min_oid': min_oid,
            'min_serial': min_serial,
681
            'max_serial': max_serial,
682
            'length': length,
683
            'partition': partition,
684
        })
685 686 687 688 689 690 691 692
        result = {}
        for oid, serial in r:
            try:
                serial_list = result[oid]
            except KeyError:
                serial_list = result[oid] = []
            serial_list.append(p64(serial))
        return dict((p64(x), y) for x, y in result.iteritems())
693

694
    def getTIDList(self, offset, length, num_partitions, partition_list):
695
        q = self.query
696
        r = q("""SELECT tid FROM trans WHERE partition in (%s)
697
                    ORDER BY tid DESC LIMIT %d,%d""" \
698
                % (','.join([str(p) for p in partition_list]), offset, length))
699 700
        return [util.p64(t[0]) for t in r]

701
    def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
702
            partition):
703
        q = self.query
704 705 706 707
        u64 = util.u64
        p64 = util.p64
        min_tid = u64(min_tid)
        max_tid = u64(max_tid)
708 709
        r = q("""SELECT tid FROM trans
                    WHERE partition = %(partition)d
710
                    AND tid >= %(min_tid)d AND tid <= %(max_tid)d
711
                    ORDER BY tid ASC LIMIT %(length)d""" % {
712
            'partition': partition,
713 714
            'min_tid': min_tid,
            'max_tid': max_tid,
715 716
            'length': length,
        })
717
        return [p64(t[0]) for t in r]
718

719 720 721 722
    def _updatePackFuture(self, oid, orig_serial, max_serial,
            updateObjectDataForPack):
        q = self.query
        p64 = util.p64
723
        getPartition = self._getPartition
724 725 726 727 728 729 730
        # Before deleting this objects revision, see if there is any
        # transaction referencing its value at max_serial or above.
        # If there is, copy value to the first future transaction. Any further
        # reference is just updated to point to the new data location.
        value_serial = None
        for table in ('obj', 'tobj'):
            for (serial, ) in q('SELECT serial FROM %(table)s WHERE '
731 732
                    'partition = %(partition)d AND oid = %(oid)d '
                    'AND serial >= %(max_serial)d AND '
733 734
                    'value_serial = %(orig_serial)d ORDER BY serial ASC' % {
                        'table': table,
735
                        'partition': getPartition(oid),
736 737 738 739 740 741 742 743
                        'oid': oid,
                        'orig_serial': orig_serial,
                        'max_serial': max_serial,
                    }):
                if value_serial is None:
                    # First found, copy data to it and mark its serial for
                    # future reference.
                    value_serial = serial
744 745
                    q('REPLACE INTO %(table)s (partition, oid, serial, compression, '
                        'checksum, value, value_serial) SELECT partition, oid, '
746
                        '%(serial)d, compression, checksum, value, NULL FROM '
747 748
                        'obj WHERE partition = %(partition)d AND oid = %(oid)d '
                        'AND serial = %(orig_serial)d' \
749 750
                        % {
                            'table': table,
751
                            'partition': getPartition(oid),
752 753 754 755 756
                            'oid': oid,
                            'serial': serial,
                            'orig_serial': orig_serial,
                    })
                else:
757 758 759
                    q('REPLACE INTO %(table)s (partition, oid, serial, value_serial) '
                        'VALUES (%(partition)d, %(oid)d, %(serial)d, '
                        '%(value_serial)d)' % {
760
                            'table': table,
761
                            'partition': getPartition(oid),
762 763 764 765 766 767 768
                            'oid': oid,
                            'serial': serial,
                            'value_serial': value_serial,
                    })
        def getObjectData():
            assert value_serial is None
            return q('SELECT compression, checksum, value FROM obj WHERE '
769 770
                'partition = %(partition)d AND oid = %(oid)d '
                'AND serial = %(orig_serial)d' % {
771
                    'partition': getPartition(oid),
772 773 774 775 776 777 778 779 780 781 782
                    'oid': oid,
                    'orig_serial': orig_serial,
                })[0]
        if value_serial:
            value_serial = p64(value_serial)
        updateObjectDataForPack(p64(oid), p64(orig_serial), value_serial,
            getObjectData)

    def pack(self, tid, updateObjectDataForPack):
        # TODO: unit test (along with updatePackFuture)
        q = self.query
783
        objQuery = self.objQuery
784 785
        tid = util.u64(tid)
        updatePackFuture = self._updatePackFuture
786
        getPartition = self._getPartition
787 788 789 790
        self.begin()
        try:
            self._setPackTID(tid)
            for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
791
                    'MAX(serial) FROM obj_short WHERE serial <= %(tid)d '
792
                    'GROUP BY oid' % {'tid': tid}):
793 794
                if q('SELECT LENGTH(value) FROM obj WHERE partition ='
                        '%(partition)s AND oid = %(oid)d AND '
795 796
                        'serial = %(max_serial)d' % {
                            'oid': oid,
797
                            'partition': getPartition(oid),
798 799 800 801 802 803
                            'max_serial': max_serial,
                        })[0][0] == 0:
                    count += 1
                    max_serial += 1
                if count:
                    # There are things to delete for this object
804
                    for (serial, ) in q('SELECT serial FROM obj_short WHERE '
805 806
                            'partition=%(partition)d AND oid=%(oid)d AND '
                            'serial < %(max_serial)d' % {
807
                                'oid': oid,
808
                                'partition': getPartition(oid),
809 810 811 812
                                'max_serial': max_serial,
                            }):
                        updatePackFuture(oid, serial, max_serial,
                            updateObjectDataForPack)
813 814
                        objQuery('DELETE FROM %%(table)s WHERE '
                            'partition=%(partition)d '
815
                            'AND oid=%(oid)d AND serial=%(serial)d' % {
816
                                'partition': getPartition(oid),
817 818 819 820 821 822 823
                                'oid': oid,
                                'serial': serial
                        })
        except:
            self.rollback()
            raise
        self.commit()
824 825 826 827 828 829
  
    def checkTIDRange(self, min_tid, length, num_partitions, partition):
        # XXX: XOR is a lame checksum
        count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
            'BIT_XOR(tid), MAX(tid) FROM ('
              'SELECT tid FROM trans '
830
              'WHERE partition = %(partition)s '
831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
              'AND tid >= %(min_tid)d '
              'ORDER BY tid ASC LIMIT %(length)d'
            ') AS foo' % {
                'partition': partition,
                'min_tid': util.u64(min_tid),
                'length': length,
        })[0]
        if count == 0:
            tid_checksum = 0
            max_tid = ZERO_TID
        else:
            max_tid = util.p64(max_tid)
        return count, tid_checksum, max_tid

    def checkSerialRange(self, min_oid, min_serial, length, num_partitions,
            partition):
        # XXX: XOR is a lame checksum
        u64 = util.u64
        p64 = util.p64
850
        r = self.query('SELECT oid, serial FROM obj_short WHERE '
851
            'partition = %(partition)s AND '
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
            '(oid > %(min_oid)d OR '
            '(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
            'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
                'min_oid': u64(min_oid),
                'min_serial': u64(min_serial),
                'length': length,
                'partition': partition,
        })
        count = len(r)
        oid_checksum = serial_checksum = 0
        if count == 0:
            max_oid = ZERO_OID
            max_serial = ZERO_TID
        else:
            for max_oid, max_serial in r:
                oid_checksum ^= max_oid
                serial_checksum ^= max_serial
            max_oid = p64(max_oid)
            max_serial = p64(max_serial)
        return count, oid_checksum, max_oid, serial_checksum, max_serial
872