Commit ca2caf87 authored by Julien Muchembled's avatar Julien Muchembled

Bump protocol version and upgrade storages automatically

parents cff279af d3c8b76d
......@@ -58,8 +58,6 @@
committed by future transactions.
- Add a 'devid' storage configuration so that master do not distribute
replicated partitions on storages with same 'devid'.
- Make tpc_finish safer as described in its __doc__: moving work to
tpc_vote and recover from master failure when possible.
Storage
- Use libmysqld instead of a stand-alone MySQL server.
......@@ -143,9 +141,7 @@
Admin
- Make admin node able to monitor multiple clusters simultaneously
- Send notifications (ie: mail) when a storage or master node is lost
- Add ctl command to truncate DB at arbitrary TID. 'Truncate' message
can be reused. There should also be a way to list last transactions,
like fstail for FileStorage.
- Add ctl command to list last transactions, like fstail for FileStorage.
Tests
- Use another mock library: Python 3.3+ has unittest.mock, which is
......
......@@ -65,10 +65,12 @@ class AdminEventHandler(EventHandler):
askLastIDs = forward_ask(Packets.AskLastIDs)
askLastTransaction = forward_ask(Packets.AskLastTransaction)
addPendingNodes = forward_ask(Packets.AddPendingNodes)
askRecovery = forward_ask(Packets.AskRecovery)
tweakPartitionTable = forward_ask(Packets.TweakPartitionTable)
setClusterState = forward_ask(Packets.SetClusterState)
setNodeState = forward_ask(Packets.SetNodeState)
checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate)
class MasterEventHandler(EventHandler):
......
......@@ -612,18 +612,29 @@ class Application(ThreadedApplication):
packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension),
txn_context['cache_dict'])
add_involved_nodes = txn_context['involved_nodes'].add
queue = txn_context['queue']
trans_nodes = []
for node, conn in self.cp.iterateForObject(ttid):
logging.debug("voting transaction %s on %s", dump(ttid),
dump(conn.getUUID()))
try:
self._askStorage(conn, packet)
conn.ask(packet, queue=queue)
except ConnectionClosed:
continue
add_involved_nodes(node)
trans_nodes.append(node)
# check at least one storage node accepted
if txn_context['involved_nodes']:
if trans_nodes:
involved_nodes = txn_context['involved_nodes']
packet = Packets.AskVoteTransaction(ttid)
for node in involved_nodes.difference(trans_nodes):
conn = self.cp.getConnForNode(node)
if conn is not None:
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
pass
involved_nodes.update(trans_nodes)
self.waitResponses(queue)
txn_context['voted'] = None
# We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish.
......@@ -667,27 +678,14 @@ class Application(ThreadedApplication):
fail in tpc_finish. In particular, making a transaction permanent
should ideally be as simple as switching a bit permanently.
In NEO, tpc_finish breaks this promise by not ensuring earlier that all
data and metadata are written, and it is for example vulnerable to
ENOSPC errors. In other words, some work should be moved to tpc_vote.
TODO: - In tpc_vote, all involved storage nodes must be asked to write
all metadata to ttrans/tobj and _commit_. AskStoreTransaction
can be extended for this: for nodes that don't store anything
in ttrans, it can just contain the ttid. The final tid is not
known yet, so ttrans/tobj would contain the ttid.
- In tpc_finish, AskLockInformation is still required for read
locking, ttrans.tid must be updated with the final value and
ttrans _committed_.
- The Verification phase would need some change because
ttrans/tobj may contain data for which tpc_finish was not
called. The ttid is also in trans so a mapping ttid<->tid is
always possible and can be forwarded via the master so that all
storage are still able to update the tid column with the final
value when moving rows from tobj to obj.
The resulting cost is:
- additional RPCs in tpc_vote
- 1 updated row in ttrans + commit
In NEO, all the data (with the exception of the tid, simply because
it is not known yet) is already flushed on disk at the end on the vote.
During tpc_finish, all nodes storing the transaction metadata are asked
to commit by saving the new tid and flushing again: for SQL backends,
it's just an UPDATE of 1 cell. At last, the metadata is moved to
a final place so that the new transaction is readable, but this is
something that can always be replayed (during the verification phase)
if any failure happens.
TODO: We should recover from master failures when the transaction got
successfully committed. More precisely, we should not raise:
......
......@@ -102,11 +102,17 @@ class PrimaryNotificationsHandler(MTEventHandler):
if app.master_conn is None:
app._cache_lock_acquire()
try:
oid_list = app._cache.clear_current()
db = app.getDB()
if db is not None:
db.invalidate(app.last_tid and
add64(app.last_tid, 1), oid_list)
if app.last_tid < ltid:
oid_list = app._cache.clear_current()
db is None or db.invalidate(
app.last_tid and add64(app.last_tid, 1),
oid_list)
else:
# The DB was truncated. It happens so
# rarely that we don't need to optimize.
app._cache.clear()
db is None or db.invalidateCache()
finally:
app._cache_lock_release()
app.last_tid = ltid
......
......@@ -112,9 +112,11 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, _):
def answerStoreTransaction(self, conn):
pass
answerVoteTransaction = answerStoreTransaction
def answerTIDsFrom(self, conn, tid_list):
logging.debug('Get %u TIDs from %r', len(tid_list), conn)
self.app.setHandlerData(tid_list)
......
......@@ -41,9 +41,6 @@ class BootstrapManager(EventHandler):
self.num_partitions = None
self.current = None
def notifyNodeInformation(self, conn, node_list):
pass
def announcePrimary(self, conn):
# We found the primary master early enough to be notified of election
# end. Lucky. Anyway, we must carry on with identification request, so
......
......@@ -23,7 +23,7 @@ class ElectionFailure(NeoException):
class PrimaryFailure(NeoException):
pass
class OperationFailure(NeoException):
class StoppedOperation(NeoException):
pass
class DatabaseFailure(NeoException):
......
......@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO
from struct import Struct
PROTOCOL_VERSION = 4
PROTOCOL_VERSION = 5
# Size restrictions.
MIN_PACKET_SIZE = 10
......@@ -722,16 +722,24 @@ class ReelectPrimary(Packet):
Force a re-election of a primary master node. M -> M.
"""
class Recovery(Packet):
"""
Ask all data needed by master to recover. PM -> S, S -> PM.
"""
_answer = PStruct('answer_recovery',
PPTID('ptid'),
PTID('backup_tid'),
PTID('truncate_tid'),
)
class LastIDs(Packet):
"""
Ask the last OID, the last TID and the last Partition Table ID so that
a master recover. PM -> S, S -> PM.
Ask the last OID/TID so that a master can initialize its TransactionManager.
PM -> S, S -> PM.
"""
_answer = PStruct('answer_last_ids',
POID('last_oid'),
PTID('last_tid'),
PPTID('last_ptid'),
PTID('backup_tid'),
)
class PartitionTable(Packet):
......@@ -775,6 +783,8 @@ class StartOperation(Packet):
this message, it must not serve client nodes. PM -> S.
"""
_fmt = PStruct('start_operation',
# XXX: Is this boolean needed ? Maybe this
# can be deduced from cluster state.
PBoolean('backup'),
)
......@@ -786,8 +796,8 @@ class StopOperation(Packet):
class UnfinishedTransactions(Packet):
"""
Ask unfinished transactions PM -> S.
Answer unfinished transactions S -> PM.
Ask unfinished transactions S -> PM.
Answer unfinished transactions PM -> S.
"""
_answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'),
......@@ -796,36 +806,36 @@ class UnfinishedTransactions(Packet):
),
)
class ObjectPresent(Packet):
class LockedTransactions(Packet):
"""
Ask if an object is present. If not present, OID_NOT_FOUND should be
returned. PM -> S.
Answer that an object is present. PM -> S.
Ask locked transactions PM -> S.
Answer locked transactions S -> PM.
"""
_fmt = PStruct('object_present',
POID('oid'),
PTID('tid'),
)
_answer = PStruct('object_present',
POID('oid'),
PTID('tid'),
_answer = PStruct('answer_locked_transactions',
PDict('tid_dict',
PTID('ttid'),
PTID('tid'),
),
)
class DeleteTransaction(Packet):
class FinalTID(Packet):
"""
Delete a transaction. PM -> S.
Return final tid if ttid has been committed. * -> S.
"""
_fmt = PStruct('delete_transaction',
_fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid',
PTID('tid'),
PFOidList,
)
class CommitTransaction(Packet):
class ValidateTransaction(Packet):
"""
Commit a transaction. PM -> S.
"""
_fmt = PStruct('commit_transaction',
_fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'),
)
......@@ -878,11 +888,10 @@ class LockInformation(Packet):
_fmt = PStruct('ask_lock_informations',
PTID('ttid'),
PTID('tid'),
PFOidList,
)
_answer = PStruct('answer_information_locked',
PTID('tid'),
PTID('ttid'),
)
class InvalidateObjects(Packet):
......@@ -899,7 +908,7 @@ class UnlockInformation(Packet):
Unlock information on a transaction. PM -> S.
"""
_fmt = PStruct('notify_unlock_information',
PTID('tid'),
PTID('ttid'),
)
class GenerateOIDs(Packet):
......@@ -961,10 +970,17 @@ class StoreTransaction(Packet):
PString('extension'),
PFOidList,
)
_answer = PFEmpty
_answer = PStruct('answer_store_transaction',
class VoteTransaction(Packet):
"""
Ask to store a transaction. C -> S.
Answer if transaction has been stored. S -> C.
"""
_fmt = PStruct('ask_vote_transaction',
PTID('tid'),
)
_answer = PFEmpty
class GetObject(Packet):
"""
......@@ -1462,13 +1478,14 @@ class ReplicationDone(Packet):
class Truncate(Packet):
"""
XXX: Used for both make storage consistent and leave backup mode
M -> S
Request DB to be truncated. Also used to leave backup mode.
"""
_fmt = PStruct('truncate',
PTID('tid'),
)
_answer = Error
StaticRegistry = {}
def register(request, ignore_when_closed=None):
......@@ -1586,6 +1603,8 @@ class Packets(dict):
ReelectPrimary)
NotifyNodeInformation = register(
NotifyNodeInformation)
AskRecovery, AnswerRecovery = register(
Recovery)
AskLastIDs, AnswerLastIDs = register(
LastIDs)
AskPartitionTable, AnswerPartitionTable = register(
......@@ -1600,12 +1619,12 @@ class Packets(dict):
StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions)
AskObjectPresent, AnswerObjectPresent = register(
ObjectPresent)
DeleteTransaction = register(
DeleteTransaction)
CommitTransaction = register(
CommitTransaction)
AskLockedTransactions, AnswerLockedTransactions = register(
LockedTransactions)
AskFinalTID, AnswerFinalTID = register(
FinalTID)
ValidateTransaction = register(
ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction)
AskFinishTransaction, AnswerTransactionFinished = register(
......@@ -1624,6 +1643,8 @@ class Packets(dict):
AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register(
GetObject)
AskTIDs, AnswerTIDs = register(
......
......@@ -24,7 +24,7 @@ from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.handler import EventHandler
from neo.lib.connection import ListeningConnection, ClientConnection
from neo.lib.exception import ElectionFailure, PrimaryFailure, OperationFailure
from neo.lib.exception import ElectionFailure, PrimaryFailure, StoppedOperation
class StateChangedException(Exception): pass
......@@ -45,6 +45,7 @@ class Application(BaseApplication):
backup_tid = None
backup_app = None
uuid = None
truncate_tid = None
def __init__(self, config):
super(Application, self).__init__(
......@@ -77,7 +78,6 @@ class Application(BaseApplication):
self.primary = None
self.primary_master_node = None
self.cluster_state = None
self._startup_allowed = False
uuid = config.getUUID()
if uuid:
......@@ -221,7 +221,7 @@ class Application(BaseApplication):
self.primary = self.primary is None
break
def broadcastNodesInformation(self, node_list):
def broadcastNodesInformation(self, node_list, exclude=None):
"""
Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth
......@@ -243,7 +243,7 @@ class Application(BaseApplication):
# send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList():
node_list = node_dict.get(node.getType())
if node_list and node.isRunning():
if node_list and node.isRunning() and node is not exclude:
node.notify(Packets.NotifyNodeInformation(node_list))
def broadcastPartitionChanges(self, cell_list):
......@@ -254,7 +254,6 @@ class Application(BaseApplication):
ptid = self.pt.setNextID()
packet = Packets.NotifyPartitionChanges(ptid, cell_list)
for node in self.nm.getIdentifiedList():
# TODO: notify masters
if node.isRunning() and not node.isMaster():
node.notify(packet)
......@@ -266,8 +265,6 @@ class Application(BaseApplication):
"""
logging.info('provide service')
poll = self.em.poll
self.tm.reset()
self.changeClusterState(ClusterStates.RUNNING)
# Now everything is passive.
......@@ -278,8 +275,13 @@ class Application(BaseApplication):
if e.args[0] != ClusterStates.STARTING_BACKUP:
raise
self.backup_tid = tid = self.getLastTransaction()
self.pt.setBackupTidDict({node.getUUID(): tid
for node in self.nm.getStorageList(only_identified=True)})
packet = Packets.StartOperation(True)
tid_dict = {}
for node in self.nm.getStorageList(only_identified=True):
tid_dict[node.getUUID()] = tid
if node.isRunning():
node.notify(packet)
self.pt.setBackupTidDict(tid_dict)
def playPrimaryRole(self):
logging.info('play the primary role with %r', self.listening_conn)
......@@ -323,30 +325,46 @@ class Application(BaseApplication):
in_conflict)
in_conflict.setUUID(None)
# recover the cluster status at startup
# Do not restart automatically if ElectionFailure is raised, in order
# to avoid a split of the database. For example, with 2 machines with
# a master and a storage on each one and replicas=1, the secondary
# master becomes primary in case of network failure between the 2
# machines but must not start automatically: otherwise, each storage
# node would diverge.
self._startup_allowed = False
try:
self.runManager(RecoveryManager)
while True:
self.runManager(VerificationManager)
self.runManager(RecoveryManager)
try:
if self.backup_tid:
if self.backup_app is None:
raise RuntimeError("No upstream cluster to backup"
" defined in configuration")
self.backup_app.provideService()
# Reset connection with storages (and go through a
# recovery phase) when leaving backup mode in order
# to get correct last oid/tid.
self.runManager(RecoveryManager)
continue
self.provideService()
except OperationFailure:
self.runManager(VerificationManager)
if not self.backup_tid:
self.provideService()
# self.provideService only returns without raising
# when switching to backup mode.
if self.backup_app is None:
raise RuntimeError("No upstream cluster to backup"
" defined in configuration")
truncate = Packets.Truncate(
self.backup_app.provideService())
except StoppedOperation, e:
logging.critical('No longer operational')
truncate = Packets.Truncate(*e.args) if e.args else None
# Automatic restart except if we truncate or retry to.
self._startup_allowed = not (self.truncate_tid or truncate)
node_list = []
for node in self.nm.getIdentifiedList():
if node.isStorage() or node.isClient():
node.notify(Packets.StopOperation())
conn = node.getConnection()
conn.notify(Packets.StopOperation())
if node.isClient():
node.getConnection().abort()
conn.abort()
continue
if truncate:
conn.notify(truncate)
if node.isRunning():
node.setPending()
node_list.append(node)
self.broadcastNodesInformation(node_list)
except StateChangedException, e:
assert e.args[0] == ClusterStates.STOPPING
self.shutdown()
......@@ -427,7 +445,7 @@ class Application(BaseApplication):
continue # keep handler
if type(handler) is not type(conn.getLastHandler()):
conn.setHandler(handler)
handler.connectionCompleted(conn)
handler.connectionCompleted(conn, new=False)
self.cluster_state = state
def getNewUUID(self, uuid, address, node_type):
......@@ -461,7 +479,7 @@ class Application(BaseApplication):
# wait for all transaction to be finished
while self.tm.hasPending():
self.em.poll(1)
except OperationFailure:
except StoppedOperation:
logging.critical('No longer operational')
logging.info("asking remaining nodes to shutdown")
......
......@@ -152,24 +152,20 @@ class BackupApplication(object):
assert tid != ZERO_TID
logging.warning("Truncating at %s (last_tid was %s)",
dump(app.backup_tid), dump(last_tid))
# XXX: We want to go through a recovery phase in order to
# initialize the transaction manager, but this is only
# possible if storages already know that we left backup
# mode. To that purpose, we always send a Truncate packet,
# even if there's nothing to truncate.
p = Packets.Truncate(tid)
for node in app.nm.getStorageList(only_identified=True):
conn = node.getConnection()
conn.setHandler(handler)
node.setState(NodeStates.TEMPORARILY_DOWN)
# Packets will be sent at the beginning of the recovery
# phase.
conn.notify(p)
conn.abort()
else:
# We will do a dummy truncation, just to leave backup mode,
# so it's fine to start automatically if there's any
# missing storage.
# XXX: Consider using another method to leave backup mode,
# at least when there's nothing to truncate. Because
# in case of StoppedOperation during VERIFYING state,
# this flag will be wrongly set to False.
app._startup_allowed = True
# If any error happened before reaching this line, we'd go back
# to backup mode, which is the right mode to recover.
del app.backup_tid
break
# Now back to RECOVERY...
return tid
finally:
del self.primary_partition_dict, self.tid_list
pt.clearReplicating()
......
......@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging
from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler
from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets,
BrokenNodeDisallowedError,
......@@ -23,6 +24,10 @@ from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets,
class MasterHandler(EventHandler):
"""This class implements a generic part of the event handlers."""
def connectionCompleted(self, conn, new=None):
if new is None:
super(MasterHandler, self).connectionCompleted(conn)
def requestIdentification(self, conn, node_type, uuid, address, name):
self.checkClusterName(name)
app = self.app
......@@ -61,25 +66,31 @@ class MasterHandler(EventHandler):
state = self.app.getClusterState()
conn.answer(Packets.AnswerClusterState(state))
def askLastIDs(self, conn):
def askRecovery(self, conn):
app = self.app
conn.answer(Packets.AnswerLastIDs(
app.tm.getLastOID(),
app.tm.getLastTID(),
conn.answer(Packets.AnswerRecovery(
app.pt.getID(),
app.backup_tid))
app.backup_tid and app.pt.getBackupTid(),
app.truncate_tid))
def askLastIDs(self, conn):
tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastOID(), tm.getLastTID()))
def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction(
self.app.getLastTransaction()))
def askNodeInformation(self, conn):
def _notifyNodeInformation(self, conn):
nm = self.app.nm
node_list = []
node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getClientList())
node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list))
def askNodeInformation(self, conn):
self._notifyNodeInformation(conn)
conn.answer(Packets.AnswerNodeInformation())
def askPartitionTable(self, conn):
......@@ -94,15 +105,18 @@ DISCONNECTED_STATE_DICT = {
class BaseServiceHandler(MasterHandler):
"""This class deals with events for a service phase."""
def nodeLost(self, conn, node):
# This method provides a hook point overridable by service classes.
# It is triggered when a connection to a node gets lost.
pass
def connectionCompleted(self, conn, new):
self._notifyNodeInformation(conn)
pt = self.app.pt
conn.notify(Packets.SendPartitionTable(pt.getID(), pt.getRowList()))
def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID())
app = self.app
node = app.nm.getByUUID(conn.getUUID())
if node is None:
return # for example, when a storage is removed by an admin
assert node.isStorage(), node
logging.info('storage node lost')
if new_state != NodeStates.BROKEN:
new_state = DISCONNECTED_STATE_DICT.get(node.getType(),
NodeStates.DOWN)
......@@ -117,10 +131,13 @@ class BaseServiceHandler(MasterHandler):
# was in pending state, so drop it from the node manager to forget
# it and do not set in running state when it comes back
logging.info('drop a pending node from the node manager')
self.app.nm.remove(node)
self.app.broadcastNodesInformation([node])
# clean node related data in specialized handlers
self.nodeLost(conn, node)
app.nm.remove(node)
app.broadcastNodesInformation([node])
if app.truncate_tid:
raise StoppedOperation
app.broadcastPartitionChanges(app.pt.outdate(node))
if not app.pt.operational():
raise StoppedOperation
def notifyReady(self, conn):
self.app.setStorageReady(conn.getUUID())
......
......@@ -19,6 +19,7 @@ import random
from . import MasterHandler
from ..app import StateChangedException
from neo.lib import logging
from neo.lib.exception import StoppedOperation
from neo.lib.pt import PartitionTableException
from neo.lib.protocol import ClusterStates, Errors, \
NodeStates, NodeTypes, Packets, ProtocolError, uuid_str
......@@ -159,6 +160,13 @@ class AdministrationHandler(MasterHandler):
map(app.nm.getByUUID, uuid_list)))
conn.answer(Errors.Ack(''))
def truncate(self, conn, tid):
app = self.app
if app.cluster_state != ClusterStates.RUNNING:
raise ProtocolError('Can not truncate in this state')
conn.answer(Errors.Ack(''))
raise StoppedOperation(tid)
def checkReplicas(self, conn, partition_dict, min_tid, max_tid):
app = self.app
pt = app.pt
......
......@@ -20,9 +20,6 @@ from . import MasterHandler
class ClientServiceHandler(MasterHandler):
""" Handler dedicated to client during service state """
def connectionCompleted(self, conn):
pass
def connectionLost(self, conn, new_state):
# cancel its transactions and forgot the node
app = self.app
......@@ -59,9 +56,10 @@ class ClientServiceHandler(MasterHandler):
pt = app.pt
# Collect partitions related to this transaction.
lock_oid_list = oid_list + checked_list
partition_set = set(map(pt.getPartition, lock_oid_list))
partition_set.add(pt.getPartition(ttid))
getPartition = pt.getPartition
partition_set = set(map(getPartition, oid_list))
partition_set.update(map(getPartition, checked_list))
partition_set.add(getPartition(ttid))
# Collect the UUIDs of nodes related to this transaction.
uuid_list = filter(app.isStorageReady, {cell.getUUID()
......@@ -85,7 +83,6 @@ class ClientServiceHandler(MasterHandler):
{x.getUUID() for x in identified_node_list},
conn.getPeerId(),
),
lock_oid_list,
)
for node in identified_node_list:
node.ask(p, timeout=60)
......
......@@ -26,7 +26,7 @@ class IdentificationHandler(MasterHandler):
**kw)
handler = conn.getHandler()
assert not isinstance(handler, IdentificationHandler), handler
handler.connectionCompleted(conn)
handler.connectionCompleted(conn, True)
def _setupNode(self, conn, node_type, uuid, address, node):
app = self.app
......@@ -72,7 +72,7 @@ class IdentificationHandler(MasterHandler):
node.setState(state)
node.setConnection(conn)
conn.setHandler(handler)
app.broadcastNodesInformation([node])
app.broadcastNodesInformation([node], node)
return uuid
class SecondaryIdentificationHandler(MasterHandler):
......
......@@ -16,7 +16,7 @@
from neo.lib import logging
from neo.lib.protocol import CellStates, ClusterStates, Packets, ProtocolError
from neo.lib.exception import OperationFailure
from neo.lib.exception import StoppedOperation
from neo.lib.pt import PartitionTableException
from . import BaseServiceHandler
......@@ -24,25 +24,27 @@ from . import BaseServiceHandler
class StorageServiceHandler(BaseServiceHandler):
""" Handler dedicated to storages during service state """
def connectionCompleted(self, conn):
# TODO: unit test
def connectionCompleted(self, conn, new):
app = self.app
uuid = conn.getUUID()
node = app.nm.getByUUID(uuid)
app.setStorageNotReady(uuid)
if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new)
# XXX: what other values could happen ?
if node.isRunning():
conn.notify(Packets.StartOperation(bool(app.backup_tid)))
def nodeLost(self, conn, node):
logging.info('storage node lost')
assert not node.isRunning(), node.getState()
def connectionLost(self, conn, new_state):
app = self.app
app.broadcastPartitionChanges(app.pt.outdate(node))
if not app.pt.operational():
raise OperationFailure, 'cannot continue operation'
node = app.nm.getByUUID(conn.getUUID())
super(StorageServiceHandler, self).connectionLost(conn, new_state)
app.tm.forget(conn.getUUID())
if app.getClusterState() == ClusterStates.BACKINGUP:
if (app.getClusterState() == ClusterStates.BACKINGUP
# Also check if we're exiting, because backup_app is not usable
# in this case. Maybe cluster state should be set to something
# else, like STOPPING, during cleanup (__del__/close).
and app.listening_conn):
app.backup_app.nodeLost(node)
if app.packing is not None:
self.answerPack(conn, False)
......@@ -74,7 +76,7 @@ class StorageServiceHandler(BaseServiceHandler):
CellStates.CORRUPTED))
self.app.broadcastPartitionChanges(change_list)
if not self.app.pt.operational():
raise OperationFailure('cannot continue operation')
raise StoppedOperation
def notifyReplicationDone(self, conn, offset, tid):
app = self.app
......
......@@ -299,15 +299,19 @@ class PartitionTable(neo.lib.pt.PartitionTable):
yield offset, cell
break
def getReadableCellNodeSet(self):
def getOperationalNodeSet(self):
"""
Return a set of all nodes which are part of at least one UP TO DATE
partition.
partition. An empty list is returned if these nodes aren't enough to
become operational.
"""
return {cell.getNode()
for row in self.partition_list
for cell in row
if cell.isReadable()}
node_set = set()
for row in self.partition_list:
if not any(cell.isReadable() and cell.getNode().isPending()
for cell in row):
return () # not operational
node_set.update(cell.getNode() for cell in row if cell.isReadable())
return node_set
def clearReplicating(self):
for row in self.partition_list:
......
......@@ -15,9 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging
from neo.lib.util import dump
from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates
from neo.lib.protocol import ZERO_OID
from .handlers import MasterHandler
......@@ -29,7 +27,9 @@ class RecoveryManager(MasterHandler):
def __init__(self, app):
# The target node's uuid to request next.
self.target_ptid = None
self.ask_pt = []
self.backup_tid_dict = {}
self.truncate_dict = {}
def getHandler(self):
return self
......@@ -51,7 +51,7 @@ class RecoveryManager(MasterHandler):
app = self.app
pt = app.pt
app.changeClusterState(ClusterStates.RECOVERING)
pt.setID(None)
pt.clear()
# collect the last partition table available
poll = app.em.poll
......@@ -60,11 +60,15 @@ class RecoveryManager(MasterHandler):
if pt.filled():
# A partition table exists, we are starting an existing
# cluster.
node_list = pt.getReadableCellNodeSet()
node_list = pt.getOperationalNodeSet()
if app._startup_allowed:
node_list = [node for node in node_list if node.isPending()]
elif not all(node.isPending() for node in node_list):
continue
elif node_list:
# we want all nodes to be there if we're going to truncate
if app.truncate_tid:
node_list = pt.getNodeSet()
if not all(node.isPending() for node in node_list):
continue
elif app._startup_allowed or app.autostart:
# No partition table and admin allowed startup, we are
# creating a new cluster out of all pending nodes.
......@@ -76,6 +80,17 @@ class RecoveryManager(MasterHandler):
if node_list and not any(node.getConnection().isPending()
for node in node_list):
if pt.filled():
if app.truncate_tid:
node_list = app.nm.getIdentifiedList(pool_set={uuid
for uuid, tid in self.truncate_dict.iteritems()
if not tid or app.truncate_tid < tid})
if node_list:
truncate = Packets.Truncate(app.truncate_tid)
for node in node_list:
conn = node.getConnection()
conn.notify(truncate)
self.connectionCompleted(conn, False)
continue
node_list = pt.getConnectedNodeList()
break
......@@ -88,64 +103,81 @@ class RecoveryManager(MasterHandler):
if pt.getID() is None:
logging.info('creating a new partition table')
# reset IDs generators & build new partition with running nodes
app.tm.setLastOID(ZERO_OID)
pt.make(node_list)
self._broadcastPartitionTable(pt.getID(), pt.getRowList())
elif app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid()
app.setLastTransaction(app.tm.getLastTID())
logging.debug('cluster starts with loid=%s and this partition table :',
dump(app.tm.getLastOID()))
self._notifyAdmins(Packets.SendPartitionTable(
pt.getID(), pt.getRowList()))
else:
cell_list = pt.outdate()
if cell_list:
self._notifyAdmins(Packets.NotifyPartitionChanges(
pt.setNextID(), cell_list))
if app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid()
logging.debug('cluster starts this partition table:')
pt.log()
def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID())
assert node is not None
uuid = conn.getUUID()
self.backup_tid_dict.pop(uuid, None)
self.truncate_dict.pop(uuid, None)
node = self.app.nm.getByUUID(uuid)
try:
i = self.ask_pt.index(uuid)
except ValueError:
pass
else:
del self.ask_pt[i]
if not i:
if self.ask_pt:
self.app.nm.getByUUID(self.ask_pt[0]) \
.ask(Packets.AskPartitionTable())
else:
logging.warning("Waiting for %r to come back."
" No other node has version %s of the partition table.",
node, self.target_ptid)
if node.getState() == new_state:
return
node.setState(new_state)
# broadcast to all so that admin nodes gets informed
self.app.broadcastNodesInformation([node])
def connectionCompleted(self, conn):
def connectionCompleted(self, conn, new):
# ask the last IDs to perform the recovery
conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
# Get max values.
if loid is not None:
self.app.tm.setLastOID(loid)
if ltid is not None:
self.app.tm.setLastTID(ltid)
if lptid > self.target_ptid:
# something newer
self.target_ptid = lptid
conn.ask(Packets.AskPartitionTable())
self.backup_tid_dict[conn.getUUID()] = backup_tid
conn.ask(Packets.AskRecovery())
def answerRecovery(self, conn, ptid, backup_tid, truncate_tid):
uuid = conn.getUUID()
if self.target_ptid <= ptid:
# Maybe a newer partition table.
if self.target_ptid == ptid and self.ask_pt:
# Another node is already asked.
self.ask_pt.append(uuid)
elif self.target_ptid < ptid or self.ask_pt is not ():
# No node asked yet for the newest partition table.
self.target_ptid = ptid
self.ask_pt = [uuid]
conn.ask(Packets.AskPartitionTable())
self.backup_tid_dict[uuid] = backup_tid
self.truncate_dict[uuid] = truncate_tid
def answerPartitionTable(self, conn, ptid, row_list):
if ptid != self.target_ptid:
# If this is not from a target node, ignore it.
logging.warn('Got %s while waiting %s', dump(ptid),
dump(self.target_ptid))
else:
self._broadcastPartitionTable(ptid, row_list)
self.app.backup_tid = self.backup_tid_dict[conn.getUUID()]
def _broadcastPartitionTable(self, ptid, row_list):
try:
new_nodes = self.app.pt.load(ptid, row_list, self.app.nm)
except IndexError:
raise ProtocolError('Invalid offset')
else:
notification = Packets.NotifyNodeInformation(new_nodes)
ptid = self.app.pt.getID()
row_list = self.app.pt.getRowList()
partition_table = Packets.SendPartitionTable(ptid, row_list)
# notify the admin nodes
for node in self.app.nm.getAdminList(only_identified=True):
node.notify(notification)
node.notify(partition_table)
# If this is not from a target node, ignore it.
if ptid == self.target_ptid:
app = self.app
try:
new_nodes = app.pt.load(ptid, row_list, app.nm)
except IndexError:
raise ProtocolError('Invalid offset')
self._notifyAdmins(Packets.NotifyNodeInformation(new_nodes),
Packets.SendPartitionTable(ptid, row_list))
self.ask_pt = ()
uuid = conn.getUUID()
app.backup_tid = self.backup_tid_dict[uuid]
app.truncate_tid = self.truncate_dict[uuid]
def _notifyAdmins(self, *packets):
for node in self.app.nm.getAdminList(only_identified=True):
for packet in packets:
node.notify(packet)
......@@ -17,7 +17,7 @@
from time import time
from struct import pack, unpack
from neo.lib import logging
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_TID
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_OID, ZERO_TID
from neo.lib.util import dump, u64, addTID, tidFromTime
class DelayedError(Exception):
......@@ -155,15 +155,18 @@ class TransactionManager(object):
"""
Manage current transactions
"""
_last_tid = ZERO_TID
def __init__(self, on_commit):
self._on_commit = on_commit
self.reset()
def reset(self):
# ttid -> transaction
self._ttid_dict = {}
# node -> transactions mapping
self._node_dict = {}
self._last_oid = None
self._on_commit = on_commit
self._last_oid = ZERO_OID
self._last_tid = ZERO_TID
# queue filled with ttids pointing to transactions with increasing tids
self._queue = []
......@@ -182,8 +185,6 @@ class TransactionManager(object):
def getNextOIDList(self, num_oids):
""" Generate a new OID list """
if self._last_oid is None:
raise RuntimeError, 'I do not know the last OID'
oid = unpack('!Q', self._last_oid)[0] + 1
oid_list = [pack('!Q', oid + i) for i in xrange(num_oids)]
self._last_oid = oid_list[-1]
......@@ -249,14 +250,6 @@ class TransactionManager(object):
if self._last_tid < tid:
self._last_tid = tid
def reset(self):
"""
Discard all manager content
This doesn't reset the last TID.
"""
self._ttid_dict = {}
self._node_dict = {}
def hasPending(self):
"""
Returns True if some transactions are pending
......@@ -359,7 +352,12 @@ class TransactionManager(object):
self._unlockPending()
def _unlockPending(self):
# unlock pending transactions
"""Serialize transaction unlocks
This should rarely delay unlocks since the time needed to lock a
transaction is roughly constant. The most common case where reordering
is required is when some storages are already busy by other tasks.
"""
queue = self._queue
pop = queue.pop
insert = queue.insert
......
......@@ -14,48 +14,29 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import defaultdict
from neo.lib import logging
from neo.lib.util import dump
from neo.lib.protocol import ClusterStates, Packets, NodeStates
from .handlers import BaseServiceHandler
class VerificationFailure(Exception):
"""
Exception raised each time the cluster integrity failed.
- An required storage node is missing
- A transaction or an object is missing on a node
"""
pass
class VerificationManager(BaseServiceHandler):
"""
Manager for verification step of a NEO cluster:
- Wait for at least one available storage per partition
- Check if all expected content is present
"""
def __init__(self, app):
self._oid_set = set()
self._tid_set = set()
self._locked_dict = {}
self._voted_dict = defaultdict(set)
self._uuid_set = set()
self._object_present = False
def _askStorageNodesAndWait(self, packet, node_list):
poll = self.app.em.poll
operational = self.app.pt.operational
uuid_set = self._uuid_set
uuid_set.clear()
for node in node_list:
uuid_set.add(node.getUUID())
node.ask(packet)
while True:
while uuid_set:
poll(1)
if not operational():
raise VerificationFailure
if not uuid_set:
break
def getHandler(self):
return self
......@@ -76,135 +57,80 @@ class VerificationManager(BaseServiceHandler):
return state, self
def run(self):
self.app.changeClusterState(ClusterStates.VERIFYING)
while True:
try:
self.verifyData()
except VerificationFailure:
continue
break
# At this stage, all non-working nodes are out-of-date.
self.app.broadcastPartitionChanges(self.app.pt.outdate())
app = self.app
app.changeClusterState(ClusterStates.VERIFYING)
app.tm.reset()
if not app.backup_tid:
self.verifyData()
# This is where storages truncate if requested:
# - we make sure all nodes are running with a truncate_tid value saved
# - there's no unfinished data
# - just before they return the last tid/oid
self._askStorageNodesAndWait(Packets.AskLastIDs(),
[x for x in app.nm.getIdentifiedList() if x.isStorage()])
app.setLastTransaction(app.tm.getLastTID())
# Just to not return meaningless information in AnswerRecovery.
app.truncate_tid = None
def verifyData(self):
"""Verify the data in storage nodes and clean them up, if necessary."""
app = self.app
# wait for any missing node
logging.debug('waiting for the cluster to be operational')
while not app.pt.operational():
app.em.poll(1)
if app.backup_tid:
return
logging.info('start to verify data')
getIdentifiedList = app.nm.getIdentifiedList
# Gather all unfinished transactions.
self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(),
# Gather all transactions that may have been partially finished.
self._askStorageNodesAndWait(Packets.AskLockedTransactions(),
[x for x in getIdentifiedList() if x.isStorage()])
# Gather OIDs for each unfinished TID, and verify whether the
# transaction can be finished or must be aborted. This could be
# in parallel in theory, but not so easy. Thus do it one-by-one
# at the moment.
for tid in self._tid_set:
uuid_set = self.verifyTransaction(tid)
if uuid_set is None:
packet = Packets.DeleteTransaction(tid, self._oid_set or [])
# Make sure that no node has this transaction.
for node in getIdentifiedList():
if node.isStorage():
node.notify(packet)
# Some nodes may have already unlocked these transactions and
# _locked_dict is incomplete, but we can ask them the final tid.
for ttid, voted_set in self._voted_dict.iteritems():
if ttid in self._locked_dict:
continue
partition = app.pt.getPartition(ttid)
for node in getIdentifiedList(pool_set={cell.getUUID()
# If an outdated cell had unlocked ttid, then either
# it is already in _locked_dict or a readable cell also
# unlocked it.
for cell in app.pt.getCellList(partition, readable=True)
} - voted_set):
self._askStorageNodesAndWait(Packets.AskFinalTID(ttid), (node,))
if self._tid is not None:
self._locked_dict[ttid] = self._tid
break
else:
if app.getLastTransaction() < tid: # XXX: refactoring needed
app.setLastTransaction(tid)
app.tm.setLastTID(tid)
packet = Packets.CommitTransaction(tid)
# Transaction not locked. No need to tell nodes to delete it,
# since they drop any unfinished data just before being
# operational.
pass
# Finish all transactions for which we know that tpc_finish was called
# but not fully processed. This may include replicas with transactions
# that were not even locked.
for ttid, tid in self._locked_dict.iteritems():
uuid_set = self._voted_dict.get(ttid)
if uuid_set:
packet = Packets.ValidateTransaction(ttid, tid)
for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet)
self._oid_set = set()
# If possible, send the packets now.
app.em.poll(0)
def verifyTransaction(self, tid):
nm = self.app.nm
uuid_set = set()
# Determine to which nodes I should ask.
partition = self.app.pt.getPartition(tid)
uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(uuid_list) == 0:
raise VerificationFailure
uuid_set.update(uuid_list)
# Gather OIDs.
node_list = self.app.nm.getIdentifiedList(pool_set=uuid_list)
if len(node_list) == 0:
raise VerificationFailure
self._askStorageNodesAndWait(Packets.AskTransactionInformation(tid),
node_list)
if self._oid_set is None or len(self._oid_set) == 0:
# Not commitable.
return None
# Verify that all objects are present.
for oid in self._oid_set:
partition = self.app.pt.getPartition(oid)
object_uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(object_uuid_list) == 0:
raise VerificationFailure
uuid_set.update(object_uuid_list)
self._object_present = True
self._askStorageNodesAndWait(Packets.AskObjectPresent(oid, tid),
nm.getIdentifiedList(pool_set=object_uuid_list))
if not self._object_present:
# Not commitable.
return None
return uuid_set
def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
logging.info('got unfinished transactions %s from %r',
map(dump, tid_list), conn)
self._uuid_set.remove(conn.getUUID())
self._tid_set.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
self._uuid_set.remove(conn.getUUID())
oid_set = set(oid_list)
if self._oid_set is None:
# Someone does not agree.
pass
elif len(self._oid_set) == 0:
# This is the first answer.
self._oid_set.update(oid_set)
elif self._oid_set != oid_set:
raise ValueError, "Inconsistent transaction %s" % \
(dump(tid, ))
def tidNotFound(self, conn, message):
logging.info('TID not found: %s', message)
def answerLastIDs(self, conn, loid, ltid):
self._uuid_set.remove(conn.getUUID())
self._oid_set = None
def answerObjectPresent(self, conn, oid, tid):
logging.info('object %s:%s found', dump(oid), dump(tid))
tm = self.app.tm
tm.setLastOID(loid)
tm.setLastTID(ltid)
def answerLockedTransactions(self, conn, tid_dict):
uuid = conn.getUUID()
self._uuid_set.remove(uuid)
for ttid, tid in tid_dict.iteritems():
if tid:
self._locked_dict[ttid] = tid
self._voted_dict[ttid].add(uuid)
def answerFinalTID(self, conn, tid):
self._uuid_set.remove(conn.getUUID())
self._tid = tid
def oidNotFound(self, conn, message):
logging.info('OID not found: %s', message)
self._uuid_set.remove(conn.getUUID())
self._object_present = False
def connectionCompleted(self, conn):
pass
def nodeLost(self, conn, node):
if not self.app.pt.operational():
raise VerificationFailure, 'cannot continue verification'
def connectionLost(self, conn, new_state):
self._uuid_set.discard(conn.getUUID())
super(VerificationManager, self).connectionLost(conn, new_state)
......@@ -37,6 +37,7 @@ action_dict = {
'tweak': 'tweakPartitionTable',
'drop': 'dropNode',
'kill': 'killNode',
'truncate': 'truncate',
}
uuid_int = (lambda ns: lambda uuid:
......@@ -85,11 +86,14 @@ class TerminalNeoCTL(object):
Get last ids.
"""
assert not params
r = self.neoctl.getLastIds()
if r[3]:
return "last_tid = 0x%x" % u64(self.neoctl.getLastTransaction())
return "last_oid = 0x%x\nlast_tid = 0x%x\nlast_ptid = %u" % (
u64(r[0]), u64(r[1]), r[2])
ptid, backup_tid, truncate_tid = self.neoctl.getRecovery()
if backup_tid:
ltid = self.neoctl.getLastTransaction()
r = "backup_tid = 0x%x" % u64(backup_tid)
else:
loid, ltid = self.neoctl.getLastIds()
r = "last_oid = 0x%x" % u64(loid)
return r + "\nlast_tid = 0x%x\nlast_ptid = %u" % (u64(ltid), ptid)
def getPartitionRowList(self, params):
"""
......@@ -193,6 +197,19 @@ class TerminalNeoCTL(object):
"""
return uuid_str(self.neoctl.getPrimary())
def truncate(self, params):
"""
Truncate the database at the given tid.
The cluster must be in RUNNING state, without any pending transaction.
This causes the cluster to go back in RECOVERING state, waiting all
nodes to be pending (do not use 'start' command unless you're sure
the missing nodes don't need to be truncated).
Parameters: tid
"""
self.neoctl.truncate(self.asTID(*params))
def checkReplicas(self, params):
"""
Test whether partitions have corrupted metadata
......
......@@ -61,3 +61,4 @@ class CommandEventHandler(EventHandler):
answerPrimary = __answer(Packets.AnswerPrimary)
answerLastIDs = __answer(Packets.AnswerLastIDs)
answerLastTransaction = __answer(Packets.AnswerLastTransaction)
answerRecovery = __answer(Packets.AnswerRecovery)
......@@ -120,6 +120,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response)
return response[1]
def getRecovery(self):
response = self.__ask(Packets.AskRecovery())
if response[0] != Packets.AnswerRecovery:
raise RuntimeError(response)
return response[1:]
def getNodeList(self, node_type=None):
"""
Get a list of nodes, filtering with given type.
......@@ -163,6 +169,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response)
return response[1]
def truncate(self, tid):
response = self.__ask(Packets.Truncate(tid))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
def checkReplicas(self, *args):
response = self.__ask(Packets.CheckReplicas(*args))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
......
......@@ -53,7 +53,6 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions',
# client application
......
......@@ -23,14 +23,14 @@ from neo.lib.protocol import uuid_str, \
CellStates, ClusterStates, NodeTypes, Packets
from neo.lib.node import NodeManager
from neo.lib.connection import ListeningConnection
from neo.lib.exception import OperationFailure, PrimaryFailure
from neo.lib.exception import StoppedOperation, PrimaryFailure
from neo.lib.pt import PartitionTable
from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager
from .checker import Checker
from .database import buildDatabaseManager
from .exception import AlreadyPendingError
from .handlers import identification, verification, initialization
from .handlers import identification, initialization
from .handlers import master, hidden
from .replicator import Replicator
from .transactions import TransactionManager
......@@ -193,14 +193,11 @@ class Application(BaseApplication):
self.event_queue = deque()
self.event_queue_dict = {}
try:
self.verifyData()
self.initialize()
self.doOperation()
raise RuntimeError, 'should not reach here'
except OperationFailure, msg:
except StoppedOperation, msg:
logging.error('operation stopped: %s', msg)
if self.cluster_state == ClusterStates.STOPPING_BACKUP:
self.dm.setBackupTID(None)
except PrimaryFailure, msg:
logging.error('primary master is down: %s', msg)
finally:
......@@ -247,30 +244,11 @@ class Application(BaseApplication):
self.pt = PartitionTable(num_partitions, num_replicas)
self.loadPartitionTable()
def verifyData(self):
"""Verify data under the control by a primary master node.
Connections from client nodes may not be accepted at this stage."""
logging.info('verifying data')
handler = verification.VerificationHandler(self)
self.master_conn.setHandler(handler)
_poll = self._poll
while not self.operational:
_poll()
def initialize(self):
""" Retreive partition table and node informations from the primary """
logging.debug('initializing...')
_poll = self._poll
handler = initialization.InitializationHandler(self)
self.master_conn.setHandler(handler)
# ask node list and partition table
self.pt.clear()
self.master_conn.ask(Packets.AskNodeInformation())
self.master_conn.ask(Packets.AskPartitionTable())
while self.master_conn.isPending():
self.master_conn.setHandler(initialization.InitializationHandler(self))
while not self.operational:
_poll()
self.ready = True
self.replicator.populate()
......
......@@ -297,8 +297,9 @@ class ImporterDatabaseManager(DatabaseManager):
self.db = buildDatabaseManager(main['adapter'],
(main['database'], main.get('engine'), main['wait']))
for x in """query erase getConfiguration _setConfiguration
getPartitionTable changePartitionTable getUnfinishedTIDList
dropUnfinishedData storeTransaction finishTransaction
getPartitionTable changePartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction
storeData _pruneData
""".split():
setattr(self, x, getattr(self.db, x))
......@@ -421,7 +422,7 @@ class ImporterDatabaseManager(DatabaseManager):
logging.warning("All data are imported. You should change"
" your configuration to use the native backend and restart.")
self._import = None
for x in """getObject objectPresent getReplicationTIDList
for x in """getObject getReplicationTIDList
""".split():
setattr(self, x, getattr(self.db, x))
......@@ -434,23 +435,11 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid
def getLastIDs(self, all=True):
tid, _, _, oid = self.db.getLastIDs(all)
def getLastIDs(self):
tid, _, _, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), None, None,
max(oid, util.p64(self.zodb_loid)))
def objectPresent(self, oid, tid, all=True):
r = self.db.objectPresent(oid, tid, all)
if not r:
u_oid = util.u64(oid)
u_tid = util.u64(tid)
if self.inZodb(u_oid, u_tid):
zodb, oid = self.zodbFromOid(u_oid)
try:
return zodb.loadSerial(util.p64(oid), tid)
except POSKeyError:
pass
def getObject(self, oid, tid=None, before_tid=None):
u64 = util.u64
u_oid = u64(oid)
......@@ -511,6 +500,16 @@ class ImporterDatabaseManager(DatabaseManager):
else:
return self.db.getTransaction(tid, all)
def getFinalTID(self, ttid):
if u64(ttid) <= self.zodb_ltid and self._import:
raise NotImplementedError
return self.db.getFinalTID(ttid)
def deleteTransaction(self, tid):
if u64(tid) <= self.zodb_ltid and self._import:
raise NotImplementedError
self.db.deleteTransaction(tid)
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64
tid = p64(self.zodb_tid)
......
......@@ -17,6 +17,7 @@
from collections import defaultdict
from functools import wraps
from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_TID, BackendNotImplemented
def lazymethod(func):
......@@ -94,6 +95,22 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def nonempty(self, table):
"""Check whether table is empty or return None if it does not exist"""
raise NotImplementedError
def _checkNoUnfinishedTransactions(self, *hint):
if self.nonempty('ttrans') or self.nonempty('tobj'):
raise DatabaseFailure(
"The database can not be upgraded because you have unfinished"
" transactions. Use an older version of NEO to verify them.")
def _getVersion(self):
version = int(self.getConfiguration("version") or 0)
if self.VERSION < version:
raise DatabaseFailure("The database can not be downgraded.")
return version
def doOperation(self, app):
pass
......@@ -194,10 +211,18 @@ class DatabaseManager(object):
def getBackupTID(self):
return util.bin(self.getConfiguration('backup_tid'))
def setBackupTID(self, backup_tid):
tid = util.dump(backup_tid)
def _setBackupTID(self, tid):
tid = util.dump(tid)
logging.debug('backup_tid = %s', tid)
return self.setConfiguration('backup_tid', tid)
return self._setConfiguration('backup_tid', tid)
def getTruncateTID(self):
return util.bin(self.getConfiguration('truncate_tid'))
def _setTruncateTID(self, tid):
tid = util.dump(tid)
logging.debug('truncate_tid = %s', tid)
return self._setConfiguration('truncate_tid', tid)
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
......@@ -222,10 +247,10 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def _getLastIDs(self, all=True):
def _getLastIDs(self):
raise NotImplementedError
def getLastIDs(self, all=True):
def getLastIDs(self):
trans, obj, oid = self._getLastIDs()
if trans:
tid = max(trans.itervalues())
......@@ -241,16 +266,16 @@ class DatabaseManager(object):
trans = obj = {}
return tid, trans, obj, oid
def getUnfinishedTIDList(self):
"""Return a list of unfinished transaction's IDs."""
def _getUnfinishedTIDDict(self):
raise NotImplementedError
def objectPresent(self, oid, tid, all = True):
"""Return true iff an object specified by a given pair of an
object ID and a transaction ID is present in a database.
Otherwise, return false. If all is true, the object must be
searched from unfinished transactions as well."""
raise NotImplementedError
def getUnfinishedTIDDict(self):
trans, obj = self._getUnfinishedTIDDict()
obj = dict.fromkeys(obj)
obj.update(trans)
p64 = util.p64
return {p64(ttid): None if tid is None else p64(tid)
for ttid, tid in obj.iteritems()}
@fallback
def getLastObjectTID(self, oid):
......@@ -478,14 +503,18 @@ class DatabaseManager(object):
data_tid = p64(data_tid)
return p64(current_tid), data_tid, is_current
def finishTransaction(self, tid):
"""Finish a transaction specified by a given ID, by moving
temporarily data to a finished area."""
def lockTransaction(self, tid, ttid):
"""Mark voted transaction 'ttid' as committed with given 'tid'"""
raise NotImplementedError
def unlockTransaction(self, tid, ttid):
"""Finalize a transaction by moving data to a finished area."""
raise NotImplementedError
def deleteTransaction(self, tid, oid_list=()):
"""Delete a transaction and its content specified by a given ID and
an oid list"""
def abortTransaction(self, ttid):
raise NotImplementedError
def deleteTransaction(self, tid):
raise NotImplementedError
def deleteObject(self, oid, serial=None):
......@@ -498,13 +527,14 @@ class DatabaseManager(object):
and max_tid (included)"""
raise NotImplementedError
def truncate(self, tid):
assert tid not in (None, ZERO_TID), tid
assert self.getBackupTID()
self.setBackupTID(None) # XXX
for partition in xrange(self.getNumPartitions()):
self._deleteRange(partition, tid)
self.commit()
def truncate(self):
tid = self.getTruncateTID()
if tid:
assert tid != ZERO_TID, tid
for partition in xrange(self.getNumPartitions()):
self._deleteRange(partition, tid)
self._setTruncateTID(None)
self.commit()
def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information,
......
......@@ -16,9 +16,10 @@
from binascii import a2b_hex
import MySQLdb
from MySQLdb import DataError, IntegrityError, OperationalError
from MySQLdb import DataError, IntegrityError, \
OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
from array import array
from hashlib import sha1
import os
......@@ -42,6 +43,7 @@ def getPrintableQuery(query, max=70):
class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL."""
VERSION = 1
ENGINES = "InnoDB", "TokuDB"
_engine = ENGINES[0] # default engine
......@@ -144,16 +146,33 @@ class MySQLDatabaseManager(DatabaseManager):
self.query("DROP TABLE IF EXISTS"
" config, pt, trans, obj, data, bigdata, ttrans, tobj")
def nonempty(self, table):
try:
return bool(self.query("SELECT 1 FROM %s LIMIT 1" % table))
except ProgrammingError, (code, _):
if code != NO_SUCH_TABLE:
raise
def _setup(self):
self._config.clear()
q = self.query
p = engine = self._engine
# The table "config" stores configuration parameters which affect the
# persistent data.
q("""CREATE TABLE IF NOT EXISTS config (
name VARBINARY(255) NOT NULL PRIMARY KEY,
value VARBINARY(255) NULL
) ENGINE=""" + engine)
if self.nonempty("config") is None:
# The table "config" stores configuration
# parameters which affect the persistent data.
q("""CREATE TABLE config (
name VARBINARY(255) NOT NULL PRIMARY KEY,
value VARBINARY(255) NULL
) ENGINE=""" + engine)
else:
# Automatic migration.
version = self._getVersion()
if version < 1:
self._checkNoUnfinishedTransactions()
q("DROP TABLE IF EXISTS ttrans")
self._setConfiguration("version", self.VERSION)
# The table "pt" stores a partition table.
q("""CREATE TABLE IF NOT EXISTS pt (
......@@ -214,7 +233,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
`partition` SMALLINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED,
packed BOOLEAN NOT NULL,
oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL,
......@@ -274,7 +293,7 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans"
" WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0]
def _getLastIDs(self, all=True):
def _getLastIDs(self):
p64 = util.p64
q = self.query
trans = {partition: p64(tid)
......@@ -285,29 +304,21 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM obj GROUP BY `partition`")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY `partition`) as t")[0][0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0]
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all = True):
oid = util.u64(oid)
tid = util.u64(tid)
def _getUnfinishedTIDDict(self):
q = self.query
return q("SELECT 1 FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d"
% (self._getPartition(oid), oid, tid)) or all and \
q("SELECT 1 FROM tobj WHERE tid=%d AND oid=%d" % (tid, oid))
return q("SELECT ttid, tid FROM ttrans"), (ttid
for ttid, in q("SELECT DISTINCT tid FROM tobj"))
def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# MariaDB is smart enough to realize that 'ttid' is constant.
r = self.query("SELECT tid FROM trans"
" WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1"
% (self._getPartition(ttid), ttid))
if r:
return util.p64(r[0][0])
def getLastObjectTID(self, oid):
oid = util.u64(oid)
......@@ -450,9 +461,9 @@ class MySQLDatabaseManager(DatabaseManager):
oid_list, user, desc, ext, packed, ttid = transaction
partition = self._getPartition(tid)
assert packed in (0, 1)
q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % (
trans_table, partition, tid, packed, e(''.join(oid_list)),
e(user), e(desc), e(ext), u64(ttid)))
q("REPLACE INTO %s VALUES (%s,%s,%s,'%s','%s','%s','%s',%s)" % (
trans_table, partition, 'NULL' if temporary else tid, packed,
e(''.join(oid_list)), e(user), e(desc), e(ext), u64(ttid)))
if temporary:
self.commit()
......@@ -544,40 +555,40 @@ class MySQLDatabaseManager(DatabaseManager):
r = self.query(sql)
return r[0] if r else (None, None)
def finishTransaction(self, tid):
def lockTransaction(self, tid, ttid):
u64 = util.u64
self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1"
% (u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query
tid = util.u64(tid)
sql = " FROM tobj WHERE tid=%d" % tid
u64 = util.u64
tid = u64(tid)
sql = " FROM tobj WHERE tid=%d" % u64(ttid)
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql)
q("DELETE FROM tobj WHERE tid=%d" % tid)
q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s"
% (tid, sql))
q("DELETE" + sql)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.releaseData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
def abortTransaction(self, ttid):
ttid = util.u64(ttid)
q = self.query
sql = " FROM tobj WHERE tid=%d" % tid
sql = " FROM tobj WHERE tid=%s" % ttid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.releaseData(data_id_list)
q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE `partition` = %d AND tid = %d""" %
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d" \
% (getPartition(oid), oid, tid)
data_id_list.update(*q("SELECT data_id" + sql))
q("DELETE" + sql)
data_id_list.discard(None)
self._pruneData(data_id_list)
q("DELETE FROM ttrans WHERE ttid=%s" % ttid)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" %
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None):
u64 = util.u64
......
......@@ -66,6 +66,8 @@ class SQLiteDatabaseManager(DatabaseManager):
never be used for small requests.
"""
VERSION = 1
def __init__(self, *args, **kw):
super(SQLiteDatabaseManager, self).__init__(*args, **kw)
self._config = {}
......@@ -101,15 +103,32 @@ class SQLiteDatabaseManager(DatabaseManager):
for t in 'config', 'pt', 'trans', 'obj', 'data', 'ttrans', 'tobj':
self.query('DROP TABLE IF EXISTS ' + t)
def nonempty(self, table):
try:
return bool(self.query(
"SELECT 1 FROM %s LIMIT 1" % table).fetchone())
except sqlite3.OperationalError as e:
if not e.args[0].startswith("no such table:"):
raise
def _setup(self):
self._config.clear()
q = self.query
# The table "config" stores configuration parameters which affect the
# persistent data.
q("""CREATE TABLE IF NOT EXISTS config (
name TEXT NOT NULL PRIMARY KEY,
value TEXT)
""")
if self.nonempty("config") is None:
# The table "config" stores configuration
# parameters which affect the persistent data.
q("CREATE TABLE IF NOT EXISTS config ("
" name TEXT NOT NULL PRIMARY KEY,"
" value TEXT)")
else:
# Automatic migration.
version = self._getVersion()
if version < 1:
self._checkNoUnfinishedTransactions()
q("DROP TABLE IF EXISTS ttrans")
self._setConfiguration("version", self.VERSION)
# The table "pt" stores a partition table.
q("""CREATE TABLE IF NOT EXISTS pt (
......@@ -162,7 +181,7 @@ class SQLiteDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
tid INTEGER,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
......@@ -221,7 +240,7 @@ class SQLiteDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?",
(max_tid,)).next()[0]
def _getLastIDs(self, all=True):
def _getLastIDs(self):
p64 = util.p64
q = self.query
trans = {partition: p64(tid)
......@@ -232,30 +251,21 @@ class SQLiteDatabaseManager(DatabaseManager):
" FROM obj GROUP BY partition")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t").next()[0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans").next()[0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next()
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid)
tid = util.u64(tid)
def _getUnfinishedTIDDict(self):
q = self.query
return q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?",
(self._getPartition(oid), oid, tid)).fetchone() or all and \
q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
(tid, oid)).fetchone()
return q("SELECT ttid, tid FROM ttrans"), (ttid
for ttid, in q("SELECT DISTINCT tid FROM tobj"))
def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# As of SQLite 3.8.7.1, 'tid>=ttid' would ignore the index on tid,
# even though ttid is a constant.
for tid, in self.query("SELECT tid FROM trans"
" WHERE partition=? AND tid>=? AND ttid=? LIMIT 1",
(self._getPartition(ttid), ttid, ttid)):
return util.p64(tid)
def getLastObjectTID(self, oid):
oid = util.u64(oid)
......@@ -362,7 +372,8 @@ class SQLiteDatabaseManager(DatabaseManager):
partition = self._getPartition(tid)
assert packed in (0, 1)
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)),
(partition, None if temporary else tid,
packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid)))
if temporary:
self.commit()
......@@ -407,40 +418,41 @@ class SQLiteDatabaseManager(DatabaseManager):
r = r.fetchone()
return r or (None, None)
def finishTransaction(self, tid):
args = util.u64(tid),
def lockTransaction(self, tid, ttid):
u64 = util.u64
self.query("UPDATE ttrans SET tid=? WHERE ttid=?",
(u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query
u64 = util.u64
tid = u64(tid)
ttid = u64(ttid)
sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args)
q("DELETE FROM tobj WHERE tid=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args)
q("DELETE FROM ttrans WHERE tid=?", args)
data_id_list = [x for x, in q("SELECT data_id" + sql, (ttid,)) if x]
q("INSERT INTO obj SELECT partition, oid, ?, data_id, value_tid" + sql,
(tid, ttid))
q("DELETE" + sql, (ttid,))
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
self.releaseData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
def abortTransaction(self, ttid):
args = util.u64(ttid),
q = self.query
sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.releaseData(data_id_list)
q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND tid=?"
args = getPartition(oid), oid, tid
data_id_list.update(*q("SELECT data_id" + sql, args))
q("DELETE" + sql, args)
data_id_list.discard(None)
self._pruneData(data_id_list)
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("DELETE" + sql, args)
q("DELETE FROM ttrans WHERE ttid=?", args)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None):
oid = util.u64(oid)
......
......@@ -16,8 +16,8 @@
from neo.lib import logging
from neo.lib.handler import EventHandler
from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.protocol import uuid_str, NodeStates, NodeTypes
from neo.lib.exception import PrimaryFailure, StoppedOperation
from neo.lib.protocol import uuid_str, NodeStates, NodeTypes, Packets
class BaseMasterHandler(EventHandler):
......@@ -27,7 +27,7 @@ class BaseMasterHandler(EventHandler):
raise PrimaryFailure('connection lost')
def stopOperation(self, conn):
raise OperationFailure('operation stopped')
raise StoppedOperation
def reelectPrimary(self, conn):
raise PrimaryFailure('re-election occurs')
......@@ -48,7 +48,7 @@ class BaseMasterHandler(EventHandler):
erase = state == NodeStates.DOWN
self.app.shutdown(erase=erase)
elif state == NodeStates.HIDDEN:
raise OperationFailure
raise StoppedOperation
elif node_type == NodeTypes.CLIENT and state != NodeStates.RUNNING:
logging.info('Notified of non-running client, abort (%s)',
uuid_str(uuid))
......@@ -56,3 +56,6 @@ class BaseMasterHandler(EventHandler):
def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw)
def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
......@@ -19,7 +19,7 @@ from neo.lib.handler import EventHandler
from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \
ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError
from ..transactions import ConflictError, DelayedError, NotRegisteredError
from ..exception import AlreadyPendingError
import time
......@@ -68,21 +68,17 @@ class ClientOperationHandler(EventHandler):
def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid)
def askStoreTransaction(self, conn, ttid, user, desc, ext, oid_list):
def askStoreTransaction(self, conn, ttid, *txn_info):
self.app.tm.register(conn.getUUID(), ttid)
self.app.tm.storeTransaction(ttid, oid_list, user, desc, ext, False)
conn.answer(Packets.AnswerStoreTransaction(ttid))
self.app.tm.vote(ttid, txn_info)
conn.answer(Packets.AnswerStoreTransaction())
def askVoteTransaction(self, conn, ttid):
self.app.tm.vote(ttid)
conn.answer(Packets.AnswerVoteTransaction())
def _askStoreObject(self, conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
return
try:
self.app.tm.storeObject(ttid, serial, oid, compression,
checksum, data, data_serial, unlock)
......@@ -101,6 +97,13 @@ class ClientOperationHandler(EventHandler):
raise_on_duplicate=unlock)
except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
else:
if SLOW_STORE is not None:
duration = time.time() - request_time
......@@ -189,14 +192,6 @@ class ClientOperationHandler(EventHandler):
self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time())
def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
return
try:
self.app.tm.checkCurrentSerial(ttid, serial, oid)
except ConflictError, err:
......@@ -210,6 +205,13 @@ class ClientOperationHandler(EventHandler):
serial, oid, request_time), key=(oid, ttid))
except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
else:
if SLOW_STORE is not None:
duration = time.time() - request_time
......
......@@ -15,24 +15,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler
from neo.lib import logging, protocol
from neo.lib import logging
from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn):
pass
def answerPartitionTable(self, conn, ptid, row_list):
def sendPartitionTable(self, conn, ptid, row_list):
app = self.app
pt = app.pt
pt.load(ptid, row_list, self.app.nm)
if not pt.filled():
raise protocol.ProtocolError('Partial partition table received')
logging.debug('Got the partition table:')
self.app.pt.log()
raise ProtocolError('Partial partition table received')
# Install the partition table into the database for persistency.
cell_list = []
num_partitions = app.pt.getPartitions()
num_partitions = pt.getPartitions()
unassigned_set = set(xrange(num_partitions))
for offset in xrange(num_partitions):
for cell in pt.getCellList(offset):
......@@ -46,12 +45,47 @@ class InitializationHandler(BaseMasterHandler):
app.dm.changePartitionTable(ptid, cell_list, reset=True)
def notifyPartitionChanges(self, conn, ptid, cell_list):
# XXX: This is safe to ignore those notifications because all of the
# following applies:
# - we first ask for node information, and *then* partition
# table content, so it is possible to get notifyPartitionChanges
# packets in between (or even before asking for node information).
# - this handler will be changed after receiving answerPartitionTable
# and before handling the next packet
logging.debug('ignoring notifyPartitionChanges during initialization')
def truncate(self, conn, tid):
dm = self.app.dm
dm._setBackupTID(None)
dm._setTruncateTID(tid)
dm.commit()
def askRecovery(self, conn):
app = self.app
conn.answer(Packets.AnswerRecovery(
app.pt.getID(),
app.dm.getBackupTID(),
app.dm.getTruncateTID()))
def askLastIDs(self, conn):
dm = self.app.dm
dm.truncate()
ltid, _, _, loid = dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(loid, ltid))
def askPartitionTable(self, conn):
pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
def askLockedTransactions(self, conn):
conn.answer(Packets.AnswerLockedTransactions(
self.app.dm.getUnfinishedTIDDict()))
def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm
dm.lockTransaction(tid, ttid)
dm.unlockTransaction(tid, ttid)
def startOperation(self, conn, backup):
self.app.operational = True
# XXX: see comment in protocol
dm = self.app.dm
if backup:
if dm.getBackupTID():
return
tid = dm.getLastIDs()[0] or ZERO_TID
else:
tid = None
dm._setBackupTID(tid)
dm.commit()
......@@ -16,13 +16,21 @@
from neo.lib import logging
from neo.lib.util import dump
from neo.lib.protocol import Packets, ProtocolError
from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
from . import BaseMasterHandler
class MasterOperationHandler(BaseMasterHandler):
""" This handler is used for the primary master """
def startOperation(self, conn, backup):
# XXX: see comment in protocol
assert self.app.operational and backup
dm = self.app.dm
if not dm.getBackupTID():
dm._setBackupTID(dm.getLastIDs()[0] or ZERO_TID)
dm.commit()
def notifyTransactionFinished(self, conn, *args, **kw):
self.app.replicator.transactionFinished(*args, **kw)
......@@ -42,17 +50,11 @@ class MasterOperationHandler(BaseMasterHandler):
# Check changes for replications
app.replicator.notifyPartitionChanges(cell_list)
def askLockInformation(self, conn, ttid, tid, oid_list):
if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction')
self.app.tm.lock(ttid, tid, oid_list)
if not conn.isClosed():
conn.answer(Packets.AnswerInformationLocked(ttid))
def askLockInformation(self, conn, ttid, tid):
self.app.tm.lock(ttid, tid)
conn.answer(Packets.AnswerInformationLocked(ttid))
def notifyUnlockInformation(self, conn, ttid):
if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction')
# TODO: send an answer
self.app.tm.unlock(ttid)
def askPack(self, conn, tid):
......@@ -60,17 +62,11 @@ class MasterOperationHandler(BaseMasterHandler):
logging.info('Pack started, up to %s...', dump(tid))
app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.')
if not conn.isClosed():
conn.answer(Packets.AnswerPack(True))
conn.answer(Packets.AnswerPack(True))
def replicate(self, conn, tid, upstream_name, source_dict):
self.app.replicator.backup(tid, {p: a and (a, upstream_name)
for p, a in source_dict.iteritems()})
def truncate(self, conn, tid):
self.app.replicator.cancel()
self.app.dm.truncate(tid)
conn.close()
def checkPartition(self, conn, *args):
self.app.checker(*args)
#
# Copyright (C) 2006-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler
from neo.lib import logging
from neo.lib.protocol import Packets, Errors, INVALID_TID, ZERO_TID
from neo.lib.util import dump
from neo.lib.exception import OperationFailure
class VerificationHandler(BaseMasterHandler):
"""This class deals with events for a verification phase."""
def askLastIDs(self, conn):
app = self.app
ltid, _, _, loid = app.dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(
loid,
ltid,
app.pt.getID(),
app.dm.getBackupTID()))
def askPartitionTable(self, conn):
pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that
the information is only about changes from the previous."""
app = self.app
if ptid <= app.pt.getID():
# Ignore this packet.
logging.debug('ignoring older partition changes')
return
# update partition table in memory and the database
app.pt.update(ptid, cell_list, app.nm)
app.dm.changePartitionTable(ptid, cell_list)
def startOperation(self, conn, backup):
self.app.operational = True
dm = self.app.dm
if backup:
if dm.getBackupTID():
return
tid = dm.getLastIDs()[0] or ZERO_TID
else:
tid = None
dm.setBackupTID(tid)
def stopOperation(self, conn):
raise OperationFailure('operation stopped')
def askUnfinishedTransactions(self, conn):
tid_list = self.app.dm.getUnfinishedTIDList()
conn.answer(Packets.AnswerUnfinishedTransactions(INVALID_TID, tid_list))
def askTransactionInformation(self, conn, tid):
app = self.app
t = app.dm.getTransaction(tid, all=True)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def askObjectPresent(self, conn, oid, tid):
if self.app.dm.objectPresent(oid, tid):
p = Packets.AnswerObjectPresent(oid, tid)
else:
p = Errors.OidNotFound(
'%s:%s do not exist' % (dump(oid), dump(tid)))
conn.answer(p)
def deleteTransaction(self, conn, tid, oid_list):
self.app.dm.deleteTransaction(tid, oid_list)
def commitTransaction(self, conn, tid):
self.app.dm.finishTransaction(tid)
......@@ -128,7 +128,8 @@ class Replicator(object):
if tid:
new_tid = self.getBackupTID()
if tid != new_tid:
dm.setBackupTID(new_tid)
dm._setBackupTID(new_tid)
dm.commit()
def populate(self):
app = self.app
......
......@@ -38,18 +38,22 @@ class DelayedError(Exception):
Raised when an object is locked by a previous transaction
"""
class NotRegisteredError(Exception):
"""
Raised when a ttid is not registered
"""
class Transaction(object):
"""
Container for a pending transaction
"""
_tid = None
has_trans = False
def __init__(self, uuid, ttid):
self._uuid = uuid
self._ttid = ttid
self._object_dict = {}
self._transaction = None
self._locked = False
self._birth = time()
self._checked_set = set()
......@@ -89,13 +93,6 @@ class Transaction(object):
def isLocked(self):
return self._locked
def prepare(self, oid_list, user, desc, ext, packed):
"""
Set the transaction informations
"""
# assert self._transaction is not None
self._transaction = oid_list, user, desc, ext, packed, self._ttid
def addObject(self, oid, data_id, value_serial):
"""
Add an object to the transaction
......@@ -121,9 +118,6 @@ class Transaction(object):
def getLockedOIDList(self):
return self._object_dict.keys() + list(self._checked_set)
def getTransactionInformations(self):
return self._transaction
class TransactionManager(object):
"""
......@@ -137,12 +131,6 @@ class TransactionManager(object):
self._load_lock_dict = {}
self._uuid_dict = {}
def __contains__(self, ttid):
"""
Returns True if the TID is known by the manager
"""
return ttid in self._transaction_dict
def register(self, uuid, ttid):
"""
Register a transaction, it may be already registered
......@@ -174,7 +162,21 @@ class TransactionManager(object):
self._load_lock_dict.clear()
self._uuid_dict.clear()
def lock(self, ttid, tid, oid_list):
def vote(self, ttid, txn_info=None):
"""
Store transaction information received from client node
"""
logging.debug('Vote TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid]
object_list = transaction.getObjectList()
if txn_info:
user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid
transaction.has_trans = True
# store metadata to temporary table
self._app.dm.storeTransaction(ttid, object_list, txn_info)
def lock(self, ttid, tid):
"""
Lock a transaction
"""
......@@ -182,43 +184,22 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid]
# remember that the transaction has been locked
transaction.lock()
for oid in transaction.getOIDList():
self._load_lock_dict[oid] = ttid
# check every object that should be locked
uuid = transaction.getUUID()
is_assigned = self._app.pt.isAssigned
for oid in oid_list:
if is_assigned(oid, uuid) and \
self._load_lock_dict.get(oid) != ttid:
raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on
# this storage.
txn_info = transaction.getTransactionInformations()
# store data from memory to temporary table
self._app.dm.storeTransaction(tid, object_list, txn_info)
# ...and remember its definitive TID
self._load_lock_dict.update(
dict.fromkeys(transaction.getOIDList(), ttid))
# commit transaction and remember its definitive TID
if transaction.has_trans:
self._app.dm.lockTransaction(tid, ttid)
transaction.setTID(tid)
def getTIDFromTTID(self, ttid):
return self._transaction_dict[ttid].getTID()
def unlock(self, ttid):
"""
Unlock transaction
"""
logging.debug('Unlock TXN %s', dump(ttid))
self._app.dm.finishTransaction(self.getTIDFromTTID(ttid))
tid = self._transaction_dict[ttid].getTID()
logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
self._app.dm.unlockTransaction(tid, ttid)
self.abort(ttid, even_if_locked=True)
def storeTransaction(self, ttid, oid_list, user, desc, ext, packed):
"""
Store transaction information received from client node
"""
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid]
transaction.prepare(oid_list, user, desc, ext, packed)
def getLockingTID(self, oid):
return self._store_lock_dict.get(oid)
......@@ -283,9 +264,11 @@ class TransactionManager(object):
self._store_lock_dict[oid] = ttid
def checkCurrentSerial(self, ttid, serial, oid):
try:
transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True)
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid]
transaction.addCheckedObject(oid)
def storeObject(self, ttid, serial, oid, compression, checksum, data,
......@@ -293,14 +276,17 @@ class TransactionManager(object):
"""
Store an object received from client node
"""
try:
transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=unlock)
# store object
assert ttid in self, "Transaction not registered"
if data is None:
data_id = None
else:
data_id = self._app.dm.holdData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, data_id, value_serial)
transaction.addObject(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False):
"""
......@@ -322,9 +308,7 @@ class TransactionManager(object):
if not even_if_locked:
return
else:
self._app.dm.releaseData([data_id
for oid, data_id, value_serial in transaction.getObjectList()
if data_id], True)
self._app.dm.abortTransaction(ttid)
# unlock any object
for oid in transaction.getLockedOIDList():
if has_load_lock:
......
......@@ -463,9 +463,6 @@ class NeoUnitTestBase(NeoTestBase):
def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObjectPresent(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw)
......@@ -514,18 +511,12 @@ class NeoUnitTestBase(NeoTestBase):
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
class Patch(object):
......
......@@ -77,7 +77,7 @@ class ClusterTests(NEOFunctionalTest):
self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0)
self.neo.killStorage()
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
def testClusterBreaksWithTwoNodes(self):
self.neo = NEOCluster(['test_neo1', 'test_neo2'],
......@@ -88,7 +88,7 @@ class ClusterTests(NEOFunctionalTest):
self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0)
self.neo.killStorage()
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
def testClusterDoesntBreakWithTwoNodesOneReplica(self):
self.neo = NEOCluster(['test_neo1', 'test_neo2'],
......@@ -127,7 +127,7 @@ class ClusterTests(NEOFunctionalTest):
self.assertEqual(len(self.neo.getClientlist()), 1)
# drop the storage, the cluster is no more operational...
self.neo.getStorageProcessList()[0].stop()
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
# ...and the client gets disconnected
self.assertEqual(len(self.neo.getClientlist()), 0)
# restart storage so that the cluster is operational again
......
......@@ -179,7 +179,7 @@ class StorageTests(NEOFunctionalTest):
# Cluster not operational anymore. Only cells of second storage that
# were shared with the third one should become outdated.
self.neo.expectUnavailable(started[1])
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
self.neo.expectOudatedCells(3)
def testVerificationTriggered(self):
......@@ -200,7 +200,7 @@ class StorageTests(NEOFunctionalTest):
# stop it, the cluster must switch to verification
started[0].stop()
self.neo.expectUnavailable(started[0])
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
# client must have been disconnected
self.assertEqual(len(self.neo.getClientlist()), 0)
conn.close()
......@@ -245,7 +245,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectUnavailable(started[1])
self.neo.expectUnavailable(started[2])
self.neo.expectOudatedCells(number=20)
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
def testConflictingStorageRejected(self):
""" Check that a storage coming after the recovery process with the same
......@@ -403,7 +403,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectUnavailable(started[0])
self.neo.expectUnavailable(started[1])
self.neo.expectOudatedCells(number=10)
self.neo.expectClusterVerifying()
self.neo.expectClusterRecovering()
# XXX: need to sync with storages first
self.neo.stop()
......
......@@ -67,29 +67,6 @@ class MasterRecoveryTests(NeoUnitTestBase):
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.TEMPORARILY_DOWN)
def test_09_answerLastIDs(self):
recovery = self.recovery
uuid = self.identifyToMasterNode()
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
ptid1 = self.getPTID(1)
ptid2 = self.getPTID(2)
self.app.tm.setLastOID(oid1)
self.app.tm.setLastTID(tid1)
self.app.pt.setID(ptid1)
# send information which are later to what PMN knows, this must update target node
conn = self.getFakeConnection(uuid, self.storage_port)
self.assertTrue(ptid2 > self.app.pt.getID())
self.assertTrue(oid2 > self.app.tm.getLastOID())
self.assertTrue(tid2 > self.app.tm.getLastTID())
recovery.answerLastIDs(conn, oid2, tid2, ptid2, None)
self.assertEqual(oid2, self.app.tm.getLastOID())
self.assertEqual(tid2, self.app.tm.getLastTID())
self.assertEqual(ptid2, recovery.target_ptid)
def test_10_answerPartitionTable(self):
recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
......
......@@ -21,7 +21,7 @@ from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.storage import StorageServiceHandler
from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application
from neo.lib.exception import OperationFailure
from neo.lib.exception import StoppedOperation
class MasterStorageHandlerTests(NeoUnitTestBase):
......@@ -114,24 +114,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.checkNotifyUnlockInformation(storage_conn_1)
self.checkNotifyUnlockInformation(storage_conn_2)
def test_12_askLastIDs(self):
service = self.service
node, conn = self.identifyToMasterNode()
# give a uuid
conn = self.getFakeConnection(node.getUUID(), self.storage_address)
ptid = self.app.pt.getID()
oid = self.getOID(1)
tid = self.getNextTID()
self.app.tm.setLastOID(oid)
self.app.tm.setLastTID(tid)
service.askLastIDs(conn)
packet = self.checkAnswerLastIDs(conn)
loid, ltid, lptid, backup_tid = packet.decode()
self.assertEqual(loid, oid)
self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid)
self.assertEqual(backup_tid, None)
def test_13_askUnfinishedTransactions(self):
service = self.service
node, conn = self.identifyToMasterNode()
......@@ -173,64 +155,10 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# drop the second, no storage node left
lptid = self.app.pt.getID()
self.assertEqual(node2.getState(), NodeStates.RUNNING)
self.assertRaises(OperationFailure, method, conn2)
self.assertRaises(StoppedOperation, method, conn2)
self.assertEqual(node2.getState(), state)
self.assertEqual(lptid, self.app.pt.getID())
def test_nodeLostAfterAskLockInformation(self):
# 2 storage nodes, one will die
node1, conn1 = self._getStorage()
node2, conn2 = self._getStorage()
# client nodes, to distinguish answers for the sample transactions
client1, cconn1 = self._getClient()
client2, cconn2 = self._getClient()
client3, cconn3 = self._getClient()
oid_list = [self.getOID(), ]
# Some shortcuts to simplify test code
self.app.pt = Mock({'operational': True})
# Register some transactions
tm = self.app.tm
# Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock
msg_id_1 = 1
ttid1 = tm.begin(client1)
tid1 = tm.prepare(ttid1, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_1)
tm.lock(ttid1, node2.getUUID())
# storage 1 request a notification at commit
tm. registerForNotification(node1.getUUID())
self.checkNoPacketSent(cconn1)
# Storage 1 dies
node1.setTemporarilyDown()
self.service.nodeLost(conn1, node1)
# T1: last locking node lost, client receives AnswerTransactionFinished
self.checkAnswerTransactionFinished(cconn1)
self.checkNotifyTransactionFinished(conn1)
self.checkNotifyUnlockInformation(conn2)
# ...and notifications are sent to other clients
self.checkInvalidateObjects(cconn2)
self.checkInvalidateObjects(cconn3)
# Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2
ttid2 = tm.begin(node1)
tid2 = tm.prepare(ttid2, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_2)
# T2: pending locking answer, client keeps waiting
self.checkNoPacketSent(cconn2, check_notify=False)
tm.remove(node1.getUUID(), ttid2)
# Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3
ttid3 = tm.begin(node1)
tid3 = tm.prepare(ttid3, 1, oid_list,
[node2.getUUID(), ], msg_id_3)
# T3: action not significant to this transacion, so no response
self.checkNoPacketSent(cconn3, check_notify=False)
tm.remove(node1.getUUID(), ttid3)
def test_answerPack(self):
# Note: incomming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage()
......
......@@ -112,19 +112,6 @@ class testTransactionManager(NeoUnitTestBase):
# ...and the lock is available
txnman.begin(client, self.getNextTID())
def test_getNextOIDList(self):
txnman = TransactionManager(lambda tid, txn: None)
# must raise as we don"t have one
self.assertEqual(txnman.getLastOID(), None)
self.assertRaises(RuntimeError, txnman.getNextOIDList, 1)
# ask list
txnman.setLastOID(self.getOID(1))
oid_list = txnman.getNextOIDList(15)
self.assertEqual(len(oid_list), 15)
# begin from 1, so generated oid from 2 to 16
for i, oid in zip(xrange(len(oid_list)), oid_list):
self.assertEqual(oid, self.getOID(i+2))
def test_forget(self):
client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2})
......
......@@ -191,18 +191,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn)
def test_askStoreTransaction(self):
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
user = 'USER'
desc = 'DESC'
ext = 'EXT'
oid_list = (self.getOID(1), self.getOID(2))
self.operation.askStoreTransaction(conn, tid, user, desc, ext, oid_list)
calls = self.app.tm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
self.checkAnswerStoreTransaction(conn)
def _getObject(self):
oid = self.getOID(0)
serial = self.getNextTID()
......
......@@ -76,7 +76,7 @@ class StorageInitializationHandlerTests(NeoUnitTestBase):
(2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))]
self.assertFalse(self.app.pt.filled())
# send a complete new table and ack
self.verification.answerPartitionTable(conn, 2, row_list)
self.verification.sendPartitionTable(conn, 2, row_list)
self.assertTrue(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 2)
self.assertTrue(list(self.app.dm.getPartitionTable()))
......
......@@ -20,7 +20,7 @@ from collections import deque
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler
from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.exception import PrimaryFailure
from neo.lib.pt import PartitionTable
from neo.lib.protocol import CellStates, ProtocolError, Packets
......@@ -104,58 +104,9 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid2, cells)
def test_16_stopOperation1(self):
# OperationFailure
conn = self.getFakeConnection(is_server=False)
self.assertRaises(OperationFailure, self.operation.stopOperation, conn)
def _getConnection(self):
return self.getFakeConnection()
def test_askLockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID()
ttid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn,
ttid, tid, oid_list)
def test_askLockInformation2(self):
""" Lock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
ttid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, ttid, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ttid, tid, oid_list)
self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
tid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
conn, tid)
def test_notifyUnlockInformation2(self):
""" Unlock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
self.operation.notifyUnlockInformation(conn, tid)
calls = self.app.tm.mockGetNamedCalls('unlock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
self.checkNoPacketSent(conn)
def test_askPack(self):
self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection()
......
......@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from binascii import a2b_hex
from contextlib import contextmanager
import unittest
from neo.lib.util import add64, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
......@@ -80,6 +81,17 @@ class StorageDBTests(NeoUnitTestBase):
set_call(value * 2)
self.assertEqual(get_call(), value * 2)
@contextmanager
def commitTransaction(self, tid, objs, txn, commit=True):
ttid = txn[-1]
self.db.storeTransaction(ttid, objs, txn)
self.db.lockTransaction(tid, ttid)
yield
if commit:
self.db.unlockTransaction(tid, ttid)
elif commit is not None:
self.db.abortTransaction(ttid)
def test_UUID(self):
db = self.getDB()
self.checkConfigEntry(db.getUUID, db.setUUID, 123)
......@@ -122,38 +134,24 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2))
def test_getUnfinishedTIDList(self):
def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2])
# nothing pending
self.db.storeTransaction(tid1, objs, txn, False)
self.checkSet(self.db.getUnfinishedTIDList(), [])
# one unfinished txn
self.db.storeTransaction(tid2, objs, txn)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2])
# no changes
self.db.storeTransaction(tid3, objs, None, False)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2])
# a second txn known by objs only
self.db.storeTransaction(tid4, objs, None)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2, tid4])
def test_objectPresent(self):
tid = self.getNextTID()
oid = self.getOID(1)
txn, objs = self.getTransaction([oid])
# not present
self.assertFalse(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in temp table
self.db.storeTransaction(tid, objs, txn)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in both tables
self.db.finishTransaction(tid)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertTrue(self.db.objectPresent(oid, tid, all=False))
with self.commitTransaction(tid2, objs, txn):
expected = {txn[-1]: tid2}
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# no changes
self.db.storeTransaction(tid3, objs, None, False)
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# a second txn known by objs only
expected[tid4] = None
self.db.storeTransaction(tid4, objs, None)
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
self.db.abortTransaction(tid4)
# nothing pending
self.assertEqual(self.db.getUnfinishedTIDDict(), {})
def test_getObject(self):
oid1, = self.getOIDs(1)
......@@ -169,27 +167,26 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one non-commited version
self.db.storeTransaction(tid1, objs1, txn1)
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
with self.commitTransaction(tid1, objs1, txn1):
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one commited version
self.db.finishTransaction(tid1)
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE)
# two version available, one non-commited
self.db.storeTransaction(tid2, objs2, txn2)
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, tid2), FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NO_NEXT)
with self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, tid2),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NO_NEXT)
# two commited versions
self.db.finishTransaction(tid2)
self.assertEqual(self.db.getObject(oid1), OBJECT_T2)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
......@@ -242,82 +239,28 @@ class StorageDBTests(NeoUnitTestBase):
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
def test_dropUnfinishedData(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
# nothing
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [])
# one is still pending
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it
self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self):
def test_commitTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# nothing in database
self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None))
self.assertEqual(self.db.getUnfinishedTIDList(), [])
self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# store in temporary tables
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# commit pending transaction
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
result = self.db.getTransaction(tid1, False)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, False)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_askFinishTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# stored but not finished
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# stored and finished
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1)))
self.assertEqual(self.db.getTransaction(tid2, True),
([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
......@@ -328,32 +271,29 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_deleteTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.deleteTransaction(tid1, [oid1])
self.db.deleteTransaction(tid2, [oid2])
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
txn, objs = self.getTransaction([])
tid = txn[-1]
self.db.storeTransaction(tid, objs, txn, False)
self.assertEqual(self.db.getTransaction(tid), txn)
self.db.deleteTransaction(tid)
self.assertEqual(self.db.getTransaction(tid), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
tid1 = txn1[-1]
tid2 = txn2[-1]
self.db.storeTransaction(tid1, objs1, txn1, False)
self.db.storeTransaction(tid2, objs2, txn2, False)
self.assertEqual(self.db.getObject(oid1, tid=tid1),
(tid1, tid2, 1, "0" * 20, '', None))
self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.assertIs(self.db.getObject(oid1, tid=tid1), None)
self.assertIs(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1))
self.assertIs(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None))
......@@ -364,8 +304,7 @@ class StorageDBTests(NeoUnitTestBase):
oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list)
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.storeTransaction(tid, objs, txn, False)
def check(offset, tid_list, *tids):
self.assertEqual(self.db.getReplicationTIDList(ZERO_TID,
MAX_TID, len(tid_list) + 1, offset), tid_list)
......@@ -386,9 +325,9 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# get from temporary table or not
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2, None):
pass
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
......@@ -405,15 +344,13 @@ class StorageDBTests(NeoUnitTestBase):
txn2, objs2 = self.getTransaction([oid])
txn3, objs3 = self.getTransaction([oid])
# one revision
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid1, objs1, txn1, False)
result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 1)
self.assertEqual(result, None)
# two revisions
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid2)
self.db.storeTransaction(tid2, objs2, txn2, False)
result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid2, 0), (tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 3)
......@@ -427,8 +364,7 @@ class StorageDBTests(NeoUnitTestBase):
oid = self.getOID(1)
for tid in tid_list:
txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.storeTransaction(tid, objs, txn, False)
return tid_list
def test_getTIDList(self):
......
......@@ -45,15 +45,6 @@ class TransactionTests(NeoUnitTestBase):
# disallow lock more than once
self.assertRaises(AssertionError, txn.lock)
def testTransaction(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
repr(txn) # check __repr__ does not raise
oid_list = [self.getOID(1), self.getOID(2)]
txn_info = (oid_list, 'USER', 'DESC', 'EXT', False)
txn.prepare(*txn_info)
self.assertEqual(txn.getTransactionInformations(),
txn_info + (txn.getTTID(),))
def testObjects(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2)
......@@ -91,10 +82,10 @@ class TransactionManagerTests(NeoUnitTestBase):
def _getTransaction(self):
tid = self.getNextTID(self.ltid)
oid_list = [self.getOID(1), self.getOID(2)]
return (tid, (oid_list, 'USER', 'DESC', 'EXT', False))
return (tid, ('USER', 'DESC', 'EXT', oid_list))
def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]):
for i, oid in enumerate(txn[3]):
self.manager.storeObject(tid, None,
oid, 1, '%020d' % i, '0' + str(i), None)
......@@ -108,15 +99,21 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args)
def _checkTransactionFinished(self, tid):
calls = self.app.dm.mockGetNamedCalls('finishTransaction')
def _checkTransactionFinished(self, *args):
calls = self.app.dm.mockGetNamedCalls('unlockTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
calls[0].checkArgs(*args)
def _checkQueuedEventExecuted(self, number=1):
calls = self.app.mockGetNamedCalls('executeQueuedEvents')
self.assertEqual(len(calls), number)
def assertRegistered(self, ttid):
self.assertIn(ttid, self.manager._transaction_dict)
def assertNotRegistered(self, ttid):
self.assertNotIn(ttid, self.manager._transaction_dict)
def testSimpleCase(self):
""" One node, one transaction, not abort """
data_id_list = random.random(), random.random()
......@@ -127,18 +124,23 @@ class TransactionManagerTests(NeoUnitTestBase):
serial1, object1 = self._getObject(1)
serial2, object2 = self._getObject(2)
self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self.manager.storeObject(ttid, serial1, *object1)
self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid, txn[0])
self._checkTransactionStored(tid, [
self.assertRegistered(ttid)
self.manager.vote(ttid, txn)
user, desc, ext, oid_list = txn
call, = self.app.dm.mockGetNamedCalls('storeTransaction')
call.checkArgs(ttid, [
(object1[0], data_id_list[0], object1[4]),
(object2[0], data_id_list[1], object2[4]),
], txn + (ttid,))
], (oid_list, user, desc, ext, False, ttid))
self.manager.lock(ttid, tid)
call, = self.app.dm.mockGetNamedCalls('lockTransaction')
call.checkArgs(tid, ttid)
self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager)
self._checkTransactionFinished(tid)
self.assertNotRegistered(ttid)
call, = self.app.dm.mockGetNamedCalls('unlockTransaction')
call.checkArgs(tid, ttid)
def testDelayed(self):
""" Two transactions, the first cause the second to be delayed """
......@@ -150,14 +152,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1)
# first transaction lock the object
self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self._storeTransactionObjects(ttid1, txn1)
self.manager.lock(ttid1, tid1, txn1[0])
self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# the second is delayed
self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.assertTrue(ttid2 in self.manager)
self.assertRegistered(ttid2)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial, *obj)
......@@ -171,14 +172,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1)
# the (later) transaction lock (change) the object
self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.assertTrue(ttid2 in self.manager)
self.assertRegistered(ttid2)
self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0])
self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the previous it's not using the latest version
self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial, *obj)
......@@ -191,7 +191,6 @@ class TransactionManagerTests(NeoUnitTestBase):
# try to store without the last revision
self.app.dm = Mock({'getLastObjectTID': next_serial})
self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.assertRaises(ConflictError, self.manager.storeObject,
tid, serial, *obj)
......@@ -208,15 +207,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2)
# first transaction lock objects
self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self.manager.storeObject(ttid1, serial1, *obj1)
self.manager.storeObject(ttid1, serial1, *obj2)
self.manager.lock(ttid1, tid1, txn1[0])
self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# second transaction is delayed
self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.assertTrue(ttid2 in self.manager)
self.assertRegistered(ttid2)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial1, *obj1)
self.assertRaises(DelayedError, self.manager.storeObject,
......@@ -235,15 +233,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2)
# the second transaction lock objects
self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeObject(ttid2, serial1, *obj1)
self.manager.storeObject(ttid2, serial2, *obj2)
self.assertTrue(ttid2 in self.manager)
self.manager.lock(ttid2, tid2, txn1[0])
self.assertRegistered(ttid2)
self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the first get a conflict
self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject,
......@@ -255,12 +252,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.manager.storeObject(tid, serial, *obj)
self.assertTrue(tid in self.manager)
self.assertRegistered(tid)
self.manager.vote(tid, txn)
# transaction is not locked
self.manager.abort(tid)
self.assertFalse(tid in self.manager)
self.assertNotRegistered(tid)
self.assertFalse(self.manager.loadLocked(obj[0]))
self._checkQueuedEventExecuted()
......@@ -270,14 +267,14 @@ class TransactionManagerTests(NeoUnitTestBase):
ttid = self.getNextTID()
tid, txn = self._getTransaction()
self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn)
self.manager.vote(ttid, txn)
# lock transaction
self.manager.lock(ttid, tid, txn[0])
self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.abort(ttid)
self.assertTrue(ttid in self.manager)
for oid in txn[0]:
self.assertRegistered(ttid)
for oid in txn[-1]:
self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0)
......@@ -295,20 +292,20 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.register(uuid1, ttid1)
self.manager.register(uuid2, ttid2)
self.manager.register(uuid2, ttid3)
self.manager.storeTransaction(ttid1, *txn1)
self.manager.vote(ttid1, txn1)
# node 2 owns tid2 & tid3 and lock tid2 only
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeTransaction(ttid3, *txn3)
self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0])
self.assertTrue(ttid1 in self.manager)
self.assertTrue(ttid2 in self.manager)
self.assertTrue(ttid3 in self.manager)
self.manager.vote(ttid2, txn2)
self.manager.vote(ttid3, txn3)
self.manager.lock(ttid2, tid2)
self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertRegistered(ttid3)
self.manager.abortFor(uuid2)
# only tid3 is aborted
self.assertTrue(ttid1 in self.manager)
self.assertTrue(ttid2 in self.manager)
self.assertFalse(ttid3 in self.manager)
self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertNotRegistered(ttid3)
self._checkQueuedEventExecuted(number=1)
def testReset(self):
......@@ -317,12 +314,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction()
ttid = self.getNextTID()
self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn)
self.manager.lock(ttid, tid, txn[0])
self.assertTrue(ttid in self.manager)
self.manager.vote(ttid, txn)
self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.reset()
self.assertFalse(ttid in self.manager)
self.assertNotRegistered(ttid)
for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid))
......
#
# Copyright (C) 2009-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.verification import VerificationHandler
from neo.lib.protocol import CellStates, ErrorCodes
from neo.lib.exception import PrimaryFailure
from neo.lib.util import p64, u64
class StorageVerificationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.verification = VerificationHandler(self.app)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.client_port = 11011
self.num_partitions = 1009
self.num_replicas = 2
self.app.operational = False
self.app.load_lock_dict = {}
self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageVerificationHandlerTests, self)._tearDown(success)
# Common methods
def getMasterConnection(self):
return self.getFakeConnection(address=("127.0.0.1", self.master_port))
# Tests
def test_03_connectionClosed(self):
conn = self.getMasterConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
# nothing happens
self.checkNoPacketSent(conn)
def test_08_askPartitionTable(self):
node = self.app.nm.createStorage(
address=("127.7.9.9", 1),
uuid=self.getStorageUUID()
)
self.app.pt.setCell(1, node, CellStates.UP_TO_DATE)
self.assertTrue(self.app.pt.hasOffset(1))
conn = self.getMasterConnection()
self.verification.askPartitionTable(conn)
ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
self.assertEqual(len(row_list), 1009)
def test_10_notifyPartitionChanges(self):
# old partition change
conn = self.getMasterConnection()
self.verification.notifyPartitionChanges(conn, 1, ())
self.verification.notifyPartitionChanges(conn, 0, ())
self.assertEqual(self.app.pt.getID(), 1)
# new node
conn = self.getMasterConnection()
new_uuid = self.getStorageUUID()
cell = (0, new_uuid, CellStates.UP_TO_DATE)
self.app.nm.createStorage(uuid=new_uuid)
self.app.pt = PartitionTable(1, 1)
self.app.dm = Mock({ })
ptid = self.getPTID()
# pt updated
self.verification.notifyPartitionChanges(conn, ptid, (cell, ))
# check db update
calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].getParam(0), ptid)
self.assertEqual(calls[0].getParam(1), (cell, ))
def test_13_askUnfinishedTransactions(self):
# client connection with no data
self.app.dm = Mock({
'getUnfinishedTIDList': [],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0)
call_list = self.app.dm.mockGetNamedCalls('getUnfinishedTIDList')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
# client connection with some data
self.app.dm = Mock({
'getUnfinishedTIDList': [p64(4)],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1)
self.assertEqual(u64(tid_list[0]), 4)
def test_14_askTransactionInformation(self):
# ask from client conn with no data
self.app.dm = Mock({
'getTransaction': None,
})
conn = self.getMasterConnection()
tid = p64(1)
self.verification.askTransactionInformation(conn, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('getTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, all=True)
# input some tmp data and ask from client, must find both transaction
self.app.dm = Mock({
'getTransaction': ([p64(2)], 'u2', 'd2', 'e2', False),
})
conn = self.getMasterConnection()
self.verification.askTransactionInformation(conn, p64(1))
tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1)
self.assertEqual(user, 'u2')
self.assertEqual(desc, 'd2')
self.assertEqual(ext, 'e2')
self.assertFalse(packed)
self.assertEqual(len(oid_list), 1)
self.assertEqual(u64(oid_list[0]), 2)
def test_15_askObjectPresent(self):
# client connection with no data
self.app.dm = Mock({
'objectPresent': False,
})
conn = self.getMasterConnection()
oid, tid = p64(1), p64(2)
self.verification.askObjectPresent(conn, oid, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('objectPresent')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(oid, tid)
# client connection with some data
self.app.dm = Mock({
'objectPresent': True,
})
conn = self.getMasterConnection()
self.verification.askObjectPresent(conn, oid, tid)
oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
self.assertEqual(u64(tid), 2)
self.assertEqual(u64(oid), 1)
def test_16_deleteTransaction(self):
# client connection with no data
self.app.dm = Mock({
'deleteTransaction': None,
})
conn = self.getMasterConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = p64(1)
self.verification.deleteTransaction(conn, tid, oid_list)
call_list = self.app.dm.mockGetNamedCalls('deleteTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, oid_list)
def test_17_commitTransaction(self):
# commit a transaction
conn = self.getMasterConnection()
dm = Mock()
self.app.dm = dm
self.verification.commitTransaction(conn, p64(1))
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
call = dm.mockGetNamedCalls("finishTransaction")[0]
tid = call.getParam(0)
self.assertEqual(u64(tid), 1)
if __name__ == "__main__":
unittest.main()
......@@ -33,7 +33,7 @@ from neo.lib import logging
from neo.lib.connection import BaseConnection, Connection
from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.locking import SimpleQueue
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
......@@ -478,6 +478,8 @@ class ConnectionFilter(object):
queue.appendleft(packet)
break
else:
if conn.isClosed():
return
cls._addPacket(conn, packet)
continue
break
......@@ -731,9 +733,12 @@ class NEOCluster(object):
return node[3]
def getOutdatedCells(self):
return [cell for row in self.neoctl.getPartitionRowList()[1]
for cell in row[1]
if cell[1] == CellStates.OUT_OF_DATE]
# Ask the admin instead of the primary master to check that it is
# notified of every change.
return [(i, cell.getUUID())
for i, row in enumerate(self.admin.pt.partition_list)
for cell in row
if not cell.isReadable()]
def getZODBStorage(self, **kw):
kw['_app'] = kw.pop('client', self.client)
......
......@@ -21,12 +21,12 @@ import transaction
import unittest
from thread import get_ident
from zlib import compress
from persistent import Persistent
from persistent import Persistent, GHOST
from ZODB import DB, POSException
from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError
from neo.lib.connection import ConnectionClosed, MTClientConnection
from neo.lib.exception import OperationFailure
from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID
from .. import expectedFailure, _ExpectedFailure, _UnexpectedSuccess, Patch
......@@ -34,6 +34,7 @@ from . import NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
from neo.storage.handlers.initialization import InitializationHandler
class PCounter(Persistent):
value = 0
......@@ -394,51 +395,122 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testRestartStoragesWithReplicas(self):
"""
Check that the master must discard its partition table when the
cluster is not operational anymore. Which means that it must go back
to RECOVERING state and remain there as long as the partition table
can't be operational.
This also checks that if the master remains the primary one after going
back to recovery, it automatically starts the cluster if possible
(i.e. without manual intervention).
"""
outdated = []
def doOperation(orig):
outdated.append(cluster.getOutdatedCells())
orig()
def stop():
with cluster.master.filterConnection(s0) as m2s0:
m2s0.add(lambda conn, packet:
isinstance(packet, Packets.NotifyPartitionChanges))
s1.stop()
cluster.join((s1,))
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(cluster.getOutdatedCells(),
[(0, s1.uuid), (1, s1.uuid)])
s0.stop()
cluster.join((s0,))
self.assertNotEqual(getClusterState(), ClusterStates.RUNNING)
s0.resetNode()
s1.resetNode()
cluster = NEOCluster(storage_count=2, partitions=2, replicas=1)
try:
cluster.start()
s0, s1 = cluster.storage_list
getClusterState = cluster.neoctl.getClusterState
if 1:
# Scenario 1: When all storage nodes are restarting,
# we want a chance to not restart with outdated cells.
stop()
with Patch(s1, doOperation=doOperation):
s0.start()
s1.start()
self.tic()
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(outdated, [[]])
if 1:
# Scenario 2: When only the first storage node to be stopped
# is started, the cluster must be able to restart.
stop()
s1.start()
self.tic()
# The master doesn't wait for s0 to come back.
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(cluster.getOutdatedCells(),
[(0, s0.uuid), (1, s0.uuid)])
finally:
cluster.stop()
def testVerificationCommitUnfinishedTransactions(self):
""" Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreTransaction(storage, die=False):
def storeTransaction(orig, *args, **kw):
orig(*args, **kw)
def onLockTransaction(storage, die=False):
def lock(orig, *args, **kw):
if die:
sys.exit()
orig(*args, **kw)
storage.master_conn.close()
return Patch(storage.dm, storeTransaction=storeTransaction)
return Patch(storage.tm, lock=lock)
cluster = NEOCluster(partitions=2, storage_count=2)
try:
cluster.start()
s0, s1 = cluster.sortStorageList()
t, c = cluster.getTransaction()
r = c.root()
r[0] = x = PCounter()
r[0] = PCounter()
tids = [r._p_serial]
t.commit()
tids.append(r._p_serial)
r[1] = PCounter()
with onStoreTransaction(s0), onStoreTransaction(s1):
with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit)
self.assertEqual(r._p_state, GHOST)
self.tic()
t.begin()
y = r[1]
self.assertEqual(y.value, 0)
assert [u64(o._p_oid) for o in (r, x, y)] == range(3)
r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s, \
cluster.moduloTID(1):
m2s.add(delayUnlockInformation)
t.commit()
x.value = 1
# s0 will accept to store y (because it's not locked) but will
# never lock the transaction (packets from master delayed),
# so the last transaction will be dropped.
y.value = 2
with onStoreTransaction(s1, die=True):
x = r[0]
self.assertEqual(x.value, 0)
cluster.master.tm._last_oid = x._p_oid
tids.append(r._p_serial)
r[1] = PCounter()
c.readCurrent(x)
with cluster.moduloTID(1):
with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit)
self.tic()
t.begin()
# The following line checks that s1 moved the transaction
# metadata to final place during the verification phase.
# If it didn't, a NEOStorageError would be raised.
self.assertEqual(3, len(c.db().history(r._p_oid, 4)))
y = r[1]
self.assertEqual(y.value, 0)
self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s:
m2s.add(delayUnlockInformation)
t.commit()
x.value = 1
# s0 will accept to store y (because it's not locked) but will
# never lock the transaction (packets from master delayed),
# so the last transaction will be dropped.
y.value = 2
di0 = s0.getDataLockInfo()
with onLockTransaction(s1, die=True):
self.assertRaises(ConnectionClosed, t.commit)
finally:
cluster.stop()
cluster.reset()
di0 = s0.getDataLockInfo()
(k, v), = set(s0.getDataLockInfo().iteritems()
).difference(di0.iteritems())
self.assertEqual(v, 1)
k, = (k for k, v in di0.iteritems() if v == 1)
di0[k] = 0 # r[2] = 'ok'
self.assertEqual(di0.values(), [0, 0, 0, 0, 0])
......@@ -458,6 +530,51 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testDropUnfinishedData(self):
def lock(orig, *args, **kw):
orig(*args, **kw)
storage.master_conn.close()
r = []
def dropUnfinishedData(orig):
r.append(len(orig.__self__.getUnfinishedTIDDict()))
orig()
r.append(len(orig.__self__.getUnfinishedTIDDict()))
cluster = NEOCluster(partitions=2, storage_count=2, replicas=1)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()._p_changed = 1
storage = cluster.storage_list[0]
with Patch(storage.tm, lock=lock), \
Patch(storage.dm, dropUnfinishedData=dropUnfinishedData):
t.commit()
self.tic()
self.assertEqual(r, [1, 0])
finally:
cluster.stop()
def testStorageUpgrade1(self):
cluster = NEOCluster()
try:
cluster.start()
storage = cluster.storage
t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1
t.commit()
storage.stop()
cluster.join((storage,))
storage.resetNode()
storage.start()
t.begin()
storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1
with Patch(storage.tm, lock=lambda *_: sys.exit()):
self.assertRaises(ConnectionClosed, t.commit)
self.assertRaises(DatabaseFailure, storage.resetNode)
finally:
cluster.stop()
def testStorageReconnectDuringStore(self):
cluster = NEOCluster(replicas=1)
try:
......@@ -550,12 +667,17 @@ class Test(NEOThreadedTest):
cluster.start()
# prevent storage to reconnect, in order to easily test
# that cluster becomes non-operational
storage.connectToPrimary = sys.exit
# send an unexpected to master so it aborts connection to storage
storage.master_conn.answer(Packets.Pong())
with Patch(storage, connectToPrimary=sys.exit):
# send an unexpected to master so it aborts connection to storage
storage.master_conn.answer(Packets.Pong())
self.tic()
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RECOVERING)
storage.resetNode()
storage.start()
self.tic()
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.VERIFYING)
ClusterStates.RUNNING)
finally:
cluster.stop()
......@@ -833,7 +955,7 @@ class Test(NEOThreadedTest):
def testStorageFailureDuringTpcFinish(self):
def answerTransactionFinished(conn, packet):
if isinstance(packet, Packets.AnswerTransactionFinished):
raise OperationFailure
raise StoppedOperation
cluster = NEOCluster()
try:
cluster.start()
......@@ -849,8 +971,12 @@ class Test(NEOThreadedTest):
raise _UnexpectedSuccess
except ConnectionClosed, e:
e = type(e), None, None
# Also check that the master reset the last oid to a correct value.
self.assertTrue(cluster.client.new_oid_list)
t.begin()
self.assertIn('x', c.root())
self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list)
self.assertEqual(2, u64(cluster.client.new_oid()))
finally:
cluster.stop()
raise _ExpectedFailure(e)
......@@ -908,5 +1034,111 @@ class Test(NEOThreadedTest):
cluster.stop()
del cluster.startCluster
def testAbortVotedTransaction(self):
r = []
def tpc_finish(*args, **kw):
for storage in cluster.storage_list:
r.append(len(storage.dm.getUnfinishedTIDDict()))
raise NEOStorageError
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()['x'] = PCounter()
with Patch(cluster.client, tpc_finish=tpc_finish):
self.assertRaises(NEOStorageError, t.commit)
self.tic()
self.assertEqual(r, [1, 1])
for storage in cluster.storage_list:
self.assertFalse(storage.dm.getUnfinishedTIDDict())
t.begin()
self.assertNotIn('x', c.root())
finally:
cluster.stop()
def testStorageLostDuringRecovery(self):
# Initialize a cluster.
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
finally:
cluster.stop()
cluster.reset()
# Restart with a connection failure for the first AskPartitionTable.
# The master must not be stuck in RECOVERING state
# or re-make the partition table.
def make(*args):
sys.exit()
def askPartitionTable(orig, self, conn):
p.revert()
conn.close()
try:
with Patch(cluster.master.pt, make=make), \
Patch(InitializationHandler,
askPartitionTable=askPartitionTable) as p:
cluster.start()
self.assertFalse(p.applied)
finally:
cluster.stop()
def testTruncate(self):
calls = [0, 0]
def dieFirst(i):
def f(orig, *args, **kw):
calls[i] += 1
if calls[i] == 1:
sys.exit()
return orig(*args, **kw)
return f
cluster = NEOCluster(replicas=1)
try:
cluster.start()
t, c = cluster.getTransaction()
r = c.root()
tids = []
for x in xrange(4):
r[x] = None
t.commit()
tids.append(r._p_serial)
truncate_tid = tids[2]
r['x'] = PCounter()
s0, s1 = cluster.storage_list
with Patch(s0.tm, unlock=dieFirst(0)), \
Patch(s1.dm, truncate=dieFirst(1)):
t.commit()
cluster.neoctl.truncate(truncate_tid)
self.tic()
getClusterState = cluster.neoctl.getClusterState
# Unless forced, the cluster waits all nodes to be up,
# so that all nodes are truncated.
self.assertEqual(getClusterState(), ClusterStates.RECOVERING)
self.assertEqual(calls, [1, 0])
s0.resetNode()
s0.start()
# s0 died with unfinished data, and before processing the
# Truncate packet from the master.
self.assertFalse(s0.dm.getTruncateTID())
self.assertEqual(s1.dm.getTruncateTID(), truncate_tid)
self.tic()
self.assertEqual(calls, [1, 1])
self.assertEqual(getClusterState(), ClusterStates.RECOVERING)
s1.resetNode()
with Patch(s1.dm, truncate=dieFirst(1)):
s1.start()
self.assertEqual(s0.dm.getLastIDs()[0], truncate_tid)
self.assertEqual(s1.dm.getLastIDs()[0], r._p_serial)
self.tic()
self.assertEqual(calls, [1, 2])
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
t.begin()
self.assertEqual(r, dict.fromkeys(xrange(3)))
self.assertEqual(r._p_serial, truncate_tid)
self.assertEqual(1, u64(c._storage.new_oid()))
for s in cluster.storage_list:
self.assertEqual(s.dm.getLastIDs()[0], truncate_tid)
finally:
cluster.stop()
if __name__ == "__main__":
unittest.main()
......@@ -424,7 +424,7 @@ class ReplicationTests(NEOThreadedTest):
check(ClusterStates.RUNNING, 1)
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
self.tic()
check(ClusterStates.VERIFYING, 4)
check(ClusterStates.RECOVERING, 4)
finally:
checker.CHECK_COUNT = CHECK_COUNT
cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment