Commit 06a64d80 authored by Julien Muchembled's avatar Julien Muchembled

client: fix spurious connection timeouts

This fixes a regression caused by
commit eef52c27
parent f180b00e
......@@ -422,6 +422,12 @@ class Connection(BaseConnection):
def onTimeout(self):
handlers = self._handlers
if handlers.isPending():
# It is possible that another thread used ask() while getting a
# timeout from epoll, so we must check again the value of
# _next_timeout (we know that _queue is still empty).
# Although this test is only useful for MTClientConnection,
# it's not worth complicating the code more.
if self._next_timeout <= time():
msg_id = handlers.timeout(self)
if msg_id is None:
self._next_timeout = time() + self._timeout
......
......@@ -30,8 +30,10 @@ import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
from neo.client import Storage
from neo.lib import logging
from neo.lib.connection import BaseConnection, Connection
from neo.lib.connection import BaseConnection, \
ClientConnection, Connection, ListeningConnection
from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
from neo.lib.util import cached_property, parseMasterList, p64
......@@ -829,6 +831,21 @@ class NEOThreadedTest(NeoTestBase):
tic = Serialized.tic
def getLoopbackConnection(self):
app = MasterApplication(getSSL=NEOCluster.SSL,
getReplicas=0, getPartitions=1)
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
conn = ClientConnection.__new__(ClientConnection)
def reset():
conn.__dict__.clear()
conn.__init__(app, handler, node)
conn.reset = reset
reset()
return conn
def getUnpickler(self, conn):
reader = conn._reader
def unpickler(data, compression=False):
......
......@@ -1176,6 +1176,19 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testConnectionTimeout(self):
conn = self.getLoopbackConnection()
conn.KEEP_ALIVE
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
def onTimeout(orig):
conn.idle()
orig()
with Patch(conn, onTimeout=onTimeout):
conn.em.poll(1)
self.assertFalse(conn.isClosed())
if __name__ == "__main__":
unittest.main()
......@@ -253,6 +253,7 @@ class ReplicationTests(NEOThreadedTest):
def _poll(orig, self, blocking):
if backup.master.em is self:
p.revert()
conn._next_timeout = 0
conn.onTimeout()
else:
orig(self, blocking)
......
......@@ -15,11 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from neo.lib.connection import ClientConnection, ListeningConnection
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets
from .. import SSL
from . import MasterApplication, NEOCluster, test, testReplication
from . import NEOCluster, test, testReplication
class SSLMixin:
......@@ -38,27 +36,25 @@ class SSLTests(SSLMixin, test.Test):
testDeadlockAvoidance = testStorageFailureDuringTpcFinish = None
def testAbortConnection(self):
app = MasterApplication(getSSL=SSL, getReplicas=0, getPartitions=1)
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
for after_handshake in 1, 0:
conn = ClientConnection(app, handler, node)
try:
conn.reset()
except UnboundLocalError:
conn = self.getLoopbackConnection()
conn.ask(Packets.Ping())
connector = conn.getConnector()
del connector.connect_limit[connector.addr]
app.em.poll(1)
conn.em.poll(1)
self.assertTrue(isinstance(connector,
connector.SSLHandshakeConnectorClass))
self.assertNotIn(connector.getDescriptor(), app.em.writer_set)
self.assertNotIn(connector.getDescriptor(), conn.em.writer_set)
if after_handshake:
while not isinstance(connector, connector.SSLConnectorClass):
app.em.poll(1)
conn.em.poll(1)
conn.abort()
fd = connector.getDescriptor()
while fd in app.em.reader_set:
app.em.poll(1)
while fd in conn.em.reader_set:
conn.em.poll(1)
self.assertIs(conn.getConnector(), None)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment