Commit b5a59a99 authored by Vincent Pelletier's avatar Vincent Pelletier Committed by Julien Muchembled

master: Forbid truncature before database's first transaction

This is intended as a sanity check, so simple typos in neoctl truncate
command do not easily lead to the entire database being wiped.

See merge request !23
parent 773bfa97
...@@ -41,7 +41,7 @@ class MasterHandler(EventHandler): ...@@ -41,7 +41,7 @@ class MasterHandler(EventHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
tm = self.app.tm tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID())) conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID(), tm.getFirstTID()))
def askLastTransaction(self, conn): def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction( conn.answer(Packets.AnswerLastTransaction(
......
...@@ -239,6 +239,9 @@ class AdministrationHandler(MasterHandler): ...@@ -239,6 +239,9 @@ class AdministrationHandler(MasterHandler):
app = self.app app = self.app
if app.getLastTransaction() <= tid: if app.getLastTransaction() <= tid:
raise AnswerDenied("Truncating after last transaction does nothing") raise AnswerDenied("Truncating after last transaction does nothing")
if app.tm.getFirstTID() > tid:
raise AnswerDenied("Truncating before first transaction is"
" probably not what you intended to do")
if app.pm.getApprovedRejected(add64(tid, 1))[0]: if app.pm.getApprovedRejected(add64(tid, 1))[0]:
# TODO: The protocol must be extended to support safe cases # TODO: The protocol must be extended to support safe cases
# (e.g. no started pack whose id is after truncation tid). # (e.g. no started pack whose id is after truncation tid).
......
...@@ -20,7 +20,7 @@ from struct import pack, unpack ...@@ -20,7 +20,7 @@ from struct import pack, unpack
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import ProtocolError from neo.lib.exception import ProtocolError
from neo.lib.handler import DelayEvent, EventQueue from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.protocol import uuid_str, ZERO_OID, ZERO_TID from neo.lib.protocol import uuid_str, ZERO_OID, ZERO_TID, MAX_TID
from neo.lib.util import dump, u64, addTID, tidFromTime from neo.lib.util import dump, u64, addTID, tidFromTime
class Transaction(object): class Transaction(object):
...@@ -179,6 +179,9 @@ class TransactionManager(EventQueue): ...@@ -179,6 +179,9 @@ class TransactionManager(EventQueue):
self._ttid_dict = {} self._ttid_dict = {}
self._last_oid = ZERO_OID self._last_oid = ZERO_OID
self._last_tid = ZERO_TID self._last_tid = ZERO_TID
self._first_tid = MAX_TID
# avoid the overhead of min_tid on every _unlockPending
self._unlockPending = self._firstUnlockPending
# queue filled with ttids pointing to transactions with increasing tids # queue filled with ttids pointing to transactions with increasing tids
self._queue = deque() self._queue = deque()
...@@ -212,6 +215,13 @@ class TransactionManager(EventQueue): ...@@ -212,6 +215,13 @@ class TransactionManager(EventQueue):
self._last_oid = oid_list[-1] self._last_oid = oid_list[-1]
return oid_list return oid_list
def setFirstTID(self, tid):
if self._first_tid > tid:
self._first_tid = tid
def getFirstTID(self):
return self._first_tid
def setLastOID(self, oid): def setLastOID(self, oid):
if self._last_oid < oid: if self._last_oid < oid:
self._last_oid = oid self._last_oid = oid
...@@ -412,6 +422,16 @@ class TransactionManager(EventQueue): ...@@ -412,6 +422,16 @@ class TransactionManager(EventQueue):
if unlock: if unlock:
self._unlockPending() self._unlockPending()
def _firstUnlockPending(self):
"""Set first TID when the first transaction is committed
Masks _unlockPending on reset.
Unmasks and call it when called.
"""
self.setFirstTID(self._ttid_dict[self._queue[0]].getTID())
del self._unlockPending
self._unlockPending()
def _unlockPending(self): def _unlockPending(self):
"""Serialize transaction unlocks """Serialize transaction unlocks
......
...@@ -139,11 +139,12 @@ class VerificationManager(BaseServiceHandler): ...@@ -139,11 +139,12 @@ class VerificationManager(BaseServiceHandler):
def notifyPackCompleted(self, conn, pack_id): def notifyPackCompleted(self, conn, pack_id):
self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
def answerLastIDs(self, conn, ltid, loid): def answerLastIDs(self, conn, ltid, loid, ftid):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
tm = self.app.tm tm = self.app.tm
tm.setLastTID(ltid) tm.setLastTID(ltid)
tm.setLastOID(loid) tm.setLastOID(loid)
tm.setFirstTID(ftid)
def answerPackOrders(self, conn, pack_list): def answerPackOrders(self, conn, pack_list):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
......
...@@ -137,7 +137,7 @@ class NeoCTL(BaseApplication): ...@@ -137,7 +137,7 @@ class NeoCTL(BaseApplication):
response = self.__ask(Packets.AskLastIDs()) response = self.__ask(Packets.AskLastIDs())
if response[0] != Packets.AnswerLastIDs: if response[0] != Packets.AnswerLastIDs:
raise RuntimeError(response) raise RuntimeError(response)
return response[1:] return response[1:3]
def getLastTransaction(self): def getLastTransaction(self):
response = self.__ask(Packets.AskLastTransaction()) response = self.__ask(Packets.AskLastTransaction())
......
...@@ -601,6 +601,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -601,6 +601,9 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1] zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid return zodb, oid - zodb.shift_oid
def getFirstTID(self):
return min(next(zodb.iterator()).tid for zodb in self.zodb)
def getLastIDs(self): def getLastIDs(self):
tid, oid = self.db.getLastIDs() tid, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), return (max(tid, util.p64(self.zodb_ltid)),
......
...@@ -758,6 +758,19 @@ class DatabaseManager(object): ...@@ -758,6 +758,19 @@ class DatabaseManager(object):
# XXX: Consider splitting getLastIDs/_getLastIDs because # XXX: Consider splitting getLastIDs/_getLastIDs because
# sometimes the last oid is not wanted. # sometimes the last oid is not wanted.
def _getFirstTID(self, partition):
"""Return tid of first transaction in given 'partition'
tids are in unpacked format.
"""
@requires(_getFirstTID)
def getFirstTID(self):
"""Return tid of first transaction
"""
x = self._readable_set
return util.p64(min(map(self._getFirstTID, x))) if x else MAX_TID
def _getLastTID(self, partition, max_tid=None): def _getLastTID(self, partition, max_tid=None):
"""Return tid of last transaction <= 'max_tid' in given 'partition' """Return tid of last transaction <= 'max_tid' in given 'partition'
......
...@@ -53,7 +53,7 @@ from .manager import MVCCDatabaseManager, splitOIDField ...@@ -53,7 +53,7 @@ from .manager import MVCCDatabaseManager, splitOIDField
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import NonReadableCell, UndoPackError from neo.lib.exception import NonReadableCell, UndoPackError
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH, MAX_TID
class MysqlError(DatabaseFailure): class MysqlError(DatabaseFailure):
...@@ -457,6 +457,12 @@ class MySQLDatabaseManager(MVCCDatabaseManager): ...@@ -457,6 +457,12 @@ class MySQLDatabaseManager(MVCCDatabaseManager):
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
(tid,), = self.query(
"SELECT MIN(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE `partition`=%s" % partition)
return util.u64(MAX_TID) if tid is None else tid
def _getLastTID(self, partition, max_tid=None): def _getLastTID(self, partition, max_tid=None):
sql = ("SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)" sql = ("SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE `partition`=%s") % partition " WHERE `partition`=%s") % partition
......
...@@ -28,7 +28,7 @@ from .manager import DatabaseManager, splitOIDField ...@@ -28,7 +28,7 @@ from .manager import DatabaseManager, splitOIDField
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import NonReadableCell, UndoPackError from neo.lib.exception import NonReadableCell, UndoPackError
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH, MAX_TID
def unique_constraint_message(table, *columns): def unique_constraint_message(table, *columns):
c = sqlite3.connect(":memory:") c = sqlite3.connect(":memory:")
...@@ -343,6 +343,11 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -343,6 +343,11 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
tid = self.query("SELECT MIN(tid) FROM trans WHERE partition=?",
(partition,)).fetchone()[0]
return util.u64(MAX_TID) if tid is None else tid
def _getLastTID(self, partition, max_tid=None): def _getLastTID(self, partition, max_tid=None):
x = self.query x = self.query
if max_tid is None: if max_tid is None:
......
...@@ -55,7 +55,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -55,7 +55,8 @@ class InitializationHandler(BaseMasterHandler):
if packed: if packed:
self.app.completed_pack_id = pack_id = min(packed.itervalues()) self.app.completed_pack_id = pack_id = min(packed.itervalues())
conn.send(Packets.NotifyPackCompleted(pack_id)) conn.send(Packets.NotifyPackCompleted(pack_id))
conn.answer(Packets.AnswerLastIDs(*dm.getLastIDs())) last_tid, last_oid = dm.getLastIDs() # PY3
conn.answer(Packets.AnswerLastIDs(last_tid, last_oid, dm.getFirstTID()))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
......
...@@ -13,7 +13,7 @@ AnswerFetchObjects(?p64,?p64,{:}) ...@@ -13,7 +13,7 @@ AnswerFetchObjects(?p64,?p64,{:})
AnswerFetchTransactions(?p64,[],?p64) AnswerFetchTransactions(?p64,[],?p64)
AnswerFinalTID(p64) AnswerFinalTID(p64)
AnswerInformationLocked(p64) AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64) AnswerLastIDs(?p64,?p64,p64)
AnswerLastTransaction(p64) AnswerLastTransaction(p64)
AnswerLockedTransactions({p64:?p64}) AnswerLockedTransactions({p64:?p64})
AnswerMonitorInformation([?bin],[?bin],bin) AnswerMonitorInformation([?bin],[?bin],bin)
......
...@@ -183,6 +183,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -183,6 +183,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getFirstTID(), MAX_TID)
self.assertEqual(self.db.getLastIDs(), (None, None)) self.assertEqual(self.db.getLastIDs(), (None, None))
self.assertEqual(self.db.getUnfinishedTIDDict(), {}) self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
...@@ -199,6 +200,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -199,6 +200,7 @@ class StorageDBTests(NeoUnitTestBase):
([oid2], 'user', 'desc', 'ext', False, p64(2), None)) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
self.assertEqual(self.db.getFirstTID(), tid1)
self.assertEqual(self.db.getTransaction(tid1, True), self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1), None)) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, True), self.assertEqual(self.db.getTransaction(tid2, True),
......
...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication): ...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication):
if conn: if conn:
conn.ask(Packets.AskLastIDs()) conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, ltid, loid): def answerLastIDs(self, ltid, loid, ftid):
self.loid = loid self.loid = loid
self.ltid = ltid self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs) self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment