Commit a4731a0c authored by Julien Muchembled's avatar Julien Muchembled

Fix invalid processing of unregistered connections

This could happen if a file descriptor was reallocated by the kernel.
parent ed50edca
...@@ -28,6 +28,10 @@ class EpollEventManager(object): ...@@ -28,6 +28,10 @@ class EpollEventManager(object):
def __init__(self): def __init__(self):
self.connection_dict = {} self.connection_dict = {}
# Initialize a dummy 'unregistered' for the very rare case a registered
# connection is closed before the first call to poll. We don't care
# leaking a few integers for connections closed between 2 polls.
self.unregistered = []
self.reader_set = set() self.reader_set = set()
self.writer_set = set() self.writer_set = set()
self.epoll = epoll() self.epoll = epoll()
...@@ -95,6 +99,7 @@ class EpollEventManager(object): ...@@ -95,6 +99,7 @@ class EpollEventManager(object):
self.writer_set.discard(fd) self.writer_set.discard(fd)
if not check_timeout: if not check_timeout:
del self.connection_dict[fd] del self.connection_dict[fd]
self.unregistered.append(fd)
def isIdle(self): def isIdle(self):
return not (self._pending_processing or self.writer_set) return not (self._pending_processing or self.writer_set)
...@@ -143,45 +148,40 @@ class EpollEventManager(object): ...@@ -143,45 +148,40 @@ class EpollEventManager(object):
elif exc.errno != EINTR: elif exc.errno != EINTR:
raise raise
return return
if not event_list: if event_list:
if blocking > 0: self.unregistered = unregistered = []
timeout_conn.onTimeout() wlist = []
return elist = []
wlist = [] for fd, event in event_list:
elist = [] if event & EPOLLIN:
for fd, event in event_list: conn = self.connection_dict[fd]
if event & EPOLLIN: if conn.readable():
conn = self.connection_dict[fd] self._addPendingConnection(conn)
if conn.readable(): if event & EPOLLOUT:
self._addPendingConnection(conn) wlist.append(fd)
if event & EPOLLOUT: if event & (EPOLLERR | EPOLLHUP):
wlist.append(fd) elist.append(fd)
if event & (EPOLLERR | EPOLLHUP): for fd in wlist:
elist.append(fd) if fd not in unregistered:
self.connection_dict[fd].writable()
for fd in wlist: for fd in elist:
# This can fail, if a connection is closed in readable(). if fd in unregistered:
try: continue
conn = self.connection_dict[fd] try:
except KeyError: conn = self.connection_dict[fd]
continue except KeyError:
conn.writable() assert fd == self._trigger_fd, fd
for fd in elist:
# This can fail, if a connection is closed in previous calls to
# readable() or writable().
try:
conn = self.connection_dict[fd]
except KeyError:
if fd == self._trigger_fd:
with self._trigger_lock: with self._trigger_lock:
self.epoll.unregister(fd) self.epoll.unregister(fd)
if self._trigger_exit: if self._trigger_exit:
del self._trigger_exit del self._trigger_exit
thread.exit() thread.exit()
continue continue
if conn.readable(): if conn.readable():
self._addPendingConnection(conn) self._addPendingConnection(conn)
elif blocking > 0:
logging.debug('timeout triggered for %r', timeout_conn)
timeout_conn.onTimeout()
def wakeup(self, exit=False): def wakeup(self, exit=False):
with self._trigger_lock: with self._trigger_lock:
......
...@@ -287,6 +287,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -287,6 +287,7 @@ class ReplicationTests(NEOThreadedTest):
# XXX: review API for checking timeouts # XXX: review API for checking timeouts
backup.storage.em._blocking = 1 backup.storage.em._blocking = 1
Serialized.tic(); self.assertEqual(count[0], 2) Serialized.tic(); self.assertEqual(count[0], 2)
Serialized.tic(); self.assertEqual(count[0], 2)
Serialized.tic(); self.assertEqual(count[0], 3) Serialized.tic(); self.assertEqual(count[0], 3)
self.assertTrue(t + 1 <= time.time()) self.assertTrue(t + 1 <= time.time())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment