Commit b5a59a99 authored by Vincent Pelletier's avatar Vincent Pelletier Committed by Julien Muchembled

master: Forbid truncature before database's first transaction

This is intended as a sanity check, so simple typos in neoctl truncate
command do not easily lead to the entire database being wiped.

See merge request !23
parent 773bfa97
......@@ -41,7 +41,7 @@ class MasterHandler(EventHandler):
def askLastIDs(self, conn):
tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID()))
conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID(), tm.getFirstTID()))
def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction(
......
......@@ -239,6 +239,9 @@ class AdministrationHandler(MasterHandler):
app = self.app
if app.getLastTransaction() <= tid:
raise AnswerDenied("Truncating after last transaction does nothing")
if app.tm.getFirstTID() > tid:
raise AnswerDenied("Truncating before first transaction is"
" probably not what you intended to do")
if app.pm.getApprovedRejected(add64(tid, 1))[0]:
# TODO: The protocol must be extended to support safe cases
# (e.g. no started pack whose id is after truncation tid).
......
......@@ -20,7 +20,7 @@ from struct import pack, unpack
from neo.lib import logging
from neo.lib.exception import ProtocolError
from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.protocol import uuid_str, ZERO_OID, ZERO_TID
from neo.lib.protocol import uuid_str, ZERO_OID, ZERO_TID, MAX_TID
from neo.lib.util import dump, u64, addTID, tidFromTime
class Transaction(object):
......@@ -179,6 +179,9 @@ class TransactionManager(EventQueue):
self._ttid_dict = {}
self._last_oid = ZERO_OID
self._last_tid = ZERO_TID
self._first_tid = MAX_TID
# avoid the overhead of min_tid on every _unlockPending
self._unlockPending = self._firstUnlockPending
# queue filled with ttids pointing to transactions with increasing tids
self._queue = deque()
......@@ -212,6 +215,13 @@ class TransactionManager(EventQueue):
self._last_oid = oid_list[-1]
return oid_list
def setFirstTID(self, tid):
if self._first_tid > tid:
self._first_tid = tid
def getFirstTID(self):
return self._first_tid
def setLastOID(self, oid):
if self._last_oid < oid:
self._last_oid = oid
......@@ -412,6 +422,16 @@ class TransactionManager(EventQueue):
if unlock:
self._unlockPending()
def _firstUnlockPending(self):
"""Set first TID when the first transaction is committed
Masks _unlockPending on reset.
Unmasks and call it when called.
"""
self.setFirstTID(self._ttid_dict[self._queue[0]].getTID())
del self._unlockPending
self._unlockPending()
def _unlockPending(self):
"""Serialize transaction unlocks
......
......@@ -139,11 +139,12 @@ class VerificationManager(BaseServiceHandler):
def notifyPackCompleted(self, conn, pack_id):
self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
def answerLastIDs(self, conn, ltid, loid):
def answerLastIDs(self, conn, ltid, loid, ftid):
self._uuid_set.remove(conn.getUUID())
tm = self.app.tm
tm.setLastTID(ltid)
tm.setLastOID(loid)
tm.setFirstTID(ftid)
def answerPackOrders(self, conn, pack_list):
self._uuid_set.remove(conn.getUUID())
......
......@@ -137,7 +137,7 @@ class NeoCTL(BaseApplication):
response = self.__ask(Packets.AskLastIDs())
if response[0] != Packets.AnswerLastIDs:
raise RuntimeError(response)
return response[1:]
return response[1:3]
def getLastTransaction(self):
response = self.__ask(Packets.AskLastTransaction())
......
......@@ -601,6 +601,9 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid
def getFirstTID(self):
return min(next(zodb.iterator()).tid for zodb in self.zodb)
def getLastIDs(self):
tid, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)),
......
......@@ -758,6 +758,19 @@ class DatabaseManager(object):
# XXX: Consider splitting getLastIDs/_getLastIDs because
# sometimes the last oid is not wanted.
def _getFirstTID(self, partition):
"""Return tid of first transaction in given 'partition'
tids are in unpacked format.
"""
@requires(_getFirstTID)
def getFirstTID(self):
"""Return tid of first transaction
"""
x = self._readable_set
return util.p64(min(map(self._getFirstTID, x))) if x else MAX_TID
def _getLastTID(self, partition, max_tid=None):
"""Return tid of last transaction <= 'max_tid' in given 'partition'
......
......@@ -53,7 +53,7 @@ from .manager import MVCCDatabaseManager, splitOIDField
from neo.lib import logging, util
from neo.lib.exception import NonReadableCell, UndoPackError
from neo.lib.interfaces import implements
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH, MAX_TID
class MysqlError(DatabaseFailure):
......@@ -457,6 +457,12 @@ class MySQLDatabaseManager(MVCCDatabaseManager):
def _getPartitionTable(self):
return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
(tid,), = self.query(
"SELECT MIN(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE `partition`=%s" % partition)
return util.u64(MAX_TID) if tid is None else tid
def _getLastTID(self, partition, max_tid=None):
sql = ("SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE `partition`=%s") % partition
......
......@@ -28,7 +28,7 @@ from .manager import DatabaseManager, splitOIDField
from neo.lib import logging, util
from neo.lib.exception import NonReadableCell, UndoPackError
from neo.lib.interfaces import implements
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH, MAX_TID
def unique_constraint_message(table, *columns):
c = sqlite3.connect(":memory:")
......@@ -343,6 +343,11 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getPartitionTable(self):
return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
tid = self.query("SELECT MIN(tid) FROM trans WHERE partition=?",
(partition,)).fetchone()[0]
return util.u64(MAX_TID) if tid is None else tid
def _getLastTID(self, partition, max_tid=None):
x = self.query
if max_tid is None:
......
......@@ -55,7 +55,8 @@ class InitializationHandler(BaseMasterHandler):
if packed:
self.app.completed_pack_id = pack_id = min(packed.itervalues())
conn.send(Packets.NotifyPackCompleted(pack_id))
conn.answer(Packets.AnswerLastIDs(*dm.getLastIDs()))
last_tid, last_oid = dm.getLastIDs() # PY3
conn.answer(Packets.AnswerLastIDs(last_tid, last_oid, dm.getFirstTID()))
def askPartitionTable(self, conn):
pt = self.app.pt
......
......@@ -13,7 +13,7 @@ AnswerFetchObjects(?p64,?p64,{:})
AnswerFetchTransactions(?p64,[],?p64)
AnswerFinalTID(p64)
AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64)
AnswerLastIDs(?p64,?p64,p64)
AnswerLastTransaction(p64)
AnswerLockedTransactions({p64:?p64})
AnswerMonitorInformation([?bin],[?bin],bin)
......
......@@ -183,6 +183,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# nothing in database
self.assertEqual(self.db.getFirstTID(), MAX_TID)
self.assertEqual(self.db.getLastIDs(), (None, None))
self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None)
......@@ -199,6 +200,7 @@ class StorageDBTests(NeoUnitTestBase):
([oid2], 'user', 'desc', 'ext', False, p64(2), None))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
self.assertEqual(self.db.getFirstTID(), tid1)
self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, True),
......
......@@ -200,7 +200,7 @@ class StressApplication(AdminApplication):
if conn:
conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, ltid, loid):
def answerLastIDs(self, ltid, loid, ftid):
self.loid = loid
self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment