# # Copyright (C) 2009-2010 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import unittest from mock import Mock from collections import deque from neo.tests import NeoUnitTestBase from neo.storage.app import Application from neo.storage.handlers.storage import StorageOperationHandler from neo.lib.protocol import INVALID_PARTITION, Packets from neo.lib.protocol import INVALID_TID, INVALID_OID class StorageStorageHandlerTests(NeoUnitTestBase): def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs): conn = self.getFakeConnection(address=("127.0.0.1", self.master_port), is_server=_listening) # hook self.operation.peerBroken = lambda c: c.peerBrokendCalled() self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs) def setUp(self): NeoUnitTestBase.setUp(self) self.prepareDatabase(number=1) # create an application object config = self.getStorageConfiguration(master_number=1) self.app = Application(config) self.app.transaction_dict = {} self.app.store_lock_dict = {} self.app.load_lock_dict = {} self.app.event_queue = deque() self.app.event_queue_keys = set() # handler self.operation = StorageOperationHandler(self.app) # set pmn self.master_uuid = self.getNewUUID() pmn = self.app.nm.getMasterList()[0] pmn.setUUID(self.master_uuid) self.app.primary_master_node = pmn self.master_port = 10010 def test_18_askTransactionInformation1(self): # transaction does not exists conn = self.getFakeConnection() self.app.dm = Mock({'getNumPartitions': 1}) self.operation.askTransactionInformation(conn, INVALID_TID) self.checkErrorPacket(conn) def test_18_askTransactionInformation2(self): # answer conn = self.getFakeConnection() tid = self.getNextTID() oid_list = [self.getOID(1), self.getOID(2)] dm = Mock({"getTransaction": (oid_list, 'user', 'desc', '', False), }) self.app.dm = dm self.operation.askTransactionInformation(conn, tid) self.checkAnswerTransactionInformation(conn) def test_24_askObject1(self): # delayed response conn = self.getFakeConnection() oid = self.getOID(1) tid = self.getNextTID() serial = self.getNextTID() self.app.dm = Mock() self.app.tm = Mock({'loadLocked': True}) self.app.load_lock_dict[oid] = object() self.assertEquals(len(self.app.event_queue), 0) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) self.assertEquals(len(self.app.event_queue), 1) self.checkNoPacketSent(conn) self.assertEquals(len(self.app.dm.mockGetNamedCalls('getObject')), 0) def test_24_askObject2(self): # invalid serial / tid / packet not found self.app.dm = Mock({'getObject': None}) conn = self.getFakeConnection() oid = self.getOID(1) tid = self.getNextTID() serial = self.getNextTID() self.assertEquals(len(self.app.event_queue), 0) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) calls = self.app.dm.mockGetNamedCalls('getObject') self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(calls), 1) calls[0].checkArgs(oid, serial, tid, resolve_data=False) self.checkErrorPacket(conn) def test_24_askObject3(self): oid = self.getOID(1) tid = self.getNextTID() serial = self.getNextTID() next_serial = self.getNextTID() # object found => answer self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)}) conn = self.getFakeConnection() self.assertEquals(len(self.app.event_queue), 0) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) self.assertEquals(len(self.app.event_queue), 0) self.checkAnswerObject(conn) def test_25_askTIDsFrom(self): # well case => answer conn = self.getFakeConnection() self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )}) self.app.pt = Mock({'getPartitions': 1}) tid = self.getNextTID() tid2 = self.getNextTID() self.operation.askTIDsFrom(conn, tid, tid2, 2, [1]) calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') self.assertEquals(len(calls), 1) calls[0].checkArgs(tid, tid2, 2, 1, 1) self.checkAnswerTidsFrom(conn) def test_26_askObjectHistoryFrom(self): min_oid = self.getOID(2) min_serial = self.getNextTID() max_serial = self.getNextTID() length = 4 partition = 8 num_partitions = 16 tid = self.getNextTID() conn = self.getFakeConnection() self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},}) self.app.pt = Mock({ 'getPartitions': num_partitions, }) self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, max_serial, length, partition) self.checkAnswerObjectHistoryFrom(conn) calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom') self.assertEquals(len(calls), 1) calls[0].checkArgs(min_oid, min_serial, max_serial, length, num_partitions, partition) def test_askCheckTIDRange(self): count = 1 tid_checksum = 2 min_tid = self.getNextTID() num_partitions = 4 length = 5 partition = 6 max_tid = self.getNextTID() self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)}) self.app.pt = Mock({'getPartitions': num_partitions}) conn = self.getFakeConnection() self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition) calls = self.app.dm.mockGetNamedCalls('checkTIDRange') self.assertEqual(len(calls), 1) calls[0].checkArgs(min_tid, max_tid, length, num_partitions, partition) pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \ self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange, decode=True) self.assertEqual(min_tid, pmin_tid) self.assertEqual(length, plength) self.assertEqual(count, pcount) self.assertEqual(tid_checksum, ptid_checksum) self.assertEqual(max_tid, pmax_tid) def test_askCheckSerialRange(self): count = 1 oid_checksum = 2 min_oid = self.getOID(1) num_partitions = 4 length = 5 partition = 6 serial_checksum = 7 min_serial = self.getNextTID() max_serial = self.getNextTID() max_oid = self.getOID(2) self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid, serial_checksum, max_serial)}) self.app.pt = Mock({'getPartitions': num_partitions}) conn = self.getFakeConnection() self.operation.askCheckSerialRange(conn, min_oid, min_serial, max_serial, length, partition) calls = self.app.dm.mockGetNamedCalls('checkSerialRange') self.assertEqual(len(calls), 1) calls[0].checkArgs(min_oid, min_serial, max_serial, length, num_partitions, partition) pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \ pserial_checksum, pmax_serial = self.checkAnswerPacket(conn, Packets.AnswerCheckSerialRange, decode=True) self.assertEqual(min_oid, pmin_oid) self.assertEqual(min_serial, pmin_serial) self.assertEqual(length, plength) self.assertEqual(count, pcount) self.assertEqual(oid_checksum, poid_checksum) self.assertEqual(max_oid, pmax_oid) self.assertEqual(serial_checksum, pserial_checksum) self.assertEqual(max_serial, pmax_serial) if __name__ == "__main__": unittest.main()