Commit edefaca7 authored by Julien Muchembled's avatar Julien Muchembled

client: add support for reconnection to master

This implementation proper cache invalidation.

Connection to master is also made optional to load from storage nodes, as long
as partition table contains up-to-date data (which is anyway unlikely to change
when there is no master).
parent 1ea04be7
...@@ -68,7 +68,6 @@ class Application(object): ...@@ -68,7 +68,6 @@ class Application(object):
self.dispatcher = Dispatcher(self.poll_thread) self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager(dynamic_master_list) self.nm = NodeManager(dynamic_master_list)
self.cp = ConnectionPool(self) self.cp = ConnectionPool(self)
self.pt = None
self.master_conn = None self.master_conn = None
self.primary_master_node = None self.primary_master_node = None
self.trying_master_node = None self.trying_master_node = None
...@@ -117,6 +116,16 @@ class Application(object): ...@@ -117,6 +116,16 @@ class Application(object):
self.compress = compress self.compress = compress
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
def __getattr__(self, attr):
if attr == 'pt':
self._getMasterConnection()
return self.__getattribute__(attr)
@property
def txn_contexts(self):
# do not iter lazily to avoid race condition
return self._txn_container.values
def getHandlerData(self): def getHandlerData(self):
return self._thread_container.answer return self._thread_container.answer
...@@ -241,13 +250,6 @@ class Application(object): ...@@ -241,13 +250,6 @@ class Application(object):
result = self.master_conn = self._connectToPrimaryNode() result = self.master_conn = self._connectToPrimaryNode()
return result return result
def getPartitionTable(self):
""" Return the partition table manager, reconnect the PMN if needed """
# this ensure the master connection is established and the partition
# table is up to date.
self._getMasterConnection()
return self.pt
def _connectToPrimaryNode(self): def _connectToPrimaryNode(self):
""" """
Lookup for the current primary master node Lookup for the current primary master node
...@@ -660,7 +662,6 @@ class Application(object): ...@@ -660,7 +662,6 @@ class Application(object):
ttid = txn_context['ttid'] ttid = txn_context['ttid']
# Store data on each node # Store data on each node
txn_stored_counter = 0
assert not txn_context['data_dict'], txn_context assert not txn_context['data_dict'], txn_context
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
...@@ -674,20 +675,17 @@ class Application(object): ...@@ -674,20 +675,17 @@ class Application(object):
except ConnectionClosed: except ConnectionClosed:
continue continue
add_involved_nodes(node) add_involved_nodes(node)
txn_stored_counter += 1
# check at least one storage node accepted # check at least one storage node accepted
if txn_stored_counter == 0: if txn_context['involved_nodes']:
txn_context['voted'] = None
# We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish.
if 'error' in txn_context:
raise NEOStorageError(txn_context['error'])
return result
logging.error('tpc_vote failed') logging.error('tpc_vote failed')
raise NEOStorageError('tpc_vote failed') raise NEOStorageError('tpc_vote failed')
# Check if master connection is still alive.
# This is just here to lower the probability of detecting a problem
# in tpc_finish, as we should do our best to detect problem before
# tpc_finish.
self._getMasterConnection()
txn_context['txn_voted'] = True
return result
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Abort current transaction.""" """Abort current transaction."""
...@@ -718,7 +716,7 @@ class Application(object): ...@@ -718,7 +716,7 @@ class Application(object):
def tpc_finish(self, transaction, tryToResolveConflict, f=None): def tpc_finish(self, transaction, tryToResolveConflict, f=None):
"""Finish current transaction.""" """Finish current transaction."""
txn_container = self._txn_container txn_container = self._txn_container
if not txn_container.get(transaction)['txn_voted']: if 'voted' not in txn_container.get(transaction):
self.tpc_vote(transaction, tryToResolveConflict) self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire() self._load_lock_acquire()
try: try:
...@@ -735,15 +733,13 @@ class Application(object): ...@@ -735,15 +733,13 @@ class Application(object):
def undo(self, undone_tid, txn, tryToResolveConflict): def undo(self, undone_tid, txn, tryToResolveConflict):
txn_context = self._txn_container.get(txn) txn_context = self._txn_container.get(txn)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids'] txn_oid_list = txn_info['oids']
# Regroup objects per partition, to ask a minimum set of storage. # Regroup objects per partition, to ask a minimum set of storage.
partition_oid_dict = {} partition_oid_dict = {}
pt = self.getPartitionTable()
for oid in txn_oid_list: for oid in txn_oid_list:
partition = pt.getPartition(oid) partition = self.pt.getPartition(oid)
try: try:
oid_list = partition_oid_dict[partition] oid_list = partition_oid_dict[partition]
except KeyError: except KeyError:
...@@ -752,7 +748,7 @@ class Application(object): ...@@ -752,7 +748,7 @@ class Application(object):
# Ask storage the undo serial (serial at which object's previous data # Ask storage the undo serial (serial at which object's previous data
# is) # is)
getCellList = pt.getCellList getCellList = self.pt.getCellList
getCellSortKey = self.cp.getCellSortKey getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self._thread_container.queue queue = self._thread_container.queue
...@@ -838,11 +834,10 @@ class Application(object): ...@@ -838,11 +834,10 @@ class Application(object):
# First get a list of transactions from all storage nodes. # First get a list of transactions from all storage nodes.
# Each storage node will return TIDs only for UP_TO_DATE state and # Each storage node will return TIDs only for UP_TO_DATE state and
# FEEDING state cells # FEEDING state cells
pt = self.getPartitionTable()
queue = self._thread_container.queue queue = self._thread_container.queue
packet = Packets.AskTIDs(first, last, INVALID_PARTITION) packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
tid_set = set() tid_set = set()
for storage_node in pt.getNodeSet(True): for storage_node in self.pt.getNodeSet(True):
conn = self.cp.getConnForNode(storage_node) conn = self.cp.getConnForNode(storage_node)
if conn is None: if conn is None:
continue continue
......
...@@ -235,6 +235,22 @@ class ClientCache(object): ...@@ -235,6 +235,22 @@ class ClientCache(object):
else: else:
assert item.next_tid <= tid, (item, oid, tid) assert item.next_tid <= tid, (item, oid, tid)
def clear_current(self):
oid_list = []
for oid, item_list in self._oid_dict.items():
item = item_list[-1]
if item.next_tid is None:
self._remove(item)
del item_list[-1]
# We don't preserve statistics of removed items. This could be
# done easily when previous versions are cached, by copying
# counters, but it would not be fair for other oids, so it's
# probably not worth it.
if not item_list:
del self._oid_dict[oid]
oid_list.append(oid)
return oid_list
def test(self): def test(self):
cache = ClientCache() cache = ClientCache()
...@@ -250,7 +266,11 @@ def test(self): ...@@ -250,7 +266,11 @@ def test(self):
data = '15', 15, None data = '15', 15, None
cache.store(1, *data) cache.store(1, *data)
self.assertEqual(cache.load(1, None), data) self.assertEqual(cache.load(1, None), data)
self.assertEqual(cache.clear_current(), [1])
self.assertEqual(cache.load(1, None), None)
cache.store(1, *data)
cache.invalidate(1, 20) cache.invalidate(1, 20)
self.assertEqual(cache.clear_current(), [])
self.assertEqual(cache.load(1, 20), ('15', 15, 20)) self.assertEqual(cache.load(1, 20), ('15', 15, 20))
cache.store(1, '10', 10, 15) cache.store(1, '10', 10, 15)
cache.store(1, '20', 20, 21) cache.store(1, '20', 20, 21)
......
...@@ -99,7 +99,6 @@ class TransactionContainer(dict): ...@@ -99,7 +99,6 @@ class TransactionContainer(dict):
'object_stored_counter_dict': {}, 'object_stored_counter_dict': {},
'conflict_serial_dict': {}, 'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {}, 'resolved_conflict_serial_dict': {},
'txn_voted': False,
'involved_nodes': set(), 'involved_nodes': set(),
} }
return context return context
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.pt import MTPartitionTable as PartitionTable from neo.lib.pt import MTPartitionTable as PartitionTable
from neo.lib.protocol import NodeStates, Packets, ProtocolError from neo.lib.protocol import NodeStates, Packets, ProtocolError
from neo.lib.util import dump from neo.lib.util import dump, add64
from . import BaseHandler, AnswerBaseHandler from . import BaseHandler, AnswerBaseHandler
from ..exception import NEOStorageError from ..exception import NEOStorageError
...@@ -96,7 +96,20 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -96,7 +96,20 @@ class PrimaryNotificationsHandler(BaseHandler):
def packetReceived(self, conn, packet, kw={}): def packetReceived(self, conn, packet, kw={}):
if type(packet) is Packets.AnswerLastTransaction: if type(packet) is Packets.AnswerLastTransaction:
self.app.last_tid = packet.decode()[0] app = self.app
ltid = packet.decode()[0]
if app.last_tid != ltid:
if app.master_conn is None:
app._cache_lock_acquire()
try:
oid_list = app._cache.clear_current()
db = app.getDB()
if db is not None:
db.invalidate(app.last_tid and
add64(app.last_tid, 1), oid_list)
finally:
app._cache_lock_release()
app.last_tid = ltid
elif type(packet) is Packets.AnswerTransactionFinished: elif type(packet) is Packets.AnswerTransactionFinished:
app = self.app app = self.app
app.last_tid = tid = packet.decode()[1] app.last_tid = tid = packet.decode()[1]
...@@ -125,8 +138,11 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -125,8 +138,11 @@ class PrimaryNotificationsHandler(BaseHandler):
def connectionClosed(self, conn): def connectionClosed(self, conn):
app = self.app app = self.app
if app.master_conn is not None: if app.master_conn is not None:
logging.critical("connection to primary master node closed") msg = "connection to primary master node closed"
logging.critical(msg)
app.master_conn = None app.master_conn = None
for txn_context in app.txn_contexts():
txn_context['error'] = msg
app.primary_master_node = None app.primary_master_node = None
super(PrimaryNotificationsHandler, self).connectionClosed(conn) super(PrimaryNotificationsHandler, self).connectionClosed(conn)
...@@ -151,10 +167,6 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -151,10 +167,6 @@ class PrimaryNotificationsHandler(BaseHandler):
finally: finally:
app._cache_lock_release() app._cache_lock_release()
# For the two methods below, we must not use app._getPartitionTable()
# to avoid a dead lock. It is safe to not check the master connection
# because it's in the master handler, so the connection is already
# established.
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
if self.app.pt.filled(): if self.app.pt.filled():
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
......
...@@ -107,7 +107,7 @@ class ConnectionPool(object): ...@@ -107,7 +107,7 @@ class ConnectionPool(object):
def iterateForObject(self, object_id, readable=False): def iterateForObject(self, object_id, readable=False):
""" Iterate over nodes managing an object """ """ Iterate over nodes managing an object """
pt = self.app.getPartitionTable() pt = self.app.pt
if type(object_id) is str: if type(object_id) is str:
object_id = pt.getPartition(object_id) object_id = pt.getPartition(object_id)
cell_list = pt.getCellList(object_id, readable) cell_list = pt.getCellList(object_id, readable)
......
...@@ -44,11 +44,6 @@ def _getMasterConnection(self): ...@@ -44,11 +44,6 @@ def _getMasterConnection(self):
self.master_conn = Mock() self.master_conn = Mock()
return self.master_conn return self.master_conn
def getPartitionTable(self):
if self.pt is None:
self.master_conn = _getMasterConnection(self)
return self.pt
def _ask(self, conn, packet, handler=None, **kw): def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None) self.setHandlerData(None)
conn.ask(packet, **kw) conn.ask(packet, **kw)
...@@ -71,10 +66,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -71,10 +66,8 @@ class ClientApplicationTests(NeoUnitTestBase):
# apply monkey patches # apply monkey patches
self._getMasterConnection = Application._getMasterConnection self._getMasterConnection = Application._getMasterConnection
self._ask = Application._ask self._ask = Application._ask
self.getPartitionTable = Application.getPartitionTable
Application._getMasterConnection = _getMasterConnection Application._getMasterConnection = _getMasterConnection
Application._ask = _ask Application._ask = _ask
Application.getPartitionTable = getPartitionTable
self._to_stop_list = [] self._to_stop_list = []
def _tearDown(self, success): def _tearDown(self, success):
...@@ -82,9 +75,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -82,9 +75,8 @@ class ClientApplicationTests(NeoUnitTestBase):
for app in self._to_stop_list: for app in self._to_stop_list:
app.close() app.close()
# restore environnement # restore environnement
Application._getMasterConnection = self._getMasterConnection
Application._ask = self._ask Application._ask = self._ask
Application.getPartitionTable = self.getPartitionTable Application._getMasterConnection = self._getMasterConnection
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
# some helpers # some helpers
...@@ -499,7 +491,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -499,7 +491,7 @@ class ClientApplicationTests(NeoUnitTestBase):
'getAddress': ('127.0.0.1', 10010), 'getAddress': ('127.0.0.1', 10010),
'fakeReceived': packet, 'fakeReceived': packet,
}) })
txn_context['txn_voted'] = True txn_context['voted'] = None
app.tpc_finish(txn, None) app.tpc_finish(txn, None)
self.checkAskFinishTransaction(app.master_conn) self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn) #self.checkDispatcherRegisterCalled(app, app.master_conn)
......
...@@ -76,8 +76,8 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -76,8 +76,8 @@ class ConnectionPoolTests(NeoUnitTestBase):
def test_iterateForObject_noStorageAvailable(self): def test_iterateForObject_noStorageAvailable(self):
# no node available # no node available
oid = self.getOID(1) oid = self.getOID(1)
pt = Mock({'getCellList': []}) app = Mock()
app = Mock({'getPartitionTable': pt}) app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app) pool = ConnectionPool(app)
self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next) self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next)
...@@ -87,8 +87,8 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -87,8 +87,8 @@ class ConnectionPoolTests(NeoUnitTestBase):
node = Mock({'__repr__': 'node', 'isRunning': True}) node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node}) cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'}) conn = Mock({'__repr__': 'conn'})
pt = Mock({'getCellList': [cell]}) app = Mock()
app = Mock({'getPartitionTable': pt}) app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app) pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': ReturnValues(None, conn)}) pool.getConnForNode = Mock({'__call__': ReturnValues(None, conn)})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)]) self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
...@@ -99,8 +99,8 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -99,8 +99,8 @@ class ConnectionPoolTests(NeoUnitTestBase):
node = Mock({'__repr__': 'node', 'isRunning': True}) node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node}) cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'}) conn = Mock({'__repr__': 'conn'})
pt = Mock({'getCellList': [cell]}) app = Mock()
app = Mock({'getPartitionTable': pt}) app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app) pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': conn}) pool.getConnForNode = Mock({'__call__': conn})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)]) self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
......
...@@ -29,7 +29,8 @@ class MasterHandlerTests(NeoUnitTestBase): ...@@ -29,7 +29,8 @@ class MasterHandlerTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
super(MasterHandlerTests, self).setUp() super(MasterHandlerTests, self).setUp()
self.db = Mock() self.db = Mock()
self.app = Mock({'getDB': self.db}) self.app = Mock({'getDB': self.db,
'txn_contexts': ()})
self.app.nm = NodeManager() self.app.nm = NodeManager()
self.app.dispatcher = Mock() self.app.dispatcher = Mock()
self._next_port = 3000 self._next_port = 3000
......
...@@ -649,6 +649,44 @@ class Test(NEOThreadedTest): ...@@ -649,6 +649,44 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testClientReconnection(self):
cluster = NEOCluster()
try:
cluster.start()
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter()
c1.root()['y'] = y = PCounter()
y.value = 1
t1.commit()
x = c1._storage.load(x1._p_oid)[0]
y = c1._storage.load(y._p_oid)[0]
# close connections to master & storage
c, = cluster.master.nm.getClientList()
c.getConnection().close()
c, = cluster.storage.nm.getClientList()
c.getConnection().close()
cluster.tic()
# modify x with another client
client = ClientApplication(name=cluster.name,
master_nodes=cluster.master_nodes)
cluster.client.setPoll(0)
client.setPoll(1)
txn = transaction.Transaction()
client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn)
tid = client.tpc_finish(txn, None)
client.close()
client.setPoll(0)
cluster.client.setPoll(1)
t1.begin()
self.assertEqual(x1._p_changed ,None)
self.assertEqual(x1.value, 1)
finally:
cluster.stop()
def testInvalidTTID(self): def testInvalidTTID(self):
cluster = NEOCluster() cluster = NEOCluster()
try: try:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment