Commit d68e9053 authored by Julien Muchembled's avatar Julien Muchembled

client: fix undetected disconnections to storage nodes during commit

When a client-storage connection breaks, the storage node discards data of all
ongoing transactions by the client. Therefore, a reconnection within the
context of the transaction is wrong, as it could lead to partially-written
transactions.

This fixes cases where such reconnection happened. The biggest issue was that
the mechanism to dispatch disconnection events only works when waiting for an
answer.

The client can still reconnect for other purposes but the new connection won't
be reused by transactions that already involved the storage node.
parent 59698faa
...@@ -515,15 +515,12 @@ class Application(ThreadedApplication): ...@@ -515,15 +515,12 @@ class Application(ThreadedApplication):
self._store(txn_context, oid, serial, data) self._store(txn_context, oid, serial, data)
def _askStorageForWrite(self, txn_context, uuid, packet): def _askStorageForWrite(self, txn_context, uuid, packet):
node = self.nm.getByUUID(uuid) conn = txn_context.conn_dict[uuid]
if node is not None:
conn = self.cp.getConnForNode(node)
if conn is not None:
try: try:
return conn.ask(packet, queue=txn_context.queue) return conn.ask(packet, queue=txn_context.queue)
except ConnectionClosed: except ConnectionClosed:
pass
txn_context.involved_nodes[uuid] = 2 txn_context.involved_nodes[uuid] = 2
del txn_context.conn_dict[uuid]
def waitResponses(self, queue): def waitResponses(self, queue):
"""Wait for all requests to be answered (or their connection to be """Wait for all requests to be answered (or their connection to be
...@@ -600,10 +597,10 @@ class Application(ThreadedApplication): ...@@ -600,10 +597,10 @@ class Application(ThreadedApplication):
# condition. The consequence would be that storage nodes lock oids # condition. The consequence would be that storage nodes lock oids
# forever. # forever.
p = Packets.AbortTransaction(txn_context.ttid, ()) p = Packets.AbortTransaction(txn_context.ttid, ())
for uuid in txn_context.involved_nodes: for conn in txn_context.conn_dict.itervalues():
try: try:
self.cp.connection_dict[uuid].send(p) conn.send(p)
except (KeyError, ConnectionClosed): except ConnectionClosed:
pass pass
# Because we want to be sure that the involved nodes are notified, # Because we want to be sure that the involved nodes are notified,
# we still have to send the full list to the master. Most of the # we still have to send the full list to the master. Most of the
......
...@@ -18,6 +18,7 @@ from ZODB.TimeStamp import TimeStamp ...@@ -18,6 +18,7 @@ from ZODB.TimeStamp import TimeStamp
from neo.lib import logging from neo.lib import logging
from neo.lib.compress import decompress_list from neo.lib.compress import decompress_list
from neo.lib.connection import ConnectionClosed
from neo.lib.protocol import Packets, uuid_str from neo.lib.protocol import Packets, uuid_str
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
...@@ -95,7 +96,9 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -95,7 +96,9 @@ class StorageAnswersHandler(AnswerBaseHandler):
conn.ask(Packets.AskRebaseObject(ttid, oid), conn.ask(Packets.AskRebaseObject(ttid, oid),
queue=queue, oid=oid) queue=queue, oid=oid)
except ConnectionClosed: except ConnectionClosed:
txn_context.involved_nodes[conn.getUUID()] = 2 uuid = conn.getUUID()
txn_context.involved_nodes[uuid] = 2
del txn_context.conn_dict[uuid]
def answerRebaseObject(self, conn, conflict, oid): def answerRebaseObject(self, conn, conflict, oid):
if conflict: if conflict:
...@@ -107,8 +110,10 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -107,8 +110,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
cached = txn_context.cache_dict.pop(oid) cached = txn_context.cache_dict.pop(oid)
except KeyError: except KeyError:
if resolved: if resolved:
# We should still be waiting for an answer from this node. # We should still be waiting for an answer from this node,
assert conn.uuid in txn_context.data_dict[oid][2] # unless we lost connection.
assert conn.uuid in txn_context.data_dict[oid][2] or \
txn_context.involved_nodes[conn.uuid] == 2
return return
assert oid in txn_context.data_dict assert oid in txn_context.data_dict
if serial <= txn_context.conflict_dict.get(oid, ''): if serial <= txn_context.conflict_dict.get(oid, ''):
...@@ -136,7 +141,7 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -136,7 +141,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
if cached: if cached:
assert cached == data assert cached == data
txn_context.cache_size -= size txn_context.cache_size -= size
txn_context.data_dict[oid] = data, serial, None txn_context.data_dict[oid] = data, serial, []
txn_context.conflict_dict[oid] = conflict txn_context.conflict_dict[oid] = conflict
def answerStoreTransaction(self, conn): def answerStoreTransaction(self, conn):
...@@ -145,6 +150,7 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -145,6 +150,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerVoteTransaction = answerStoreTransaction answerVoteTransaction = answerStoreTransaction
def connectionClosed(self, conn): def connectionClosed(self, conn):
# only called if we were waiting for an answer
txn_context = self.app.getHandlerData() txn_context = self.app.getHandlerData()
if type(txn_context) is Transaction: if type(txn_context) is Transaction:
txn_context.nodeLost(self.app, conn.getUUID()) txn_context.nodeLost(self.app, conn.getUUID())
......
...@@ -52,6 +52,7 @@ class Transaction(object): ...@@ -52,6 +52,7 @@ class Transaction(object):
# if the id is still known by the NodeManager. # if the id is still known by the NodeManager.
# status: 0 -> check only, 1 -> store, 2 -> failed # status: 0 -> check only, 1 -> store, 2 -> failed
self.involved_nodes = {} # {node_id: status} self.involved_nodes = {} # {node_id: status}
self.conn_dict = {} # {node_id: connection}
def wakeup(self, conn): def wakeup(self, conn):
self.queue.put((conn, _WakeupPacket, {})) self.queue.put((conn, _WakeupPacket, {}))
...@@ -69,7 +70,10 @@ class Transaction(object): ...@@ -69,7 +70,10 @@ class Transaction(object):
involved[uuid] = store involved[uuid] = store
elif status > 1: elif status > 1:
continue continue
conn = app.cp.getConnForNode(node) if status < 0:
conn = self.conn_dict[uuid] = app.cp.getConnForNode(node)
else:
conn = self.conn_dict[uuid]
if conn is not None: if conn is not None:
try: try:
if status < 0 and self.locking_tid and 'oid' in kw: if status < 0 and self.locking_tid and 'oid' in kw:
...@@ -84,6 +88,7 @@ class Transaction(object): ...@@ -84,6 +88,7 @@ class Transaction(object):
continue continue
except ConnectionClosed: except ConnectionClosed:
pass pass
del self.conn_dict[uuid]
involved[uuid] = 2 involved[uuid] = 2
if uuid_list: if uuid_list:
return uuid_list return uuid_list
...@@ -131,7 +136,9 @@ class Transaction(object): ...@@ -131,7 +136,9 @@ class Transaction(object):
self.cache_dict[oid] = data self.cache_dict[oid] = data
def nodeLost(self, app, uuid): def nodeLost(self, app, uuid):
# The following 2 lines are sometimes redundant with the 2 in write().
self.involved_nodes[uuid] = 2 self.involved_nodes[uuid] = 2
self.conn_dict.pop(uuid, None)
for oid in list(self.data_dict): for oid in list(self.data_dict):
self.written(app, uuid, oid) self.written(app, uuid, oid)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment