Commit 8eb14b01 authored by Julien Muchembled's avatar Julien Muchembled

Bump protocol version

parents 54e819ff 9385706f
...@@ -105,13 +105,9 @@ class Application(BaseApplication): ...@@ -105,13 +105,9 @@ class Application(BaseApplication):
""" """
self.cluster_state = None self.cluster_state = None
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN, bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server)
self.uuid, self.server) self.master_node, self.master_conn, num_partitions, num_replicas = \
data = bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node
self.master_conn = conn
self.uuid = uuid
if self.pt is None: if self.pt is None:
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
...@@ -125,7 +121,6 @@ class Application(BaseApplication): ...@@ -125,7 +121,6 @@ class Application(BaseApplication):
# passive handler # passive handler
self.master_conn.setHandler(self.master_event_handler) self.master_conn.setHandler(self.master_event_handler)
self.master_conn.ask(Packets.AskClusterState()) self.master_conn.ask(Packets.AskClusterState())
self.master_conn.ask(Packets.AskNodeInformation())
self.master_conn.ask(Packets.AskPartitionTable()) self.master_conn.ask(Packets.AskPartitionTable())
def sendPartitionTable(self, conn, min_offset, max_offset, uuid): def sendPartitionTable(self, conn, min_offset, max_offset, uuid):
......
...@@ -106,11 +106,6 @@ class MasterEventHandler(EventHandler): ...@@ -106,11 +106,6 @@ class MasterEventHandler(EventHandler):
def answerClusterState(self, conn, state): def answerClusterState(self, conn, state):
self.app.cluster_state = state self.app.cluster_state = state
def answerNodeInformation(self, conn):
# XXX: This will no more exists when the initialization module will be
# implemented for factorize code (as done for bootstrap)
logging.debug("answerNodeInformation")
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
...@@ -125,8 +120,6 @@ class MasterEventHandler(EventHandler): ...@@ -125,8 +120,6 @@ class MasterEventHandler(EventHandler):
def notifyClusterInformation(self, conn, cluster_state): def notifyClusterInformation(self, conn, cluster_state):
self.app.cluster_state = cluster_state self.app.cluster_state = cluster_state
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
class MasterRequestEventHandler(EventHandler): class MasterRequestEventHandler(EventHandler):
""" This class handle all answer from primary master node""" """ This class handle all answer from primary master node"""
......
...@@ -240,10 +240,10 @@ class Application(ThreadedApplication): ...@@ -240,10 +240,10 @@ class Application(ThreadedApplication):
self.notifications_handler, self.notifications_handler,
node=self.trying_master_node, node=self.trying_master_node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification(
NodeTypes.CLIENT, self.uuid, None, self.name, None)
try: try:
ask(conn, Packets.RequestIdentification( ask(conn, p, handler=handler)
NodeTypes.CLIENT, self.uuid, None, self.name),
handler=handler)
except ConnectionClosed: except ConnectionClosed:
continue continue
# If we reached the primary master node, mark as connected # If we reached the primary master node, mark as connected
...@@ -256,7 +256,6 @@ class Application(ThreadedApplication): ...@@ -256,7 +256,6 @@ class Application(ThreadedApplication):
# operational. Might raise ConnectionClosed so that the new # operational. Might raise ConnectionClosed so that the new
# primary can be looked-up again. # primary can be looked-up again.
logging.info('Initializing from master') logging.info('Initializing from master')
ask(conn, Packets.AskNodeInformation(), handler=handler)
ask(conn, Packets.AskPartitionTable(), handler=handler) ask(conn, Packets.AskPartitionTable(), handler=handler)
ask(conn, Packets.AskLastTransaction(), handler=handler) ask(conn, Packets.AskLastTransaction(), handler=handler)
if self.pt.operational(): if self.pt.operational():
......
...@@ -30,6 +30,16 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -30,6 +30,16 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
self.app.trying_master_node = None self.app.trying_master_node = None
conn.close() conn.close()
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerLastTransaction(*args):
pass
class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """
def _acceptIdentification(self, node, uuid, num_partitions, def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list): num_replicas, your_uuid, primary, known_master_list):
app = self.app app = self.app
...@@ -77,27 +87,13 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -77,27 +87,13 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
raise ProtocolError('No UUID supplied') raise ProtocolError('No UUID supplied')
app.uuid = your_uuid app.uuid = your_uuid
logging.info('Got an UUID: %s', dump(app.uuid)) logging.info('Got an UUID: %s', dump(app.uuid))
app.id_timestamp = None
# Always create partition table # Always create partition table
app.pt = PartitionTable(num_partitions, num_replicas) app.pt = PartitionTable(num_partitions, num_replicas)
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
pass
class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """
def packetReceived(self, conn, packet, kw={}):
if type(packet) is Packets.AnswerLastTransaction:
app = self.app app = self.app
ltid = packet.decode()[0]
if app.last_tid != ltid: if app.last_tid != ltid:
# Either we're connecting or we already know the last tid # Either we're connecting or we already know the last tid
# via invalidations. # via invalidations.
...@@ -124,15 +120,15 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -124,15 +120,15 @@ class PrimaryNotificationsHandler(MTEventHandler):
db = app.getDB() db = app.getDB()
db is None or db.invalidateCache() db is None or db.invalidateCache()
app.last_tid = ltid app.last_tid = ltid
elif type(packet) is Packets.AnswerTransactionFinished:
def answerTransactionFinished(self, conn, _, tid, callback, cache_dict):
app = self.app app = self.app
app.last_tid = tid = packet.decode()[1] app.last_tid = tid
callback = kw.pop('callback')
# Update cache # Update cache
cache = app._cache cache = app._cache
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
for oid, data in kw.pop('cache_dict').iteritems(): for oid, data in cache_dict.iteritems():
# Update ex-latest value in cache # Update ex-latest value in cache
cache.invalidate(oid, tid) cache.invalidate(oid, tid)
if data is not None: if data is not None:
...@@ -142,7 +138,6 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -142,7 +138,6 @@ class PrimaryNotificationsHandler(MTEventHandler):
callback(tid) callback(tid)
finally: finally:
app._cache_lock_release() app._cache_lock_release()
MTEventHandler.packetReceived(self, conn, packet, kw)
def connectionClosed(self, conn): def connectionClosed(self, conn):
app = self.app app = self.app
...@@ -185,13 +180,14 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -185,13 +180,14 @@ class PrimaryNotificationsHandler(MTEventHandler):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
nm = self.app.nm super(PrimaryNotificationsHandler, self).notifyNodeInformation(
nm.update(node_list) conn, node_list)
# XXX: 'update' automatically closes DOWN nodes. Do we really want # XXX: 'update' automatically closes DOWN nodes. Do we really want
# to do the same thing for nodes in other non-running states ? # to do the same thing for nodes in other non-running states ?
for node_type, addr, uuid, state in node_list: getByUUID = self.app.nm.getByUUID
if state != NodeStates.RUNNING: for node in node_list:
node = nm.getByUUID(uuid) if node[3] != NodeStates.RUNNING:
node = getByUUID(node[2])
if node and node.isConnected(): if node and node.isConnected():
node.getConnection().close() node.getConnection().close()
......
...@@ -41,14 +41,6 @@ class StorageEventHandler(MTEventHandler): ...@@ -41,14 +41,6 @@ class StorageEventHandler(MTEventHandler):
self.app.cp.removeConnection(node) self.app.cp.removeConnection(node)
super(StorageEventHandler, self).connectionFailed(conn) super(StorageEventHandler, self).connectionFailed(conn)
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
def notReady(self, conn, message):
conn.close()
raise NodeNotReady(message)
def _acceptIdentification(self, node, def _acceptIdentification(self, node,
uuid, num_partitions, num_replicas, your_uuid, primary, uuid, num_partitions, num_replicas, your_uuid, primary,
master_list): master_list):
...@@ -57,6 +49,13 @@ class StorageBootstrapHandler(AnswerBaseHandler): ...@@ -57,6 +49,13 @@ class StorageBootstrapHandler(AnswerBaseHandler):
primary, self.app.master_conn) primary, self.app.master_conn)
assert uuid == node.getUUID(), (uuid, node.getUUID()) assert uuid == node.getUUID(), (uuid, node.getUUID())
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
def notReady(self, conn, message):
conn.close()
raise NodeNotReady(message)
class StorageAnswersHandler(AnswerBaseHandler): class StorageAnswersHandler(AnswerBaseHandler):
""" Handle all messages related to ZODB operations """ """ Handle all messages related to ZODB operations """
......
...@@ -57,7 +57,7 @@ class ConnectionPool(object): ...@@ -57,7 +57,7 @@ class ConnectionPool(object):
conn = MTClientConnection(app, app.storage_event_handler, node, conn = MTClientConnection(app, app.storage_event_handler, node,
dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name, app.id_timestamp)
try: try:
app._ask(conn, p, handler=app.storage_bootstrap_handler) app._ask(conn, p, handler=app.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
......
...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler): ...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler):
""" """
accepted = False accepted = False
def __init__(self, app, name, node_type, uuid=None, server=None): def __init__(self, app, node_type, server=None):
""" """
Manage the bootstrap stage of a non-master node, it lookup for the Manage the bootstrap stage of a non-master node, it lookup for the
primary master node, connect to it then returns when the master node primary master node, connect to it then returns when the master node
...@@ -35,12 +35,12 @@ class BootstrapManager(EventHandler): ...@@ -35,12 +35,12 @@ class BootstrapManager(EventHandler):
self.primary = None self.primary = None
self.server = server self.server = server
self.node_type = node_type self.node_type = node_type
self.uuid = uuid
self.name = name
self.num_replicas = None self.num_replicas = None
self.num_partitions = None self.num_partitions = None
self.current = None self.current = None
uuid = property(lambda self: self.app.uuid)
def announcePrimary(self, conn): def announcePrimary(self, conn):
# We found the primary master early enough to be notified of election # We found the primary master early enough to be notified of election
# end. Lucky. Anyway, we must carry on with identification request, so # end. Lucky. Anyway, we must carry on with identification request, so
...@@ -55,7 +55,7 @@ class BootstrapManager(EventHandler): ...@@ -55,7 +55,7 @@ class BootstrapManager(EventHandler):
EventHandler.connectionCompleted(self, conn) EventHandler.connectionCompleted(self, conn)
self.current.setRunning() self.current.setRunning()
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid, conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.name)) self.server, self.app.name, None))
def connectionFailed(self, conn): def connectionFailed(self, conn):
""" """
...@@ -106,8 +106,9 @@ class BootstrapManager(EventHandler): ...@@ -106,8 +106,9 @@ class BootstrapManager(EventHandler):
self.num_replicas = num_replicas self.num_replicas = num_replicas
if self.uuid != your_uuid: if self.uuid != your_uuid:
# got an uuid from the primary master # got an uuid from the primary master
self.uuid = your_uuid self.app.uuid = your_uuid
logging.info('Got a new UUID: %s', uuid_str(self.uuid)) logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.app.id_timestamp = None
self.accepted = True self.accepted = True
def getPrimaryConnection(self): def getPrimaryConnection(self):
...@@ -141,8 +142,4 @@ class BootstrapManager(EventHandler): ...@@ -141,8 +142,4 @@ class BootstrapManager(EventHandler):
continue continue
# still processing # still processing
poll(1) poll(1)
return (self.current, conn, self.uuid, self.num_partitions, return self.current, conn, self.num_partitions, self.num_replicas
self.num_replicas)
...@@ -165,6 +165,10 @@ class EventHandler(object): ...@@ -165,6 +165,10 @@ class EventHandler(object):
return return
conn.close() conn.close()
def notifyNodeInformation(self, conn, node_list):
app = self.app
app.nm.update(app, node_list)
def ping(self, conn): def ping(self, conn):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
...@@ -227,6 +231,9 @@ class MTEventHandler(EventHandler): ...@@ -227,6 +231,9 @@ class MTEventHandler(EventHandler):
def packetReceived(self, conn, packet, kw={}): def packetReceived(self, conn, packet, kw={}):
"""Redirect all received packet to dispatcher thread.""" """Redirect all received packet to dispatcher thread."""
if packet.isResponse(): if packet.isResponse():
if packet.poll_thread:
self.dispatch(conn, packet, kw)
kw = {}
if not (self.dispatcher.dispatch(conn, packet.getId(), packet, kw) if not (self.dispatcher.dispatch(conn, packet.getId(), packet, kw)
or type(packet) is Packets.Pong): or type(packet) is Packets.Pong):
raise ProtocolError('Unexpected response packet from %r: %r' raise ProtocolError('Unexpected response packet from %r: %r'
...@@ -254,3 +261,6 @@ class AnswerBaseHandler(EventHandler): ...@@ -254,3 +261,6 @@ class AnswerBaseHandler(EventHandler):
packetReceived = unexpectedInAnswerHandler packetReceived = unexpectedInAnswerHandler
peerBroken = unexpectedInAnswerHandler peerBroken = unexpectedInAnswerHandler
protocolError = unexpectedInAnswerHandler protocolError = unexpectedInAnswerHandler
def acceptIdentification(*args):
pass
...@@ -27,6 +27,7 @@ class Node(object): ...@@ -27,6 +27,7 @@ class Node(object):
_connection = None _connection = None
_identified = False _identified = False
id_timestamp = None
def __init__(self, manager, address=None, uuid=None, def __init__(self, manager, address=None, uuid=None,
state=NodeStates.UNKNOWN): state=NodeStates.UNKNOWN):
...@@ -172,7 +173,8 @@ class Node(object): ...@@ -172,7 +173,8 @@ class Node(object):
def asTuple(self): def asTuple(self):
""" Returned tuple is intended to be used in protocol encoders """ """ Returned tuple is intended to be used in protocol encoders """
return (self.getType(), self._address, self._uuid, self._state) return (self.getType(), self._address, self._uuid, self._state,
self.id_timestamp)
def __gt__(self, node): def __gt__(self, node):
# sort per UUID if defined # sort per UUID if defined
...@@ -348,9 +350,11 @@ class NodeManager(object): ...@@ -348,9 +350,11 @@ class NodeManager(object):
""" Return the node that match with a given address """ """ Return the node that match with a given address """
return self._address_dict.get(address, None) return self._address_dict.get(address, None)
def getByUUID(self, uuid): def getByUUID(self, uuid, *id_timestamp):
""" Return the node that match with a given UUID """ """ Return the node that match with a given UUID """
return self._uuid_dict.get(uuid, None) node = self._uuid_dict.get(uuid)
if not id_timestamp or node and (node.id_timestamp,) == id_timestamp:
return node
def _createNode(self, klass, address=None, uuid=None, **kw): def _createNode(self, klass, address=None, uuid=None, **kw):
by_address = self.getByAddress(address) by_address = self.getByAddress(address)
...@@ -386,8 +390,9 @@ class NodeManager(object): ...@@ -386,8 +390,9 @@ class NodeManager(object):
def createFromNodeType(self, node_type, **kw): def createFromNodeType(self, node_type, **kw):
return self._createNode(NODE_TYPE_MAPPING[node_type], **kw) return self._createNode(NODE_TYPE_MAPPING[node_type], **kw)
def update(self, node_list): def update(self, app, node_list):
for node_type, addr, uuid, state in node_list: node_set = self._node_set.copy() if app.id_timestamp is None else None
for node_type, addr, uuid, state, id_timestamp in node_list:
# This should be done here (although klass might not be used in this # This should be done here (although klass might not be used in this
# iteration), as it raises if type is not valid. # iteration), as it raises if type is not valid.
klass = NODE_TYPE_MAPPING[node_type] klass = NODE_TYPE_MAPPING[node_type]
...@@ -397,14 +402,14 @@ class NodeManager(object): ...@@ -397,14 +402,14 @@ class NodeManager(object):
node_by_addr = self.getByAddress(addr) node_by_addr = self.getByAddress(addr)
node = node_by_uuid or node_by_addr node = node_by_uuid or node_by_addr
log_args = node_type, uuid_str(uuid), addr, state log_args = node_type, uuid_str(uuid), addr, state, id_timestamp
if node is None: if node is None:
if state == NodeStates.DOWN: if state == NodeStates.DOWN:
logging.debug('NOT creating node %s %s %s %s', *log_args) logging.debug('NOT creating node %s %s %s %s %s', *log_args)
else: continue
node = self._createNode(klass, address=addr, uuid=uuid, node = self._createNode(klass, address=addr, uuid=uuid,
state=state) state=state)
logging.debug('creating node %r', node) logging.debug('creating node %r', node)
else: else:
assert isinstance(node, klass), 'node %r is not ' \ assert isinstance(node, klass), 'node %r is not ' \
'of expected type: %r' % (node, klass) 'of expected type: %r' % (node, klass)
...@@ -414,7 +419,7 @@ class NodeManager(object): ...@@ -414,7 +419,7 @@ class NodeManager(object):
'node_by_addr (%r)' % (node_by_uuid, node_by_addr) 'node_by_addr (%r)' % (node_by_uuid, node_by_addr)
if state == NodeStates.DOWN: if state == NodeStates.DOWN:
logging.debug('dropping node %r (%r), found with %s ' logging.debug('dropping node %r (%r), found with %s '
'%s %s %s', node, node.isConnected(), *log_args) '%s %s %s %s', node, node.isConnected(), *log_args)
if node.isConnected(): if node.isConnected():
# Cut this connection, node removed by handler. # Cut this connection, node removed by handler.
# It's important for a storage to disconnect nodes that # It's important for a storage to disconnect nodes that
...@@ -424,12 +429,20 @@ class NodeManager(object): ...@@ -424,12 +429,20 @@ class NodeManager(object):
# partition table upon disconnection. # partition table upon disconnection.
node.getConnection().close() node.getConnection().close()
self.remove(node) self.remove(node)
else: continue
logging.debug('updating node %r to %s %s %s %s', logging.debug('updating node %r to %s %s %s %s %s',
node, *log_args) node, *log_args)
node.setUUID(uuid) node.setUUID(uuid)
node.setAddress(addr) node.setAddress(addr)
node.setState(state) node.setState(state)
node.id_timestamp = id_timestamp
if app.uuid == uuid:
app.id_timestamp = id_timestamp
if node_set:
# For the first notification, we receive a full list of nodes from
# the master. Remove all unknown nodes from a previous connection.
for node in node_set - self._node_set:
self.remove(node)
self.log() self.log()
def log(self): def log(self):
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
PROTOCOL_VERSION = 7 PROTOCOL_VERSION = 8
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -234,6 +234,7 @@ class Packet(object): ...@@ -234,6 +234,7 @@ class Packet(object):
_code = None _code = None
_fmt = None _fmt = None
_id = None _id = None
poll_thread = False
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
...@@ -594,6 +595,13 @@ class PTID(PItem): ...@@ -594,6 +595,13 @@ class PTID(PItem):
# same definition, for now # same definition, for now
POID = PTID POID = PTID
class PFloat(PStructItemOrNone):
"""
A float number (8-bytes length)
"""
_fmt = '!d'
_None = '\xff' * 8
# common definitions # common definitions
PFEmpty = PStruct('no_content') PFEmpty = PStruct('no_content')
...@@ -607,6 +615,7 @@ PFNodeList = PList('node_list', ...@@ -607,6 +615,7 @@ PFNodeList = PList('node_list',
PAddress('address'), PAddress('address'),
PUUID('uuid'), PUUID('uuid'),
PFNodeState, PFNodeState,
PFloat('id_timestamp'),
), ),
) )
...@@ -680,6 +689,7 @@ class RequestIdentification(Packet): ...@@ -680,6 +689,7 @@ class RequestIdentification(Packet):
Request a node identification. This must be the first packet for any Request a node identification. This must be the first packet for any
connection. Any -> Any. connection. Any -> Any.
""" """
poll_thread = True
_fmt = PStruct('request_identification', _fmt = PStruct('request_identification',
PProtocol('protocol_version'), PProtocol('protocol_version'),
...@@ -687,6 +697,7 @@ class RequestIdentification(Packet): ...@@ -687,6 +697,7 @@ class RequestIdentification(Packet):
PUUID('uuid'), PUUID('uuid'),
PAddress('address'), PAddress('address'),
PString('name'), PString('name'),
PFloat('id_timestamp'),
) )
_answer = PStruct('accept_identification', _answer = PStruct('accept_identification',
...@@ -867,6 +878,8 @@ class FinishTransaction(Packet): ...@@ -867,6 +878,8 @@ class FinishTransaction(Packet):
Finish a transaction. C -> PM. Finish a transaction. C -> PM.
Answer when a transaction is finished. PM -> C. Answer when a transaction is finished. PM -> C.
""" """
poll_thread = True
_fmt = PStruct('ask_finish_transaction', _fmt = PStruct('ask_finish_transaction',
PTID('tid'), PTID('tid'),
PFOidList, PFOidList,
...@@ -1152,12 +1165,6 @@ class NotifyNodeInformation(Packet): ...@@ -1152,12 +1165,6 @@ class NotifyNodeInformation(Packet):
PFNodeList, PFNodeList,
) )
class NodeInformation(Packet):
"""
Ask node information
"""
_answer = PFEmpty
class SetClusterState(Packet): class SetClusterState(Packet):
""" """
Set the cluster state Set the cluster state
...@@ -1373,6 +1380,7 @@ class LastTransaction(Packet): ...@@ -1373,6 +1380,7 @@ class LastTransaction(Packet):
Answer last committed TID. Answer last committed TID.
M -> C M -> C
""" """
poll_thread = True
_answer = PStruct('answer_last_transaction', _answer = PStruct('answer_last_transaction',
PTID('tid'), PTID('tid'),
...@@ -1521,6 +1529,7 @@ def register(request, ignore_when_closed=None): ...@@ -1521,6 +1529,7 @@ def register(request, ignore_when_closed=None):
# build a class for the answer # build a class for the answer
answer = type('Answer%s' % (request.__name__, ), (Packet, ), {}) answer = type('Answer%s' % (request.__name__, ), (Packet, ), {})
answer._fmt = request._answer answer._fmt = request._answer
answer.poll_thread = request.poll_thread
# compute the answer code # compute the answer code
code = code | RESPONSE_MASK code = code | RESPONSE_MASK
answer._request = request answer._request = request
...@@ -1673,8 +1682,6 @@ class Packets(dict): ...@@ -1673,8 +1682,6 @@ class Packets(dict):
AddPendingNodes, ignore_when_closed=False) AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable = register( TweakPartitionTable = register(
TweakPartitionTable, ignore_when_closed=False) TweakPartitionTable, ignore_when_closed=False)
AskNodeInformation, AnswerNodeInformation = register(
NodeInformation)
SetClusterState = register( SetClusterState = register(
SetClusterState, ignore_when_closed=False) SetClusterState, ignore_when_closed=False)
NotifyClusterInformation = register( NotifyClusterInformation = register(
......
...@@ -43,6 +43,8 @@ class ThreadContainer(threading.local): ...@@ -43,6 +43,8 @@ class ThreadContainer(threading.local):
class ThreadedApplication(BaseApplication): class ThreadedApplication(BaseApplication):
"""The client node application.""" """The client node application."""
uuid = None
def __init__(self, master_nodes, name, **kw): def __init__(self, master_nodes, name, **kw):
super(ThreadedApplication, self).__init__(**kw) super(ThreadedApplication, self).__init__(**kw)
self.poll_thread = threading.Thread(target=self.run, name=name) self.poll_thread = threading.Thread(target=self.run, name=name)
...@@ -56,8 +58,6 @@ class ThreadedApplication(BaseApplication): ...@@ -56,8 +58,6 @@ class ThreadedApplication(BaseApplication):
for address in master_nodes: for address in master_nodes:
self.nm.createMaster(address=address) self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one
self.uuid = None
# Internal attribute distinct between thread # Internal attribute distinct between thread
self._thread_container = ThreadContainer() self._thread_container = ThreadContainer()
app_set.add(self) # to register self.on_log app_set.add(self) # to register self.on_log
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys, weakref import sys, weakref
from collections import defaultdict
from time import time from time import time
from neo.lib import logging from neo.lib import logging
...@@ -44,7 +45,6 @@ class Application(BaseApplication): ...@@ -44,7 +45,6 @@ class Application(BaseApplication):
last_transaction = ZERO_TID last_transaction = ZERO_TID
backup_tid = None backup_tid = None
backup_app = None backup_app = None
uuid = None
truncate_tid = None truncate_tid = None
def __init__(self, config): def __init__(self, config):
...@@ -79,9 +79,7 @@ class Application(BaseApplication): ...@@ -79,9 +79,7 @@ class Application(BaseApplication):
self.primary_master_node = None self.primary_master_node = None
self.cluster_state = None self.cluster_state = None
uuid = config.getUUID() self.uuid = config.getUUID()
if uuid:
self.uuid = uuid
# election related data # election related data
self.unconnected_master_node_set = set() self.unconnected_master_node_set = set()
...@@ -227,19 +225,20 @@ class Application(BaseApplication): ...@@ -227,19 +225,20 @@ class Application(BaseApplication):
Broadcast changes for a set a nodes Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth Send only one packet per connection to reduce bandwidth
""" """
node_dict = {} node_dict = defaultdict(list)
# group modified nodes by destination node type # group modified nodes by destination node type
for node in node_list: for node in node_list:
node_info = node.asTuple() node_info = node.asTuple()
def assign_for_notification(node_type): if node.isAdmin():
# helper function continue
node_dict.setdefault(node_type, []).append(node_info) node_dict[NodeTypes.ADMIN].append(node_info)
if node.isMaster() or node.isStorage(): node_dict[NodeTypes.STORAGE].append(node_info)
# client get notifications for master and storage only if node.isClient():
assign_for_notification(NodeTypes.CLIENT) continue
if node.isMaster() or node.isStorage() or node.isClient(): node_dict[NodeTypes.CLIENT].append(node_info)
assign_for_notification(NodeTypes.STORAGE) if node.isStorage():
assign_for_notification(NodeTypes.ADMIN) continue
node_dict[NodeTypes.MASTER].append(node_info)
# send at most one non-empty notification packet per node # send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
...@@ -498,7 +497,7 @@ class Application(BaseApplication): ...@@ -498,7 +497,7 @@ class Application(BaseApplication):
conn.setHandler(handler) conn.setHandler(handler)
conn.notify(Packets.NotifyNodeInformation((( conn.notify(Packets.NotifyNodeInformation(((
node.getType(), node.getAddress(), node.getUUID(), node.getType(), node.getAddress(), node.getUUID(),
NodeStates.TEMPORARILY_DOWN),))) NodeStates.TEMPORARILY_DOWN, None),)))
conn.abort() conn.abort()
elif conn.pending(): elif conn.pending():
conn.abort() conn.abort()
......
...@@ -65,6 +65,7 @@ There is no UUID conflict between the 2 clusters: ...@@ -65,6 +65,7 @@ There is no UUID conflict between the 2 clusters:
class BackupApplication(object): class BackupApplication(object):
pt = None pt = None
uuid = None
def __init__(self, app, name, master_addresses): def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app) self.app = weakref.proxy(app)
...@@ -92,7 +93,7 @@ class BackupApplication(object): ...@@ -92,7 +93,7 @@ class BackupApplication(object):
pt = app.pt pt = app.pt
while True: while True:
app.changeClusterState(ClusterStates.STARTING_BACKUP) app.changeClusterState(ClusterStates.STARTING_BACKUP)
bootstrap = BootstrapManager(self, self.name, NodeTypes.CLIENT) bootstrap = BootstrapManager(self, NodeTypes.CLIENT)
# {offset -> node} # {offset -> node}
self.primary_partition_dict = {} self.primary_partition_dict = {}
# [[tid]] # [[tid]]
...@@ -105,7 +106,7 @@ class BackupApplication(object): ...@@ -105,7 +106,7 @@ class BackupApplication(object):
else: else:
break break
poll(1) poll(1)
node, conn, uuid, num_partitions, num_replicas = \ node, conn, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
...@@ -114,7 +115,6 @@ class BackupApplication(object): ...@@ -114,7 +115,6 @@ class BackupApplication(object):
raise RuntimeError("inconsistent number of partitions") raise RuntimeError("inconsistent number of partitions")
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskNodeInformation())
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
......
...@@ -18,7 +18,7 @@ from neo.lib import logging ...@@ -18,7 +18,7 @@ from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets, from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets,
BrokenNodeDisallowedError, BrokenNodeDisallowedError, ProtocolError,
) )
class MasterHandler(EventHandler): class MasterHandler(EventHandler):
...@@ -27,18 +27,19 @@ class MasterHandler(EventHandler): ...@@ -27,18 +27,19 @@ class MasterHandler(EventHandler):
def connectionCompleted(self, conn, new=None): def connectionCompleted(self, conn, new=None):
if new is None: if new is None:
super(MasterHandler, self).connectionCompleted(conn) super(MasterHandler, self).connectionCompleted(conn)
elif new:
self._notifyNodeInformation(conn)
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name, _):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
if node: if node:
assert node_type is not NodeTypes.MASTER or node.getAddress() in ( if node_type is NodeTypes.MASTER and not (
address, None), (node, address) None != address == node.getAddress()):
raise ProtocolError
if node.isBroken(): if node.isBroken():
raise BrokenNodeDisallowedError raise BrokenNodeDisallowedError
else:
node = app.nm.getByAddress(address)
peer_uuid = self._setupNode(conn, node_type, uuid, address, node) peer_uuid = self._setupNode(conn, node_type, uuid, address, node)
if app.primary: if app.primary:
primary_address = app.server primary_address = app.server
...@@ -89,10 +90,6 @@ class MasterHandler(EventHandler): ...@@ -89,10 +90,6 @@ class MasterHandler(EventHandler):
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
def askNodeInformation(self, conn):
self._notifyNodeInformation(conn)
conn.answer(Packets.AnswerNodeInformation())
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
......
...@@ -31,12 +31,6 @@ class BackupHandler(EventHandler): ...@@ -31,12 +31,6 @@ class BackupHandler(EventHandler):
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
def answerLastTransaction(self, conn, tid): def answerLastTransaction(self, conn, tid):
app = self.app app = self.app
if tid != ZERO_TID: if tid != ZERO_TID:
......
...@@ -31,14 +31,12 @@ class ClientServiceHandler(MasterHandler): ...@@ -31,14 +31,12 @@ class ClientServiceHandler(MasterHandler):
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
app.nm.remove(node) app.nm.remove(node)
def askNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
# send informations about master and storages only
nm = self.app.nm nm = self.app.nm
node_list = [] node_list = [nm.getByUUID(conn.getUUID()).asTuple()] # for id_timestamp
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
conn.answer(Packets.AnswerNodeInformation())
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn, tid):
""" """
......
...@@ -23,6 +23,9 @@ from . import MasterHandler ...@@ -23,6 +23,9 @@ from . import MasterHandler
class BaseElectionHandler(EventHandler): class BaseElectionHandler(EventHandler):
def _notifyNodeInformation(self, conn):
pass
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested' raise ElectionFailure, 'reelection requested'
...@@ -53,6 +56,11 @@ class BaseElectionHandler(EventHandler): ...@@ -53,6 +56,11 @@ class BaseElectionHandler(EventHandler):
class ClientElectionHandler(BaseElectionHandler): class ClientElectionHandler(BaseElectionHandler):
def notifyNodeInformation(self, conn, node_list):
# XXX: For the moment, do nothing because
# we'll close this connection and reconnect.
pass
def connectionFailed(self, conn): def connectionFailed(self, conn):
addr = conn.getAddress() addr = conn.getAddress()
node = self.app.nm.getByAddress(addr) node = self.app.nm.getByAddress(addr)
...@@ -68,6 +76,7 @@ class ClientElectionHandler(BaseElectionHandler): ...@@ -68,6 +76,7 @@ class ClientElectionHandler(BaseElectionHandler):
app.uuid, app.uuid,
app.server, app.server,
app.name, app.name,
None,
)) ))
super(ClientElectionHandler, self).connectionCompleted(conn) super(ClientElectionHandler, self).connectionCompleted(conn)
...@@ -126,8 +135,8 @@ class ServerElectionHandler(BaseElectionHandler, MasterHandler): ...@@ -126,8 +135,8 @@ class ServerElectionHandler(BaseElectionHandler, MasterHandler):
logging.info('reject a connection from a non-master') logging.info('reject a connection from a non-master')
raise NotReadyError raise NotReadyError
if node is None: if node is None is app.nm.getByAddress(address):
node = app.nm.createMaster(address=address) app.nm.createMaster(address=address)
self.elect(conn, address) self.elect(conn, address)
return uuid return uuid
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \
NotReadyError, ProtocolError, uuid_str NotReadyError, ProtocolError, uuid_str
...@@ -30,18 +31,32 @@ class IdentificationHandler(MasterHandler): ...@@ -30,18 +31,32 @@ class IdentificationHandler(MasterHandler):
def _setupNode(self, conn, node_type, uuid, address, node): def _setupNode(self, conn, node_type, uuid, address, node):
app = self.app app = self.app
if node: by_addr = address and app.nm.getByAddress(address)
if node.isRunning(): while 1:
if uuid > 0: if by_addr:
# cloned/evil/buggy node connecting to us if not by_addr.isConnected():
raise ProtocolError('already connected') if node is by_addr:
# The peer wants a temporary id that's already assigned. break
# Let's give it another one. if not node or uuid < 0:
node = uuid = None # In case of address conflict for a peer with temporary
# ids, we'll generate a new id.
node = by_addr
break
elif node:
if node.isConnected():
if uuid < 0:
# The peer wants a temporary id that's already assigned.
# Let's give it another one.
node = uuid = None
break
else:
node.setAddress(address)
break
# Id conflict for a storage node.
else: else:
assert not node.isConnected() break
node.setAddress(address) # cloned/evil/buggy node connecting to us
node.setRunning() raise ProtocolError('already connected')
state = NodeStates.RUNNING state = NodeStates.RUNNING
if node_type == NodeTypes.CLIENT: if node_type == NodeTypes.CLIENT:
...@@ -68,14 +83,16 @@ class IdentificationHandler(MasterHandler): ...@@ -68,14 +83,16 @@ class IdentificationHandler(MasterHandler):
handler = app.administration_handler handler = app.administration_handler
human_readable_node_type = 'n admin ' human_readable_node_type = 'n admin '
else: else:
raise NotImplementedError(node_type) raise ProtocolError
uuid = app.getNewUUID(uuid, address, node_type) uuid = app.getNewUUID(uuid, address, node_type)
logging.info('Accept a' + human_readable_node_type + uuid_str(uuid)) logging.info('Accept a' + human_readable_node_type + uuid_str(uuid))
if node is None: if node is None:
node = app.nm.createFromNodeType(node_type, node = app.nm.createFromNodeType(node_type,
uuid=uuid, address=address) uuid=uuid, address=address)
node.setUUID(uuid) else:
node.setUUID(uuid)
node.id_timestamp = time()
node.setState(state) node.setState(state)
node.setConnection(conn) node.setConnection(conn)
conn.setHandler(handler) conn.setHandler(handler)
......
...@@ -36,6 +36,10 @@ class SecondaryMasterHandler(MasterHandler): ...@@ -36,6 +36,10 @@ class SecondaryMasterHandler(MasterHandler):
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested' raise ElectionFailure, 'reelection requested'
def _notifyNodeInformation(self, conn):
node_list = [n.asTuple() for n in self.app.nm.getMasterList()]
conn.notify(Packets.NotifyNodeInformation(node_list))
class PrimaryHandler(EventHandler): class PrimaryHandler(EventHandler):
""" Handler used by secondaries to handle primary master""" """ Handler used by secondaries to handle primary master"""
...@@ -58,6 +62,7 @@ class PrimaryHandler(EventHandler): ...@@ -58,6 +62,7 @@ class PrimaryHandler(EventHandler):
app.uuid, app.uuid,
app.server, app.server,
app.name, app.name,
None,
)) ))
super(PrimaryHandler, self).connectionCompleted(conn) super(PrimaryHandler, self).connectionCompleted(conn)
...@@ -68,27 +73,11 @@ class PrimaryHandler(EventHandler): ...@@ -68,27 +73,11 @@ class PrimaryHandler(EventHandler):
self.app.cluster_state = state self.app.cluster_state = state
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
app = self.app super(PrimaryHandler, self).notifyNodeInformation(conn, node_list)
for node_type, addr, uuid, state in node_list: for node_type, _, uuid, state, _ in node_list:
if node_type != NodeTypes.MASTER: assert node_type == NodeTypes.MASTER, node_type
# No interest. if uuid == self.app.uuid and state == NodeStates.UNKNOWN:
continue
if uuid == app.uuid and state == NodeStates.UNKNOWN:
sys.exit() sys.exit()
# Register new master nodes.
if app.server == addr:
# This is self.
continue
else:
n = app.nm.getByAddress(addr)
# master node must be known
assert n is not None
if uuid is not None:
# If I don't know the UUID yet, believe what the peer
# told me at the moment.
if n.getUUID() is None:
n.setUUID(uuid)
def _acceptIdentification(self, node, uuid, num_partitions, def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list): num_replicas, your_uuid, primary, known_master_list):
...@@ -101,4 +90,5 @@ class PrimaryHandler(EventHandler): ...@@ -101,4 +90,5 @@ class PrimaryHandler(EventHandler):
logging.info('My UUID: ' + uuid_str(your_uuid)) logging.info('My UUID: ' + uuid_str(your_uuid))
node.setUUID(uuid) node.setUUID(uuid)
app.id_timestamp = None
...@@ -27,13 +27,11 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -27,13 +27,11 @@ class StorageServiceHandler(BaseServiceHandler):
def connectionCompleted(self, conn, new): def connectionCompleted(self, conn, new):
app = self.app app = self.app
uuid = conn.getUUID() uuid = conn.getUUID()
node = app.nm.getByUUID(uuid)
app.setStorageNotReady(uuid) app.setStorageNotReady(uuid)
if new: if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new) super(StorageServiceHandler, self).connectionCompleted(conn, new)
# XXX: what other values could happen ? if app.nm.getByUUID(uuid).isRunning(): # node may be PENDING
if node.isRunning(): conn.notify(Packets.StartOperation(app.backup_tid))
conn.notify(Packets.StartOperation(bool(app.backup_tid)))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
......
...@@ -146,15 +146,14 @@ class Log(object): ...@@ -146,15 +146,14 @@ class Log(object):
def notifyNodeInformation(self, node_list): def notifyNodeInformation(self, node_list):
node_list.sort(key=lambda x: x[2]) node_list.sort(key=lambda x: x[2])
node_list = [(self.uuid_str(uuid), str(node_type), node_list = [(self.uuid_str(x[2]), str(x[0]),
'%s:%u' % address if address else '?', state) '%s:%u' % x[1] if x[1] else '?', str(x[3]))
for node_type, address, uuid, state in node_list] + ((repr(x[4]),) if len(x) > 4 else ()) # BBB
for x in node_list]
if node_list: if node_list:
t = ' ! %%%us | %%%us | %%%us | %%s' % ( t = ''.join(' %%%us |' % max(len(x[i]) for x in node_list)
max(len(x[0]) for x in node_list), for i in xrange(len(node_list[0]) - 1))
max(len(x[1]) for x in node_list), return map((' !' + t + ' %s').__mod__, node_list)
max(len(x[2]) for x in node_list))
return map(t.__mod__, node_list)
return () return ()
......
...@@ -219,14 +219,11 @@ class Application(BaseApplication): ...@@ -219,14 +219,11 @@ class Application(BaseApplication):
conn.close() conn.close()
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, bootstrap = BootstrapManager(self, NodeTypes.STORAGE, self.server)
NodeTypes.STORAGE, self.uuid, self.server) self.master_node, self.master_conn, num_partitions, num_replicas = \
data = bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data uuid = self.uuid
self.master_node = node
self.master_conn = conn
logging.info('I am %s', uuid_str(uuid)) logging.info('I am %s', uuid_str(uuid))
self.uuid = uuid
self.dm.setUUID(uuid) self.dm.setUUID(uuid)
# Reload a partition table from the database. This is necessary # Reload a partition table from the database. This is necessary
......
...@@ -50,8 +50,8 @@ class Checker(object): ...@@ -50,8 +50,8 @@ class Checker(object):
conn.asClient() conn.asClient()
else: else:
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification( conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
NodeTypes.STORAGE, uuid, app.server, name)) uuid, app.server, name, app.id_timestamp))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict) conn_set = set(self.conn_dict)
conn_set.discard(None) conn_set.discard(None)
......
...@@ -38,8 +38,8 @@ class BaseMasterHandler(EventHandler): ...@@ -38,8 +38,8 @@ class BaseMasterHandler(EventHandler):
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
"""Store information on nodes, only if this is sent by a primary """Store information on nodes, only if this is sent by a primary
master node.""" master node."""
self.app.nm.update(node_list) super(BaseMasterHandler, self).notifyNodeInformation(conn, node_list)
for node_type, addr, uuid, state in node_list: for node_type, _, uuid, state, _ in node_list:
if uuid == self.app.uuid: if uuid == self.app.uuid:
# This is me, do what the master tell me # This is me, do what the master tell me
logging.info("I was told I'm %s", state) logging.info("I was told I'm %s", state)
......
...@@ -27,7 +27,8 @@ class IdentificationHandler(EventHandler): ...@@ -27,7 +27,8 @@ class IdentificationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
logging.warning('A connection was lost during identification') logging.warning('A connection was lost during identification')
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name,
id_timestamp):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
# reject any incoming connections if not ready # reject any incoming connections if not ready
...@@ -41,7 +42,7 @@ class IdentificationHandler(EventHandler): ...@@ -41,7 +42,7 @@ class IdentificationHandler(EventHandler):
else: else:
if uuid == app.uuid: if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection") raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid, id_timestamp)
if node is None: if node is None:
# Do never create node automatically, or we could get id # Do never create node automatically, or we could get id
# conflicts. We must only rely on the notifications from the # conflicts. We must only rely on the notifications from the
...@@ -56,12 +57,7 @@ class IdentificationHandler(EventHandler): ...@@ -56,12 +57,7 @@ class IdentificationHandler(EventHandler):
handler = ClientReadOnlyOperationHandler handler = ClientReadOnlyOperationHandler
else: else:
handler = ClientOperationHandler handler = ClientOperationHandler
if node.isConnected(): # XXX assert not node.isConnected(), node
# This can happen if we haven't processed yet a notification
# from the master, telling us the existing node is not
# running anymore. If we accept the new client, we won't
# know what to do with this late notification.
raise NotReadyError('uuid conflict: retry later')
assert node.isRunning(), node assert node.isRunning(), node
elif node_type == NodeTypes.STORAGE: elif node_type == NodeTypes.STORAGE:
handler = StorageOperationHandler handler = StorageOperationHandler
......
...@@ -20,9 +20,6 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID ...@@ -20,9 +20,6 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn):
pass
def sendPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
......
...@@ -258,7 +258,8 @@ class Replicator(object): ...@@ -258,7 +258,8 @@ class Replicator(object):
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
try: try:
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name)) None if name else app.uuid, app.server, name or app.name,
app.id_timestamp))
except ConnectionClosed: except ConnectionClosed:
if previous_node is self.current_node: if previous_node is self.current_node:
return return
......
...@@ -753,11 +753,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -753,11 +753,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# will raise IndexError at the third iteration # will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011') app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases # TODO: test more connection failure cases
all_passed = []
# askLastTransaction # askLastTransaction
def _ask9(_):
all_passed.append(1)
# Seventh packet : askNodeInformation succeeded
def _ask8(_): def _ask8(_):
pass pass
# Sixth packet : askPartitionTable succeeded # Sixth packet : askPartitionTable succeeded
...@@ -789,8 +785,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -789,8 +785,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# telling us what its address is.) # telling us what its address is.)
def _ask1(_): def _ask1(_):
pass pass
ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, _ask8]
_ask8, _ask9]
def _ask_base(conn, _, handler=None): def _ask_base(conn, _, handler=None):
ask_func_list.pop(0)(conn) ask_func_list.pop(0)(conn)
app.nm.getByAddress(conn.getAddress())._connection = None app.nm.getByAddress(conn.getAddress())._connection = None
...@@ -801,7 +796,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -801,7 +796,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.pt = Mock({ 'operational': False}) app.pt = Mock({ 'operational': False})
app.start = lambda: None app.start = lambda: None
app.master_conn = app._connectToPrimaryNode() app.master_conn = app._connectToPrimaryNode()
self.assertEqual(len(all_passed), 1) self.assertFalse(ask_func_list)
self.assertTrue(app.master_conn is not None) self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational()) self.assertTrue(app.pt.operational())
......
...@@ -44,69 +44,6 @@ class MasterHandlerTests(NeoUnitTestBase): ...@@ -44,69 +44,6 @@ class MasterHandlerTests(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return node, conn return node, conn
class MasterBootstrapHandlerTests(MasterHandlerTests):
def setUp(self):
super(MasterBootstrapHandlerTests, self).setUp()
self.handler = PrimaryBootstrapHandler(self.app)
def checkCalledOnApp(self, method, index=0):
calls = self.app.mockGetNamedCalls(method)
self.assertTrue(len(calls) > index)
return calls[index].params
def test_notReady(self):
conn = self.getFakeConnection()
self.handler.notReady(conn, 'message')
self.assertEqual(self.app.trying_master_node, None)
def test_acceptIdentification1(self):
""" Non-master node """
node, conn = self.getKnownMaster()
self.handler.acceptIdentification(conn, NodeTypes.CLIENT,
node.getUUID(), 100, 0, None, None, [])
self.checkClosed(conn)
def test_acceptIdentification2(self):
""" No UUID supplied """
node, conn = self.getKnownMaster()
uuid = self.getMasterUUID()
addr = conn.getAddress()
self.checkProtocolErrorRaised(self.handler.acceptIdentification,
conn, NodeTypes.MASTER, uuid, 100, 0, None,
addr, [(addr, uuid)],
)
def test_acceptIdentification3(self):
""" identification accepted """
node, conn = self.getKnownMaster()
uuid = self.getMasterUUID()
addr = conn.getAddress()
your_uuid = self.getClientUUID()
self.handler.acceptIdentification(conn, NodeTypes.MASTER, uuid,
100, 2, your_uuid, addr, [(addr, uuid)])
self.assertEqual(self.app.uuid, your_uuid)
self.assertEqual(node.getUUID(), uuid)
self.assertTrue(isinstance(self.app.pt, PartitionTable))
def _getMasterList(self, uuid_list):
port = 1000
master_list = []
for uuid in uuid_list:
master_list.append((('127.0.0.1', port), uuid))
port += 1
return master_list
def test_answerPartitionTable(self):
conn = self.getFakeConnection()
self.app.pt = Mock()
ptid = 0
row_list = ([], [])
self.handler.answerPartitionTable(conn, ptid, row_list)
load_calls = self.app.pt.mockGetNamedCalls('load')
self.assertEqual(len(load_calls), 1)
# load_calls[0].checkArgs(ptid, row_list, self.app.nm)
class MasterNotificationsHandlerTests(MasterHandlerTests): class MasterNotificationsHandlerTests(MasterHandlerTests):
......
...@@ -119,7 +119,7 @@ class NEOProcess(object): ...@@ -119,7 +119,7 @@ class NEOProcess(object):
except ImportError: except ImportError:
raise NotFound, '%s not found' % (command) raise NotFound, '%s not found' % (command)
self.command = command self.command = command
self.arg_dict = {'--' + k: v for k, v in arg_dict.iteritems()} self.arg_dict = arg_dict
self.with_uuid = True self.with_uuid = True
self.setUUID(uuid) self.setUUID(uuid)
...@@ -131,11 +131,11 @@ class NEOProcess(object): ...@@ -131,11 +131,11 @@ class NEOProcess(object):
args = [] args = []
self.with_uuid = with_uuid self.with_uuid = with_uuid
for arg, param in self.arg_dict.iteritems(): for arg, param in self.arg_dict.iteritems():
if with_uuid is False and arg == '--uuid': args.append('--' + arg)
continue
args.append(arg)
if param is not None: if param is not None:
args.append(str(param)) args.append(str(param))
if with_uuid:
args += '--uuid', str(self.uuid)
self.pid = os.fork() self.pid = os.fork()
if self.pid == 0: if self.pid == 0:
# Child # Child
...@@ -213,7 +213,6 @@ class NEOProcess(object): ...@@ -213,7 +213,6 @@ class NEOProcess(object):
Note: for this change to take effect, the node must be restarted. Note: for this change to take effect, the node must be restarted.
""" """
self.uuid = uuid self.uuid = uuid
self.arg_dict['--uuid'] = str(uuid)
def isAlive(self): def isAlive(self):
try: try:
...@@ -297,7 +296,6 @@ class NEOCluster(object): ...@@ -297,7 +296,6 @@ class NEOCluster(object):
def _newProcess(self, node_type, logfile=None, port=None, **kw): def _newProcess(self, node_type, logfile=None, port=None, **kw):
self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0) self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0)
uuid += UUID_NAMESPACES[node_type] << 24 uuid += UUID_NAMESPACES[node_type] << 24
kw['uuid'] = uuid
kw['cluster'] = self.cluster_name kw['cluster'] = self.cluster_name
kw['masters'] = self.master_nodes kw['masters'] = self.master_nodes
if logfile: if logfile:
...@@ -483,13 +481,9 @@ class NEOCluster(object): ...@@ -483,13 +481,9 @@ class NEOCluster(object):
return self.__getNodeList(NodeTypes.CLIENT, state) return self.__getNodeList(NodeTypes.CLIENT, state)
def __getNodeState(self, node_type, uuid): def __getNodeState(self, node_type, uuid):
node_list = self.__getNodeList(node_type) for node in self.__getNodeList(node_type):
for node_type, address, node_uuid, state in node_list: if node[2] == uuid:
if node_uuid == uuid: return node[3]
break
else:
state = None
return state
def getMasterNodeState(self, uuid): def getMasterNodeState(self, uuid):
return self.__getNodeState(NodeTypes.MASTER, uuid) return self.__getNodeState(NodeTypes.MASTER, uuid)
......
...@@ -144,18 +144,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -144,18 +144,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.assertEqual(len(txn.getOIDList()), 0) self.assertEqual(len(txn.getOIDList()), 0)
self.assertEqual(len(txn.getUUIDList()), 1) self.assertEqual(len(txn.getUUIDList()), 1)
def test_askNodeInformations(self):
# check that only informations about master and storages nodes are
# send to a client
self.app.nm.createClient()
conn = self.getFakeConnection()
self.service.askNodeInformation(conn)
calls = conn.mockGetNamedCalls('notify')
self.assertEqual(len(calls), 1)
packet = calls[0].getParam(0)
(node_list, ) = packet.decode()
self.assertEqual(len(node_list), 2)
def test_connectionClosed(self): def test_connectionClosed(self):
# give a client uuid which have unfinished transactions # give a client uuid which have unfinished transactions
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
......
...@@ -231,7 +231,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -231,7 +231,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
def test_requestIdentification1(self): def test_requestIdentification1(self):
""" A non-master node request identification """ """ A non-master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.NotReadyError, self.assertRaises(protocol.NotReadyError,
self.election.requestIdentification, self.election.requestIdentification,
conn, NodeTypes.CLIENT, *args) conn, NodeTypes.CLIENT, *args)
...@@ -240,7 +240,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -240,7 +240,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
""" A broken master node request identification """ """ A broken master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
node.setBroken() node.setBroken()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.BrokenNodeDisallowedError, self.assertRaises(protocol.BrokenNodeDisallowedError,
self.election.requestIdentification, self.election.requestIdentification,
conn, NodeTypes.MASTER, *args) conn, NodeTypes.MASTER, *args)
...@@ -248,7 +248,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -248,7 +248,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
def test_requestIdentification4(self): def test_requestIdentification4(self):
""" No conflict """ """ No conflict """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.election.requestIdentification(conn, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
...@@ -280,11 +280,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -280,11 +280,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
conn = self.__getClient() conn = self.__getClient()
self.checkNotReadyErrorRaised( self.checkNotReadyErrorRaised(
self.election.requestIdentification, self.election.requestIdentification,
conn=conn, conn,
node_type=NodeTypes.CLIENT, NodeTypes.CLIENT,
uuid=conn.getUUID(), conn.getUUID(),
address=conn.getAddress(), conn.getAddress(),
name=self.app.name self.app.name,
None,
) )
def _requestIdentification(self): def _requestIdentification(self):
...@@ -297,6 +298,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -297,6 +298,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
peer_uuid, peer_uuid,
address, address,
self.app.name, self.app.name,
None,
) )
node_type, uuid, partitions, replicas, _peer_uuid, primary, \ node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True) master_list = self.checkAcceptIdentification(conn, decode=True)
......
...@@ -50,6 +50,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -50,6 +50,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
self.getClientUUID(), self.getClientUUID(),
None, None,
self.app.name, self.app.name,
None,
) )
self.app.ready = True self.app.ready = True
self.assertRaises( self.assertRaises(
...@@ -60,6 +61,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -60,6 +61,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
self.getStorageUUID(), self.getStorageUUID(),
None, None,
self.app.name, self.app.name,
None,
) )
def test_requestIdentification3(self): def test_requestIdentification3(self):
...@@ -75,6 +77,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -75,6 +77,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
uuid, uuid,
None, None,
self.app.name, self.app.name,
None,
) )
def test_requestIdentification2(self): def test_requestIdentification2(self):
...@@ -87,7 +90,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -87,7 +90,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
'getAddress': master, 'getAddress': master,
}) })
self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid, self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid,
None, self.app.name) None, self.app.name, None)
self.assertTrue(node.isRunning()) self.assertTrue(node.isRunning())
self.assertTrue(node.isConnected()) self.assertTrue(node.isConnected())
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
......
...@@ -28,7 +28,7 @@ class BootstrapManagerTests(NeoUnitTestBase): ...@@ -28,7 +28,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getStorageConfiguration() config = self.getStorageConfiguration()
self.app = Application(config) self.app = Application(config)
self.bootstrap = BootstrapManager(self.app, 'main', NodeTypes.STORAGE) self.bootstrap = BootstrapManager(self.app, NodeTypes.STORAGE)
# define some variable to simulate client and storage node # define some variable to simulate client and storage node
self.master_port = 10010 self.master_port = 10010
self.storage_port = 10020 self.storage_port = 10020
......
...@@ -183,15 +183,15 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -183,15 +183,15 @@ class NodeManagerTests(NeoUnitTestBase):
old_uuid = self.storage.getUUID() old_uuid = self.storage.getUUID()
new_uuid = self.getStorageUUID() new_uuid = self.getStorageUUID()
node_list = ( node_list = (
(NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN), (NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN, None),
(NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING), (NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING, None),
(NodeTypes.STORAGE, self.storage.getAddress(), new_uuid, (NodeTypes.STORAGE, self.storage.getAddress(), new_uuid,
NodeStates.RUNNING), NodeStates.RUNNING, None),
(NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(), (NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(),
NodeStates.UNKNOWN), NodeStates.UNKNOWN, None),
) )
# update manager content # update manager content
manager.update(node_list) manager.update(Mock(), node_list)
# - the client gets down # - the client gets down
self.checkClients([]) self.checkClients([])
# - master change it's address # - master change it's address
......
...@@ -27,14 +27,14 @@ from ZODB import DB, POSException ...@@ -27,14 +27,14 @@ from ZODB import DB, POSException
from ZODB.DB import TransactionalUndo from ZODB.DB import TransactionalUndo
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
from neo.lib.connection import MTClientConnection from neo.lib.connection import ServerConnection, MTClientConnection
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch from .. import expectedFailure, Patch
from . import LockLock, NEOCluster, NEOThreadedTest from . import LockLock, NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
...@@ -1347,6 +1347,58 @@ class Test(NEOThreadedTest): ...@@ -1347,6 +1347,58 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testIdTimestamp(self):
"""
Given a master M, a storage S, and 2 clients Ca and Cb.
While Ca(id=1) is being identified by S:
1. connection between Ca and M breaks
2. M -> S: C1 down
3. Cb connect to M: id=1
4. M -> S: C1 up
5. S processes RequestIdentification from Ca with id=1
At 5, S must reject Ca, otherwise Cb can't connect to S. This is where
id timestamps come into play: with C1 up since t2, S rejects Ca due to
a request with t1 < t2.
To avoid issues with clocks that are out of sync, the client gets its
connection timestamp by being notified about itself from the master.
"""
s2c = []
def __init__(orig, self, *args, **kw):
orig(self, *args, **kw)
self.readable = bool
s2c.append(self)
ll()
def connectToStorage(client):
next(client.cp.iterateForObject(0))
cluster = NEOCluster()
try:
cluster.start()
Ca = cluster.client
Ca.pt # only connect to the master
# In a separate thread, connect to the storage but suspend the
# processing of the RequestIdentification packet, until the
# storage is notified about the existence of the other client.
with LockLock() as ll, Patch(ServerConnection, __init__=__init__):
t = self.newThread(connectToStorage, Ca)
ll()
s2c, = s2c
m2c, = cluster.master.getConnectionList(cluster.client)
m2c.close()
Cb = cluster.newClient()
try:
Cb.pt # only connect to the master
del s2c.readable
self.assertRaises(NEOPrimaryMasterLost, t.join)
self.assertTrue(s2c.isClosed())
connectToStorage(Cb)
finally:
Cb.close()
finally:
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment