# # Copyright (C) 2009-2019 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import __builtin__ import errno import functools import gc import os import random import signal import socket import subprocess import sys import tempfile import thread import unittest import weakref import transaction from contextlib import closing, contextmanager from ConfigParser import SafeConfigParser from cStringIO import StringIO try: from ZODB._compat import Unpickler except ImportError: from cPickle import Unpickler from functools import wraps from inspect import isclass from itertools import islice from .mock import Mock from neo.lib import debug, event, logging from neo.lib.protocol import NodeTypes, Packet, Packets, UUID_NAMESPACES from neo.lib.util import cached_property from neo.storage.database.manager import DatabaseManager from time import time, sleep from struct import pack, unpack from unittest.case import _ExpectedFailure, _UnexpectedSuccess try: from transaction.interfaces import IDataManager from ZODB.utils import newTid from ZODB.ConflictResolution import PersistentReferenceFactory except ImportError: pass def expectedFailure(exception=AssertionError): def decorator(func): def wrapper(*args, **kw): try: func(*args, **kw) except exception, e: # XXX: passing sys.exc_info() causes deadlocks raise _ExpectedFailure((type(e), None, None)) raise _UnexpectedSuccess return wraps(func)(wrapper) if callable(exception) and not isinstance(exception, type): func = exception exception = Exception return decorator(func) return decorator DB_PREFIX = os.getenv('NEO_DB_PREFIX', 'test_neo') DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root') DB_PASSWD = os.getenv('NEO_DB_PASSWD', '') DB_USER = os.getenv('NEO_DB_USER', 'test') DB_SOCKET = os.getenv('NEO_DB_SOCKET', '') DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db') DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld') DB_MYCNF = os.getenv('NEO_DB_MYCNF') DatabaseManager.TEST_IDENT = thread.get_ident() adapter = os.getenv('NEO_TESTS_ADAPTER') if adapter: from neo.storage.database import getAdapterKlass if getAdapterKlass(adapter).__name__ == 'MySQLDatabaseManager': os.environ['NEO_TESTS_ADAPTER'] = 'MySQL' IP_VERSION_FORMAT_DICT = { socket.AF_INET: '127.0.0.1', socket.AF_INET6: '::1', } ADDRESS_TYPE = socket.AF_INET SSL = os.path.dirname(__file__) + os.sep SSL = SSL + "ca.crt", SSL + "node.crt", SSL + "node.key" logging.default_root_handler.handle = lambda record: None debug.register() def mockDefaultValue(name, function): def method(self, *args, **kw): if name in self.mockReturnValues: return self.__getattr__(name)(*args, **kw) return function(self, *args, **kw) method.__name__ = name setattr(Mock, name, method) mockDefaultValue('__nonzero__', lambda self: self.__len__() != 0) mockDefaultValue('__repr__', lambda self: '<%s object at 0x%x>' % (self.__class__.__name__, id(self))) mockDefaultValue('__str__', repr) def buildUrlFromString(address): try: socket.inet_pton(socket.AF_INET6, address) address = '[%s]' % address except Exception: pass return address def getTempDirectory(): """get the current temp directory or a new one""" try: temp_dir = os.environ['TEMP'] except KeyError: neo_dir = os.path.join(tempfile.gettempdir(), 'neo_tests') while True: temp_name = repr(time()) temp_dir = os.path.join(neo_dir, temp_name) try: os.makedirs(temp_dir) break except OSError, e: if e.errno != errno.EEXIST: raise last = os.path.join(neo_dir, "last") try: os.remove(last) except OSError, e: if e.errno != errno.ENOENT: raise os.symlink(temp_name, last) os.environ['TEMP'] = temp_dir print 'Using temp directory %r.' % temp_dir return temp_dir def setupMySQL(db_list, clear_databases=True): if mysql_pool: return mysql_pool.setup(db_list, clear_databases) from neo.storage.database.mysql import \ Connection, OperationalError, BAD_DB_ERROR user = DB_USER password = '' kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {} # BBB: passwd is deprecated favour of password since 1.3.8 with closing(Connection(user=DB_ADMIN, passwd=DB_PASSWD, **kw)) as conn: for database in db_list: try: conn.select_db(database) if not clear_databases: continue conn.query('DROP DATABASE `%s`' % database) except OperationalError, (code, _): if code != BAD_DB_ERROR: raise conn.query('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED' ' BY "%s"' % (database, user, password)) conn.query('CREATE DATABASE `%s`' % database) return '{}:{}@%s{}'.format(user, password, DB_SOCKET).__mod__ class MySQLPool(object): def __init__(self, pool_dir=None): self._args = {} self._mysqld_dict = {} if not pool_dir: pool_dir = getTempDirectory() self._base = pool_dir + os.sep self._sock_template = os.path.join(pool_dir, '%s', 'mysql.sock') def __del__(self): self.kill(*self._mysqld_dict) def setup(self, db_list, clear_databases): from neo.storage.database.mysql import Connection start_list = set(db_list).difference(self._mysqld_dict) if start_list: start_list = sorted(start_list) x = [] with open(os.devnull, 'wb') as f: for db in start_list: base = self._base + db datadir = os.path.join(base, 'datadir') sock = self._sock_template % db tmpdir = os.path.join(base, 'tmp') args = [DB_INSTALL, '--defaults-file=' + DB_MYCNF, '--datadir=' + datadir, '--socket=' + sock, '--tmpdir=' + tmpdir, '--log_error=' + os.path.join(base, 'error.log')] if os.path.exists(datadir): try: os.remove(sock) except OSError, e: if e.errno != errno.ENOENT: raise else: os.makedirs(tmpdir) x.append(subprocess.Popen(args, stdout=f, stderr=subprocess.STDOUT)) args[0] = DB_MYSQLD self._args[db] = args for x in x: x = x.wait() if x: raise subprocess.CalledProcessError(x, DB_INSTALL) self.start(*start_list) for db in start_list: sock = self._sock_template % db p = self._mysqld_dict[db] while not os.path.exists(sock): sleep(1) x = p.poll() if x is not None: raise subprocess.CalledProcessError(x, DB_MYSQLD) for db in db_list: with closing(Connection(unix_socket=self._sock_template % db, user='root')) as db: if clear_databases: db.query('DROP DATABASE IF EXISTS neo') db.query('CREATE DATABASE IF NOT EXISTS neo') return ('root@neo' + self._sock_template).__mod__ def start(self, *db, **kw): assert set(db).isdisjoint(self._mysqld_dict) for db in db: self._mysqld_dict[db] = subprocess.Popen(self._args[db], **kw) def kill(self, *db): processes = [] for db in db: p = self._mysqld_dict.pop(db) processes.append(p) p.kill() for p in processes: p.wait() mysql_pool = MySQLPool() if DB_MYCNF else None def ImporterConfigParser(adapter, zodb, **kw): cfg = SafeConfigParser() cfg.add_section("neo") cfg.set("neo", "adapter", adapter) for x in kw.iteritems(): cfg.set("neo", *x) for name, zodb in zodb: cfg.add_section(name) for x in zodb.iteritems(): cfg.set(name, *x) return cfg class NeoTestBase(unittest.TestCase): maxDiff = None def setUp(self): logging.name = self.setupLog() unittest.TestCase.setUp(self) def setupLog(self): test_case, logging.name = self.id().rsplit('.', 1) logging.setup(os.path.join(getTempDirectory(), test_case + '.log')) def tearDown(self): assert self.tearDown.im_func is NeoTestBase.tearDown.im_func self._tearDown(sys._getframe(1).f_locals['success']) assert not gc.garbage, gc.garbage # XXX: I tried the following line to avoid random freezes on PyPy... gc.collect() def _tearDown(self, success): # Kill all unfinished transactions for next test. # Note we don't even abort them because it may require a valid # connection to a master node (see Storage.sync()). transaction.manager.__init__() if logging._max_size is not None: logging.flush() class failureException(AssertionError): def __init__(self, msg=None): logging.error(msg) AssertionError.__init__(self, msg) failIfEqual = failUnlessEqual = assertEquals = assertNotEquals = None def assertNotEqual(self, first, second, msg=None): assert not (isinstance(first, Mock) or isinstance(second, Mock)), \ "Mock objects can't be compared with '==' or '!='" return super(NeoTestBase, self).assertNotEqual(first, second, msg=msg) def assertEqual(self, first, second, msg=None): assert not (isinstance(first, Mock) or isinstance(second, Mock)), \ "Mock objects can't be compared with '==' or '!='" return super(NeoTestBase, self).assertEqual(first, second, msg=msg) def assertPartitionTable(self, pt, expected, key=None): self.assertEqual( expected if isinstance(expected, str) else '|'.join(expected), '|'.join(pt._formatRows(sorted(pt.count_dict, key=key)))) @contextmanager def expectedFailure(self, exception=AssertionError, regex=None): with self.assertRaisesRegexp(exception, regex) as cm: yield raise _UnexpectedSuccess # XXX: passing sys.exc_info() causes deadlocks raise _ExpectedFailure((type(cm.exception), None, None)) class NeoUnitTestBase(NeoTestBase): """ Base class for neo tests, implements common checks """ local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE] def setUp(self): self.uuid_dict = {} NeoTestBase.setUp(self) @cached_property def nm(self): from neo.lib import node return node.NodeManager() def createStorage(self, *args): return self.nm.createStorage(**dict(zip( ('address', 'uuid', 'state'), args))) def prepareDatabase(self, number, prefix=DB_PREFIX): """ create empty databases """ adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') if adapter == 'MySQL': db_template = setupMySQL( [prefix + str(i) for i in xrange(number)]) self.db_template = lambda i: db_template(prefix + str(i)) elif adapter == 'SQLite': self.db_template = os.path.join(getTempDirectory(), prefix + '%s.sqlite').__mod__ for i in xrange(number): try: os.remove(self.db_template(i)) except OSError, e: if e.errno != errno.ENOENT: raise else: assert False, adapter def getMasterConfiguration(self, cluster='main', master_number=2, replicas=2, partitions=1009, uuid=None): assert master_number >= 1 and master_number <= 10 masters = ([(self.local_ip, 10010 + i) for i in xrange(master_number)]) return { 'cluster': cluster, 'bind': masters[0], 'masters': masters, 'replicas': replicas, 'partitions': partitions, 'uuid': uuid, } def getStorageConfiguration(self, cluster='main', master_number=2, index=0, prefix=DB_PREFIX, uuid=None): assert master_number >= 1 and master_number <= 10 masters = [(buildUrlFromString(self.local_ip), 10010 + i) for i in xrange(master_number)] adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') return { 'cluster': cluster, 'bind': (masters[0], 10020 + index), 'masters': masters, 'database': self.db_template(index), 'uuid': uuid, 'adapter': adapter, 'wait': 0, } def getNewUUID(self, node_type): """ Retuns a 16-bytes UUID according to namespace 'prefix' """ if node_type is None: node_type = random.choice(NodeTypes) self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0) return uuid + (UUID_NAMESPACES[node_type] << 24) def getClientUUID(self): return self.getNewUUID(NodeTypes.CLIENT) def getMasterUUID(self): return self.getNewUUID(NodeTypes.MASTER) def getStorageUUID(self): return self.getNewUUID(NodeTypes.STORAGE) def getAdminUUID(self): return self.getNewUUID(NodeTypes.ADMIN) def getNextTID(self, ltid=None): return newTid(ltid) def getFakeConnector(self, descriptor=None): return Mock({ '__repr__': 'FakeConnector', 'getDescriptor': descriptor, 'getAddress': ('', 0), }) def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000), is_server=False, connector=None, peer_id=None): if connector is None: connector = self.getFakeConnector() conn = Mock({ 'getUUID': uuid, 'getAddress': address, 'isServer': is_server, '__repr__': 'FakeConnection', '__nonzero__': 0, 'getConnector': connector, 'getPeerId': peer_id, }) conn.mockAddReturnValues(__hash__ = id(conn)) conn.connecting = False return conn def checkAborted(self, conn): """ Ensure the connection was aborted """ self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1) def checkClosed(self, conn): """ Ensure the connection was closed """ self.assertEqual(len(conn.mockGetNamedCalls('close')), 1) def _checkNoPacketSend(self, conn, method_id): self.assertEqual([], conn.mockGetNamedCalls(method_id)) def checkNoPacketSent(self, conn): """ check if no packet were sent """ self._checkNoPacketSend(conn, 'send') self._checkNoPacketSend(conn, 'answer') self._checkNoPacketSend(conn, 'ask') # in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in tests if more accurate checks are required def checkErrorPacket(self, conn): """ Check if an error packet was answered """ calls = conn.mockGetNamedCalls("answer") self.assertEqual(len(calls), 1) packet = calls.pop().getParam(0) self.assertTrue(isinstance(packet, Packet)) self.assertEqual(type(packet), Packets.Error) return packet def checkAskPacket(self, conn, packet_type): """ Check if an ask-packet with the right type is sent """ calls = conn.mockGetNamedCalls('ask') self.assertEqual(len(calls), 1) packet = calls.pop().getParam(0) self.assertTrue(isinstance(packet, Packet)) self.assertEqual(type(packet), packet_type) return packet def checkAnswerPacket(self, conn, packet_type): """ Check if an answer-packet with the right type is sent """ calls = conn.mockGetNamedCalls('answer') self.assertEqual(len(calls), 1) packet = calls.pop().getParam(0) self.assertTrue(isinstance(packet, Packet)) self.assertEqual(type(packet), packet_type) return packet def checkNotifyPacket(self, conn, packet_type, packet_number=0): """ Check if a notify-packet with the right type is sent """ calls = conn.mockGetNamedCalls('send') packet = calls.pop(packet_number).getParam(0) self.assertTrue(isinstance(packet, Packet)) self.assertEqual(type(packet), packet_type) return packet class TransactionalResource(object): class _sortKey(object): def __init__(self, last): self._last = last def __cmp__(self, other): assert type(self) is not type(other), other return 1 if self._last else -1 def __init__(self, txn, last, **kw): self.sortKey = lambda: self._sortKey(last) for k in kw: assert callable(IDataManager.get(k)), k self.__dict__.update(kw) txn.get().join(self) def __call__(self, func): name = func.__name__ assert callable(IDataManager.get(name)), name setattr(self, name, func) return func def __getattr__(self, attr): if callable(IDataManager.get(attr)): return lambda *_: None return self.__getattribute__(attr) try: from ZODB.Connection import TransactionMetaData except ImportError: # BBB: ZODB < 5 def getTransactionMetaData(txn, conn): return txn else: def getTransactionMetaData(txn, conn): return txn.data(conn) class Patch(object): """ Patch attributes and revert later automatically. Usage: with Patch(someObject, [new,] attrToPatch=newValue) as patch: [... code that runs with patches ...] [... code that runs without patch ...] The 'new' positional parameter defaults to False and it must be equal to not hasattr(someObject, 'attrToPatch') It is an assertion to detect when a Patch is obsolete. ' as patch' is optional: 'patch.revert()' can be used to revert patches in the middle of the 'with' clause. Or: patch = Patch(...) patch.apply() In this case, patches are automatically reverted when 'patch' is deleted. For patched callables, the new one receives the original value as first argument if 'new' is True. Alternative usage: @Patch(someObject) def funcToPatch(orig, ...): ... ... funcToPatch.revert() The decorator applies the patch immediately. """ applied = False def __new__(cls, patched, *args, **patch): if patch: return object.__new__(cls) def patch(func): self = cls(patched, *args, **{func.__name__: func}) self.apply() return self return patch def __init__(self, patched, *args, **patch): new, = args or (0,) (name, patch), = patch.iteritems() self._patched = patched self._name = name try: wrapped = getattr(patched, name) except AttributeError: assert new, (patched, name) else: assert not new, (patched, name) if callable(patch): func = patch patch = lambda *args, **kw: func(wrapped, *args, **kw) if callable(wrapped): patch = wraps(wrapped)(patch) self._patch = patch try: orig = patched.__dict__[name] except KeyError: if new or isclass(patched): self._revert = lambda: delattr(patched, name) return orig = getattr(patched, name) self._revert = lambda: setattr(patched, name, orig) def apply(self): assert not self.applied setattr(self._patched, self._name, self._patch) self.applied = True def revert(self): del self.applied self._revert() def __del__(self): if self.applied: self.revert() def __enter__(self): self.apply() return weakref.proxy(self) def __exit__(self, t, v, tb): self.__del__() def consume(iterator, n): """Advance the iterator n-steps ahead and returns the last consumed item""" return next(islice(iterator, n-1, n)) def unpickle_state(data): unpickler = Unpickler(StringIO(data)) unpickler.persistent_load = PersistentReferenceFactory().persistent_load unpickler.load() # skip the class tuple return unpickler.load() __builtin__.pdb = lambda depth=0: \ debug.getPdb().set_trace(sys._getframe(depth+1))