Commit 1e0c5efc authored by Julien Muchembled's avatar Julien Muchembled

Master must not die if client sends an invalid ttid

parent 8aef2569
...@@ -733,6 +733,7 @@ class Application(object): ...@@ -733,6 +733,7 @@ class Application(object):
tid = self._askPrimary(Packets.AskFinishTransaction( tid = self._askPrimary(Packets.AskFinishTransaction(
txn_context['ttid'], cache_dict), txn_context['ttid'], cache_dict),
cache_dict=cache_dict, callback=f) cache_dict=cache_dict, callback=f)
assert tid
return tid return tid
finally: finally:
self._load_lock_release() self._load_lock_release()
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import ProtocolError, Packets from neo.lib.protocol import ProtocolError, Packets
from ZODB.POSException import StorageError
class BaseHandler(EventHandler): class BaseHandler(EventHandler):
"""Base class for client-side EventHandler implementations.""" """Base class for client-side EventHandler implementations."""
...@@ -59,3 +60,5 @@ class AnswerBaseHandler(EventHandler): ...@@ -59,3 +60,5 @@ class AnswerBaseHandler(EventHandler):
packetReceived = unexpectedInAnswerHandler packetReceived = unexpectedInAnswerHandler
peerBroken = unexpectedInAnswerHandler peerBroken = unexpectedInAnswerHandler
def protocolError(self, conn, message):
raise StorageError("protocol error: %s" % message)
...@@ -162,7 +162,7 @@ class EventHandler(object): ...@@ -162,7 +162,7 @@ class EventHandler(object):
# Error packet handlers. # Error packet handlers.
def error(self, conn, code, message): def error(self, conn, code, message, **kw):
try: try:
getattr(self, Errors[code])(conn, message) getattr(self, Errors[code])(conn, message)
except (AttributeError, ValueError): except (AttributeError, ValueError):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import uuid_str, ZERO_TID from neo.lib.protocol import ProtocolError, uuid_str, ZERO_TID
from neo.lib.util import dump, u64, addTID, tidFromTime from neo.lib.util import dump, u64, addTID, tidFromTime
class DelayedError(Exception): class DelayedError(Exception):
...@@ -295,7 +295,10 @@ class TransactionManager(object): ...@@ -295,7 +295,10 @@ class TransactionManager(object):
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
# XXX: not efficient but the list should be often small # XXX: not efficient but the list should be often small
try:
txn = self._ttid_dict[ttid] txn = self._ttid_dict[ttid]
except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid))
node = txn.getNode() node = txn.getNode()
for _, tid in self._queue: for _, tid in self._queue:
if ttid == tid: if ttid == tid:
......
...@@ -27,7 +27,7 @@ from neo.lib.connection import MTClientConnection ...@@ -27,7 +27,7 @@ from neo.lib.connection import MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_TID
from . import ClientApplication, NEOCluster, NEOThreadedTest, Patch from . import ClientApplication, NEOCluster, NEOThreadedTest, Patch
from neo.lib.util import makeChecksum from neo.lib.util import add64, makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class PCounter(Persistent): class PCounter(Persistent):
...@@ -649,6 +649,21 @@ class Test(NEOThreadedTest): ...@@ -649,6 +649,21 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testInvalidTTID(self):
cluster = NEOCluster()
try:
cluster.start()
client = cluster.client
client.setPoll(1)
txn = transaction.Transaction()
client.tpc_begin(txn)
txn_context = client._txn_container.get(txn)
txn_context['ttid'] = add64(txn_context['ttid'], 1)
self.assertRaises(POSException.StorageError,
client.tpc_finish, txn, None)
finally:
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment