Commit eef52c27 authored by Julien Muchembled's avatar Julien Muchembled

Tickless poll loop, for lowest latency and cpu usage

With this patch, the epolling object is not awoken every second to check
if a timeout has expired. The API of Connection is changed to get the smallest
timeout.
parent fd0b9c98
...@@ -42,8 +42,6 @@ class _ThreadedPoll(Thread): ...@@ -42,8 +42,6 @@ class _ThreadedPoll(Thread):
try: try:
while 1: while 1:
try: try:
# XXX: Delay can't be infinite here, because we need
# to check connection timeouts.
self.em.poll(1) self.em.poll(1)
except Exception: except Exception:
log(ERROR, 'poll raised, retrying', exc_info=1) log(ERROR, 'poll raised, retrying', exc_info=1)
......
...@@ -225,7 +225,7 @@ class BaseConnection(object): ...@@ -225,7 +225,7 @@ class BaseConnection(object):
def cancelRequests(self, *args, **kw): def cancelRequests(self, *args, **kw):
return self._handlers.cancelRequests(self, *args, **kw) return self._handlers.cancelRequests(self, *args, **kw)
def checkTimeout(self, t): def getTimeout(self):
pass pass
def lockWrapper(self, func): def lockWrapper(self, func):
...@@ -351,7 +351,8 @@ class Connection(BaseConnection): ...@@ -351,7 +351,8 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_base_timeout = None _next_timeout = None
_timeout = 0
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
...@@ -428,22 +429,23 @@ class Connection(BaseConnection): ...@@ -428,22 +429,23 @@ class Connection(BaseConnection):
def updateTimeout(self, t=None): def updateTimeout(self, t=None):
if not self._queue: if not self._queue:
if t: if not t:
self._base_timeout = t t = self._next_timeout - self._timeout
self._timeout = self._handlers.getNextTimeout() or self.KEEP_ALIVE self._timeout = self._handlers.getNextTimeout() or self.KEEP_ALIVE
self._next_timeout = t + self._timeout
def getTimeout(self):
if not self._queue:
return self._next_timeout
def checkTimeout(self, t): def onTimeout(self):
# first make sure we don't timeout on answers we already received
if self._base_timeout and not self._queue:
if self._timeout <= t - self._base_timeout:
handlers = self._handlers handlers = self._handlers
if handlers.isPending(): if handlers.isPending():
msg_id = handlers.timeout(self) msg_id = handlers.timeout(self)
if msg_id is None: if msg_id is None:
self._base_timeout = t self._next_timeout = time() + self._timeout
else: else:
logging.info('timeout for #0x%08x with %r', logging.info('timeout for #0x%08x with %r', msg_id, self)
msg_id, self)
self.close() self.close()
else: else:
self.idle() self.idle()
...@@ -544,8 +546,8 @@ class Connection(BaseConnection): ...@@ -544,8 +546,8 @@ class Connection(BaseConnection):
# try to reenable polling for writing. # try to reenable polling for writing.
self.write_buf[:] = '', self.write_buf[:] = '',
self.em.unregister(self, check_timeout=True) self.em.unregister(self, check_timeout=True)
self.checkTimeout = self.lockWrapper(lambda t: self.getTimeout = lambda: connect_limit
t < connect_limit or self._delayed_closure()) self.onTimeout = self.lockWrapper(self._delayed_closure)
self.readable = self.writable = lambda: None self.readable = self.writable = lambda: None
else: else:
connect_limit = t + 1 connect_limit = t + 1
...@@ -575,7 +577,8 @@ class Connection(BaseConnection): ...@@ -575,7 +577,8 @@ class Connection(BaseConnection):
logging.debug('Connection %r closed in recv', self.connector) logging.debug('Connection %r closed in recv', self.connector)
self._closure() self._closure()
return return
self._base_timeout = time() # last known remote activity # last known remote activity
self._next_timeout = time() + self._timeout
self.read_buf.append(data) self.read_buf.append(data)
def _send(self): def _send(self):
...@@ -639,7 +642,11 @@ class Connection(BaseConnection): ...@@ -639,7 +642,11 @@ class Connection(BaseConnection):
handlers = self._handlers handlers = self._handlers
t = None if handlers.isPending() else time() t = None if handlers.isPending() else time()
handlers.emit(packet, timeout, on_timeout, kw) handlers.emit(packet, timeout, on_timeout, kw)
if not self._queue:
next_timeout = self._next_timeout
self.updateTimeout(t) self.updateTimeout(t)
if self._next_timeout < next_timeout:
self.em.wakeup()
return msg_id return msg_id
@not_closed @not_closed
...@@ -717,7 +724,7 @@ class MTConnectionType(type): ...@@ -717,7 +724,7 @@ class MTConnectionType(type):
if __debug__: if __debug__:
for name in 'analyse', 'answer': for name in 'analyse', 'answer':
setattr(cls, name, cls.lockCheckWrapper(name)) setattr(cls, name, cls.lockCheckWrapper(name))
for name in ('close', 'checkTimeout', 'notify', for name in ('close', 'notify', 'onTimeout',
'process', 'readable', 'writable'): 'process', 'readable', 'writable'):
setattr(cls, name, cls.__class__.lockWrapper(cls, name)) setattr(cls, name, cls.__class__.lockWrapper(cls, name))
...@@ -775,5 +782,9 @@ class MTClientConnection(ClientConnection): ...@@ -775,5 +782,9 @@ class MTClientConnection(ClientConnection):
handlers = self._handlers handlers = self._handlers
t = None if handlers.isPending() else time() t = None if handlers.isPending() else time()
handlers.emit(packet, timeout, on_timeout, kw) handlers.emit(packet, timeout, on_timeout, kw)
if not self._queue:
next_timeout = self._next_timeout
self.updateTimeout(t) self.updateTimeout(t)
if self._next_timeout < next_timeout:
self.em.wakeup()
return msg_id return msg_id
...@@ -123,6 +123,17 @@ class EpollEventManager(object): ...@@ -123,6 +123,17 @@ class EpollEventManager(object):
self._poll(timeout=0) self._poll(timeout=0)
def _poll(self, timeout=1): def _poll(self, timeout=1):
if timeout:
timeout = None
for conn in self.connection_dict.itervalues():
t = conn.getTimeout()
if t and (timeout is None or t < timeout):
timeout = t
timeout_conn = conn
# Make sure epoll_wait does not return too early, because it has a
# granularity of 1ms and Python 2.7 rounds the timeout towards zero.
# See also https://bugs.python.org/issue20452 (fixed in Python 3).
timeout = .001 + max(0, timeout - time()) if timeout else -1
try: try:
event_list = self.epoll.poll(timeout) event_list = self.epoll.poll(timeout)
except IOError, exc: except IOError, exc:
...@@ -131,7 +142,11 @@ class EpollEventManager(object): ...@@ -131,7 +142,11 @@ class EpollEventManager(object):
exc.errno) exc.errno)
elif exc.errno != EINTR: elif exc.errno != EINTR:
raise raise
event_list = () return
if not event_list:
if timeout > 0:
timeout_conn.onTimeout()
return
wlist = [] wlist = []
elist = [] elist = []
for fd, event in event_list: for fd, event in event_list:
...@@ -168,10 +183,6 @@ class EpollEventManager(object): ...@@ -168,10 +183,6 @@ class EpollEventManager(object):
if conn.readable(): if conn.readable():
self._addPendingConnection(conn) self._addPendingConnection(conn)
t = time()
for conn in self.connection_dict.values():
conn.checkTimeout(t)
def wakeup(self, exit=False): def wakeup(self, exit=False):
with self._trigger_lock: with self._trigger_lock:
self._trigger_exit |= exit self._trigger_exit |= exit
......
...@@ -789,8 +789,12 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -789,8 +789,12 @@ class ConnectionTests(NeoUnitTestBase):
p.setId(packet_id) p.setId(packet_id)
conn.connector.receive = [''.join(p.encode())].pop conn.connector.receive = [''.join(p.encode())].pop
conn.readable() conn.readable()
conn.checkTimeout(connection.time()) checkTimeout()
conn.process() conn.process()
def checkTimeout():
timeout = conn.getTimeout()
if timeout and timeout <= connection.time():
conn.onTimeout()
try: try:
for use_case, expected in use_case_list: for use_case, expected in use_case_list:
i = iter(use_case) i = iter(use_case)
...@@ -801,7 +805,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -801,7 +805,7 @@ class ConnectionTests(NeoUnitTestBase):
conn.ask(Packets.Ping()) conn.ask(Packets.Ping())
for t in i: for t in i:
set_time(t) set_time(t)
conn.checkTimeout(connection.time()) checkTimeout()
packet_id = i.next() packet_id = i.next()
if packet_id is None: if packet_id is None:
conn.ask(Packets.Ping()) conn.ask(Packets.Ping())
...@@ -810,11 +814,11 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -810,11 +814,11 @@ class ConnectionTests(NeoUnitTestBase):
i = iter(expected) i = iter(expected)
for t in i: for t in i:
set_time(t - .1) set_time(t - .1)
conn.checkTimeout(connection.time()) checkTimeout()
set_time(t) set_time(t)
# this test method relies on the fact that only # this test method relies on the fact that only
# conn.close is called in case of a timeout # conn.close is called in case of a timeout
conn.checkTimeout(connection.time()) checkTimeout()
self.assertEqual(closed.pop(), connection.time()) self.assertEqual(closed.pop(), connection.time())
answer(i.next()) answer(i.next())
self.assertFalse(conn.isPending()) self.assertFalse(conn.isPending())
......
...@@ -96,12 +96,12 @@ class EventTests(NeoUnitTestBase): ...@@ -96,12 +96,12 @@ class EventTests(NeoUnitTestBase):
(r_connector.getDescriptor(), EPOLLIN), (r_connector.getDescriptor(), EPOLLIN),
(w_connector.getDescriptor(), EPOLLOUT), (w_connector.getDescriptor(), EPOLLOUT),
)}) )})
em.poll(timeout=10) em.poll(timeout=1)
# check it called poll on epoll # check it called poll on epoll
self.assertEqual(len(em.epoll.mockGetNamedCalls("poll")), 1) self.assertEqual(len(em.epoll.mockGetNamedCalls("poll")), 1)
call = em.epoll.mockGetNamedCalls("poll")[0] call = em.epoll.mockGetNamedCalls("poll")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data, 10) self.assertEqual(data, -1)
# need to rebuild completely this test and the the packet queue # need to rebuild completely this test and the the packet queue
# check readable conn # check readable conn
#self.assertEqual(len(r_conn.mockGetNamedCalls("readable")), 1) #self.assertEqual(len(r_conn.mockGetNamedCalls("readable")), 1)
......
...@@ -137,7 +137,7 @@ class SerializedEventManager(EventManager): ...@@ -137,7 +137,7 @@ class SerializedEventManager(EventManager):
def _poll(self, timeout=1): def _poll(self, timeout=1):
if self._pending_processing: if self._pending_processing:
assert timeout <= 0 assert timeout == 0, timeout
elif 0 == self._timeout == timeout == Serialized.pending == len( elif 0 == self._timeout == timeout == Serialized.pending == len(
self.writer_set): self.writer_set):
return return
...@@ -365,7 +365,7 @@ class NeoCTL(neo.neoctl.app.NeoCTL): ...@@ -365,7 +365,7 @@ class NeoCTL(neo.neoctl.app.NeoCTL):
@SerializedEventManager.decorate @SerializedEventManager.decorate
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
super(NeoCTL, self).__init__(*args, **kw) super(NeoCTL, self).__init__(*args, **kw)
self.em._timeout = -1 self.em._timeout = 1
class LoggerThreadName(str): class LoggerThreadName(str):
...@@ -466,7 +466,7 @@ class ConnectionFilter(object): ...@@ -466,7 +466,7 @@ class ConnectionFilter(object):
class NEOCluster(object): class NEOCluster(object):
BaseConnection_checkTimeout = staticmethod(BaseConnection.checkTimeout) BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout)
SocketConnector_makeClientConnection = staticmethod( SocketConnector_makeClientConnection = staticmethod(
SocketConnector.makeClientConnection) SocketConnector.makeClientConnection)
SocketConnector_makeListeningConnection = staticmethod( SocketConnector_makeListeningConnection = staticmethod(
...@@ -517,7 +517,7 @@ class NEOCluster(object): ...@@ -517,7 +517,7 @@ class NEOCluster(object):
# TODO: 'sleep' should 'tic' in a smart way, so that storages can be # TODO: 'sleep' should 'tic' in a smart way, so that storages can be
# safely started even if the cluster isn't. # safely started even if the cluster isn't.
bootstrap.sleep = lambda seconds: None bootstrap.sleep = lambda seconds: None
BaseConnection.checkTimeout = lambda self, t: None BaseConnection.getTimeout = lambda self: None
SocketConnector.makeClientConnection = makeClientConnection SocketConnector.makeClientConnection = makeClientConnection
SocketConnector.makeListeningConnection = lambda self, addr: \ SocketConnector.makeListeningConnection = lambda self, addr: \
cls.SocketConnector_makeListeningConnection(self, BIND) cls.SocketConnector_makeListeningConnection(self, BIND)
...@@ -533,7 +533,7 @@ class NEOCluster(object): ...@@ -533,7 +533,7 @@ class NEOCluster(object):
if cls._patch_count: if cls._patch_count:
return return
bootstrap.sleep = time.sleep bootstrap.sleep = time.sleep
BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout BaseConnection.getTimeout = cls.BaseConnection_getTimeout
SocketConnector.makeClientConnection = \ SocketConnector.makeClientConnection = \
cls.SocketConnector_makeClientConnection cls.SocketConnector_makeClientConnection
SocketConnector.makeListeningConnection = \ SocketConnector.makeListeningConnection = \
......
...@@ -22,6 +22,7 @@ from functools import wraps ...@@ -22,6 +22,7 @@ from functools import wraps
from neo.lib import logging from neo.lib import logging
from neo.storage.checker import CHECK_COUNT from neo.storage.checker import CHECK_COUNT
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64 from neo.lib.util import p64
...@@ -249,17 +250,22 @@ class ReplicationTests(NEOThreadedTest): ...@@ -249,17 +250,22 @@ class ReplicationTests(NEOThreadedTest):
""" """
conn, = backup.master.getConnectionList(backup.upstream.master) conn, = backup.master.getConnectionList(backup.upstream.master)
# trigger ping # trigger ping
conn.updateTimeout(1)
self.assertFalse(conn.isPending()) self.assertFalse(conn.isPending())
conn.checkTimeout(time.time()) conn.onTimeout()
self.assertTrue(conn.isPending()) self.assertTrue(conn.isPending())
# force ping to have expired # force ping to have expired
conn.updateTimeout(1)
# connection will be closed before upstream master has time # connection will be closed before upstream master has time
# to answer # to answer
def _poll(orig, self, timeout):
if backup.master.em is self:
p.revert()
conn.onTimeout()
else:
orig(self, timeout)
with Patch(EventManager, _poll=_poll) as p:
backup.tic(force=1) backup.tic(force=1)
new_conn, = backup.master.getConnectionList(backup.upstream.master) new_conn, = backup.master.getConnectionList(backup.upstream.master)
self.assertFalse(new_conn is conn) self.assertIsNot(new_conn, conn)
@backup_test() @backup_test()
def testBackupUpstreamStorageDead(self, backup): def testBackupUpstreamStorageDead(self, backup):
...@@ -277,11 +283,12 @@ class ReplicationTests(NEOThreadedTest): ...@@ -277,11 +283,12 @@ class ReplicationTests(NEOThreadedTest):
upstream.storage.listening_conn.close() upstream.storage.listening_conn.close()
Serialized.tic(); self.assertEqual(count[0], 0) Serialized.tic(); self.assertEqual(count[0], 0)
Serialized.tic(); count[0] or Serialized.tic() Serialized.tic(); count[0] or Serialized.tic()
t = time.time()
# XXX: review API for checking timeouts
backup.storage.em._timeout = 1
Serialized.tic(); self.assertEqual(count[0], 2) Serialized.tic(); self.assertEqual(count[0], 2)
Serialized.tic(); self.assertEqual(count[0], 2)
time.sleep(1.1)
Serialized.tic(); self.assertEqual(count[0], 3)
Serialized.tic(); self.assertEqual(count[0], 3) Serialized.tic(); self.assertEqual(count[0], 3)
self.assertTrue(t + 1 <= time.time())
@backup_test() @backup_test()
def testBackupDelayedUnlockTransaction(self, backup): def testBackupDelayedUnlockTransaction(self, backup):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment