Commit db8db123 authored by Grégory Wisniewski's avatar Grégory Wisniewski

MTConnection handle local queue to unify ask() prototype.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1925 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f7707668
......@@ -246,7 +246,7 @@ class Application(object):
def _askStorage(self, conn, packet):
""" Send a request to a storage node and process it's answer """
try:
msg_id = conn.ask(self.local_var.queue, packet)
msg_id = conn.ask(packet)
finally:
# assume that the connection was already locked
conn.unlock()
......@@ -258,7 +258,7 @@ class Application(object):
conn = self._getMasterConnection()
conn.lock()
try:
msg_id = conn.ask(self.local_var.queue, packet)
msg_id = conn.ask(packet)
finally:
conn.unlock()
self._waitMessage(conn, msg_id, self.primary_handler)
......@@ -321,7 +321,8 @@ class Application(object):
self.trying_master_node = master_list[0]
index += 1
# Connect to master
conn = MTClientConnection(self.em, self.notifications_handler,
conn = MTClientConnection(self.local_var, self.em,
self.notifications_handler,
addr=self.trying_master_node.getAddress(),
connector=self.connector_handler(),
dispatcher=self.dispatcher)
......@@ -333,8 +334,7 @@ class Application(object):
logging.error('Connection to master node %s failed',
self.trying_master_node)
continue
msg_id = conn.ask(self.local_var.queue,
Packets.AskPrimary())
msg_id = conn.ask(Packets.AskPrimary())
finally:
conn.unlock()
try:
......@@ -358,7 +358,7 @@ class Application(object):
break
p = Packets.RequestIdentification(NodeTypes.CLIENT,
self.uuid, None, self.name)
msg_id = conn.ask(self.local_var.queue, p)
msg_id = conn.ask(p)
finally:
conn.unlock()
try:
......@@ -373,16 +373,14 @@ class Application(object):
if self.uuid is not None:
conn.lock()
try:
msg_id = conn.ask(self.local_var.queue,
Packets.AskNodeInformation())
msg_id = conn.ask(Packets.AskNodeInformation())
finally:
conn.unlock()
self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler)
conn.lock()
try:
msg_id = conn.ask(self.local_var.queue,
Packets.AskPartitionTable([]))
msg_id = conn.ask(Packets.AskPartitionTable([]))
finally:
conn.unlock()
self._waitMessage(conn, msg_id,
......@@ -600,14 +598,13 @@ class Application(object):
# Store data on each node
self.local_var.object_stored_counter_dict[oid] = 0
self.local_var.object_serial_dict[oid] = (serial, version)
local_queue = self.local_var.queue
for cell in cell_list:
conn = self.cp.getConnForCell(cell)
if conn is None:
continue
try:
try:
conn.ask(local_queue, p)
conn.ask(p)
finally:
conn.unlock()
except ConnectionClosed:
......@@ -882,8 +879,7 @@ class Application(object):
continue
try:
conn.ask(self.local_var.queue, Packets.AskTIDs(first, last,
INVALID_PARTITION))
conn.ask(Packets.AskTIDs(first, last, INVALID_PARTITION))
finally:
conn.unlock()
......
......@@ -50,7 +50,8 @@ class ConnectionPool(object):
while True:
logging.debug('trying to connect to %s - %s', node, node.getState())
app.setNodeReady()
conn = MTClientConnection(app.em, app.storage_event_handler, addr,
conn = MTClientConnection(app.local_var, app.em,
app.storage_event_handler, addr,
connector=app.connector_handler(), dispatcher=app.dispatcher)
conn.lock()
......@@ -62,7 +63,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name)
msg_id = conn.ask(app.local_var.queue, p)
msg_id = conn.ask(p)
finally:
conn.unlock()
......
......@@ -565,9 +565,10 @@ class ServerConnection(Connection):
class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection."""
def __init__(self, *args, **kwargs):
def __init__(self, local_var, *args, **kwargs):
# _lock is only here for lock debugging purposes. Do not use.
self._lock = lock = RLock()
self._local_var = local_var
self.acquire = lock.acquire
self.release = lock.release
self.dispatcher = kwargs.pop('dispatcher')
......@@ -600,10 +601,10 @@ class MTClientConnection(ClientConnection):
return super(MTClientConnection, self).notify(*args, **kw)
@lockCheckWrapper
def ask(self, queue, packet, timeout=CRITICAL_TIMEOUT):
def ask(self, packet, timeout=CRITICAL_TIMEOUT):
msg_id = self._getNextId()
packet.setId(msg_id)
self.dispatcher.register(self, msg_id, queue)
self.dispatcher.register(self, msg_id, self._local_var.queue)
self._addPacket(packet)
if not self._handlers.isPending():
self._timeout.update(time(), timeout=timeout)
......
......@@ -76,7 +76,7 @@ class ClientApplicationTests(NeoTestBase):
calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1)
# client connection got queue as first parameter
packet = calls[0].getParam(1)
packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), packet_type)
if decode:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment