Commit 57481c35 authored by Julien Muchembled's avatar Julien Muchembled

Review API betweeen connections and connectors

- Review error handling. Only 2 exceptions remain in connector.py:

  - Drop useless exception handling for EAGAIN since it should not happen
    if the kernel says the socket is ready.
  - Do not distinguish other socket errors. Just close and log in a generic way.
  - No need to raise a specific exception for EOF.
  - Make 'connect' return a boolean instead of raising an exception.
  - Raise appropriate exception when answer/ask/notify is called on a closed
    non-MT connection.

- Add support for more complex connectors, which may need to write for a read
  operation, or to read when there's pending data to send. This will be
  required for SSL support (more exactly, the handshake will be done in
  a transparent way):

  - Move write buffer to connector.
  - Make 'receive' fill the read buffer, instead of returning the read data.
  - Make 'receive' & 'send' return a boolean to switch polling for writing.
  - Tolerate that sockets return 0 as number of bytes sent.

- In testConnection, simply delete all failing tests, as announced
  in commit 71e30fb9.
parent 36a32f23
...@@ -18,9 +18,7 @@ from functools import wraps ...@@ -18,9 +18,7 @@ from functools import wraps
from time import time from time import time
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorTryAgainException, \ from .connector import ConnectorException, ConnectorDelayedConnection
ConnectorInProgressException, ConnectorConnectionRefusedException, \
ConnectorConnectionClosedException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, \ from .protocol import uuid_str, Errors, \
PacketMalformedError, Packets, ParserState PacketMalformedError, Packets, ParserState
...@@ -31,14 +29,6 @@ CRITICAL_TIMEOUT = 30 ...@@ -31,14 +29,6 @@ CRITICAL_TIMEOUT = 30
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
def not_closed(func):
def decorator(self, *args, **kw):
if self.connector is None:
raise ConnectorConnectionClosedException
return func(self, *args, **kw)
return wraps(func)(decorator)
class HandlerSwitcher(object): class HandlerSwitcher(object):
_is_handling = False _is_handling = False
_next_timeout = None _next_timeout = None
...@@ -316,14 +306,11 @@ class ListeningConnection(BaseConnection): ...@@ -316,14 +306,11 @@ class ListeningConnection(BaseConnection):
self.em.register(self) self.em.register(self)
def readable(self): def readable(self):
try: connector, addr = self.connector.accept()
connector, addr = self.connector.accept() logging.debug('accepted a connection from %s:%d', *addr)
logging.debug('accepted a connection from %s:%d', *addr) handler = self.getHandler()
handler = self.getHandler() new_conn = ServerConnection(self.em, handler, connector, addr)
new_conn = ServerConnection(self.em, handler, connector, addr) handler.connectionAccepted(new_conn)
handler.connectionAccepted(new_conn)
except ConnectorTryAgainException:
pass
def getAddress(self): def getAddress(self):
return self.connector.getAddress() return self.connector.getAddress()
...@@ -347,7 +334,6 @@ class Connection(BaseConnection): ...@@ -347,7 +334,6 @@ class Connection(BaseConnection):
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer() self.read_buf = ReadBuffer()
self.write_buf = []
self.cur_id = 0 self.cur_id = 0
self.aborted = False self.aborted = False
self.uuid = None self.uuid = None
...@@ -444,39 +430,47 @@ class Connection(BaseConnection): ...@@ -444,39 +430,47 @@ class Connection(BaseConnection):
"""Abort dealing with this connection.""" """Abort dealing with this connection."""
logging.debug('aborting a connector for %r', self) logging.debug('aborting a connector for %r', self)
self.aborted = True self.aborted = True
assert self.write_buf assert self.pending()
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
def writable(self): def writable(self):
"""Called when self is writable.""" """Called when self is writable."""
self._send() try:
if not self.write_buf and self.connector is not None: if self.connector.send():
if self.aborted: if self.aborted:
self.close() self.close()
else: else:
self.em.removeWriter(self) self.em.removeWriter(self)
except ConnectorException:
self._closure()
def readable(self): def readable(self):
"""Called when self is readable.""" """Called when self is readable."""
self._recv() # last known remote activity
self._analyse() self._next_timeout = time() + self._timeout
if self.aborted: read_buf = self.read_buf
self.em.removeReader(self)
return not not self._queue
def _analyse(self):
"""Analyse received data."""
try: try:
while True: try:
packet = Packets.parse(self.read_buf, self._parser_state) if self.connector.receive(read_buf):
if packet is None: self.em.addWriter(self)
break finally:
self._queue.append(packet) # A connector may read some data
# before raising ConnectorException
while 1:
packet = Packets.parse(read_buf, self._parser_state)
if packet is None:
break
self._queue.append(packet)
except ConnectorException:
self._closure()
except PacketMalformedError, e: except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', self, e) logging.error('malformed packet from %r: %s', self, e)
self._closure() self._closure()
if self.aborted:
self.em.removeReader(self)
return not not self._queue
def hasPendingMessages(self): def hasPendingMessages(self):
""" """
...@@ -493,7 +487,8 @@ class Connection(BaseConnection): ...@@ -493,7 +487,8 @@ class Connection(BaseConnection):
self.updateTimeout() self.updateTimeout()
def pending(self): def pending(self):
return self.connector is not None and self.write_buf connector = self.connector
return connector is not None and connector.queued
@property @property
def setReconnectionNoDelay(self): def setReconnectionNoDelay(self):
...@@ -503,7 +498,6 @@ class Connection(BaseConnection): ...@@ -503,7 +498,6 @@ class Connection(BaseConnection):
if self.connector is None: if self.connector is None:
assert self._on_close is None assert self._on_close is None
assert not self.read_buf assert not self.read_buf
assert not self.write_buf
assert not self.isPending() assert not self.isPending()
return return
# process the network events with the last registered handler to # process the network events with the last registered handler to
...@@ -514,7 +508,6 @@ class Connection(BaseConnection): ...@@ -514,7 +508,6 @@ class Connection(BaseConnection):
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
del self.write_buf[:]
self.read_buf.clear() self.read_buf.clear()
try: try:
if self.connecting: if self.connecting:
...@@ -531,89 +524,28 @@ class Connection(BaseConnection): ...@@ -531,89 +524,28 @@ class Connection(BaseConnection):
self._handlers.handle(self, self._queue.pop(0)) self._handlers.handle(self, self._queue.pop(0))
self.close() self.close()
def _recv(self):
"""Receive data from a connector."""
try:
data = self.connector.receive()
except ConnectorTryAgainException:
pass
except ConnectorConnectionRefusedException:
assert self.connecting
self._closure()
except ConnectorConnectionClosedException:
# connection resetted by peer, according to the man, this error
# should not occurs but it seems it's false
logging.debug('Connection reset by peer: %r', self.connector)
self._closure()
except:
logging.debug('Unknown connection error: %r', self.connector)
self._closure()
# unhandled connector exception
raise
else:
if not data:
logging.debug('Connection %r closed in recv', self.connector)
self._closure()
return
# last known remote activity
self._next_timeout = time() + self._timeout
self.read_buf.append(data)
def _send(self):
"""Send data to a connector."""
if not self.write_buf:
return
msg = ''.join(self.write_buf)
try:
n = self.connector.send(msg)
except ConnectorTryAgainException:
pass
except ConnectorConnectionClosedException:
# connection resetted by peer
logging.debug('Connection reset by peer: %r', self.connector)
self._closure()
except:
logging.debug('Unknown connection error: %r', self.connector)
# unhandled connector exception
self._closure()
raise
else:
if not n:
logging.debug('Connection %r closed in send', self.connector)
self._closure()
return
if n == len(msg):
del self.write_buf[:]
else:
self.write_buf = [msg[n:]]
def _addPacket(self, packet): def _addPacket(self, packet):
"""Add a packet into the write buffer.""" """Add a packet into the write buffer."""
if self.connector is None: if self.connector.queue(packet.encode()):
return
was_empty = not self.write_buf
self.write_buf.extend(packet.encode())
if was_empty:
# enable polling for writing. # enable polling for writing.
self.em.addWriter(self) self.em.addWriter(self)
logging.packet(self, packet, True) logging.packet(self, packet, True)
@not_closed
def notify(self, packet): def notify(self, packet):
""" Then a packet with a new ID """ """ Then a packet with a new ID """
if self.isClosed():
raise ConnectionClosed
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self._addPacket(packet) self._addPacket(packet)
return msg_id return msg_id
@not_closed
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, **kw): def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, **kw):
""" """
Send a packet with a new ID and register the expectation of an answer Send a packet with a new ID and register the expectation of an answer
""" """
if self.isClosed():
raise ConnectionClosed
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self._addPacket(packet) self._addPacket(packet)
...@@ -627,9 +559,10 @@ class Connection(BaseConnection): ...@@ -627,9 +559,10 @@ class Connection(BaseConnection):
self.em.wakeup() self.em.wakeup()
return msg_id return msg_id
@not_closed
def answer(self, packet, msg_id=None): def answer(self, packet, msg_id=None):
""" Answer to a packet by re-using its ID for the packet answer """ """ Answer to a packet by re-using its ID for the packet answer """
if self.isClosed():
raise ConnectionClosed
if msg_id is None: if msg_id is None:
msg_id = self.getPeerId() msg_id = self.getPeerId()
packet.setId(msg_id) packet.setId(msg_id)
...@@ -656,32 +589,25 @@ class ClientConnection(Connection): ...@@ -656,32 +589,25 @@ class ClientConnection(Connection):
def _connect(self): def _connect(self):
try: try:
self.connector.makeClientConnection() connected = self.connector.makeClientConnection()
except ConnectorInProgressException:
self.em.register(self)
self.em.addWriter(self)
except ConnectorDelayedConnection, c: except ConnectorDelayedConnection, c:
connect_limit, = c.args connect_limit, = c.args
self.getTimeout = lambda: connect_limit self.getTimeout = lambda: connect_limit
self.onTimeout = self._delayedConnect self.onTimeout = self._delayedConnect
self.em.register(self, timeout_only=True) self.em.register(self, timeout_only=True)
# Fake _addPacket so that if does not
# try to reenable polling for writing.
self.write_buf.insert(0, '')
except ConnectorConnectionRefusedException:
self._closure()
except ConnectorException: except ConnectorException:
# unhandled connector exception
self._closure() self._closure()
raise
else: else:
self.em.register(self) self.em.register(self)
if self.write_buf: if connected:
self.em.addWriter(self) self._connectionCompleted()
self._connectionCompleted() # A client connection usually has a pending packet to send
# from the beginning. It would be too smart to detect when
# it's not required to poll for writing.
self.em.addWriter(self)
def _delayedConnect(self): def _delayedConnect(self):
del self.getTimeout, self.onTimeout, self.write_buf[0] del self.getTimeout, self.onTimeout
self._connect() self._connect()
def writable(self): def writable(self):
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import socket import socket
import errno import errno
from time import time from time import time
from . import logging
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
...@@ -56,8 +57,19 @@ class SocketConnector(object): ...@@ -56,8 +57,19 @@ class SocketConnector(object):
s.setblocking(0) s.setblocking(0)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = []
return self return self
def queue(self, data):
was_empty = not self.queued
self.queued += data
return was_empty
def _error(self, op, exc):
logging.debug("%s failed for %s: %s (%s)",
op, self, errno.errorcode[exc.errno], exc.strerror)
raise ConnectorException
# Threaded tests monkey-patch the following 2 operations. # Threaded tests monkey-patch the following 2 operations.
_connect = lambda self, addr: self.socket.connect(addr) _connect = lambda self, addr: self.socket.connect(addr)
_bind = lambda self, addr: self.socket.bind(addr) _bind = lambda self, addr: self.socket.bind(addr)
...@@ -68,20 +80,23 @@ class SocketConnector(object): ...@@ -68,20 +80,23 @@ class SocketConnector(object):
try: try:
connect_limit = self.connect_limit[addr] connect_limit = self.connect_limit[addr]
if time() < connect_limit: if time() < connect_limit:
# Next call to queue() must return False
# in order not to enable polling for writing.
self.queued or self.queued.append('')
raise ConnectorDelayedConnection(connect_limit) raise ConnectorDelayedConnection(connect_limit)
if self.queued and not self.queued[0]:
del self.queued[0]
except KeyError: except KeyError:
pass pass
self.connect_limit[addr] = time() + self.CONNECT_LIMIT self.connect_limit[addr] = time() + self.CONNECT_LIMIT
self.is_server = self.is_closed = False self.is_server = self.is_closed = False
try: try:
self._connect(addr) self._connect(addr)
except socket.error, (err, errmsg): except socket.error, e:
if err == errno.EINPROGRESS: if e.errno == errno.EINPROGRESS:
raise ConnectorInProgressException return False
if err == errno.ECONNREFUSED: self._error('connect', e)
raise ConnectorConnectionRefusedException return True
raise ConnectorException, 'makeClientConnection to %s failed:' \
' %s:%s' % (addr, err, errmsg)
def makeListeningConnection(self): def makeListeningConnection(self):
assert self.is_closed is None assert self.is_closed is None
...@@ -90,10 +105,9 @@ class SocketConnector(object): ...@@ -90,10 +105,9 @@ class SocketConnector(object):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._bind(self.addr) self._bind(self.addr)
self.socket.listen(5) self.socket.listen(5)
except socket.error, (err, errmsg): except socket.error, e:
self.socket.close() self.socket.close()
raise ConnectorException, 'makeListeningConnection on %s failed:' \ self._error('listen', e)
' %s:%s' % (addr, err, errmsg)
def getError(self): def getError(self):
return self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) return self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
...@@ -116,33 +130,39 @@ class SocketConnector(object): ...@@ -116,33 +130,39 @@ class SocketConnector(object):
s, addr = self.socket.accept() s, addr = self.socket.accept()
s = self.__class__(addr, s) s = self.__class__(addr, s)
return s, s.addr return s, s.addr
except socket.error, (err, errmsg): except socket.error, e:
if err == errno.EAGAIN: self._error('accept', e)
raise ConnectorTryAgainException
raise ConnectorException, 'accept failed: %s:%s' % \
(err, errmsg)
def receive(self): def receive(self, read_buf):
try: try:
return self.socket.recv(4096) data = self.socket.recv(4096)
except socket.error, (err, errmsg): except socket.error, e:
if err == errno.EAGAIN: self._error('recv', e)
raise ConnectorTryAgainException if data:
if err in (errno.ECONNREFUSED, errno.EHOSTUNREACH): read_buf.append(data)
raise ConnectorConnectionRefusedException return
if err in (errno.ECONNRESET, errno.ETIMEDOUT): logging.debug('%r closed in recv', self)
raise ConnectorConnectionClosedException raise ConnectorException
raise ConnectorException, 'receive failed: %s:%s' % (err, errmsg)
def send(self):
def send(self, msg): msg = ''.join(self.queued)
try: if msg:
return self.socket.send(msg) try:
except socket.error, (err, errmsg): n = self.socket.send(msg)
if err == errno.EAGAIN: except socket.error, e:
raise ConnectorTryAgainException self._error('send', e)
if err in (errno.ECONNRESET, errno.ETIMEDOUT, errno.EPIPE): # Do nothing special if n == 0:
raise ConnectorConnectionClosedException # - it never happens for simple sockets;
raise ConnectorException, 'send failed: %s:%s' % (err, errmsg) # - for SSL sockets, this is always the case unless everything
# could be sent.
if n != len(msg):
self.queued[:] = msg[n:],
return False
del self.queued[:]
else:
assert not self.queued
return True
def close(self): def close(self):
self.is_closed = True self.is_closed = True
...@@ -195,17 +215,5 @@ registerConnectorHandler(SocketConnectorIPv6) ...@@ -195,17 +215,5 @@ registerConnectorHandler(SocketConnectorIPv6)
class ConnectorException(Exception): class ConnectorException(Exception):
pass pass
class ConnectorTryAgainException(ConnectorException):
pass
class ConnectorInProgressException(ConnectorException):
pass
class ConnectorConnectionClosedException(ConnectorException):
pass
class ConnectorConnectionRefusedException(ConnectorException):
pass
class ConnectorDelayedConnection(ConnectorException): class ConnectorDelayedConnection(ConnectorException):
pass pass
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
from collections import deque from collections import deque
from neo.lib import logging from neo.lib import logging
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection, ConnectionClosed
from neo.lib.connector import ConnectorConnectionClosedException
from neo.lib.protocol import NodeTypes, Packets, ZERO_OID from neo.lib.protocol import NodeTypes, Packets, ZERO_OID
from neo.lib.util import add64, dump from neo.lib.util import add64, dump
from .handlers.storage import StorageOperationHandler from .handlers.storage import StorageOperationHandler
...@@ -85,7 +84,7 @@ class Checker(object): ...@@ -85,7 +84,7 @@ class Checker(object):
if self.conn_dict: if self.conn_dict:
break break
msg = "no replica" msg = "no replica"
except ConnectorConnectionClosedException: except ConnectionClosed:
msg = "connection closed" msg = "connection closed"
finally: finally:
conn_set.update(self.conn_dict) conn_set.update(self.conn_dict)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import weakref import weakref
from functools import wraps from functools import wraps
from neo.lib.connector import ConnectorConnectionClosedException from neo.lib.connection import ConnectionClosed
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Errors, NodeStates, Packets, ProtocolError, \ from neo.lib.protocol import Errors, NodeStates, Packets, ProtocolError, \
ZERO_HASH ZERO_HASH
...@@ -154,7 +154,7 @@ class StorageOperationHandler(EventHandler): ...@@ -154,7 +154,7 @@ class StorageOperationHandler(EventHandler):
r = app.dm.checkTIDRange(*args) r = app.dm.checkTIDRange(*args)
try: try:
conn.answer(Packets.AnswerCheckTIDRange(*r), msg_id) conn.answer(Packets.AnswerCheckTIDRange(*r), msg_id)
except (weakref.ReferenceError, ConnectorConnectionClosedException): except (weakref.ReferenceError, ConnectionClosed):
pass pass
yield yield
app.newTask(check()) app.newTask(check())
...@@ -170,7 +170,7 @@ class StorageOperationHandler(EventHandler): ...@@ -170,7 +170,7 @@ class StorageOperationHandler(EventHandler):
r = app.dm.checkSerialRange(*args) r = app.dm.checkSerialRange(*args)
try: try:
conn.answer(Packets.AnswerCheckSerialRange(*r), msg_id) conn.answer(Packets.AnswerCheckSerialRange(*r), msg_id)
except (weakref.ReferenceError, ConnectorConnectionClosedException): except (weakref.ReferenceError, ConnectionClosed):
pass pass
yield yield
app.newTask(check()) app.newTask(check())
...@@ -211,7 +211,7 @@ class StorageOperationHandler(EventHandler): ...@@ -211,7 +211,7 @@ class StorageOperationHandler(EventHandler):
conn.answer(Packets.AnswerFetchTransactions( conn.answer(Packets.AnswerFetchTransactions(
pack_tid, next_tid, peer_tid_set), msg_id) pack_tid, next_tid, peer_tid_set), msg_id)
yield yield
except (weakref.ReferenceError, ConnectorConnectionClosedException): except (weakref.ReferenceError, ConnectionClosed):
pass pass
app.newTask(push()) app.newTask(push())
...@@ -253,6 +253,6 @@ class StorageOperationHandler(EventHandler): ...@@ -253,6 +253,6 @@ class StorageOperationHandler(EventHandler):
conn.answer(Packets.AnswerFetchObjects( conn.answer(Packets.AnswerFetchObjects(
pack_tid, next_tid, next_oid, object_dict), msg_id) pack_tid, next_tid, next_oid, object_dict), msg_id)
yield yield
except (weakref.ReferenceError, ConnectorConnectionClosedException): except (weakref.ReferenceError, ConnectionClosed):
pass pass
app.newTask(push()) app.newTask(push())
...@@ -18,14 +18,10 @@ import unittest ...@@ -18,14 +18,10 @@ import unittest
from time import time from time import time
from mock import Mock from mock import Mock
from neo.lib import connection, logging from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ListeningConnection, \ from neo.lib.connection import BaseConnection, ClientConnection, \
Connection, ClientConnection, ServerConnection, MTClientConnection, \ MTClientConnection, HandlerSwitcher, CRITICAL_TIMEOUT
HandlerSwitcher, CRITICAL_TIMEOUT
from neo.lib.connector import registerConnectorHandler
from neo.lib.connector import ConnectorException, ConnectorTryAgainException, \
ConnectorInProgressException, ConnectorConnectionRefusedException
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, PACKET_HEADER_FORMAT from neo.lib.protocol import Packets
from . import NeoUnitTestBase, Patch from . import NeoUnitTestBase, Patch
...@@ -64,18 +60,6 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -64,18 +60,6 @@ class ConnectionTests(NeoUnitTestBase):
self.handler = Mock({'__repr__': 'Fake Handler'}) self.handler = Mock({'__repr__': 'Fake Handler'})
self.address = ("127.0.0.7", 93413) self.address = ("127.0.0.7", 93413)
self.node = Mock({'getAddress': self.address}) self.node = Mock({'getAddress': self.address})
connection.connect_limit = 0
def _makeListeningConnection(self, addr):
with dummy_connector:
conn = ListeningConnection(self.app, self.handler, addr)
self.connector = conn.connector
return conn
def _makeServerConnection(self):
addr = self.address
self.connector = DummyConnector(addr)
return Connection(self.em, self.handler, self.connector, addr)
def _makeClientConnection(self): def _makeClientConnection(self):
with dummy_connector: with dummy_connector:
...@@ -83,598 +67,8 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -83,598 +67,8 @@ class ConnectionTests(NeoUnitTestBase):
self.connector = conn.connector self.connector = conn.connector
return conn return conn
_makeConnection = _makeClientConnection def testTimeout(self):
# NOTE: This method uses ping/pong packets only because MT connections
def _checkRegistered(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("register")), n)
def _checkUnregistered(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("unregister")), n)
def _checkReaderRemoved(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("removeReader")), n)
def _checkWriterAdded(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("addWriter")), n)
def _checkWriterRemoved(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("removeWriter")), n)
def _checkClose(self, n=1):
self.assertEqual(len(self.connector.mockGetNamedCalls("close")), n)
def _checkAccept(self, n=1):
calls = self.connector.mockGetNamedCalls('accept')
self.assertEqual(len(calls), n)
def _checkSend(self, n=1, data=None):
calls = self.connector.mockGetNamedCalls('send')
self.assertEqual(len(calls), n)
if n > 1 and data is not None:
data = calls[n-1].getParam(0)
self.assertEqual(data, "testdata")
def _checkConnectionAccepted(self, n=1):
calls = self.handler.mockGetNamedCalls('connectionAccepted')
self.assertEqual(len(calls), n)
def _checkConnectionFailed(self, n=1):
calls = self.handler.mockGetNamedCalls('connectionFailed')
self.assertEqual(len(calls), n)
def _checkConnectionClosed(self, n=1):
calls = self.handler.mockGetNamedCalls('connectionClosed')
self.assertEqual(len(calls), n)
def _checkConnectionStarted(self, n=1):
calls = self.handler.mockGetNamedCalls('connectionStarted')
self.assertEqual(len(calls), n)
def _checkConnectionCompleted(self, n=1):
calls = self.handler.mockGetNamedCalls('connectionCompleted')
self.assertEqual(len(calls), n)
def _checkMakeListeningConnection(self, n=1):
calls = self.connector.mockGetNamedCalls('makeListeningConnection')
self.assertEqual(len(calls), n)
def _checkMakeClientConnection(self, n=1):
calls = self.connector.mockGetNamedCalls("makeClientConnection")
self.assertEqual(len(calls), n)
def _checkPacketReceived(self, n=1):
calls = self.handler.mockGetNamedCalls('packetReceived')
self.assertEqual(len(calls), n)
def _checkReadBuf(self, bc, data):
content = bc.read_buf.read(len(bc.read_buf))
self.assertEqual(''.join(content), data)
def _appendToReadBuf(self, bc, data):
bc.read_buf.append(data)
def _appendPacketToReadBuf(self, bc, packet):
data = ''.join(packet.encode())
bc.read_buf.append(data)
def _checkWriteBuf(self, bc, data):
self.assertEqual(''.join(bc.write_buf), data)
def test_01_BaseConnection(self):
# init with address
bc = self._makeConnection()
self.assertEqual(bc.getAddress(), self.address)
self.assertIsNot(bc.connector, None)
self._checkRegistered(1)
def test_02_ListeningConnection1(self):
# test init part
addr = ("127.0.0.7", 93413)
with Patch(DummyConnector, accept=lambda orig, self: (self, ('', 0))):
bc = self._makeListeningConnection(addr=addr)
self.assertEqual(bc.getAddress(), addr)
self._checkRegistered()
self._checkMakeListeningConnection()
# test readable
bc.readable()
self._checkAccept()
self._checkConnectionAccepted()
def test_02_ListeningConnection2(self):
# test with exception raise when getting new connection
def accept(orig, self):
raise ConnectorTryAgainException
addr = ("127.0.0.7", 93413)
with Patch(DummyConnector, accept=accept):
bc = self._makeListeningConnection(addr=addr)
self.assertEqual(bc.getAddress(), addr)
self._checkRegistered()
self._checkMakeListeningConnection()
# test readable
bc.readable()
self._checkAccept(1)
self._checkConnectionAccepted(0)
def test_03_Connection(self):
bc = self._makeConnection()
self.assertEqual(bc.getAddress(), self.address)
self._checkReadBuf(bc, '')
self._checkWriteBuf(bc, '')
self.assertEqual(bc.cur_id, 0)
self.assertFalse(bc.aborted)
# test uuid
self.assertEqual(bc.uuid, None)
self.assertEqual(bc.getUUID(), None)
uuid = self.getNewUUID(None)
bc.setUUID(uuid)
self.assertEqual(bc.getUUID(), uuid)
# test next id
cur_id = bc.cur_id
next_id = bc._getNextId()
self.assertEqual(next_id, cur_id)
next_id = bc._getNextId()
self.assertTrue(next_id > cur_id)
# test overflow of next id
bc.cur_id = 0xffffffff
next_id = bc._getNextId()
self.assertEqual(next_id, 0xffffffff)
next_id = bc._getNextId()
self.assertEqual(next_id, 0)
def test_Connection_pending(self):
bc = self._makeConnection()
self.assertEqual(''.join(bc.write_buf), '')
self.assertFalse(bc.pending())
bc.write_buf += '1'
self.assertTrue(bc.pending())
def test_Connection_recv1(self):
# patch receive method to return data
with Patch(DummyConnector, receive=lambda orig, self: "testdata"):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
bc._recv()
self._checkReadBuf(bc, 'testdata')
def test_Connection_recv2(self):
# patch receive method to raise try again
def receive(orig, self):
raise ConnectorTryAgainException
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
bc._recv()
self._checkReadBuf(bc, '')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_recv3(self):
# patch receive method to raise ConnectorConnectionRefusedException
def receive(orig, self):
raise ConnectorConnectionRefusedException
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
# fake client connection instance with connecting attribute
bc.connecting = True
bc._recv()
self._checkReadBuf(bc, '')
self._checkConnectionFailed(1)
self._checkUnregistered(1)
def test_Connection_recv4(self):
# patch receive method to raise any other connector error
def receive(orig, self):
raise ConnectorException
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
self._checkReadBuf(bc, '')
self.assertRaises(ConnectorException, bc._recv)
self._checkReadBuf(bc, '')
self._checkConnectionClosed(1)
self._checkUnregistered(1)
def test_Connection_send1(self):
# no data, nothing done
# patch receive method to return data
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc._send()
self._checkSend(0)
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_send2(self):
# send all data
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
bc._send()
self._checkSend(1, "testdata")
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_send3(self):
# send part of the data
with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
bc._send()
self._checkSend(1, "testdata")
self._checkWriteBuf(bc, 'data')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_send4(self):
# send multiple packet
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
bc._send()
self._checkSend(1, "testdatasecondthird")
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_send5(self):
# send part of multiple packet
with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
bc._send()
self._checkSend(1, "testdatasecondthird")
self._checkWriteBuf(bc, 'econdthird')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_send6(self):
# raise try again
def send(orig, self, data):
raise ConnectorTryAgainException
with Patch(DummyConnector, send=send):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
bc._send()
self._checkSend(1, "testdatasecondthird")
self._checkWriteBuf(bc, 'testdatasecondthird')
self._checkConnectionClosed(0)
self._checkUnregistered(0)
def test_Connection_send7(self):
# raise other error
def send(orig, self, data):
raise ConnectorException
with Patch(DummyConnector, send=send):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"]
self.assertRaises(ConnectorException, bc._send)
self._checkSend(1, "testdatasecondthird")
# connection closed -> buffers flushed
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(1)
self._checkUnregistered(1)
def test_07_Connection_addPacket(self):
# new packet
p = Packets.Ping()
p._id = 0
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc._addPacket(p)
self._checkWriteBuf(bc, PACKET_HEADER_FORMAT.pack(0, p._code, 10))
self._checkWriterAdded(1)
def test_Connection_analyse1(self):
# nothing to read, nothing is done
bc = self._makeConnection()
bc._queue = Mock()
self._checkReadBuf(bc, '')
bc._analyse()
self._checkPacketReceived(0)
self._checkReadBuf(bc, '')
p = Packets.AnswerPrimary(self.getNewUUID(None))
p.setId(1)
p_data = ''.join(p.encode())
data_edge = len(p_data) - 1
p_data_1, p_data_2 = p_data[:data_edge], p_data[data_edge:]
# append an incomplete packet, nothing is done
bc.read_buf.append(p_data_1)
bc._analyse()
self._checkPacketReceived(0)
self.assertNotEqual(len(bc.read_buf), 0)
self.assertNotEqual(len(bc.read_buf), len(p_data))
# append the rest of the packet
bc.read_buf.append(p_data_2)
bc._analyse()
# check packet decoded
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(type(data), type(p))
self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode())
self._checkReadBuf(bc, '')
def test_Connection_analyse2(self):
# give multiple packet
bc = self._makeConnection()
bc._queue = Mock()
p1 = Packets.AnswerPrimary(self.getNewUUID(None))
p1.setId(1)
self._appendPacketToReadBuf(bc, p1)
p2 = Packets.AnswerPrimary( self.getNewUUID(None))
p2.setId(2)
self._appendPacketToReadBuf(bc, p2)
self.assertEqual(len(bc.read_buf), len(p1) + len(p2))
bc._analyse()
# check two packets decoded
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 2)
# packet 1
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(type(data), type(p1))
self.assertEqual(data.getId(), p1.getId())
self.assertEqual(data.decode(), p1.decode())
# packet 2
call = bc._queue.mockGetNamedCalls("append")[1]
data = call.getParam(0)
self.assertEqual(type(data), type(p2))
self.assertEqual(data.getId(), p2.getId())
self.assertEqual(data.decode(), p2.decode())
self._checkReadBuf(bc, '')
def test_Connection_analyse3(self):
# give a bad packet, won't be decoded
bc = self._makeConnection()
p = Packets.Ping()
p.setId(1)
self._appendToReadBuf(bc, '%s%sdatadatadatadata' % p.encode())
bc._analyse()
self._checkPacketReceived(1) # ping packet
self._checkClose(1) # malformed packet
def test_Connection_analyse4(self):
# give an expected packet
bc = self._makeConnection()
bc._queue = Mock()
p = Packets.AnswerPrimary(self.getNewUUID(None))
p.setId(1)
self._appendPacketToReadBuf(bc, p)
bc._analyse()
# check packet decoded
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(type(data), type(p))
self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode())
self._checkReadBuf(bc, '')
def test_Connection_writable1(self):
# with pending operation after send
with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
self.assertTrue(bc.pending())
self.assertFalse(bc.aborted)
bc.writable()
# test send was called
self._checkSend(1, "testdata")
self._checkWriteBuf(bc, "data")
self._checkConnectionClosed(0)
self._checkClose(0)
self._checkUnregistered(0)
# pending, so nothing called
self.assertTrue(bc.pending())
self._checkWriterRemoved(0)
self._checkReaderRemoved(0)
self._checkClose(0)
def test_Connection_writable2(self):
# without pending operation after send
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
self.assertTrue(bc.pending())
self.assertFalse(bc.aborted)
bc.writable()
# test send was called
self._checkSend(1, "testdata")
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0)
self._checkClose(0)
self._checkUnregistered(0)
# nothing else pending, so writer has been removed
self.assertFalse(bc.pending())
self._checkWriterRemoved(1)
self._checkReaderRemoved(0)
self._checkClose(0)
def test_Connection_writable3(self):
# without pending operation after send and aborted set to true
with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
self.assertTrue(bc.pending())
bc.abort()
self.assertTrue(bc.aborted)
bc.writable()
# test send was called
self._checkSend(1, "testdata")
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(1)
self._checkClose(1)
self._checkUnregistered(1)
# nothing else pending, so writer has been removed
self.assertFalse(bc.pending())
self._checkClose(1)
def test_Connection_readable(self):
# With aborted set to false
# patch receive method to return data
def receive(orig, self):
p = Packets.AnswerPrimary(self.getNewUUID(None))
p.setId(1)
return ''.join(p.encode())
with Patch(DummyConnector, receive=receive):
bc = self._makeConnection()
bc._queue = Mock({'__len__': 0})
self._checkReadBuf(bc, '')
self.assertFalse(bc.aborted)
bc.readable()
# check packet decoded
self._checkReadBuf(bc, '')
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(type(data), Packets.AnswerPrimary)
self.assertEqual(data.getId(), 1)
self._checkReadBuf(bc, '')
# check not aborted
self.assertFalse(bc.aborted)
self._checkUnregistered(0)
self._checkWriterRemoved(0)
self._checkReaderRemoved(0)
self._checkClose(0)
def test_ClientConnection_init1(self):
# create a good client connection
bc = self._makeClientConnection()
# check connector created and connection initialize
self.assertFalse(bc.connecting)
self.assertFalse(bc.isServer())
self._checkMakeClientConnection(1)
# check call to handler
self.assertFalse(bc.getHandler() is None)
self._checkConnectionStarted(1)
self._checkConnectionCompleted(1)
self._checkConnectionFailed(0)
# check call to event manager
self.assertIsNot(bc.em, None)
self._checkWriterAdded(0)
def test_ClientConnection_init2(self):
# raise connection in progress
def makeClientConnection(orig, self):
raise ConnectorInProgressException
with Patch(DummyConnector, makeClientConnection=makeClientConnection):
bc = self._makeClientConnection()
# check connector created and connection initialize
self.assertTrue(bc.connecting)
self.assertFalse(bc.isServer())
self._checkMakeClientConnection(1)
# check call to handler
self.assertFalse(bc.getHandler() is None)
self._checkConnectionStarted(1)
self._checkConnectionCompleted(0)
self._checkConnectionFailed(0)
# check call to event manager
self.assertIsNot(bc.em, None)
self._checkWriterAdded(1)
def test_ClientConnection_init3(self):
# raise another error, connection must fail
def makeClientConnection(orig, self):
raise ConnectorException
with Patch(DummyConnector, makeClientConnection=makeClientConnection):
self.assertRaises(ConnectorException, self._makeClientConnection)
# since the exception was raised, the connection is not created
# check call to handler
self._checkConnectionStarted(1)
self._checkConnectionCompleted(0)
self._checkConnectionFailed(1)
# check call to event manager
self._checkWriterAdded(0)
def test_ClientConnection_writable1(self):
# with a non connecting connection, will call parent's method
with Patch(DummyConnector, send=lambda orig, self, data: len(data)), \
Patch(DummyConnector,
makeClientConnection=lambda orig, self: "OK") as p:
bc = self._makeClientConnection()
p.revert()
# check connector created and connection initialize
self.assertFalse(bc.connecting)
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
self.assertTrue(bc.pending())
self.assertFalse(bc.aborted)
# call
self._checkConnectionCompleted(1)
bc.writable()
self.assertFalse(bc.pending())
self.assertFalse(bc.aborted)
self.assertFalse(bc.connecting)
self._checkSend(1, "testdata")
self._checkConnectionClosed(0)
self._checkConnectionCompleted(1)
self._checkConnectionFailed(0)
self._checkUnregistered(0)
self._checkWriterRemoved(1)
self._checkReaderRemoved(0)
self._checkClose(0)
def test_ClientConnection_writable2(self):
# with a connecting connection, must not call parent's method
# with errors, close connection
with Patch(DummyConnector, getError=lambda orig, self: True):
bc = self._makeClientConnection()
# check connector created and connection initialize
self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"]
self.assertTrue(bc.pending())
self.assertFalse(bc.aborted)
# call
self._checkConnectionCompleted(1)
bc.writable()
self.assertFalse(bc.connecting)
self.assertFalse(bc.pending())
self.assertFalse(bc.aborted)
self._checkWriteBuf(bc, '')
self._checkConnectionClosed(1)
self._checkConnectionCompleted(1)
self._checkConnectionFailed(0)
self._checkUnregistered(1)
def test_14_ServerConnection(self):
bc = self._makeServerConnection()
self.assertEqual(bc.getAddress(), ("127.0.0.7", 93413))
self._checkReadBuf(bc, '')
self._checkWriteBuf(bc, '')
self.assertEqual(bc.cur_id, 0)
self.assertFalse(bc.aborted)
# test uuid
self.assertEqual(bc.uuid, None)
self.assertEqual(bc.getUUID(), None)
uuid = self.getNewUUID(None)
bc.setUUID(uuid)
self.assertEqual(bc.getUUID(), uuid)
# test next id
cur_id = bc.cur_id
next_id = bc._getNextId()
self.assertEqual(next_id, cur_id)
next_id = bc._getNextId()
self.assertTrue(next_id > cur_id)
# test overflow of next id
bc.cur_id = 0xffffffff
next_id = bc._getNextId()
self.assertEqual(next_id, 0xffffffff)
next_id = bc._getNextId()
self.assertEqual(next_id, 0)
def test_15_Timeout(self):
# NOTE: This method uses ping/pong packets only because MT connection
# don't accept any other packet without specifying a queue. # don't accept any other packet without specifying a queue.
self.handler = EventHandler(self.app) self.handler = EventHandler(self.app)
conn = self._makeClientConnection() conn = self._makeClientConnection()
...@@ -692,7 +86,6 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -692,7 +86,6 @@ class ConnectionTests(NeoUnitTestBase):
((.1, None, .5, 1), (1.5, 0)), ((.1, None, .5, 1), (1.5, 0)),
) )
from neo.lib import connection
def set_time(t): def set_time(t):
connection.time = lambda: int(CRITICAL_TIMEOUT * (1000 + t)) connection.time = lambda: int(CRITICAL_TIMEOUT * (1000 + t))
closed = [] closed = []
...@@ -700,7 +93,8 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -700,7 +93,8 @@ class ConnectionTests(NeoUnitTestBase):
def answer(packet_id): def answer(packet_id):
p = Packets.Pong() p = Packets.Pong()
p.setId(packet_id) p.setId(packet_id)
conn.connector.receive = [''.join(p.encode())].pop conn.connector.receive = lambda read_buf: \
read_buf.append(''.join(p.encode()))
conn.readable() conn.readable()
checkTimeout() checkTimeout()
conn.process() conn.process()
......
...@@ -31,8 +31,7 @@ import neo.client.app, neo.neoctl.app ...@@ -31,8 +31,7 @@ import neo.client.app, neo.neoctl.app
from neo.client import Storage from neo.client import Storage
from neo.lib import logging from neo.lib import logging
from neo.lib.connection import BaseConnection, Connection from neo.lib.connection import BaseConnection, Connection
from neo.lib.connector import SocketConnector, \ from neo.lib.connector import SocketConnector, ConnectorException
ConnectorConnectionRefusedException
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
...@@ -322,7 +321,7 @@ class ServerNode(Node): ...@@ -322,7 +321,7 @@ class ServerNode(Node):
try: try:
return self.listening_conn.getAddress() return self.listening_conn.getAddress()
except AttributeError: except AttributeError:
raise ConnectorConnectionRefusedException raise ConnectorException
class AdminApplication(ServerNode, neo.admin.app.Application): class AdminApplication(ServerNode, neo.admin.app.Application):
pass pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment