Commit eea5aff2 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Raise an exception when a broken node send packet through a connection. Fix

tests according to this changes and fix some others affected with previous
commits.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@506 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent aded81bc
......@@ -841,11 +841,12 @@ class Application(object):
def __del__(self):
"""Clear all connection."""
# TODO: Stop polling thread here.
# Due to bug in ZODB, close is not always called when shutting
# down zope, so use __del__ to close connections
for conn in self.em.getConnectionList():
conn.close()
# Stop polling thread
self.poll_thread.stop()
close = __del__
def sync(self):
......
......@@ -15,7 +15,7 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from threading import Thread
from threading import Thread, Event
import logging
class ThreadedPoll(Thread):
......@@ -25,10 +25,11 @@ class ThreadedPoll(Thread):
Thread.__init__(self, **kw)
self.em = em
self.setDaemon(True)
self._stop = Event()
self.start()
def run(self):
while 1:
while not self._stop.isSet():
# First check if we receive any new message from other node
try:
self.em.poll()
......@@ -36,4 +37,7 @@ class ThreadedPoll(Thread):
# This happen when there is no connection
# XXX: This should be handled inside event manager, not here.
logging.error('Dispatcher, run, poll returned a KeyError')
logging.info('Threaded poll stopped')
def stop(self):
self._stop.set()
......@@ -335,17 +335,8 @@ class ClientEventHandlerTest(unittest.TestCase):
self.assertEquals(app.uuid, 'C' * 16)
def _testHandleUnexpectedPacketCalledWithMedhod(self, client_handler, method, args=(), kw=()):
# Monkey-patch handleUnexpectedPacket to check if it is called
call_list = []
def ClientHandler_handleUnexpectedPacket(self, conn, packet):
call_list.append((conn, packet))
original_handleUnexpectedPacket = client_handler.__class__.handleUnexpectedPacket
client_handler.__class__.handleUnexpectedPacket = ClientHandler_handleUnexpectedPacket
try:
self.assertRaises(UnexpectedPacketError, method, *args, **dict(kw))
finally:
# Restore original method
client_handler.__class__.handleUnexpectedPacket = original_handleUnexpectedPacket
self.assertRaises(UnexpectedPacketError, method, *args, **dict(kw))
# Master node handler
def test_initialAnswerPrimaryMaster(self):
......
......@@ -18,7 +18,8 @@
import logging
from neo import protocol
from neo.protocol import Packet, PacketMalformedError, UnexpectedPacketError
from neo.protocol import Packet, PacketMalformedError, UnexpectedPacketError, \
BrokenNotDisallowedError, NotReadyError
from neo.connection import ServerConnection
from protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
......@@ -156,8 +157,11 @@ class EventHandler(object):
conn.notify(protocol.protocolError(message))
conn.abort()
self.peerBroken(conn)
# TODO: remove this old method name
handleUnexpectedPacket = unexpectedPacket
def brokenNodeDisallowedError(conn, packet, message=None):
""" Called when a broken node send packets """
conn.notify(protocol.brokenNodeDisallowedError('go away'))
conn.abort()
def dispatch(self, conn, packet):
"""This is a helper method to handle various packet types."""
......@@ -172,6 +176,9 @@ class EventHandler(object):
self.unexpectedPacket(conn, packet, msg)
except PacketMalformedError, msg:
self.packetMalformed(conn, packet, msg)
except BrokenNotDisallowedError, msg:
self.brokenNodeDisallowedError(conn, packet, msg)
# Packet handlers.
......
......@@ -29,10 +29,6 @@ from neo.node import MasterNode, StorageNode, ClientNode
from neo.handler import identification_required, restrict_node_types, \
client_connection_required, server_connection_required
# TODO: finalize decorators integration (identification, restriction, client...)
# TODO: here use specific decorator such as restrict_node_types which do custom
# operations such as send retryLater instead of unexpectedPacket
class ElectionEventHandler(MasterEventHandler):
"""This class deals with events for a primary master election."""
......@@ -198,10 +194,7 @@ class ElectionEventHandler(MasterEventHandler):
# If this node is broken, reject it.
if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE:
conn.answer(protocol.brokenNodeDisallowedError(
'go away'), packet)
conn.abort()
return
raise protocol.BrokenNotDisallowedError
# supplied another uuid in case of conflict
while not app.isValidUUID(uuid, addr):
......
......@@ -166,14 +166,10 @@ class RecoveryEventHandler(MasterEventHandler):
# If this node is broken, reject it. Otherwise, assume that it is
# working again.
if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away')
conn.answer(p, packet)
conn.abort()
return
else:
node.setUUID(uuid)
node.setState(RUNNING_STATE)
app.broadcastNodeInformation(node)
raise protocol.BrokenNotDisallowedError
node.setUUID(uuid)
node.setState(RUNNING_STATE)
app.broadcastNodeInformation(node)
conn.setUUID(uuid)
......
......@@ -263,14 +263,11 @@ class ServiceEventHandler(MasterEventHandler):
# If this node is broken, reject it. Otherwise, assume that
# it is working again.
if node.getState() == BROKEN_STATE:
conn.notify(protocol.brokenNodeDisallowedError('go away'))
conn.abort()
return
else:
node.setUUID(uuid)
node.setState(RUNNING_STATE)
logging.debug('broadcasting node information')
app.broadcastNodeInformation(node)
raise protocol.BrokenNotDisallowedError
node.setUUID(uuid)
node.setState(RUNNING_STATE)
logging.debug('broadcasting node information')
app.broadcastNodeInformation(node)
conn.setUUID(uuid)
......
......@@ -22,7 +22,7 @@ from tempfile import mkstemp
from mock import Mock
from struct import pack, unpack
from neo import protocol
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID
from neo.protocol import Packet, INVALID_UUID
from neo.master.election import ElectionEventHandler
from neo.master.app import Application
from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
......@@ -136,12 +136,16 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
......@@ -577,14 +581,15 @@ server: 127.0.0.1:10023
self.assertEqual(node.getState(), RUNNING_STATE)
node.setState(BROKEN_STATE)
self.assertEqual(node.getState(), BROKEN_STATE)
election.handleRequestNodeIdentification(conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
uuid=new_uuid,
ip_address='127.0.0.1',
port=self.master_port+1,
name=self.app.name,)
self.checkCalledAbort(conn)
self.checkBrokenNotDisallowedErrorRaised(
election.handleRequestNodeIdentification,
conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
uuid=new_uuid,
ip_address='127.0.0.1',
port=self.master_port+1,
name=self.app.name,)
def test_11_handleAskPrimaryMaster(self):
......
......@@ -22,7 +22,7 @@ from tempfile import mkstemp
from mock import Mock
from struct import pack, unpack
from neo import protocol
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID
from neo.protocol import Packet, INVALID_UUID
from neo.master.recovery import RecoveryEventHandler
from neo.master.app import Application
from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
......@@ -120,15 +120,6 @@ server: 127.0.0.1:10023
# Delete tmp file
os.remove(self.tmp_path)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Common methods
def getNewUUID(self):
uuid = INVALID_UUID
......@@ -155,12 +146,25 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send"""
......@@ -443,15 +447,15 @@ server: 127.0.0.1:10023
self.assertEqual(node.getState(), BROKEN_STATE)
self.assertEqual(node.getUUID(), uuid)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
recovery.handleRequestNodeIdentification(conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
uuid=uuid,
ip_address='127.0.0.1',
port=self.master_port,
name=self.app.name,)
self.checkCalledAbort(conn)
self.checkBrokenNotDisallowedErrorRaised(
recovery.handleRequestNodeIdentification,
conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
uuid=uuid,
ip_address='127.0.0.1',
port=self.master_port,
name=self.app.name,)
# 8. known node but down
conn = Mock({"addPacket" : None,
......
......@@ -22,7 +22,7 @@ from tempfile import mkstemp
from mock import Mock
from struct import pack, unpack
from neo import protocol
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID
from neo.protocol import Packet, INVALID_UUID
from neo.master.service import ServiceEventHandler
from neo.master.app import Application
from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
......@@ -118,12 +118,16 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
# Method to test the kind of packet returned in answer
def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent"""
......@@ -379,14 +383,15 @@ server: 127.0.0.1:10023
sn.setState(BROKEN_STATE)
self.assertEquals(sn.getState(), BROKEN_STATE)
service.handleRequestNodeIdentification(conn,
packet=packet,
node_type=STORAGE_NODE_TYPE,
uuid=uuid,
ip_address='127.0.0.1',
port=self.storage_port,
name=self.app.name,)
self.checkCalledNotifyAbort(conn)
self.checkBrokenNotDisallowedErrorRaised(
service.handleRequestNodeIdentification,
conn,
packet=packet,
node_type=STORAGE_NODE_TYPE,
uuid=uuid,
ip_address='127.0.0.1',
port=self.storage_port,
name=self.app.name,)
self.assertEquals(len(self.app.nm.getStorageNodeList()), 2)
sn = self.app.nm.getStorageNodeList()[0]
self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port))
......
......@@ -21,7 +21,7 @@ import logging
from tempfile import mkstemp
from mock import Mock
from struct import pack, unpack
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID
from neo.protocol import Packet, INVALID_UUID
from neo.master.verification import VerificationEventHandler
from neo.master.app import Application
from neo import protocol
......@@ -125,12 +125,16 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
......@@ -465,15 +469,15 @@ server: 127.0.0.1:10023
self.assertEqual(node.getState(), BROKEN_STATE)
self.assertEqual(node.getUUID(), uuid)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
verification.handleRequestNodeIdentification(conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
uuid=uuid,
ip_address='127.0.0.1',
port=self.master_port,
name=self.app.name,)
self.checkCalledAbort(conn)
self.checkBrokenNotDisallowedErrorRaised(
verification.handleRequestNodeIdentification,
conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
uuid=uuid,
ip_address='127.0.0.1',
port=self.master_port,
name=self.app.name,)
# 8. known node but down
conn = Mock({"addPacket" : None,
......
......@@ -189,14 +189,10 @@ class VerificationEventHandler(MasterEventHandler):
# If this node is broken, reject it. Otherwise, assume that it is
# working again.
if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away')
conn.answer(p, packet)
conn.abort()
return
else:
node.setUUID(uuid)
node.setState(RUNNING_STATE)
app.broadcastNodeInformation(node)
raise protocol.BrokenNotDisallowedError
node.setUUID(uuid)
node.setState(RUNNING_STATE)
app.broadcastNodeInformation(node)
conn.setUUID(uuid)
......
......@@ -318,9 +318,26 @@ UUID_NAMESPACES = {
ADMIN_NODE_TYPE: ADMIN_NS,
}
class ProtocolError(Exception): pass
class PacketMalformedError(ProtocolError): pass
class UnexpectedPacketError(ProtocolError): pass
class ProtocolError(Exception):
""" Base class for protocol errors, close the connection """
pass
class PacketMalformedError(ProtocolError):
""" Close the connection and set the node as broken"""
pass
class UnexpectedPacketError(ProtocolError):
""" Close the connection and set the node as broken"""
pass
class NotReadyError(ProtocolError):
""" Just close the connection """
pass
class BrokenNotDisallowedError(ProtocolError):
""" Just close the connection """
pass
decode_table = {}
......
......@@ -131,10 +131,7 @@ class BootstrapEventHandler(StorageEventHandler):
# If this node is broken, reject it.
if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away')
conn.answer(p, packet)
conn.abort()
return
raise protocol.BrokenNotDisallowedError
# Trust the UUID sent by the peer.
node.setUUID(uuid)
......
......@@ -163,10 +163,7 @@ class OperationEventHandler(StorageEventHandler):
# If this node is broken, reject it.
if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away')
conn.answer(p, packet)
conn.abort()
return
raise protocol.BrokenNotDisallowedError
# Trust the UUID sent by the peer.
node.setUUID(uuid)
......
......@@ -26,6 +26,7 @@ from neo.pt import PartitionTable
from neo.storage.app import Application, StorageNode
from neo.storage.bootstrap import BootstrapEventHandler
from neo.storage.verification import VerificationEventHandler
from neo import protocol
from neo.protocol import STORAGE_NODE_TYPE, MASTER_NODE_TYPE
from neo.protocol import BROKEN_STATE, RUNNING_STATE, Packet, INVALID_UUID
from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION
......@@ -111,12 +112,16 @@ server: 127.0.0.1:10020
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
# Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send"""
......@@ -284,7 +289,8 @@ server: 127.0.0.1:10020
conn = Mock({"isServerConnection": False,
"getAddress" : ("127.0.0.1", self.master_port), })
self.app.trying_master_node = self.trying_master_node
self.bootstrap.handleRequestNodeIdentification(
self.checkUnexpectedPacketRaised(
self.bootstrap.handleRequestNodeIdentification,
conn=conn,
uuid=self.getNewUUID(),
packet=packet,
......@@ -292,7 +298,6 @@ server: 127.0.0.1:10020
node_type=MASTER_NODE_TYPE,
ip_address='127.0.0.1',
name='',)
self.checkCalledAbort(conn)
self.assertEquals(len(conn.mockGetNamedCalls("setUUID")), 0)
def test_08_handleRequestNodeIdentification2(self):
......@@ -366,7 +371,8 @@ server: 127.0.0.1:10020
uuid=self.getNewUUID()
master.setState(BROKEN_STATE)
master.setUUID(uuid)
self.bootstrap.handleRequestNodeIdentification(
self.checkBrokenNotDisallowedErrorRaised(
self.bootstrap.handleRequestNodeIdentification,
conn=conn,
uuid=uuid,
packet=packet,
......@@ -374,7 +380,6 @@ server: 127.0.0.1:10020
node_type=MASTER_NODE_TYPE,
ip_address='127.0.0.1',
name=self.app.name,)
self.checkCalledAbort(conn)
self.assertEquals(len(conn.mockGetNamedCalls("setUUID")), 0)
def test_08_handleRequestNodeIdentification6(self):
......@@ -415,7 +420,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.master_port), })
packet = Packet(msg_type=ACCEPT_NODE_IDENTIFICATION)
self.app.trying_master_node = self.trying_master_node
self.bootstrap.handleAcceptNodeIdentification(
self.checkUnexpectedPacketRaised(
self.bootstrap.handleAcceptNodeIdentification,
conn=conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
......@@ -425,7 +431,6 @@ server: 127.0.0.1:10020
num_partitions=self.app.num_partitions,
num_replicas=self.app.num_replicas,
your_uuid=self.getNewUUID())
self.checkCalledAbort(conn)
def test_09_handleAcceptNodeIdentification2(self):
# not a master node -> rejected
......@@ -560,13 +565,13 @@ server: 127.0.0.1:10020
packet = Packet(msg_type=ANSWER_PRIMARY_MASTER)
self.app.trying_master_node = self.trying_master_node
self.app.primary_master_node = None
self.bootstrap.handleAnswerPrimaryMaster(
self.checkUnexpectedPacketRaised(
self.bootstrap.handleAnswerPrimaryMaster,
conn=conn,
packet=packet,
primary_uuid=self.getNewUUID(),
known_master_list=()
)
self.checkCalledAbort(conn)
self.assertEquals(self.app.trying_master_node, self.trying_master_node)
self.assertEquals(self.app.primary_master_node, None)
......
......@@ -28,6 +28,7 @@ from neo.storage.app import Application, StorageNode
from neo.storage.operation import TransactionInformation, OperationEventHandler
from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable
from neo import protocol
from neo.protocol import *
SQL_ADMIN_USER = 'root'
......@@ -52,12 +53,16 @@ class StorageOperationTests(unittest.TestCase):
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent"""
# sometimes we answer an error, sometimes we just send it
......@@ -373,7 +378,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.master_port),
})
count = len(self.app.nm.getNodeList())
self.operation.handleRequestNodeIdentification(
self.checkBrokenNotDisallowedErrorRaised(
self.operation.handleRequestNodeIdentification,
conn=conn,
packet=packet,
node_type=MASTER_NODE_TYPE,
......@@ -381,7 +387,6 @@ server: 127.0.0.1:10020
ip_address='127.0.0.1',
port=self.master_port,
name=self.app.name)
self.checkPacket(conn, packet_type=ERROR)
self.assertEquals(len(self.app.nm.getNodeList()), count)
def test_09_handleRequestNodeIdentification4(self):
......
......@@ -36,7 +36,7 @@ from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION
UNLOCK_INFORMATION, TID_NOT_FOUND_CODE, ASK_TRANSACTION_INFORMATION, ANSWER_TRANSACTION_INFORMATION, \
ANSWER_PARTITION_TABLE,SEND_PARTITION_TABLE, COMMIT_TRANSACTION
from neo.protocol import ERROR, BROKEN_NODE_DISALLOWED_CODE, ASK_PRIMARY_MASTER
from neo.protocol import ANSWER_PRIMARY_MASTER, UnexpectedPacketError
from neo.protocol import ANSWER_PRIMARY_MASTER
from neo.exception import PrimaryFailure, OperationFailure
from neo.storage.mysqldb import MySQLDatabaseManager, p64, u64
......@@ -129,12 +129,16 @@ server: 127.0.0.1:10020
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent"""
# sometimes we answer an error, sometimes we just notify it
......@@ -277,14 +281,10 @@ server: 127.0.0.1:10020
node = self.app.nm.getNodeByServer(conn.getAddress())
node.setState(BROKEN_STATE)
self.assertEqual(node.getUUID(), uuid)
self.verification.handleRequestNodeIdentification(conn, p, MASTER_NODE_TYPE,
uuid, "127.0.0.1", self.master_port, "main")
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
call = conn.mockGetNamedCalls("answer")[0]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ERROR)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 1)
self.checkBrokenNotDisallowedErrorRaised(
self.verification.handleRequestNodeIdentification,
conn, p, MASTER_NODE_TYPE,
uuid, "127.0.0.1", self.master_port, "main")
# change uuid of a known node
uuid = self.getNewUUID()
......
......@@ -88,10 +88,7 @@ class VerificationEventHandler(StorageEventHandler):
# If this node is broken, reject it.
if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away')
conn.answer(p, packet)
conn.abort()
return
raise protocol.BrokenNotDisallowedError
# Trust the UUID sent by the peer.
node.setUUID(uuid)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment