Commit f7378a70 authored by Julien Muchembled's avatar Julien Muchembled

Remove overkill Packet.getType method

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2761 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 86ce8602
......@@ -494,13 +494,13 @@ class Connection(BaseConnection):
self.getHandler()._packetMalformed(self, msg)
return
self._timeout.refresh(time())
packet_type = packet.getType()
if packet_type == Packets.Ping:
packet_type = type(packet)
if packet_type is Packets.Ping:
# Send a pong notification
PACKET_LOGGER.dispatch(self, packet, False)
if not self.aborted:
self.answer(Packets.Pong(), packet.getId())
elif packet_type == Packets.Pong:
elif packet_type is Packets.Pong:
# Skip PONG packets, its only purpose is refresh the timeout
# generated upong ping. But still log them.
PACKET_LOGGER.dispatch(self, packet, False)
......@@ -772,7 +772,7 @@ class MTClientConnection(ClientConnection):
msg_id = self._getNextId()
packet.setId(msg_id)
if queue is None:
if not isinstance(packet, Packets.Ping):
if type(packet) is not Packets.Ping:
raise TypeError, 'Only Ping packet can be asked ' \
'without a queue, got a %r.' % (packet, )
else:
......
......@@ -33,7 +33,7 @@ class EventHandler(object):
def __unexpectedPacket(self, conn, packet, message=None):
"""Handle an unexpected packet."""
if message is None:
message = 'unexpected packet type %s in %s' % (packet.getType(),
message = 'unexpected packet type %s in %s' % (type(packet),
self.__class__.__name__)
else:
message = 'unexpected packet: %s in %s' % (message,
......
......@@ -193,9 +193,6 @@ class Packet(object):
assert self._id is not None, "No identifier applied on the packet"
return self._id
def getType(self):
return self.__class__
def encode(self):
""" Encode a packet as a string to send it over the network """
content = self._body
......
......@@ -334,10 +334,10 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), Packets.Error)
self.assertEqual(type(packet), Packets.Error)
if decode:
return packet.decode()
return protocol.decode_table[packet.getType()](packet._body)
return protocol.decode_table[type(packet)](packet._body)
return packet
def checkAskPacket(self, conn, packet_type, decode=False):
......@@ -346,7 +346,7 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), packet_type)
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
......@@ -357,7 +357,7 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), packet_type)
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
......@@ -367,7 +367,7 @@ class NeoUnitTestBase(NeoTestBase):
calls = conn.mockGetNamedCalls('notify')
packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), packet_type)
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
......
......@@ -934,7 +934,7 @@ class ClientApplicationTests(NeoUnitTestBase):
now = time.time()
app.pack(now)
self.assertEqual(len(marker), 1)
self.assertEqual(marker[0].getType(), Packets.AskPack)
self.assertEqual(type(marker[0]), Packets.AskPack)
# XXX: how to validate packet content ?
if __name__ == '__main__':
......
......@@ -121,7 +121,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(next_range.getType(), Packets.AskCheckTIDRange)
self.assertEqual(type(next_range), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_range.decode()
self.assertEqual(pmin_tid, add64(next_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
......@@ -132,7 +132,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
self.assertEqual(len(packet_list), len(tid_list))
for packet in packet_list:
self.assertEqual(packet.getType(),
self.assertEqual(type(packet),
Packets.AskTransactionInformation)
ptid = packet.decode()[0]
for tid in tid_list:
......@@ -147,7 +147,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(next_range.getType(), Packets.AskCheckSerialRange)
self.assertEqual(type(next_range), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_range.decode()
self.assertEqual(pmin_oid, next_oid)
self.assertEqual(pmin_serial, add64(next_serial, 1))
......@@ -422,7 +422,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
tid_packet = calls[0].getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
......@@ -449,7 +449,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_packet = calls[0].getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
......@@ -577,7 +577,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
serial_packet = calls[0].getParam(0)
self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom)
self.assertEqual(type(serial_packet), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
serial_packet.decode()
self.assertEqual(pmin_oid, min_oid)
......
......@@ -115,7 +115,8 @@ class StorageReplicatorTests(NeoUnitTestBase):
act()
unfinished_tids = app.master_conn.mockGetNamedCalls('ask')[0].getParam(0)
self.assertTrue(replicator.new_partition_set)
self.assertEqual(unfinished_tids.getType(), Packets.AskUnfinishedTransactions)
self.assertEqual(type(unfinished_tids),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False
act()
......
......@@ -413,7 +413,7 @@ class ConnectionTests(NeoUnitTestBase):
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(data.getType(), p.getType())
self.assertEqual(type(data), type(p))
self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode())
self._checkReadBuf(bc, '')
......@@ -455,13 +455,13 @@ class ConnectionTests(NeoUnitTestBase):
# packet 1
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(data.getType(), p1.getType())
self.assertEqual(type(data), type(p1))
self.assertEqual(data.getId(), p1.getId())
self.assertEqual(data.decode(), p1.decode())
# packet 2
call = bc._queue.mockGetNamedCalls("append")[1]
data = call.getParam(0)
self.assertEqual(data.getType(), p2.getType())
self.assertEqual(type(data), type(p2))
self.assertEqual(data.getId(), p2.getId())
self.assertEqual(data.decode(), p2.decode())
self._checkReadBuf(bc, '')
......@@ -497,7 +497,7 @@ class ConnectionTests(NeoUnitTestBase):
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(data.getType(), p.getType())
self.assertEqual(type(data), type(p))
self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode())
self._checkReadBuf(bc, '')
......@@ -519,7 +519,7 @@ class ConnectionTests(NeoUnitTestBase):
buffer.append(chunk)
answer = Packets.parse(buffer, parser_state)
self.assertTrue(answer is not None)
self.assertTrue(answer.getType() == Packets.Pong)
self.assertTrue(type(answer) == Packets.Pong)
self.assertEqual(answer.getId(), p.getId())
def test_Connection_analyse6(self):
......@@ -636,7 +636,7 @@ class ConnectionTests(NeoUnitTestBase):
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0)
self.assertEqual(data.getType(), Packets.AnswerPrimary)
self.assertEqual(type(data), Packets.AnswerPrimary)
self.assertEqual(data.getId(), 1)
self._checkReadBuf(bc, '')
# check not aborted
......
......@@ -259,7 +259,7 @@ class ProtocolTests(NeoUnitTestBase):
tid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
p = Packets.DeleteTransaction(tid, oid_list)
self.assertEqual(p.getType(), Packets.DeleteTransaction)
self.assertEqual(type(p), Packets.DeleteTransaction)
self.assertEqual(p.decode(), (tid, oid_list))
def test_31_commitTransaction(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment