#
# Copyright (C) 2009-2010  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

import unittest
from mock import Mock, ReturnValues
from collections import deque
from neo.lib.util import makeChecksum
from neo.tests import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_PARTITION, INVALID_TID, INVALID_OID
from neo.lib.protocol import Packets, LockState, ZERO_HASH

class StorageClientHandlerTests(NeoUnitTestBase):

    def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
        conn = self.getFakeConnection(address=("127.0.0.1", self.master_port),
                is_server=_listening)
        # hook
        self.operation.peerBroken = lambda c: c.peerBrokendCalled()
        self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)

    def setUp(self):
        NeoUnitTestBase.setUp(self)
        self.prepareDatabase(number=1)
        # create an application object
        config = self.getStorageConfiguration(master_number=1)
        self.app = Application(config)
        self.app.transaction_dict = {}
        self.app.store_lock_dict = {}
        self.app.load_lock_dict = {}
        self.app.event_queue = deque()
        self.app.event_queue_dict = {}
        self.app.tm = Mock({'__contains__': True})
        # handler
        self.operation = ClientOperationHandler(self.app)
        # set pmn
        self.master_uuid = self.getNewUUID()
        pmn = self.app.nm.getMasterList()[0]
        pmn.setUUID(self.master_uuid)
        self.app.primary_master_node = pmn
        self.master_port = 10010

    def tearDown(self):
        self.app.close()
        del self.app
        super(StorageClientHandlerTests, self).tearDown()

    def _getConnection(self, uuid=None):
        return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000))

    def _checkTransactionsAborted(self, uuid):
        calls = self.app.tm.mockGetNamedCalls('abortFor')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(uuid)

    def test_connectionLost(self):
        uuid = self.getNewUUID()
        self.app.nm.createClient(uuid=uuid)
        conn = self._getConnection(uuid=uuid)
        self.operation.connectionClosed(conn)

    def test_18_askTransactionInformation1(self):
        # transaction does not exists
        conn = self._getConnection()
        self.app.dm = Mock({'getNumPartitions': 1})
        self.operation.askTransactionInformation(conn, INVALID_TID)
        self.checkErrorPacket(conn)

    def test_18_askTransactionInformation2(self):
        # answer
        conn = self._getConnection()
        oid_list = [self.getOID(1), self.getOID(2)]
        dm = Mock({ "getTransaction": (oid_list, 'user', 'desc', '', False), })
        self.app.dm = dm
        self.operation.askTransactionInformation(conn, INVALID_TID)
        self.checkAnswerTransactionInformation(conn)

    def test_24_askObject1(self):
        # delayed response
        conn = self._getConnection()
        self.app.dm = Mock()
        self.app.tm = Mock({'loadLocked': True})
        self.app.load_lock_dict[INVALID_OID] = object()
        self.assertEqual(len(self.app.event_queue), 0)
        self.operation.askObject(conn, oid=INVALID_OID,
            serial=INVALID_TID, tid=INVALID_TID)
        self.assertEqual(len(self.app.event_queue), 1)
        self.checkNoPacketSent(conn)
        self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)

    def test_24_askObject2(self):
        # invalid serial / tid / packet not found
        self.app.dm = Mock({'getObject': None})
        conn = self._getConnection()
        self.assertEqual(len(self.app.event_queue), 0)
        self.operation.askObject(conn, oid=INVALID_OID,
            serial=INVALID_TID, tid=INVALID_TID)
        calls = self.app.dm.mockGetNamedCalls('getObject')
        self.assertEqual(len(self.app.event_queue), 0)
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID)
        self.checkErrorPacket(conn)

    def test_24_askObject3(self):
        # object found => answer
        serial = self.getNextTID()
        next_serial = self.getNextTID()
        oid = self.getOID(1)
        tid = self.getNextTID()
        H = "0" * 20
        self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
        conn = self._getConnection()
        self.assertEqual(len(self.app.event_queue), 0)
        self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
        self.assertEqual(len(self.app.event_queue), 0)
        self.checkAnswerObject(conn)

    def test_25_askTIDs1(self):
        # invalid offsets => error
        app = self.app
        app.pt = Mock()
        app.dm = Mock()
        conn = self._getConnection()
        self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
        self.assertEqual(len(app.pt.mockGetNamedCalls('getCellList')), 0)
        self.assertEqual(len(app.dm.mockGetNamedCalls('getTIDList')), 0)

    def test_25_askTIDs2(self):
        # well case => answer
        conn = self._getConnection()
        self.app.pt = Mock({'getPartitions': 1})
        self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
        self.operation.askTIDs(conn, 1, 2, 1)
        calls = self.app.dm.mockGetNamedCalls('getTIDList')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(1, 1, 1, [1, ])
        self.checkAnswerTids(conn)

    def test_25_askTIDs3(self):
        # invalid partition => answer usable partitions
        conn = self._getConnection()
        cell = Mock({'getUUID':self.app.uuid})
        self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
        self.app.pt = Mock({
            'getCellList': (cell, ),
            'getPartitions': 1,
            'getAssignedPartitionList': [0],
        })
        self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION)
        self.assertEqual(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
        calls = self.app.dm.mockGetNamedCalls('getTIDList')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(1, 1, 1, [0])
        self.checkAnswerTids(conn)

    def test_26_askObjectHistory1(self):
        # invalid offsets => error
        app = self.app
        app.dm = Mock()
        conn = self._getConnection()
        self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
            1, 1, None)
        self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)

    def test_26_askObjectHistory2(self):
        oid1, oid2 = self.getOID(1), self.getOID(2)
        # first case: empty history
        conn = self._getConnection()
        self.app.dm = Mock({'getObjectHistory': None})
        self.operation.askObjectHistory(conn, oid1, 1, 2)
        self.checkErrorPacket(conn)
        # second case: not empty history
        conn = self._getConnection()
        serial = self.getNextTID()
        self.app.dm = Mock({'getObjectHistory': [(serial, 0, ), ]})
        self.operation.askObjectHistory(conn, oid2, 1, 2)
        self.checkAnswerObjectHistory(conn)

    def test_askStoreTransaction(self):
        uuid = self.getNewUUID()
        conn = self._getConnection(uuid=uuid)
        tid = self.getNextTID()
        user = 'USER'
        desc = 'DESC'
        ext = 'EXT'
        oid_list = (self.getOID(1), self.getOID(2))
        self.operation.askStoreTransaction(conn, tid, user, desc, ext, oid_list)
        calls = self.app.tm.mockGetNamedCalls('storeTransaction')
        self.assertEqual(len(calls), 1)
        self.checkAnswerStoreTransaction(conn)

    def _getObject(self):
        oid = self.getOID(0)
        serial = self.getNextTID()
        data = 'DATA'
        return (oid, serial, 1, makeChecksum(data), data)

    def _checkStoreObjectCalled(self, *args):
        calls = self.app.tm.mockGetNamedCalls('storeObject')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(*args)

    def test_askStoreObject1(self):
        # no conflict => answer
        uuid = self.getNewUUID()
        conn = self._getConnection(uuid=uuid)
        tid = self.getNextTID()
        oid, serial, comp, checksum, data = self._getObject()
        self.operation.askStoreObject(conn, oid, serial, comp, checksum,
                data, None, tid, False)
        self._checkStoreObjectCalled(tid, serial, oid, comp,
                checksum, data, None, False)
        pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
            decode=True)
        self.assertEqual(pconflicting, 0)
        self.assertEqual(poid, oid)
        self.assertEqual(pserial, serial)

    def test_askStoreObjectWithDataTID(self):
        # same as test_askStoreObject1, but with a non-None data_tid value
        uuid = self.getNewUUID()
        conn = self._getConnection(uuid=uuid)
        tid = self.getNextTID()
        oid, serial, comp, checksum, data = self._getObject()
        data_tid = self.getNextTID()
        self.operation.askStoreObject(conn, oid, serial, comp, ZERO_HASH,
                '', data_tid, tid, False)
        self._checkStoreObjectCalled(tid, serial, oid, comp,
                None, None, data_tid, False)
        pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
            decode=True)
        self.assertEqual(pconflicting, 0)
        self.assertEqual(poid, oid)
        self.assertEqual(pserial, serial)

    def test_askStoreObject2(self):
        # conflict error
        uuid = self.getNewUUID()
        conn = self._getConnection(uuid=uuid)
        tid = self.getNextTID()
        locking_tid = self.getNextTID(tid)
        def fakeStoreObject(*args):
            raise ConflictError(locking_tid)
        self.app.tm.storeObject = fakeStoreObject
        oid, serial, comp, checksum, data = self._getObject()
        self.operation.askStoreObject(conn, oid, serial, comp, checksum,
                data, None, tid, False)
        pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
            decode=True)
        self.assertEqual(pconflicting, 1)
        self.assertEqual(poid, oid)
        self.assertEqual(pserial, locking_tid)

    def test_abortTransaction(self):
        conn = self._getConnection()
        tid = self.getNextTID()
        self.operation.abortTransaction(conn, tid)
        calls = self.app.tm.mockGetNamedCalls('abort')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(tid)

    def test_askObjectUndoSerial(self):
        uuid = self.getNewUUID()
        conn = self._getConnection(uuid=uuid)
        tid = self.getNextTID()
        ltid = self.getNextTID()
        undone_tid = self.getNextTID()
        # Keep 2 entries here, so we check findUndoTID is called only once.
        oid_list = [self.getOID(1), self.getOID(2)]
        obj2_data = [] # Marker
        self.app.tm = Mock({
            'getObjectFromTransaction': None,
        })
        self.app.dm = Mock({
            'findUndoTID': ReturnValues((None, None, False), )
        })
        self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list)
        self.checkErrorPacket(conn)

    def test_askHasLock(self):
        tid_1 = self.getNextTID()
        tid_2 = self.getNextTID()
        oid = self.getNextTID()
        def getLockingTID(oid):
            return locking_tid
        self.app.tm.getLockingTID = getLockingTID
        for locking_tid, status in (
                    (None, LockState.NOT_LOCKED),
                    (tid_1, LockState.GRANTED),
                    (tid_2, LockState.GRANTED_TO_OTHER),
                ):
            conn = self._getConnection()
            self.operation.askHasLock(conn, tid_1, oid)
            p_oid, p_status = self.checkAnswerPacket(conn,
                Packets.AnswerHasLock, decode=True)
            self.assertEqual(oid, p_oid)
            self.assertEqual(status, p_status)

if __name__ == "__main__":
    unittest.main()