Commit 7fac1696 authored by Julien Muchembled's avatar Julien Muchembled

tests: some cleanup in threaded.__init__

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2787 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent e0aa8ef3
...@@ -20,7 +20,6 @@ import os, random, socket, sys, tempfile, threading, time, types ...@@ -20,7 +20,6 @@ import os, random, socket, sys, tempfile, threading, time, types
from collections import deque from collections import deque
from functools import wraps from functools import wraps
from Queue import Queue, Empty from Queue import Queue, Empty
from weakref import ref as weak_ref
from mock import Mock from mock import Mock
import transaction, ZODB import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app import neo.admin.app, neo.master.app, neo.storage.app
...@@ -48,66 +47,68 @@ def getVirtualIp(server_type): ...@@ -48,66 +47,68 @@ def getVirtualIp(server_type):
class Serialized(object): class Serialized(object):
_global_lock = threading.Lock() @classmethod
_global_lock.acquire() def init(cls):
cls._global_lock = threading.Lock()
cls._global_lock.acquire()
# TODO: use something else than Queue, for inspection or editing # TODO: use something else than Queue, for inspection or editing
# (e.g. we'd like to suspend nodes temporarily) # (e.g. we'd like to suspend nodes temporarily)
_lock_list = Queue() cls._lock_list = Queue()
_pdb = False cls._pdb = False
pending = 0 cls.pending = 0
@staticmethod @classmethod
def release(lock=None, wake_other=True, stop=False): def release(cls, lock=None, wake_other=True, stop=False):
"""Suspend lock owner and resume first suspended thread""" """Suspend lock owner and resume first suspended thread"""
if lock is None: if lock is None:
lock = Serialized._global_lock lock = cls._global_lock
if stop: # XXX: we should fix ClusterStates.STOPPING if stop: # XXX: we should fix ClusterStates.STOPPING
Serialized.pending = None cls.pending = None
else: else:
Serialized.pending = 0 cls.pending = 0
try: try:
sys._getframe(1).f_trace.im_self.set_continue() sys._getframe(1).f_trace.im_self.set_continue()
Serialized._pdb = True cls._pdb = True
except AttributeError: except AttributeError:
pass pass
q = Serialized._lock_list q = cls._lock_list
q.put(lock) q.put(lock)
if wake_other: if wake_other:
q.get().release() q.get().release()
@staticmethod @classmethod
def acquire(lock=None): def acquire(cls, lock=None):
"""Suspend all threads except lock owner""" """Suspend all threads except lock owner"""
if lock is None: if lock is None:
lock = Serialized._global_lock lock = cls._global_lock
lock.acquire() lock.acquire()
if Serialized.pending is None: # XXX if cls.pending is None: # XXX
if lock is Serialized._global_lock: if lock is cls._global_lock:
Serialized.pending = 0 cls.pending = 0
else: else:
sys.exit() sys.exit()
if Serialized._pdb: if cls._pdb:
Serialized._pdb = False cls._pdb = False
try: try:
sys.stdout.write(threading.currentThread().node_name) sys.stdout.write(threading.currentThread().node_name)
except AttributeError: except AttributeError:
pass pass
pdb(1) pdb(1)
@staticmethod @classmethod
def tic(lock=None): def tic(cls, lock=None):
# switch to another thread # switch to another thread
# (the following calls are not supposed to be debugged into) # (the following calls are not supposed to be debugged into)
Serialized.release(lock); Serialized.acquire(lock) cls.release(lock); cls.acquire(lock)
@staticmethod @classmethod
def background(): def background(cls):
try: try:
Serialized._lock_list.get(0).release() cls._lock_list.get(0).release()
except Empty: except Empty:
pass pass
class SerializedEventManager(Serialized, EventManager): class SerializedEventManager(EventManager):
_lock = None _lock = None
_timeout = 0 _timeout = 0
...@@ -147,7 +148,7 @@ class SerializedEventManager(Serialized, EventManager): ...@@ -147,7 +148,7 @@ class SerializedEventManager(Serialized, EventManager):
# before the first message is sent. # before the first message is sent.
# TODO: Detect where a message is sent to jump immediately to nodes # TODO: Detect where a message is sent to jump immediately to nodes
# that will do something. # that will do something.
self.tic(self._lock) Serialized.tic(self._lock)
if timeout != 0: if timeout != 0:
timeout = self._timeout timeout = self._timeout
if timeout != 0 and Serialized.pending: if timeout != 0 and Serialized.pending:
...@@ -294,15 +295,13 @@ class NEOCluster(object): ...@@ -294,15 +295,13 @@ class NEOCluster(object):
SocketConnector_send = staticmethod(SocketConnector.send) SocketConnector_send = staticmethod(SocketConnector.send)
Storage__init__ = staticmethod(Storage.__init__) Storage__init__ = staticmethod(Storage.__init__)
_cluster = None _patched = threading.Lock()
@classmethod def _patch(cluster):
def patch(cls): cls = cluster.__class__
if not cls._patched.acquire(0):
raise RuntimeError("Can't run several cluster at the same time")
def makeClientConnection(self, addr): def makeClientConnection(self, addr):
# XXX: 'threading.currentThread()._cluster'
# does not work for client. We could monkey-patch
# ClientConnection instead of using a global variable.
cluster = cls._cluster()
try: try:
real_addr = cluster.resolv(addr) real_addr = cluster.resolv(addr)
return cls.SocketConnector_makeClientConnection(self, real_addr) return cls.SocketConnector_makeClientConnection(self, real_addr)
...@@ -314,11 +313,6 @@ class NEOCluster(object): ...@@ -314,11 +313,6 @@ class NEOCluster(object):
return result return result
# TODO: 'sleep' should 'tic' in a smart way, so that storages can be # TODO: 'sleep' should 'tic' in a smart way, so that storages can be
# safely started even if the cluster isn't. # safely started even if the cluster isn't.
def sleep(seconds):
l = threading.currentThread().em._lock
while Serialized.pending:
Serialized.tic(l)
Serialized.tic(l)
bootstrap.sleep = lambda seconds: None bootstrap.sleep = lambda seconds: None
BaseConnection.checkTimeout = lambda self, t: None BaseConnection.checkTimeout = lambda self, t: None
SocketConnector.makeClientConnection = makeClientConnection SocketConnector.makeClientConnection = makeClientConnection
...@@ -328,7 +322,7 @@ class NEOCluster(object): ...@@ -328,7 +322,7 @@ class NEOCluster(object):
Storage.setupLog = lambda *args, **kw: None Storage.setupLog = lambda *args, **kw: None
@classmethod @classmethod
def unpatch(cls): def _unpatch(cls):
bootstrap.sleep = time.sleep bootstrap.sleep = time.sleep
BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout
SocketConnector.makeClientConnection = \ SocketConnector.makeClientConnection = \
...@@ -337,6 +331,7 @@ class NEOCluster(object): ...@@ -337,6 +331,7 @@ class NEOCluster(object):
cls.SocketConnector_makeListeningConnection cls.SocketConnector_makeListeningConnection
SocketConnector.send = cls.SocketConnector_send SocketConnector.send = cls.SocketConnector_send
Storage.setupLog = setupLog Storage.setupLog = setupLog
cls._patched.release()
def __init__(self, master_count=1, partitions=1, replicas=0, def __init__(self, master_count=1, partitions=1, replicas=0,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'),
...@@ -405,7 +400,8 @@ class NEOCluster(object): ...@@ -405,7 +400,8 @@ class NEOCluster(object):
self.neoctl = NeoCTL(self) self.neoctl = NeoCTL(self)
def start(self, storage_list=None, fast_startup=True): def start(self, storage_list=None, fast_startup=True):
self.__class__._cluster = weak_ref(self) self._patch()
Serialized.init()
for node_type in 'master', 'admin': for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'): for node in getattr(self, node_type + '_list'):
node.start() node.start()
...@@ -448,7 +444,7 @@ class NEOCluster(object): ...@@ -448,7 +444,7 @@ class NEOCluster(object):
node.join() node.join()
finally: finally:
Serialized.acquire() Serialized.acquire()
self.__class__._cluster = None self._unpatch()
def tic(self, force=False): def tic(self, force=False):
if force: if force:
...@@ -483,11 +479,3 @@ class NEOThreadedTest(NeoUnitTestBase): ...@@ -483,11 +479,3 @@ class NEOThreadedTest(NeoUnitTestBase):
def setupLog(self): def setupLog(self):
log_file = os.path.join(getTempDirectory(), self.id() + '.log') log_file = os.path.join(getTempDirectory(), self.id() + '.log')
setupLog(LoggerThreadName(), log_file, True) setupLog(LoggerThreadName(), log_file, True)
def setUp(self):
NeoUnitTestBase.setUp(self)
NEOCluster.patch()
def tearDown(self):
NEOCluster.unpatch()
NeoUnitTestBase.tearDown(self)
...@@ -45,16 +45,8 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -45,16 +45,8 @@ class MatrixImportBenchmark(BenchmarkRunner):
if storages[-1] < max_s: if storages[-1] < max_s:
storages.append(max_s) storages.append(max_s)
replicas = range(min_r, max_r + 1) replicas = range(min_r, max_r + 1)
if self._config.threaded:
from neo.tests.threaded import NEOCluster
NEOCluster.patch() # XXX ugly
try:
result_list = [self.runMatrix(storages, replicas) result_list = [self.runMatrix(storages, replicas)
for x in xrange(self._config.repeat)] for x in xrange(self._config.repeat)]
finally:
if self._config.threaded:
from neo.tests.threaded import NEOCluster
NEOCluster.unpatch()# XXX ugly
results = {} results = {}
for s in storages: for s in storages:
results[s] = z = {} results[s] = z = {}
...@@ -84,7 +76,7 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -84,7 +76,7 @@ class MatrixImportBenchmark(BenchmarkRunner):
datafs = 'PROD1' datafs = 'PROD1'
import random, neo.tests.stat_zodb import random, neo.tests.stat_zodb
dfs_storage = getattr(neo.tests.stat_zodb, datafs)( dfs_storage = getattr(neo.tests.stat_zodb, datafs)(
random.Random(0)).as_storage(10000) random.Random(0)).as_storage(100)
print "Import of %s with m=%s, s=%s, r=%s, p=%s" % ( print "Import of %s with m=%s, s=%s, r=%s, p=%s" % (
datafs, masters, storages, replicas, partitions) datafs, masters, storages, replicas, partitions)
# cluster # cluster
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment