#
# Copyright (C) 2009  Nexedi SA
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import os
import unittest
from neo import logging
from struct import pack, unpack
from mock import Mock
from collections import deque
from neo.tests import NeoTestBase
from neo.storage.app import Application
from neo.storage.handlers.client import TransactionInformation
from neo.storage.handlers.client import ClientOperationHandler
from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable
from neo import protocol
from neo.protocol import *

class StorageClientHandlerTests(NeoTestBase):

    def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
        conn = Mock({ 
            "getAddress" : ("127.0.0.1", self.master_port), 
            "isServer": _listening,    
        })
        packet = Packet(msg_type=_msg_type)
        # hook
        self.operation.peerBroken = lambda c: c.peerBrokendCalled()
        self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs)

    def setUp(self):
        self.prepareDatabase(number=1)
        # create an application object
        config = self.getStorageConfiguration(master_number=1)
        self.app = Application(**config)
        self.app.transaction_dict = {}
        self.app.store_lock_dict = {}
        self.app.load_lock_dict = {}
        self.app.event_queue = deque()
        for address in self.app.master_node_list:
            self.app.nm.createMaster(address=address)
        # handler
        self.operation = ClientOperationHandler(self.app)
        # set pmn
        self.master_uuid = self.getNewUUID()
        pmn = self.app.nm.getMasterList()[0]
        pmn.setUUID(self.master_uuid)
        self.app.primary_master_node = pmn
        self.master_port = 10010

    def tearDown(self):
        NeoTestBase.tearDown(self)

    def test_01_TransactionInformation(self):
        uuid = self.getNewUUID()
        transaction = TransactionInformation(uuid)
        # uuid
        self.assertEquals(transaction._uuid, uuid)
        self.assertEquals(transaction.getUUID(), uuid)
        # objects
        self.assertEquals(transaction._object_dict, {})
        object = (self.getNewUUID(), 1, 2, 3, )
        transaction.addObject(*object)
        objects = transaction.getObjectList()
        self.assertEquals(len(objects), 1)
        self.assertEquals(objects[0], object)
        # transactions
        self.assertEquals(transaction._transaction, None)
        t = ((1, 2, 3), 'user', 'desc', '')
        transaction.addTransaction(*t)
        self.assertEquals(transaction.getTransaction(), t)

    def test_05_dealWithClientFailure(self):
        # check if client's transaction are cleaned
        uuid = self.getNewUUID()
        client = self.app.nm.createClient(
            uuid=uuid,
            address=('127.0.0.1', 10010)
        )
        self.app.store_lock_dict[0] = object()
        transaction = Mock({
            'getUUID': uuid,
            'getObjectList': ((0, ), ),
        })
        self.app.transaction_dict[0] = transaction
        self.assertTrue(1 not in self.app.store_lock_dict)
        self.assertTrue(1 not in self.app.transaction_dict)
        self.operation.dealWithClientFailure(uuid)
        # objects and transaction removed
        self.assertTrue(0 not in self.app.store_lock_dict)
        self.assertTrue(0 not in self.app.transaction_dict)

    def test_18_handleAskTransactionInformation1(self):
        # transaction does not exists
        conn = Mock({ })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.operation.handleAskTransactionInformation(conn, packet, INVALID_TID)
        self.checkErrorPacket(conn)

    def test_18_handleAskTransactionInformation2(self):
        # answer
        conn = Mock({ })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), })
        self.app.dm = dm
        self.operation.handleAskTransactionInformation(conn, packet, INVALID_TID)
        self.checkAnswerTransactionInformation(conn)

    def test_24_handleAskObject1(self):
        # delayed response
        conn = Mock({})
        self.app.dm = Mock()
        packet = Packet(msg_type=ASK_OBJECT)
        self.app.load_lock_dict[INVALID_OID] = object()
        self.assertEquals(len(self.app.event_queue), 0)
        self.operation.handleAskObject(conn, packet, 
            oid=INVALID_OID, 
            serial=INVALID_SERIAL, 
            tid=INVALID_TID)
        self.assertEquals(len(self.app.event_queue), 1)
        self.checkNoPacketSent(conn)
        self.assertEquals(len(self.app.dm.mockGetNamedCalls('getObject')), 0)

    def test_24_handleAskObject2(self):
        # invalid serial / tid / packet not found
        self.app.dm = Mock({'getObject': None})
        conn = Mock({})
        packet = Packet(msg_type=ASK_OBJECT)
        self.assertEquals(len(self.app.event_queue), 0)
        self.operation.handleAskObject(conn, packet, 
            oid=INVALID_OID, 
            serial=INVALID_SERIAL, 
            tid=INVALID_TID)
        calls = self.app.dm.mockGetNamedCalls('getObject')
        self.assertEquals(len(self.app.event_queue), 0)
        self.assertEquals(len(calls), 1)
        calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID)
        self.checkErrorPacket(conn)

    def test_24_handleAskObject3(self):
        # object found => answer
        self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
        conn = Mock({})
        packet = Packet(msg_type=ASK_OBJECT)
        self.assertEquals(len(self.app.event_queue), 0)
        self.operation.handleAskObject(conn, packet, 
            oid=INVALID_OID, 
            serial=INVALID_SERIAL, 
            tid=INVALID_TID)
        self.assertEquals(len(self.app.event_queue), 0)
        self.checkAnswerObject(conn)

    def test_25_handleAskTIDs1(self):
        # invalid offsets => error
        app = self.app
        app.pt = Mock()
        app.dm = Mock()
        conn = Mock({})
        packet = Packet(msg_type=ASK_TIDS)
        self.checkProtocolErrorRaised(self.operation.handleAskTIDs, conn, packet, 1, 1, None)
        self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
        self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0)

    def test_25_handleAskTIDs2(self):
        # well case => answer
        conn = Mock({})
        packet = Packet(msg_type=ASK_TIDS)
        self.app.pt = Mock({'getPartitions': 1})
        self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
        self.operation.handleAskTIDs(conn, packet, 1, 2, 1)
        calls = self.app.dm.mockGetNamedCalls('getTIDList')
        self.assertEquals(len(calls), 1)
        calls[0].checkArgs(1, 1, 1, [1, ])
        self.checkAnswerTids(conn)

    def test_25_handleAskTIDs3(self):
        # invalid partition => answer usable partitions
        conn = Mock({})
        packet = Packet(msg_type=ASK_TIDS)
        cell = Mock({'getUUID':self.app.uuid})
        self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
        self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1})
        self.operation.handleAskTIDs(conn, packet, 1, 2, INVALID_PARTITION)
        self.assertEquals(len(self.app.pt.mockGetNamedCalls('getCellList')), 1)
        calls = self.app.dm.mockGetNamedCalls('getTIDList')
        self.assertEquals(len(calls), 1)
        calls[0].checkArgs(1, 1, 1, [0, ])
        self.checkAnswerTids(conn)

    def test_26_handleAskObjectHistory1(self):
        # invalid offsets => error
        app = self.app
        app.dm = Mock()
        conn = Mock({})
        packet = Packet(msg_type=ASK_OBJECT_HISTORY)
        self.checkProtocolErrorRaised(self.operation.handleAskObjectHistory, conn, packet, 1, 1, None)
        self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)

    def test_26_handleAskObjectHistory2(self):
        # first case: empty history
        packet = Packet(msg_type=ASK_OBJECT_HISTORY)
        conn = Mock({})
        self.app.dm = Mock({'getObjectHistory': None})
        self.operation.handleAskObjectHistory(conn, packet, INVALID_OID, 1, 2)
        self.checkAnswerObjectHistory(conn)
        # second case: not empty history
        conn = Mock({})
        self.app.dm = Mock({'getObjectHistory': [('', 0, ), ]})
        self.operation.handleAskObjectHistory(conn, packet, INVALID_OID, 1, 2)
        self.checkAnswerObjectHistory(conn)

    def test_27_handleAskStoreTransaction2(self):
        # add transaction entry
        packet = Packet(msg_type=ASK_STORE_TRANSACTION)
        conn = Mock({'getUUID': self.getNewUUID()})
        self.operation.handleAskStoreTransaction(conn, packet,
            INVALID_TID, '', '', '', ())
        t = self.app.transaction_dict.get(INVALID_TID, None)
        self.assertNotEquals(t, None)
        self.assertTrue(isinstance(t, TransactionInformation))
        self.assertEquals(t.getTransaction(), ((), '', '', ''))
        self.checkAnswerStoreTransaction(conn)

    def test_28_handleAskStoreObject2(self):
        # locked => delayed response
        packet = Packet(msg_type=ASK_STORE_OBJECT)
        conn = Mock({'getUUID': self.app.uuid})
        oid = '\x02' * 8
        tid1, tid2 = self.getTwoIDs()
        self.app.store_lock_dict[oid] = tid1
        self.assertTrue(oid in self.app.store_lock_dict)
        t_before = self.app.transaction_dict.items()[:]
        self.operation.handleAskStoreObject(conn, packet, oid, 
            INVALID_SERIAL, 0, 0, '', tid2)
        self.assertEquals(len(self.app.event_queue), 1)
        t_after = self.app.transaction_dict.items()[:]
        self.assertEquals(t_before, t_after)
        self.checkNoPacketSent(conn)
        self.assertTrue(oid in self.app.store_lock_dict)

    def test_28_handleAskStoreObject3(self):
        # locked => unresolvable conflict => answer
        packet = Packet(msg_type=ASK_STORE_OBJECT)
        conn = Mock({'getUUID': self.app.uuid})
        tid1, tid2 = self.getTwoIDs()
        self.app.store_lock_dict[INVALID_OID] = tid2
        self.operation.handleAskStoreObject(conn, packet, INVALID_OID, 
            INVALID_SERIAL, 0, 0, '', tid1)
        self.checkAnswerStoreObject(conn)
        self.assertEquals(self.app.store_lock_dict[INVALID_OID], tid2)
        # conflicting
        packet = conn.mockGetNamedCalls('answer')[0].getParam(0)
        self.assertTrue(unpack('!B8s8s', packet._body)[0])
    
    def test_28_handleAskStoreObject4(self):
        # resolvable conflict => answer
        packet = Packet(msg_type=ASK_STORE_OBJECT)
        conn = Mock({'getUUID': self.app.uuid})
        self.app.dm = Mock({'getObjectHistory':((self.getNewUUID(), ), )})
        self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None)
        self.operation.handleAskStoreObject(conn, packet, INVALID_OID, 
            INVALID_SERIAL, 0, 0, '', INVALID_TID)
        self.checkAnswerStoreObject(conn)
        self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None)
        # conflicting
        packet = conn.mockGetNamedCalls('answer')[0].getParam(0)
        self.assertTrue(unpack('!B8s8s', packet._body)[0])
        
    def test_28_handleAskStoreObject5(self):
        # no conflict => answer
        packet = Packet(msg_type=ASK_STORE_OBJECT)
        conn = Mock({'getUUID': self.app.uuid})
        self.operation.handleAskStoreObject(conn, packet, INVALID_OID, 
            INVALID_SERIAL, 0, 0, '', INVALID_TID)
        t = self.app.transaction_dict.get(INVALID_TID, None)
        self.assertNotEquals(t, None)
        self.assertEquals(len(t.getObjectList()), 1)
        object = t.getObjectList()[0]
        self.assertEquals(object, (INVALID_OID, 0, 0, ''))
        # no conflict
        packet = self.checkAnswerStoreObject(conn)
        self.assertFalse(unpack('!B8s8s', packet._body)[0])

    def test_29_handleAbortTransaction(self):
        # remove transaction
        packet = Packet(msg_type=ABORT_TRANSACTION)
        conn = Mock({'getUUID': self.app.uuid})
        transaction = Mock({ 'getObjectList': ((0, ), ), })
        self.called = False
        def called():
            self.called = True
        self.app.executeQueuedEvents = called
        self.app.load_lock_dict[0] = object()
        self.app.store_lock_dict[0] = object()
        self.app.transaction_dict[INVALID_TID] = transaction
        self.operation.handleAbortTransaction(conn, packet, INVALID_TID)
        self.assertTrue(self.called)
        self.assertEquals(len(self.app.load_lock_dict), 0)
        self.assertEquals(len(self.app.store_lock_dict), 0)
        self.assertEquals(len(self.app.store_lock_dict), 0)

if __name__ == "__main__":
    unittest.main()