Commit ed50edca authored by Julien Muchembled's avatar Julien Muchembled

Simplify API to establish connections and accept mix of IPv4/IPv6

parent c2c97752
......@@ -21,7 +21,6 @@ from neo.lib.connection import ListeningConnection
from neo.lib.exception import PrimaryFailure
from .handler import AdminEventHandler, MasterEventHandler, \
MasterRequestEventHandler
from neo.lib.connector import getConnectorHandler
from neo.lib.bootstrap import BootstrapManager
from neo.lib.pt import PartitionTable
from neo.lib.protocol import ClusterStates, Errors, \
......@@ -39,8 +38,7 @@ class Application(object):
self.name = config.getCluster()
self.server = config.getBind()
self.master_addresses, connector_name = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
self.master_addresses = config.getMasters()
logging.debug('IP address is %s, port is %d', *self.server)
# The partition table is initialized after getting the number of
......@@ -87,8 +85,7 @@ class Application(object):
# Make a listening port.
handler = AdminEventHandler(self)
self.listening_conn = ListeningConnection(self.em, handler,
addr=self.server, connector=self.connector_handler())
self.listening_conn = ListeningConnection(self.em, handler, self.server)
while self.cluster_state != ClusterStates.STOPPING:
self.connectToPrimary()
......@@ -120,7 +117,7 @@ class Application(object):
# search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN,
self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler)
data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
nm.update([(node.getType(), node.getAddress(), node.getUUID(),
NodeStates.RUNNING)])
......
......@@ -36,7 +36,6 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.node import NodeManager
from neo.lib.connector import getConnectorHandler
from .exception import NEOStorageError, NEOStorageCreationUndoneError
from .exception import NEOStorageNotFoundError
from .handlers import storage, master
......@@ -80,8 +79,6 @@ class Application(object):
# Internal Attributes common to all thread
self._db = None
self.name = name
master_addresses, connector_name = parseMasterList(master_nodes)
self.connector_handler = getConnectorHandler(connector_name)
self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager(dynamic_master_list)
self.cp = ConnectionPool(self)
......@@ -90,7 +87,7 @@ class Application(object):
self.trying_master_node = None
# load master node list
for address in master_addresses:
for address in parseMasterList(master_nodes):
self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one
......@@ -290,7 +287,6 @@ class Application(object):
conn = MTClientConnection(self.em,
self.notifications_handler,
node=self.trying_master_node,
connector=self.connector_handler(),
dispatcher=self.dispatcher)
# Query for primary master node
if conn.getConnector() is None:
......
......@@ -54,7 +54,7 @@ class ConnectionPool(object):
app = self.app
logging.debug('trying to connect to %s - %s', node, node.getState())
conn = MTClientConnection(app.em, app.storage_event_handler, node,
connector=app.connector_handler(), dispatcher=app.dispatcher)
dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name)
try:
......
......@@ -116,7 +116,7 @@ class BootstrapManager(EventHandler):
logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.accepted = True
def getPrimaryConnection(self, connector_handler):
def getPrimaryConnection(self):
"""
Primary lookup/connection process.
Returns when the connection is made.
......@@ -140,8 +140,7 @@ class BootstrapManager(EventHandler):
sleep(1)
if conn is None:
# open the connection
conn = ClientConnection(em, self, self.current,
connector_handler())
conn = ClientConnection(em, self, self.current)
# still processing
em.poll(1)
return (self.current, conn, self.uuid, self.num_partitions,
......
......@@ -206,6 +206,7 @@ class BaseConnection(object):
Timeouts in HandlerSwitcher are only there to prioritize some packets.
"""
from .connector import SocketConnector as ConnectorClass
KEEP_ALIVE = 60
def __init__(self, event_manager, handler, connector, addr=None):
......@@ -318,19 +319,18 @@ attributeTracker.track(BaseConnection)
class ListeningConnection(BaseConnection):
"""A listen connection."""
def __init__(self, event_manager, handler, addr, connector, **kw):
def __init__(self, event_manager, handler, addr):
logging.debug('listening to %s:%d', *addr)
BaseConnection.__init__(self, event_manager, handler,
addr=addr, connector=connector)
self.connector.makeListeningConnection(addr)
connector = self.ConnectorClass(addr)
BaseConnection.__init__(self, event_manager, handler, connector, addr)
connector.makeListeningConnection()
def readable(self):
try:
new_s, addr = self.connector.getNewConnection()
connector, addr = self.connector.accept()
logging.debug('accepted a connection from %s:%d', *addr)
handler = self.getHandler()
new_conn = ServerConnection(self.em, handler,
connector=new_s, addr=addr)
new_conn = ServerConnection(self.em, handler, connector, addr)
handler.connectionAccepted(new_conn)
except ConnectorTryAgainException:
pass
......@@ -668,14 +668,15 @@ class ClientConnection(Connection):
connecting = True
client = True
def __init__(self, event_manager, handler, node, connector):
def __init__(self, event_manager, handler, node):
addr = node.getAddress()
connector = self.ConnectorClass(addr)
Connection.__init__(self, event_manager, handler, connector, addr)
node.setConnection(self)
handler.connectionStarted(self)
try:
try:
self.connector.makeClientConnection(addr)
connector.makeClientConnection()
except ConnectorInProgressException:
event_manager.addWriter(self)
else:
......
......@@ -19,52 +19,51 @@ import errno
# Global connector registry.
# Fill by calling registerConnectorHandler.
# Read by calling getConnectorHandler.
# Read by calling SocketConnector.__new__
connector_registry = {}
DEFAULT_CONNECTOR = 'SocketConnectorIPv4'
def registerConnectorHandler(connector_handler):
connector_registry[connector_handler.__name__] = connector_handler
def getConnectorHandler(connector=None):
if connector is None:
connector = DEFAULT_CONNECTOR
if isinstance(connector, basestring):
connector_handler = connector_registry.get(connector)
else:
# Allow to directly provide a handler class without requiring to
# register it first.
connector_handler = connector
return connector_handler
connector_registry[connector_handler.af_type] = connector_handler
class SocketConnector:
class SocketConnector(object):
""" This class is a wrapper for a socket """
is_listening = False
remote_addr = None
is_closed = None
is_closed = is_server = None
def __init__(self, s=None, accepted_from=None):
self.accepted_from = accepted_from
if accepted_from is not None:
self.remote_addr = accepted_from
self.is_listening = False
self.is_closed = False
def __new__(cls, addr, s=None):
if s is None:
self.socket = socket.socket(self.af_type, socket.SOCK_STREAM)
host, port = addr
for af_type, cls in connector_registry.iteritems():
try :
socket.inet_pton(af_type, host)
break
except socket.error:
pass
else:
raise ValueError("Unknown type of host", host)
self = object.__new__(cls)
self.addr = cls._normAddress(addr)
if s is None:
s = socket.socket(af_type, socket.SOCK_STREAM)
else:
self.is_server = True
self.is_closed = False
self.socket = s
self.socket_fd = self.socket.fileno()
self.socket_fd = s.fileno()
# always use non-blocking sockets
self.socket.setblocking(0)
s.setblocking(0)
# disable Nagle algorithm to reduce latency
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return self
def makeClientConnection(self, addr):
self.is_closed = False
self.remote_addr = addr
# Threaded tests monkey-patch the following 2 operations.
_connect = lambda self, addr: self.socket.connect(addr)
_bind = lambda self, addr: self.socket.bind(addr)
def makeClientConnection(self):
assert self.is_closed is None
self.is_server = self.is_closed = False
try:
self.socket.connect(addr)
self._connect(self.addr)
except socket.error, (err, errmsg):
if err == errno.EINPROGRESS:
raise ConnectorInProgressException
......@@ -73,12 +72,12 @@ class SocketConnector:
raise ConnectorException, 'makeClientConnection to %s failed:' \
' %s:%s' % (addr, err, errmsg)
def makeListeningConnection(self, addr):
def makeListeningConnection(self):
assert self.is_closed is None
self.is_closed = False
self.is_listening = True
try:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(addr)
self._bind(self.addr)
self.socket.listen(5)
except socket.error, (err, errmsg):
self.socket.close()
......@@ -94,15 +93,22 @@ class SocketConnector:
# in epoll
return self.socket_fd
def getNewConnection(self):
@staticmethod
def _normAddress(addr):
return addr
def getAddress(self):
return self._normAddress(self.socket.getsockname())
def accept(self):
try:
(new_s, addr) = self._accept()
new_s = self.__class__(new_s, accepted_from=addr)
return (new_s, addr)
s, addr = self.socket.accept()
s = self.__class__(addr, s)
return s, s.addr
except socket.error, (err, errmsg):
if err == errno.EAGAIN:
raise ConnectorTryAgainException
raise ConnectorException, 'getNewConnection failed: %s:%s' % \
raise ConnectorException, 'accept failed: %s:%s' % \
(err, errmsg)
def receive(self):
......@@ -139,14 +145,14 @@ class SocketConnector:
state = 'closed '
else:
state = 'opened '
if self.is_listening:
if self.is_server is None:
state += 'listening'
else:
if self.accepted_from is None:
state += 'to '
else:
if self.is_server:
state += 'from '
state += str(self.remote_addr)
else:
state += 'to '
state += str(self.addr)
return '<%s at 0x%x fileno %s %s, %s>' % (self.__class__.__name__,
id(self), '?' if self.is_closed else self.socket_fd,
self.getAddress(), state)
......@@ -155,22 +161,13 @@ class SocketConnectorIPv4(SocketConnector):
" Wrapper for IPv4 sockets"
af_type = socket.AF_INET
def _accept(self):
return self.socket.accept()
def getAddress(self):
return self.socket.getsockname()
class SocketConnectorIPv6(SocketConnector):
" Wrapper for IPv6 sockets"
af_type = socket.AF_INET6
def _accept(self):
new_s, addr = self.socket.accept()
return new_s, addr[:2]
def getAddress(self):
return self.socket.getsockname()[:2]
@staticmethod
def _normAddress(addr):
return addr[:2]
registerConnectorHandler(SocketConnectorIPv4)
registerConnectorHandler(SocketConnectorIPv6)
......
......@@ -19,12 +19,8 @@ import sys
import traceback
from cStringIO import StringIO
from struct import Struct
try:
from .util import getAddressType
except ImportError:
pass
PROTOCOL_VERSION = 2
PROTOCOL_VERSION = 3
# Size restrictions.
MIN_PACKET_SIZE = 10
......@@ -449,65 +445,6 @@ class PEnum(PStructItem):
enum = self._enum.__class__.__name__
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code)
class PAddressIPGeneric(PStructItem):
def __init__(self, name, format):
PStructItem.__init__(self, name, format)
def encode(self, writer, address):
host, port = address
host = socket.inet_pton(self.af_type, host)
writer(self.pack(host, port))
def decode(self, reader):
data = reader(self.size)
address = self.unpack(data)
host, port = address
host = socket.inet_ntop(self.af_type, host)
return (host, port)
class PAddressIPv4(PAddressIPGeneric):
af_type = socket.AF_INET
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!4sH')
class PAddressIPv6(PAddressIPGeneric):
af_type = socket.AF_INET6
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!16sH')
class PAddress(PStructItem):
"""
An host address (IPv4/IPv6)
"""
address_format_dict = {
socket.AF_INET: PAddressIPv4('ipv4'),
socket.AF_INET6: PAddressIPv6('ipv6'),
}
def __init__(self, name):
PStructItem.__init__(self, name, '!L')
def _encode(self, writer, address):
if address is None:
writer(self.pack(INVALID_ADDRESS_TYPE))
return
af_type = getAddressType(address)
writer(self.pack(af_type))
encoder = self.address_format_dict[af_type]
encoder.encode(writer, address)
def _decode(self, reader):
af_type = self.unpack(reader(self.size))[0]
if af_type == INVALID_ADDRESS_TYPE:
return None
decoder = self.address_format_dict[af_type]
host, port = decoder.decode(reader)
return (host, port)
class PString(PStructItem):
"""
A variable-length string
......@@ -523,6 +460,29 @@ class PString(PStructItem):
length = self.unpack(reader(self.size))[0]
return reader(length)
class PAddress(PString):
"""
An host address (IPv4/IPv6)
"""
def __init__(self, name):
PString.__init__(self, name)
self._port = Struct('!H')
def _encode(self, writer, address):
if address:
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else:
PString._encode(self, writer, '')
def _decode(self, reader):
host = PString._decode(self, reader)
if host:
p = self._port
return host, p.unpack(reader(p.size))[0]
class PBoolean(PStructItem):
"""
A boolean value, encoded as a single byte
......
......@@ -23,11 +23,6 @@ from Queue import deque
from struct import pack, unpack
from time import gmtime
SOCKET_CONNECTORS_DICT = {
socket.AF_INET : 'SocketConnectorIPv4',
socket.AF_INET6: 'SocketConnectorIPv6',
}
TID_LOW_OVERFLOW = 2**32
TID_LOW_MAX = TID_LOW_OVERFLOW - 1
SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW
......@@ -125,25 +120,6 @@ def makeChecksum(s):
return sha1(s).digest()
def getAddressType(address):
"Return the type (IPv4 or IPv6) of an ip"
(host, port) = address
for af_type in SOCKET_CONNECTORS_DICT:
try :
socket.inet_pton(af_type, host)
except:
continue
else:
break
else:
raise ValueError("Unknown type of host", host)
return af_type
def getConnectorFromAddress(address):
address_type = getAddressType(address)
return SOCKET_CONNECTORS_DICT[address_type]
def parseNodeAddress(address, port_opt=None):
if address[:1] == '[':
(host, port) = address[1:].split(']')
......@@ -164,24 +140,12 @@ def parseNodeAddress(address, port_opt=None):
def parseMasterList(masters, except_node=None):
assert masters, 'At least one master must be defined'
# load master node list
socket_connector = None
master_node_list = []
for node in masters.split(' '):
if not node:
continue
for node in masters.split():
address = parseNodeAddress(node)
if (address != except_node):
if address != except_node:
master_node_list.append(address)
socket_connector_temp = getConnectorFromAddress(address)
if socket_connector is None:
socket_connector = socket_connector_temp
elif socket_connector != socket_connector_temp:
raise TypeError("Wrong connector type : you're trying to use "
"ipv6 and ipv4 simultaneously")
return master_node_list, socket_connector
return master_node_list
class ReadBuffer(object):
......
......@@ -18,7 +18,6 @@ import sys, weakref
from time import time
from neo.lib import logging
from neo.lib.connector import getConnectorHandler
from neo.lib.debug import register as registerLiveDebugger
from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
......@@ -59,9 +58,7 @@ class Application(object):
self.autostart = config.getAutostart()
self.storage_readiness = set()
master_addresses, connector_name = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
for master_address in config.getMasters():
self.nm.createMaster(address=master_address)
logging.debug('IP address is %s, port is %d', *self.server)
......@@ -102,7 +99,7 @@ class Application(object):
raise ValueError("upstream cluster name must be"
" different from cluster name")
self.backup_app = BackupApplication(self, upstream_cluster,
*config.getUpstreamMasters())
config.getUpstreamMasters())
self.administration_handler = administration.AdministrationHandler(
self)
......@@ -141,8 +138,7 @@ class Application(object):
def _run(self):
"""Make sure that the status is sane and start a loop."""
# Make a listening port.
self.listening_conn = ListeningConnection(self.em, None,
addr=self.server, connector=self.connector_handler())
self.listening_conn = ListeningConnection(self.em, None, self.server)
# Start a normal operation.
while self.cluster_state != ClusterStates.STOPPING:
......@@ -196,8 +192,7 @@ class Application(object):
ClientConnection(self.em, client_handler,
# XXX: Ugly, but the whole election code will be
# replaced soon
node=getByAddress(addr),
connector=self.connector_handler())
getByAddress(addr))
self.unconnected_master_node_set.clear()
self.em.poll(1)
except ElectionFailure, m:
......@@ -381,9 +376,7 @@ class Application(object):
# Reconnect to primary master node.
primary_handler = secondary.PrimaryHandler(self)
ClientConnection(self.em, primary_handler,
node=self.primary_master_node,
connector=self.connector_handler())
ClientConnection(self.em, primary_handler, self.primary_master_node)
# and another for the future incoming connections
self.listening_conn.setHandler(
......
......@@ -19,7 +19,6 @@ from bisect import bisect
from collections import defaultdict
from neo.lib import logging
from neo.lib.bootstrap import BootstrapManager
from neo.lib.connector import getConnectorHandler
from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler
from neo.lib.node import NodeManager
......@@ -67,11 +66,10 @@ class BackupApplication(object):
pt = None
def __init__(self, app, name, master_addresses, connector_name):
def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app)
self.name = name
self.nm = NodeManager()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
self.nm.createMaster(address=master_address)
......@@ -107,7 +105,7 @@ class BackupApplication(object):
break
poll(1)
node, conn, uuid, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection(self.connector_handler)
bootstrap.getPrimaryConnection()
try:
app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node
......
......@@ -14,11 +14,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.connector import getConnectorHandler
from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager
from neo.lib.protocol import ClusterStates, NodeStates, ErrorCodes, Packets
from neo.lib.util import getConnectorFromAddress
from neo.lib.node import NodeManager
from .handler import CommandEventHandler
......@@ -31,8 +29,6 @@ class NeoCTL(object):
connected = False
def __init__(self, address):
connector_name = getConnectorFromAddress(address)
self.connector_handler = getConnectorHandler(connector_name)
self.nm = nm = NodeManager()
self.server = nm.createAdmin(address=address)
self.em = EventManager()
......@@ -47,7 +43,7 @@ class NeoCTL(object):
def __getConnection(self):
if not self.connected:
self.connection = ClientConnection(self.em, self.handler,
node=self.server, connector=self.connector_handler())
self.server)
while not self.connected:
self.em.poll(1)
if self.connection is None:
......
......@@ -24,7 +24,6 @@ from neo.lib.node import NodeManager
from neo.lib.event import EventManager
from neo.lib.connection import ListeningConnection
from neo.lib.exception import OperationFailure, PrimaryFailure
from neo.lib.connector import getConnectorHandler
from neo.lib.pt import PartitionTable
from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager
......@@ -54,9 +53,7 @@ class Application(object):
)
# load master nodes
master_addresses, connector_name = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses :
for master_address in config.getMasters():
self.nm.createMaster(address=master_address)
# set the bind address
......@@ -177,8 +174,7 @@ class Application(object):
# Make a listening port
handler = identification.IdentificationHandler(self)
self.listening_conn = ListeningConnection(self.em, handler,
addr=self.server, connector=self.connector_handler())
self.listening_conn = ListeningConnection(self.em, handler, self.server)
self.server = self.listening_conn.getAddress()
# Connect to a primary master node, verify data, and
......@@ -234,7 +230,7 @@ class Application(object):
# search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name,
NodeTypes.STORAGE, self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler)
data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node
self.master_conn = conn
......
......@@ -46,7 +46,7 @@ class Checker(object):
conn.asClient()
else:
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
node)
conn.ask(Packets.RequestIdentification(
NodeTypes.STORAGE, uuid, app.server, name))
self.conn_dict[conn] = node.isIdentified()
......
......@@ -254,8 +254,7 @@ class Replicator(object):
self.fetchTransactions()
else:
assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
conn = ClientConnection(app.em, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected():
......
......@@ -30,7 +30,6 @@ from functools import wraps
from mock import Mock
from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import getAddressType
from time import time
from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess
......@@ -203,8 +202,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({
'getCluster': cluster,
'getBind': masters[0],
'getMasters': (masters, getAddressType((
self.local_ip, 0))),
'getMasters': masters,
'getReplicas': replicas,
'getPartitions': partitions,
'getUUID': uuid,
......@@ -226,8 +224,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({
'getCluster': cluster,
'getBind': (masters[0], 10020 + index),
'getMasters': (masters, getAddressType((
self.local_ip, 0))),
'getMasters': masters,
'getDatabase': db,
'getUUID': uuid,
'getReset': False,
......@@ -554,29 +551,5 @@ class Patch(object):
self.__del__()
connector_cpt = 0
class DoNothingConnector(Mock):
def __init__(self, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
Mock.__init__(self)
def getAddress(self):
return self.addr
def makeClientConnection(self, addr):
self.addr = addr
def makeListeningConnection(self, addr):
self.addr = addr
def getDescriptor(self):
return self.desc
__builtin__.pdb = lambda depth=0: \
debug.getPdb().set_trace(sys._getframe(depth+1))
......@@ -25,7 +25,7 @@ from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.lib.protocol import NodeTypes, Packets, Errors, \
INVALID_PARTITION, UUID_NAMESPACES
from neo.lib.util import makeChecksum, SOCKET_CONNECTORS_DICT
from neo.lib.util import makeChecksum
import time
class Dispatcher(object):
......@@ -95,10 +95,9 @@ class ClientApplicationTests(NeoUnitTestBase):
return txn_context
def getApp(self, master_nodes=None, name='test', **kw):
connector = SOCKET_CONNECTORS_DICT[ADDRESS_TYPE]
if master_nodes is None:
master_nodes = '%s:10010' % buildUrlFromString(self.local_ip)
app = Application(master_nodes, name, connector, **kw)
app = Application(master_nodes, name, **kw)
self._to_stop_list.append(app)
app.dispatcher = Mock({ })
return app
......@@ -750,7 +749,6 @@ class ClientApplicationTests(NeoUnitTestBase):
# the third will not be ready
# after the third, the partition table will be operational
# (as if it was connected to the primary master node)
from .. import DoNothingConnector
# will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases
......@@ -797,7 +795,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.getByAddress(conn.getAddress())._connection = None
app._ask = _ask_base
# faked environnement
app.connector_handler = DoNothingConnector
app.em = Mock({'getConnectionList': []})
app.pt = Mock({ 'operational': False})
app.master_conn = app._connectToPrimaryNode()
......
......@@ -17,17 +17,43 @@
import unittest
from time import time
from mock import Mock
from neo.lib import connection
from neo.lib.connection import ListeningConnection, Connection, \
ClientConnection, ServerConnection, MTClientConnection, \
from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ListeningConnection, \
Connection, ClientConnection, ServerConnection, MTClientConnection, \
HandlerSwitcher, CRITICAL_TIMEOUT
from neo.lib.connector import getConnectorHandler, registerConnectorHandler
from . import DoNothingConnector
from neo.lib.connector import registerConnectorHandler
from neo.lib.connector import ConnectorException, ConnectorTryAgainException, \
ConnectorInProgressException, ConnectorConnectionRefusedException
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, PACKET_HEADER_FORMAT
from . import NeoUnitTestBase
from . import NeoUnitTestBase, Patch
connector_cpt = 0
class DummyConnector(Mock):
def __init__(self, addr, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
self.addr = addr
Mock.__init__(self)
def getAddress(self):
return self.addr
def getDescriptor(self):
return self.desc
accept = getError = makeClientConnection = makeListeningConnection = \
receive = send = lambda *args, **kw: None
dummy_connector = Patch(BaseConnection,
ConnectorClass=lambda orig, self, *args, **kw: DummyConnector(*args, **kw))
class ConnectionTests(NeoUnitTestBase):
......@@ -41,25 +67,23 @@ class ConnectionTests(NeoUnitTestBase):
connection.connect_limit = 0
def _makeListeningConnection(self, addr):
# create instance after monkey patches
self.connector = DoNothingConnector()
return ListeningConnection(event_manager=self.em, handler=self.handler,
connector=self.connector, addr=addr)
with dummy_connector:
conn = ListeningConnection(self.em, self.handler, addr)
self.connector = conn.connector
return conn
def _makeConnection(self):
self.connector = DoNothingConnector()
return Connection(event_manager=self.em, handler=self.handler,
connector=self.connector, addr=self.address)
addr = self.address
self.connector = DummyConnector(addr)
return Connection(self.em, self.handler, self.connector, addr)
def _makeClientConnection(self):
self.connector = DoNothingConnector()
return ClientConnection(event_manager=self.em, handler=self.handler,
connector=self.connector, node=self.node)
with dummy_connector:
conn = ClientConnection(self.em, self.handler, self.node)
self.connector = conn.connector
return conn
def _makeServerConnection(self):
self.connector = DoNothingConnector()
return ServerConnection(event_manager=self.em, handler=self.handler,
connector=self.connector, addr=self.address)
_makeServerConnection = _makeConnection
def _checkRegistered(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("register")), n)
......@@ -82,8 +106,8 @@ class ConnectionTests(NeoUnitTestBase):
def _checkClose(self, n=1):
self.assertEqual(len(self.connector.mockGetNamedCalls("close")), n)
def _checkGetNewConnection(self, n=1):
calls = self.connector.mockGetNamedCalls('getNewConnection')
def _checkAccept(self, n=1):
calls = self.connector.mockGetNamedCalls('accept')
self.assertEqual(len(calls), n)
def _checkSend(self, n=1, data=None):
......@@ -120,7 +144,6 @@ class ConnectionTests(NeoUnitTestBase):
def _checkMakeClientConnection(self, n=1):
calls = self.connector.mockGetNamedCalls("makeClientConnection")
self.assertEqual(len(calls), n)
self.assertEqual(calls[n-1].getParam(0), self.address)
def _checkPacketReceived(self, n=1):
calls = self.handler.mockGetNamedCalls('packetReceived')
......@@ -140,28 +163,17 @@ class ConnectionTests(NeoUnitTestBase):
def _checkWriteBuf(self, bc, data):
self.assertEqual(''.join(bc.write_buf), data)
def test_01_BaseConnection1(self):
# init with connector
registerConnectorHandler(DoNothingConnector)
connector = getConnectorHandler("DoNothingConnector")()
self.assertFalse(connector is None)
bc = self._makeConnection()
self.assertFalse(bc.connector is None)
self._checkRegistered(1)
def test_01_BaseConnection2(self):
def test_01_BaseConnection(self):
# init with address
bc = self._makeConnection()
self.assertEqual(bc.getAddress(), self.address)
self.assertIsNot(bc.connector, None)
self._checkRegistered(1)
def test_02_ListeningConnection1(self):
# test init part
def getNewConnection(self):
return self, ('', 0)
DoNothingConnector.getNewConnection = getNewConnection
addr = ("127.0.0.7", 93413)
try:
with Patch(DummyConnector, accept=lambda orig, self: (self, ('', 0))):
bc = self._makeListeningConnection(addr=addr)
self.assertEqual(bc.getAddress(), addr)
self._checkRegistered()
......@@ -169,18 +181,15 @@ class ConnectionTests(NeoUnitTestBase):
self._checkMakeListeningConnection()
# test readable
bc.readable()
self._checkGetNewConnection()
self._checkAccept()
self._checkConnectionAccepted()
finally:
del DoNothingConnector.getNewConnection
def test_02_ListeningConnection2(self):
# test with exception raise when getting new connection
def getNewConnection(self):
def accept(orig, self):
raise ConnectorTryAgainException
DoNothingConnector.getNewConnection = getNewConnection
addr = ("127.0.0.7", 93413)
try:
with Patch(DummyConnector, accept=accept):
bc = self._makeListeningConnection(addr=addr)
self.assertEqual(bc.getAddress(), addr)
self._checkRegistered()
......@@ -188,10 +197,8 @@ class ConnectionTests(NeoUnitTestBase):
self._checkMakeListeningConnection()
# test readable
bc.readable()
self._checkGetNewConnection(1)
self._checkAccept(1)
self._checkConnectionAccepted(0)
finally:
del DoNothingConnector.getNewConnection
def test_03_Connection(self):
bc = self._makeConnection()
......@@ -229,38 +236,29 @@ class ConnectionTests(NeoUnitTestBase):
def test_Connection_recv1(self):
# patch receive method to return data
def receive(self):
return "testdata"
DoNothingConnector.receive = receive
try:
with Patch(DummyConnector, receive=lambda orig, self: "testdata"):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
bc._recv()
self._checkReadBuf(bc, 'testdata')
finally:
del DoNothingConnector.receive
def test_Connection_recv2(self):
# patch receive method to raise try again
def receive(self):
def receive(orig, self):
raise ConnectorTryAgainException
DoNothingConnector.receive = receive
try:
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
bc._recv()
self._checkReadBuf(bc, '')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
finally:
del DoNothingConnector.receive
def test_Connection_recv3(self):
# patch receive method to raise ConnectorConnectionRefusedException
def receive(self):
def receive(orig, self):
raise ConnectorConnectionRefusedException
DoNothingConnector.receive = receive
try:
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
# fake client connection instance with connecting attribute
......@@ -269,23 +267,18 @@ class ConnectionTests(NeoUnitTestBase):
self._checkReadBuf(bc, '')
self._checkConnectionFailed(1)
self._checkUnregistered(1)
finally:
del DoNothingConnector.receive
def test_Connection_recv4(self):
# patch receive method to raise any other connector error
def receive(self):
def receive(orig, self):
raise ConnectorException
DoNothingConnector.receive = receive
try:
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
self.assertRaises(ConnectorException, bc._recv)
self._checkReadBuf(bc, '')
self._checkConnectionClosed(1)
self._checkUnregistered(1)
finally:
del DoNothingConnector.receive
def test_Connection_send1(self):
# no data, nothing done
......@@ -299,10 +292,7 @@ class ConnectionTests(NeoUnitTestBase):
def test_Connection_send2(self):
# send all data
def send(self, data):
return len(data)
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
......@@ -311,15 +301,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send3(self):
# send part of the data
def send(self, data):
return len(data)/2
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
......@@ -328,15 +313,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, 'data')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send4(self):
# send multiple packet
def send(self, data):
return len(data)
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
......@@ -345,15 +325,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send5(self):
# send part of multiple packet
def send(self, data):
return len(data)/2
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
......@@ -362,15 +337,12 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, 'econdthird')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send6(self):
# raise try again
def send(self, data):
def send(orig, self, data):
raise ConnectorTryAgainException
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=send):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
......@@ -379,15 +351,12 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, 'testdatasecondthird')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send7(self):
# raise other error
def send(self, data):
def send(orig, self, data):
raise ConnectorException
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=send):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
......@@ -397,8 +366,6 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(1)
self._checkUnregistered(1)
finally:
del DoNothingConnector.send
def test_07_Connection_addPacket(self):
# new packet
......@@ -499,10 +466,7 @@ class ConnectionTests(NeoUnitTestBase):
def test_Connection_writable1(self):
# with pending operation after send
def send(self, data):
return len(data)/2
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
......@@ -520,15 +484,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(0)
self._checkReaderRemoved(0)
self._checkClose(0)
finally:
del DoNothingConnector.send
def test_Connection_writable2(self):
# without pending operation after send
def send(self, data):
return len(data)
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
......@@ -546,15 +505,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(1)
self._checkReaderRemoved(0)
self._checkClose(0)
finally:
del DoNothingConnector.send
def test_Connection_writable3(self):
# without pending operation after send and aborted set to true
def send(self, data):
return len(data)
DoNothingConnector.send = send
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
......@@ -571,18 +525,15 @@ class ConnectionTests(NeoUnitTestBase):
# nothing else pending, so writer has been removed
self.assertFalse(bc.pending())
self._checkClose(1)
finally:
del DoNothingConnector.send
def test_Connection_readable(self):
# With aborted set to false
# patch receive method to return data
def receive(self):
def receive(orig, self):
p = Packets.AnswerPrimary(self.getNewUUID(None))
p.setId(1)
return ''.join(p.encode())
DoNothingConnector.receive = receive
try:
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
bc._queue = Mock({'__len__': 0})
self._checkReadBuf(bc, '')
......@@ -602,8 +553,6 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(0)
self._checkReaderRemoved(0)
self._checkClose(0)
finally:
del DoNothingConnector.receive
def test_ClientConnection_init1(self):
# create a good client connection
......@@ -624,14 +573,10 @@ class ConnectionTests(NeoUnitTestBase):
def test_ClientConnection_init2(self):
# raise connection in progress
makeClientConnection_org = DoNothingConnector.makeClientConnection
def makeClientConnection(self, *args, **kw):
def makeClientConnection(orig, self):
raise ConnectorInProgressException
DoNothingConnector.makeClientConnection = makeClientConnection
try:
with Patch(DummyConnector, makeClientConnection=makeClientConnection):
bc = self._makeClientConnection()
finally:
DoNothingConnector.makeClientConnection = makeClientConnection_org
# check connector created and connection initialize
self.assertTrue(bc.connecting)
self.assertFalse(bc.isServer())
......@@ -648,14 +593,10 @@ class ConnectionTests(NeoUnitTestBase):
def test_ClientConnection_init3(self):
# raise another error, connection must fail
makeClientConnection_org = DoNothingConnector.makeClientConnection
def makeClientConnection(self, *args, **kw):
def makeClientConnection(orig, self):
raise ConnectorException
DoNothingConnector.makeClientConnection = makeClientConnection
try:
with Patch(DummyConnector, makeClientConnection=makeClientConnection):
self.assertRaises(ConnectorException, self._makeClientConnection)
finally:
DoNothingConnector.makeClientConnection = makeClientConnection_org
# since the exception was raised, the connection is not created
# check call to handler
self._checkConnectionStarted(1)
......@@ -667,18 +608,11 @@ class ConnectionTests(NeoUnitTestBase):
def test_ClientConnection_writable1(self):
# with a non connecting connection, will call parent's method
def makeClientConnection(self, *args, **kw):
return "OK"
def send(self, data):
return len(data)
makeClientConnection_org = DoNothingConnector.makeClientConnection
DoNothingConnector.send = send
DoNothingConnector.makeClientConnection = makeClientConnection
try:
try:
with Patch(DummyConnector, send=lambda orig, self, data: len(data)), \
Patch(DummyConnector,
makeClientConnection=lambda orig, self: "OK") as p:
bc = self._makeClientConnection()
finally:
DoNothingConnector.makeClientConnection = makeClientConnection_org
p.revert()
# check connector created and connection initialize
self.assertFalse(bc.connecting)
self._checkWriteBuf(bc, '')
......@@ -701,19 +635,12 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(1)
self._checkReaderRemoved(0)
self._checkClose(0)
finally:
del DoNothingConnector.send
def test_ClientConnection_writable2(self):
# with a connecting connection, must not call parent's method
# with errors, close connection
def getError(self):
return True
DoNothingConnector.getError = getError
try:
with Patch(DummyConnector, getError=lambda orig, self: True):
bc = self._makeClientConnection()
finally:
del DoNothingConnector.getError
# check connector created and connection initialize
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
......@@ -836,10 +763,11 @@ class MTConnectionTests(ConnectionTests):
self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'})
def _makeClientConnection(self):
self.connector = DoNothingConnector()
return MTClientConnection(event_manager=self.em, handler=self.handler,
connector=self.connector, node=self.node,
with dummy_connector:
conn = MTClientConnection(self.em, self.handler, self.node,
dispatcher=self.dispatcher)
self.connector = conn.connector
return conn
def test_MTClientConnectionQueueParameter(self):
ask = self._makeClientConnection().ask
......
......@@ -15,35 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
import socket
from . import NeoUnitTestBase, IP_VERSION_FORMAT_DICT
from neo.lib.util import ReadBuffer, getAddressType, parseNodeAddress, \
getConnectorFromAddress, SOCKET_CONNECTORS_DICT
from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress
class UtilTests(NeoUnitTestBase):
def test_getConnectorFromAddress(self):
""" Connector name must correspond to address type """
connector = getConnectorFromAddress((
IP_VERSION_FORMAT_DICT[socket.AF_INET], 0))
self.assertEqual(connector, SOCKET_CONNECTORS_DICT[socket.AF_INET])
connector = getConnectorFromAddress((
IP_VERSION_FORMAT_DICT[socket.AF_INET6], 0))
self.assertEqual(connector, SOCKET_CONNECTORS_DICT[socket.AF_INET6])
self.assertRaises(ValueError, getConnectorFromAddress, ('', 0))
self.assertRaises(ValueError, getConnectorFromAddress, ('test', 0))
def test_getAddressType(self):
""" Get the type on an IP Address """
self.assertRaises(ValueError, getAddressType, ('', 0))
address_type = getAddressType(('::1', 0))
self.assertEqual(address_type, socket.AF_INET6)
address_type = getAddressType(('0.0.0.0', 0))
self.assertEqual(address_type, socket.AF_INET)
address_type = getAddressType(('127.0.0.1', 0))
self.assertEqual(address_type, socket.AF_INET)
def test_parseNodeAddress(self):
""" Parsing of addesses """
def test(parsed, *args):
......
......@@ -35,7 +35,7 @@ from neo.lib.connector import SocketConnector, \
ConnectorConnectionRefusedException, ConnectorTryAgainException
from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64
from neo.lib.util import parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
......@@ -166,7 +166,7 @@ class SerializedEventManager(EventManager):
class Node(object):
def getConnectionList(self, *peers):
addr = lambda c: c and (c.accepted_from or c.getAddress())
addr = lambda c: c and (c.addr if c.is_server else c.getAddress())
addr_set = {addr(c.connector) for peer in peers
for c in peer.em.connection_dict.itervalues()
if isinstance(c, Connection)}
......@@ -467,10 +467,8 @@ class ConnectionFilter(object):
class NEOCluster(object):
BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout)
SocketConnector_makeClientConnection = staticmethod(
SocketConnector.makeClientConnection)
SocketConnector_makeListeningConnection = staticmethod(
SocketConnector.makeListeningConnection)
SocketConnector_bind = staticmethod(SocketConnector._bind)
SocketConnector_connect = staticmethod(SocketConnector._connect)
SocketConnector_receive = staticmethod(SocketConnector.receive)
SocketConnector_send = staticmethod(SocketConnector.send)
_patch_count = 0
......@@ -489,12 +487,6 @@ class NEOCluster(object):
cls._patch_count += 1
if cls._patch_count > 1:
return
def makeClientConnection(self, addr):
real_addr = ServerNode.resolv(addr)
try:
return cls.SocketConnector_makeClientConnection(self, real_addr)
finally:
self.remote_addr = addr
def send(self, msg):
result = cls.SocketConnector_send(self, msg)
if type(Serialized.pending) is not frozenset:
......@@ -518,9 +510,10 @@ class NEOCluster(object):
# safely started even if the cluster isn't.
bootstrap.sleep = lambda seconds: None
BaseConnection.getTimeout = lambda self: None
SocketConnector.makeClientConnection = makeClientConnection
SocketConnector.makeListeningConnection = lambda self, addr: \
cls.SocketConnector_makeListeningConnection(self, BIND)
SocketConnector._bind = lambda self, addr: \
cls.SocketConnector_bind(self, BIND)
SocketConnector._connect = lambda self, addr: \
cls.SocketConnector_connect(self, ServerNode.resolv(addr))
SocketConnector.receive = receive
SocketConnector.send = send
Serialized.init()
......@@ -534,10 +527,8 @@ class NEOCluster(object):
return
bootstrap.sleep = time.sleep
BaseConnection.getTimeout = cls.BaseConnection_getTimeout
SocketConnector.makeClientConnection = \
cls.SocketConnector_makeClientConnection
SocketConnector.makeListeningConnection = \
cls.SocketConnector_makeListeningConnection
SocketConnector._bind = cls.SocketConnector_bind
SocketConnector._connect = cls.SocketConnector_connect
SocketConnector.receive = cls.SocketConnector_receive
SocketConnector.send = cls.SocketConnector_send
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment