Commit a5eebbae authored by Jens Vagelpohl's avatar Jens Vagelpohl

- full linting with flake8

parent 0e19e22b
......@@ -2,7 +2,7 @@
# https://github.com/zopefoundation/meta/tree/master/config/pure-python
[meta]
template = "pure-python"
commit-id = "3b712f305ca8207e971c5bf81f2bdb5872489f2f"
commit-id = "0c07a1cfd78d28a07aebd23383ed16959f166574"
[python]
with-windows = false
......@@ -13,7 +13,7 @@ with-docs = true
with-sphinx-doctests = false
[tox]
use-flake8 = false
use-flake8 = true
testenv-commands = [
"# Run unit tests first.",
"zope-testrunner -u --test-path=src {posargs:-vc}",
......
......@@ -4,6 +4,8 @@ Changelog
5.4.0 (unreleased)
------------------
- linted the code with flake8
- Add support for Python 3.10.
- Add ``ConflictError`` to the list of unlogged server exceptions
......
......@@ -11,11 +11,12 @@
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
version = '5.3.1.dev0'
from setuptools import setup, find_packages
import os
version = '5.3.1.dev0'
install_requires = [
'ZODB >= 5.1.1',
'six',
......@@ -64,12 +65,14 @@ Operating System :: Unix
Framework :: ZODB
""".strip().split('\n')
def _modname(path, base, name=''):
if path == base:
return name
dirname, basename = os.path.split(path)
return _modname(dirname, base, basename + '.' + name)
def _flatten(suite, predicate=lambda *x: True):
from unittest import TestCase
for suite_or_case in suite:
......@@ -80,18 +83,20 @@ def _flatten(suite, predicate=lambda *x: True):
for x in _flatten(suite_or_case):
yield x
def _no_layer(suite_or_case):
return getattr(suite_or_case, 'layer', None) is None
def _unittests_only(suite, mod_suite):
for case in _flatten(mod_suite, _no_layer):
suite.addTest(case)
def alltests():
import logging
import pkg_resources
import unittest
import ZEO.ClientStorage
class NullHandler(logging.Handler):
level = 50
......@@ -107,7 +112,8 @@ def alltests():
for dirpath, dirnames, filenames in os.walk(base):
if os.path.basename(dirpath) == 'tests':
for filename in filenames:
if filename != 'testZEO.py': continue
if filename != 'testZEO.py':
continue
if filename.endswith('.py') and filename.startswith('test'):
mod = __import__(
_modname(dirpath, base, os.path.splitext(filename)[0]),
......@@ -115,11 +121,13 @@ def alltests():
_unittests_only(suite, mod.test_suite())
return suite
long_description = (
open('README.rst').read()
+ '\n' +
open('CHANGES.rst').read()
)
setup(name="ZEO",
version=version,
description=long_description.split('\n', 2)[1],
......@@ -133,7 +141,7 @@ setup(name="ZEO",
license="ZPL 2.1",
platforms=["any"],
classifiers=classifiers,
test_suite="__main__.alltests", # to support "setup.py test"
test_suite="__main__.alltests", # to support "setup.py test"
tests_require=tests_require,
extras_require={
'test': tests_require,
......@@ -164,4 +172,4 @@ setup(name="ZEO",
""",
include_package_data=True,
python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
)
)
......@@ -52,9 +52,11 @@ import ZEO.cache
logger = logging.getLogger(__name__)
def tid2time(tid):
return str(TimeStamp(tid))
def get_timestamp(prev_ts=None):
"""Internal helper to return a unique TimeStamp instance.
......@@ -69,8 +71,10 @@ def get_timestamp(prev_ts=None):
t = t.laterThan(prev_ts)
return t
MB = 1024**2
@zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage.
......@@ -90,7 +94,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
blob_cache_size=None, blob_cache_size_check=10,
client_label=None,
cache=None,
ssl = None, ssl_server_hostname=None,
ssl=None, ssl_server_hostname=None,
# Mostly ignored backward-compatability options
client=None, var=None,
min_disconnect_poll=1, max_disconnect_poll=None,
......@@ -189,14 +193,15 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
if isinstance(addr, int):
addr = ('127.0.0.1', addr)
self.__name__ = name or str(addr) # Standard convention for storages
self.__name__ = name or str(addr) # Standard convention for storages
if isinstance(addr, six.string_types):
if WIN:
raise ValueError("Unix sockets are not available on Windows")
addr = [addr]
elif (isinstance(addr, tuple) and len(addr) == 2 and
isinstance(addr[0], six.string_types) and isinstance(addr[1], int)):
isinstance(addr[0], six.string_types) and
isinstance(addr[1], int)):
addr = [addr]
logger.info(
......@@ -212,7 +217,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._is_read_only = read_only
self._read_only_fallback = read_only_fallback
self._addr = addr # For tests
self._addr = addr # For tests
self._iterators = weakref.WeakValueDictionary()
self._iterator_ids = set()
......@@ -228,7 +233,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._db = None
self._oids = [] # List of pre-fetched oids from server
self._oids = [] # List of pre-fetched oids from server
cache = self._cache = open_cache(
cache, var, client, storage, cache_size)
......@@ -266,7 +271,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
addr, self, cache, storage,
ZEO.asyncio.client.Fallback if read_only_fallback else read_only,
wait_timeout or 30,
ssl = ssl, ssl_server_hostname=ssl_server_hostname,
ssl=ssl, ssl_server_hostname=ssl_server_hostname,
credentials=credentials,
)
self._call = self._server.call
......@@ -308,6 +313,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._check_blob_size_thread.join()
_check_blob_size_thread = None
def _check_blob_size(self, bytes=None):
if self._blob_cache_size is None:
return
......@@ -349,8 +355,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
pass
_connection_generation = 0
def notify_connected(self, conn, info):
reconnected = self._connection_generation
self.set_server_addr(conn.get_peername())
self.protocol_version = conn.protocol_version
self._is_read_only = conn.is_read_only()
......@@ -373,22 +379,20 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._info.update(info)
for iface in (
ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageIteration,
ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IStorageCurrentRecordIteration,
ZODB.interfaces.IBlobStorage,
ZODB.interfaces.IExternalGC,
):
if (iface.__module__, iface.__name__) in self._info.get(
'interfaces', ()):
for iface in (ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageIteration,
ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IStorageCurrentRecordIteration,
ZODB.interfaces.IBlobStorage,
ZODB.interfaces.IExternalGC):
if (iface.__module__, iface.__name__) in \
self._info.get('interfaces', ()):
zope.interface.alsoProvides(self, iface)
if self.protocol_version[1:] >= b'5':
self.ping = lambda : self._call('ping')
self.ping = lambda: self._call('ping')
else:
self.ping = lambda : self._call('lastTransaction')
self.ping = lambda: self._call('lastTransaction')
if self.server_sync:
self.sync = self.ping
......@@ -536,7 +540,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
try:
return self._oids.pop()
except IndexError:
pass # We ran out. We need to get some more.
pass # We ran out. We need to get some more.
self._oids[:0] = reversed(self._call('new_oids'))
......@@ -735,7 +739,6 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
finally:
lock.close()
def temporaryDirectory(self):
return self.fshelper.temp_dir
......@@ -747,7 +750,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
conflicts = True
vote_attempts = 0
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
conflicts = False
for oid in self._call('vote', id(txn)) or ():
if isinstance(oid, dict):
......@@ -843,11 +846,11 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def tpc_abort(self, txn, timeout=None):
"""Storage API: abort a transaction.
(The timeout keyword argument is for tests to wat longer than
(The timeout keyword argument is for tests to wait longer than
they normally would.)
"""
try:
tbuf = txn.data(self)
tbuf = txn.data(self) # NOQA: F841 unused variable
except KeyError:
return
......@@ -899,7 +902,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
while blobs:
oid, blobfilename = blobs.pop()
self._blob_data_bytes_loaded += os.stat(blobfilename).st_size
targetpath = self.fshelper.getPathForOID(oid, create=True)
self.fshelper.getPathForOID(oid, create=True)
target_blob_file_name = self.fshelper.getBlobFilename(oid, tid)
lock = _lock_blob(target_blob_file_name)
try:
......@@ -1037,6 +1040,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def server_status(self):
return self._call('server_status')
class TransactionIterator(object):
def __init__(self, storage, iid, *args):
......@@ -1130,14 +1134,18 @@ class BlobCacheLayout(object):
ZODB.blob.BLOB_SUFFIX)
)
def _accessed(filename):
try:
os.utime(filename, (time.time(), os.stat(filename).st_mtime))
except OSError:
pass # We tried. :)
pass # We tried. :)
return filename
cache_file_name = re.compile(r'\d+$').match
def _check_blob_cache_size(blob_dir, target):
logger = logging.getLogger(__name__+'.check_blob_cache')
......@@ -1162,7 +1170,7 @@ def _check_blob_cache_size(blob_dir, target):
# Someone is already cleaning up, so don't bother
logger.debug("%s Another thread is checking the blob cache size.",
get_ident())
open(attempt_path, 'w').close() # Mark that we tried
open(attempt_path, 'w').close() # Mark that we tried
return
logger.debug("%s Checking blob cache size. (target: %s)",
......@@ -1200,7 +1208,7 @@ def _check_blob_cache_size(blob_dir, target):
try:
os.remove(attempt_path)
except OSError:
pass # Sigh, windows
pass # Sigh, windows
continue
logger.debug("%s -->", get_ident())
break
......@@ -1222,8 +1230,8 @@ def _check_blob_cache_size(blob_dir, target):
fsize = os.stat(file_name).st_size
try:
ZODB.blob.remove_committed(file_name)
except OSError as v:
pass # probably open on windows
except OSError:
pass # probably open on windows
else:
size -= fsize
finally:
......@@ -1238,12 +1246,14 @@ def _check_blob_cache_size(blob_dir, target):
finally:
check_lock.close()
def check_blob_size_script(args=None):
if args is None:
args = sys.argv[1:]
blob_dir, target = args
_check_blob_cache_size(blob_dir, int(target))
def _lock_blob(path):
lockfilename = os.path.join(os.path.dirname(path), '.lock')
n = 0
......@@ -1258,6 +1268,7 @@ def _lock_blob(path):
else:
break
def open_cache(cache, var, client, storage, cache_size):
if isinstance(cache, (None.__class__, str)):
from ZEO.cache import ClientCache
......
......@@ -17,27 +17,33 @@ import transaction.interfaces
from ZODB.POSException import StorageError
class ClientStorageError(StorageError):
"""An error occurred in the ZEO Client Storage.
"""
class UnrecognizedResult(ClientStorageError):
"""A server call returned an unrecognized result.
"""
class ClientDisconnected(ClientStorageError,
transaction.interfaces.TransientError):
"""The database storage is disconnected from the storage.
"""
class AuthError(StorageError):
"""The client provided invalid authentication credentials.
"""
class ProtocolError(ClientStorageError):
"""A client contacted a server with an incomparible protocol
"""
class ServerException(ClientStorageError):
"""
"""
......@@ -23,13 +23,11 @@ import codecs
import itertools
import logging
import os
import socket
import sys
import tempfile
import threading
import time
import warnings
import ZEO.asyncio.server
import ZODB.blob
import ZODB.event
import ZODB.serialize
......@@ -37,8 +35,7 @@ import ZODB.TimeStamp
import zope.interface
import six
from ZEO._compat import Pickler, Unpickler, PY3, BytesIO
from ZEO.Exceptions import AuthError
from ZEO._compat import Pickler, Unpickler, PY3
from ZEO.monitor import StorageStats
from ZEO.asyncio.server import Delay, MTDelay, Result
from ZODB.Connection import TransactionMetaData
......@@ -46,10 +43,10 @@ from ZODB.loglevels import BLATHER
from ZODB.POSException import StorageError, StorageTransactionError
from ZODB.POSException import TransactionError, ReadOnlyError, ConflictError
from ZODB.serialize import referencesf
from ZODB.utils import oid_repr, p64, u64, z64, Lock, RLock
from ZODB.utils import p64, u64, z64, Lock, RLock
# BBB mtacceptor is unused and will be removed in ZEO version 6
if os.environ.get("ZEO_MTACCEPTOR"): # mainly for tests
if os.environ.get("ZEO_MTACCEPTOR"): # mainly for tests
warnings.warn('The mtacceptor module is deprecated and will be removed '
'in ZEO version 6.', DeprecationWarning)
from .asyncio.mtacceptor import Acceptor
......@@ -58,6 +55,7 @@ else:
logger = logging.getLogger('ZEO.StorageServer')
def log(message, level=logging.INFO, label='', exc_info=False):
"""Internal helper to log a message."""
if label:
......@@ -68,15 +66,18 @@ def log(message, level=logging.INFO, label='', exc_info=False):
class StorageServerError(StorageError):
"""Error reported when an unpicklable exception is raised."""
registered_methods = set(( 'get_info', 'lastTransaction',
'getInvalidations', 'new_oids', 'pack', 'loadBefore', 'storea',
'checkCurrentSerialInTransaction', 'restorea', 'storeBlobStart',
'storeBlobChunk', 'storeBlobEnd', 'storeBlobShared',
'deleteObject', 'tpc_begin', 'vote', 'tpc_finish', 'tpc_abort',
'history', 'record_iternext', 'sendBlob', 'getTid', 'loadSerial',
'new_oid', 'undoa', 'undoLog', 'undoInfo', 'iterator_start',
'iterator_next', 'iterator_record_start', 'iterator_record_next',
'iterator_gc', 'server_status', 'set_client_label', 'ping'))
registered_methods = set(
('get_info', 'lastTransaction',
'getInvalidations', 'new_oids', 'pack', 'loadBefore', 'storea',
'checkCurrentSerialInTransaction', 'restorea', 'storeBlobStart',
'storeBlobChunk', 'storeBlobEnd', 'storeBlobShared',
'deleteObject', 'tpc_begin', 'vote', 'tpc_finish', 'tpc_abort',
'history', 'record_iternext', 'sendBlob', 'getTid', 'loadSerial',
'new_oid', 'undoa', 'undoLog', 'undoInfo', 'iterator_start',
'iterator_next', 'iterator_record_start', 'iterator_record_next',
'iterator_gc', 'server_status', 'set_client_label', 'ping'))
class ZEOStorage(object):
"""Proxy to underlying storage for a single remote client."""
......@@ -146,7 +147,7 @@ class ZEOStorage(object):
info = self.get_info()
if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: ()
self.undoLog = self.undoInfo = lambda *a, **k: ()
# XXX deprecated: but ZODB tests use getTid. They shouldn't
self.getTid = storage.getTid
......@@ -166,16 +167,16 @@ class ZEOStorage(object):
"Falling back to using _transaction attribute, which\n."
"is icky.",
logging.ERROR)
self.tpc_transaction = lambda : storage._transaction
self.tpc_transaction = lambda: storage._transaction
else:
raise
self.connection.methods = registered_methods
def history(self,tid,size=1):
def history(self, tid, size=1):
# This caters for storages which still accept
# a version parameter.
return self.storage.history(tid,size=size)
return self.storage.history(tid, size=size)
def _check_tid(self, tid, exc=None):
if self.read_only:
......@@ -235,7 +236,7 @@ class ZEOStorage(object):
def get_info(self):
storage = self.storage
supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)()
supportsUndo = (getattr(storage, 'supportsUndo', lambda: False)()
and self.connection.protocol_version[1:] >= b'310')
# Communicate the backend storage interfaces to the client
......@@ -382,7 +383,7 @@ class ZEOStorage(object):
# Called from client thread
if not self.connected:
return # We're disconnected
return # We're disconnected
try:
self.log(
......@@ -404,14 +405,12 @@ class ZEOStorage(object):
oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename)
if not self.conflicts:
try:
serials = self.storage.tpc_vote(self.transaction)
except ConflictError as err:
if (self.client_conflict_resolution and
err.oid and err.serials and err.data
):
if self.client_conflict_resolution and \
err.oid and err.serials and err.data:
self.conflicts[err.oid] = dict(
oid=err.oid, serials=err.serials, data=err.data)
else:
......@@ -424,7 +423,7 @@ class ZEOStorage(object):
self.storage.tpc_abort(self.transaction)
return list(self.conflicts.values())
else:
self.locked = True # signal to lock manager to hold lock
self.locked = True # signal to lock manager to hold lock
return self.serials
except Exception as err:
......@@ -474,7 +473,7 @@ class ZEOStorage(object):
def storeBlobEnd(self, oid, serial, data, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
assert self.txnlog is not None # effectively not allowed after undo
fd, tempname = self.blob_tempfile
self.blob_tempfile = None
os.close(fd)
......@@ -482,14 +481,11 @@ class ZEOStorage(object):
def storeBlobShared(self, oid, serial, data, filename, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
assert self.txnlog is not None # effectively not allowed after undo
# Reconstruct the full path from the filename in the OID directory
if (os.path.sep in filename
or not (filename.endswith('.tmp')
or filename[:-1].endswith('.tmp')
)
):
if os.path.sep in filename or \
not (filename.endswith('.tmp') or filename[:-1].endswith('.tmp')):
logger.critical(
"We're under attack! (bad filename to storeBlobShared, %r)",
filename)
......@@ -623,6 +619,7 @@ class ZEOStorage(object):
def ping(self):
pass
class StorageServerDB(object):
"""Adapter from StorageServerDB to ZODB.interfaces.IStorageWrapper
......@@ -649,6 +646,7 @@ class StorageServerDB(object):
transform_record_data = untransform_record_data = lambda self, data: data
class StorageServer(object):
"""The server side implementation of ZEO.
......@@ -722,9 +720,8 @@ class StorageServer(object):
log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg))
self._lock = Lock()
self.ssl = ssl # For dev convenience
self.ssl = ssl # For dev convenience
self.read_only = read_only
self.database = None
......@@ -736,9 +733,9 @@ class StorageServer(object):
self.invq_bound = invalidation_queue_size
self.invq = {}
self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.lock_managers = {} # {storage_id -> LockManager}
self.stats = {} # {storage_id -> StorageStats}
self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.lock_managers = {} # {storage_id -> LockManager}
self.stats = {} # {storage_id -> StorageStats}
for name, storage in storages.items():
self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name))
......@@ -895,6 +892,7 @@ class StorageServer(object):
return latest_tid, list(oids)
__thread = None
def start_thread(self, daemon=True):
self.__thread = thread = threading.Thread(target=self.loop)
thread.setName("StorageServer(%s)" % _addr_label(self.addr))
......@@ -902,6 +900,7 @@ class StorageServer(object):
thread.start()
__closed = False
def close(self, join_timeout=1):
"""Close the dispatcher so that there are no new connections.
......@@ -959,6 +958,7 @@ class StorageServer(object):
return dict((storage_id, self.server_status(storage_id))
for storage_id in self.storages)
class StubTimeoutThread(object):
def begin(self, client):
......@@ -967,7 +967,8 @@ class StubTimeoutThread(object):
def end(self, client):
pass
is_alive = lambda self: 'stub'
def is_alive(self):
return 'stub'
class TimeoutThread(threading.Thread):
......@@ -983,7 +984,7 @@ class TimeoutThread(threading.Thread):
self._timeout = timeout
self._client = None
self._deadline = None
self._cond = threading.Condition() # Protects _client and _deadline
self._cond = threading.Condition() # Protects _client and _deadline
def begin(self, client):
# Called from the restart code the "main" thread, whenever the
......@@ -1013,14 +1014,14 @@ class TimeoutThread(threading.Thread):
if howlong <= 0:
# Prevent reporting timeout more than once
self._deadline = None
client = self._client # For the howlong <= 0 branch below
client = self._client # For the howlong <= 0 branch below
if howlong <= 0:
client.log("Transaction timeout after %s seconds" %
self._timeout, logging.CRITICAL)
try:
client.call_soon_threadsafe(client.connection.close)
except:
except: # NOQA: E722 bare except
client.log("Timeout failure", logging.CRITICAL,
exc_info=sys.exc_info())
self.end(client)
......@@ -1074,6 +1075,7 @@ def _addr_label(addr):
host, port = addr
return str(host) + ":" + str(port)
class CommitLog(object):
def __init__(self):
......@@ -1116,23 +1118,28 @@ class CommitLog(object):
self.file.close()
self.file = None
class ServerEvent(object):
def __init__(self, server, **kw):
self.__dict__.update(kw)
self.server = server
class Serving(ServerEvent):
pass
class Closed(ServerEvent):
pass
def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle,
committedData=b''):
raise ConflictError(oid=oid, serials=(committedSerial, oldSerial),
data=newpickle)
class LockManager(object):
def __init__(self, storage_id, stats, timeout):
......@@ -1140,7 +1147,7 @@ class LockManager(object):
self.stats = stats
self.timeout = timeout
self.locked = None
self.waiting = {} # {ZEOStorage -> (func, delay)}
self.waiting = {} # {ZEOStorage -> (func, delay)}
self._lock = RLock()
def lock(self, zs, func):
......@@ -1218,10 +1225,10 @@ class LockManager(object):
zs, "(%r) dequeue lock: transactions waiting: %s")
def _log_waiting(self, zs, message):
l = len(self.waiting)
zs.log(message % (self.storage_id, l),
logging.CRITICAL if l > 9 else (
logging.WARNING if l > 3 else logging.DEBUG)
length = len(self.waiting)
zs.log(message % (self.storage_id, length),
logging.CRITICAL if length > 9 else (
logging.WARNING if length > 3 else logging.DEBUG)
)
def _can_lock(self, zs):
......
......@@ -21,12 +21,11 @@ is used to store the data until a commit or abort.
# A faster implementation might store trans data in memory until it
# reaches a certain size.
import os
import tempfile
import ZODB.blob
from ZEO._compat import Pickler, Unpickler
class TransactionBuffer(object):
# The TransactionBuffer is used by client storage to hold update
......@@ -44,8 +43,8 @@ class TransactionBuffer(object):
# stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1
self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.exception = None
def close(self):
......@@ -93,9 +92,7 @@ class TransactionBuffer(object):
if oid not in seen:
yield oid, None, True
# Support ZEO4:
def serialnos(self, args):
for oid in args:
if isinstance(oid, bytes):
......
......@@ -21,6 +21,7 @@ ZEO is now part of ZODB; ZODB's home on the web is
"""
def client(*args, **kw):
"""
Shortcut for :class:`ZEO.ClientStorage.ClientStorage`.
......@@ -28,6 +29,7 @@ def client(*args, **kw):
import ZEO.ClientStorage
return ZEO.ClientStorage.ClientStorage(*args, **kw)
def DB(*args, **kw):
"""
Shortcut for creating a :class:`ZODB.DB` using a ZEO :func:`~ZEO.client`.
......@@ -40,6 +42,7 @@ def DB(*args, **kw):
s.close()
raise
def connection(*args, **kw):
db = DB(*args, **kw)
try:
......@@ -48,6 +51,7 @@ def connection(*args, **kw):
db.close()
raise
def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
port=0, threaded=True, **kw):
"""Convenience function to start a server for interactive exploration
......
......@@ -16,13 +16,20 @@
import sys
import platform
from ZODB._compat import BytesIO # NOQA: F401 unused import
PY3 = sys.version_info[0] >= 3
PY32 = sys.version_info[:2] == (3, 2)
PYPY = getattr(platform, 'python_implementation', lambda: None)() == 'PyPy'
WIN = sys.platform.startswith('win')
if PY3:
from zodbpickle.pickle import Pickler, Unpickler as _Unpickler, dump, dumps, loads
from zodbpickle.pickle import dump
from zodbpickle.pickle import dumps
from zodbpickle.pickle import loads
from zodbpickle.pickle import Pickler
from zodbpickle.pickle import Unpickler as _Unpickler
class Unpickler(_Unpickler):
# Py3: Python 3 doesn't allow assignments to find_global,
# instead, find_class can be overridden
......@@ -44,24 +51,17 @@ else:
dumps = cPickle.dumps
loads = cPickle.loads
# String and Bytes IO
from ZODB._compat import BytesIO
if PY3:
import _thread as thread
import _thread as thread # NOQA: F401 unused import
if PY32:
from threading import _get_ident as get_ident
from threading import _get_ident as get_ident # NOQA: F401 unused
else:
from threading import get_ident
from threading import get_ident # NOQA: F401 unused import
else:
import thread
from thread import get_ident
import thread # NOQA: F401 unused import
from thread import get_ident # NOQA: F401 unused import
try:
from cStringIO import StringIO
except:
from io import StringIO
from cStringIO import StringIO # NOQA: F401 unused import
except ImportError:
from io import StringIO # NOQA: F401 unused import
......@@ -26,11 +26,10 @@ import six
from ZEO._compat import StringIO
logger = logging.getLogger('ZEO.tests.forker')
DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG')
ZEO4_SERVER = os.environ.get('ZEO4_SERVER')
class ZEOConfig(object):
"""Class to generate ZEO configuration file. """
......@@ -61,8 +60,7 @@ class ZEOConfig(object):
for name in (
'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'msgpack',
'ssl_certificate', 'ssl_key', 'client_conflict_resolution',
):
'ssl_certificate', 'ssl_key', 'client_conflict_resolution'):
v = getattr(self, name, None)
if v:
print(name.replace('_', '-'), v, file=f)
......@@ -134,7 +132,7 @@ def runner(config, qin, qout, timeout=None,
os.remove(config)
try:
qin.get(timeout=timeout) # wait for shutdown
qin.get(timeout=timeout) # wait for shutdown
except Empty:
pass
server.server.close()
......@@ -158,6 +156,7 @@ def runner(config, qin, qout, timeout=None,
ZEO.asyncio.server.best_protocol_version = old_protocol
ZEO.asyncio.server.ServerProtocol.protocols = old_protocols
def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
qin.put('stop')
try:
......@@ -180,6 +179,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
gc.collect()
def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
path='Data.fs', protocol=None, blob_dir=None,
suicide=True, debug=False,
......@@ -220,7 +220,8 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
print(zeo_conf)
# Store the config info in a temp file.
fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker', dir=os.getcwd())
fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker',
dir=os.getcwd())
with os.fdopen(fd, 'w') as fp:
fp.write(zeo_conf)
......@@ -273,10 +274,12 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG):
return stop
def whine(*message):
print(*message, file=sys.stderr)
sys.stderr.flush()
class ThreadlessQueue(object):
def __init__(self):
......
......@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface
"""
......@@ -30,9 +31,9 @@ class Protocol(asyncio.Protocol):
def __init__(self, loop, addr):
self.loop = loop
self.addr = addr
self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
......@@ -41,6 +42,7 @@ class Protocol(asyncio.Protocol):
return self.name
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -50,7 +52,6 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport):
logger.info("Connected %s", self)
if sys.version_info < (3, 6):
sock = transport.get_extra_info('socket')
if sock is not None and sock.family in INET_FAMILIES:
......@@ -91,6 +92,7 @@ class Protocol(asyncio.Protocol):
got = 0
want = 4
getting_size = True
def data_received(self, data):
# Low-level input handler collects data into sized messages.
......@@ -135,7 +137,7 @@ class Protocol(asyncio.Protocol):
def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
del self.message_received # use default handler from here on
self.finish_connect(protocol_version)
def call_async(self, method, args):
......@@ -162,7 +164,7 @@ class Protocol(asyncio.Protocol):
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
if paused: # paused again. Put iter back.
output.insert(0, data)
break
......
......@@ -19,7 +19,8 @@ logger = logging.getLogger(__name__)
Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests
local_random = random.Random() # use separate generator to facilitate tests
def future_generator(func):
"""Decorates a generator that generates futures
......@@ -52,6 +53,7 @@ def future_generator(func):
return call_generator
class Protocol(base.Protocol):
"""asyncio low-level ZEO client interface
"""
......@@ -85,7 +87,7 @@ class Protocol(base.Protocol):
self.client = client
self.connect_poll = connect_poll
self.heartbeat_interval = heartbeat_interval
self.futures = {} # { message_id -> future }
self.futures = {} # { message_id -> future }
self.ssl = ssl
self.ssl_server_hostname = ssl_server_hostname
self.credentials = credentials
......@@ -132,7 +134,9 @@ class Protocol(base.Protocol):
elif future.exception() is not None:
logger.info("Connection to %r failed, %s",
self.addr, future.exception())
else: return
else:
return
# keep trying
if not self.closed:
logger.info("retry connecting %r", self.addr)
......@@ -141,7 +145,6 @@ class Protocol(base.Protocol):
self.connect,
)
def connection_made(self, transport):
super(Protocol, self).connection_made(transport)
self.heartbeat(write=False)
......@@ -190,7 +193,8 @@ class Protocol(base.Protocol):
try:
server_tid = yield self.fut(
'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False,
(self.read_only if self.read_only is not Fallback
else False),
*credentials)
except ZODB.POSException.ReadOnlyError:
if self.read_only is Fallback:
......@@ -208,11 +212,12 @@ class Protocol(base.Protocol):
self.client.registered(self, server_tid)
exception_type_type = type(Exception)
def message_received(self, data):
msgid, async_, name, args = self.decode(data)
if name == '.reply':
future = self.futures.pop(msgid)
if async_: # ZEO 5 exception
if async_: # ZEO 5 exception
class_, args = args
factory = exc_factories.get(class_)
if factory:
......@@ -237,13 +242,14 @@ class Protocol(base.Protocol):
else:
future.set_result(args)
else:
assert async_ # clients only get async calls
assert async_ # clients only get async calls
if name in self.client_methods:
getattr(self.client, name)(*args)
else:
raise AttributeError(name)
message_id = 0
def call(self, future, method, args):
self.message_id += 1
self.futures[self.message_id] = future
......@@ -262,6 +268,7 @@ class Protocol(base.Protocol):
self.futures[message_id] = future
self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid)))
@future.add_done_callback
def _(future):
try:
......@@ -271,6 +278,7 @@ class Protocol(base.Protocol):
if data:
data, start, end = data
self.client.cache.store(oid, start, end, data)
return future
# Methods called by the server.
......@@ -290,29 +298,34 @@ class Protocol(base.Protocol):
self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat)
def create_Exception(class_, args):
return exc_classes[class_](*args)
def create_ConflictError(class_, args):
exc = exc_classes[class_](
message = args['message'],
oid = args['oid'],
serials = args['serials'],
message=args['message'],
oid=args['oid'],
serials=args['serials'],
)
exc.class_name = args.get('class_name')
return exc
def create_BTreesConflictError(class_, args):
return ZODB.POSException.BTreesConflictError(
p1 = args['p1'],
p2 = args['p2'],
p3 = args['p3'],
reason = args['reason'],
p1=args['p1'],
p2=args['p2'],
p3=args['p3'],
reason=args['reason'],
)
def create_MultipleUndoErrors(class_, args):
return ZODB.POSException.MultipleUndoErrors(args['_errs'])
exc_classes = {
'builtins.KeyError': KeyError,
'builtins.TypeError': TypeError,
......@@ -340,6 +353,8 @@ exc_factories = {
}
unlogged_exceptions = (ZODB.POSException.POSKeyError,
ZODB.POSException.ConflictError)
class Client(object):
"""asyncio low-level ZEO client interface
"""
......@@ -352,8 +367,11 @@ class Client(object):
# connect.
protocol = None
ready = None # Tri-value: None=Never connected, True=connected,
# False=Disconnected
# ready can have three values:
# None=Never connected
# True=connected
# False=Disconnected
ready = None
def __init__(self, loop,
addrs, client, cache, storage_key, read_only, connect_poll,
......@@ -404,6 +422,7 @@ class Client(object):
self.is_read_only() and self.read_only is Fallback)
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -466,7 +485,7 @@ class Client(object):
self.upgrade(protocol)
self.verify(server_tid)
else:
protocol.close() # too late, we went home with another
protocol.close() # too late, we went home with another
def register_failed(self, protocol, exc):
# A protocol failed registration. That's weird. If they've all
......@@ -474,18 +493,17 @@ class Client(object):
if protocol is not self:
protocol.close()
logger.exception("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not
any(not p.closed for p in self.protocols)
):
if self.protocol is None and \
not any(not p.closed for p in self.protocols):
self.loop.call_later(
self.register_failed_poll + local_random.random(),
self.try_connecting)
verify_result = None # for tests
verify_result = None # for tests
@future_generator
def verify(self, server_tid):
self.verify_invalidation_queue = [] # See comment in init :(
self.verify_invalidation_queue = [] # See comment in init :(
protocol = self.protocol
try:
......@@ -739,6 +757,7 @@ class Client(object):
else:
return protocol.read_only
class ClientRunner(object):
def set_options(self, addrs, wrapper, cache, storage_key, read_only,
......@@ -855,6 +874,7 @@ class ClientRunner(object):
timeout = self.timeout
self.wait_for_result(self.client.connected, timeout)
class ClientThread(ClientRunner):
"""Thread wrapper for client interface
......@@ -883,6 +903,7 @@ class ClientThread(ClientRunner):
raise self.exception
exception = None
def run(self):
loop = None
try:
......@@ -909,6 +930,7 @@ class ClientThread(ClientRunner):
logger.debug('Stopping client thread')
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -918,6 +940,7 @@ class ClientThread(ClientRunner):
if self.exception:
raise self.exception
class Fut(object):
"""Lightweight future that calls it's callbacks immediately rather than soon
"""
......@@ -929,6 +952,7 @@ class Fut(object):
self.cbv.append(cb)
exc = None
def set_exception(self, exc):
self.exc = exc
for cb in self.cbv:
......
......@@ -6,5 +6,5 @@ if PY3:
except ImportError:
from asyncio import new_event_loop
else:
import trollius as asyncio
from trollius import new_event_loop
import trollius as asyncio # NOQA: F401 unused import
from trollius import new_event_loop # NOQA: F401 unused import
......@@ -21,19 +21,22 @@ Python-independent format, or possibly a minimal pickle subset.
import logging
from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY
from .._compat import Unpickler, Pickler, BytesIO, PY3
from ..shortrepr import short_repr
PY2 = not PY3
logger = logging.getLogger(__name__)
def encoder(protocol, server=False):
"""Return a non-thread-safe encoder
"""
if protocol[:1] == b'M':
from msgpack import packb
default = server_default if server else None
def encode(*args):
return packb(
args, use_bin_type=True, default=default)
......@@ -49,6 +52,7 @@ def encoder(protocol, server=False):
pickler = Pickler(f, 3)
pickler.fast = 1
dump = pickler.dump
def encode(*args):
seek(0)
truncate()
......@@ -57,21 +61,26 @@ def encoder(protocol, server=False):
return encode
def encode(*args):
return encoder(b'Z')(*args)
def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, raw=False, use_list=False)
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode
def pickle_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
......@@ -82,11 +91,12 @@ def pickle_decode(msg):
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
......@@ -94,6 +104,7 @@ def server_decoder(protocol):
assert protocol[:1] == b'Z'
return pickle_server_decode
def pickle_server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
......@@ -105,22 +116,25 @@ def pickle_server_decode(msg):
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_default(obj):
if isinstance(obj, Exception):
return reduce_exception(obj)
else:
return obj
def reduce_exception(exc):
class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args
_globals = globals()
_silly = ('__doc__',)
......@@ -131,6 +145,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__',
)
def find_global(module, name):
"""Helper for message unpickler"""
try:
......@@ -143,7 +158,8 @@ def find_global(module, name):
except AttributeError:
raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES)
safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe:
return r
......@@ -153,6 +169,7 @@ def find_global(module, name):
raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name):
"""Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES:
......
......@@ -72,6 +72,7 @@ import logging
logger = logging.getLogger(__name__)
class Acceptor(asyncore.dispatcher):
"""A server that accepts incoming RPC connections
......@@ -115,13 +116,13 @@ class Acceptor(asyncore.dispatcher):
for i in range(25):
try:
self.bind(addr)
except Exception as exc:
except Exception:
logger.info("bind on %s failed %s waiting", addr, i)
if i == 24:
raise
else:
time.sleep(5)
except:
except: # NOQA: E722 bare except
logger.exception('binding')
raise
else:
......@@ -146,7 +147,6 @@ class Acceptor(asyncore.dispatcher):
logger.info("accepted failed: %s", msg)
return
# We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below,
......@@ -159,7 +159,7 @@ class Acceptor(asyncore.dispatcher):
# closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above.
if addr: # Sometimes None on Mac. See above.
addr = addr[:2]
try:
......@@ -172,23 +172,25 @@ class Acceptor(asyncore.dispatcher):
protocol.stop = loop.stop
if self.ssl_context is None:
cr = loop.create_connection((lambda : protocol), sock=sock)
cr = loop.create_connection((lambda: protocol), sock=sock)
else:
if hasattr(loop, 'connect_accepted_socket'):
cr = loop.connect_accepted_socket(
(lambda : protocol), sock, ssl=self.ssl_context)
(lambda: protocol), sock, ssl=self.ssl_context)
else:
#######################################################
# XXX See http://bugs.python.org/issue27392 :(
_make_ssl_transport = loop._make_ssl_transport
def make_ssl_transport(*a, **kw):
kw['server_side'] = True
return _make_ssl_transport(*a, **kw)
loop._make_ssl_transport = make_ssl_transport
#
#######################################################
cr = loop.create_connection(
(lambda : protocol), sock=sock,
(lambda: protocol), sock=sock,
ssl=self.ssl_context,
server_hostname=''
)
......@@ -212,11 +214,12 @@ class Acceptor(asyncore.dispatcher):
asyncore.loop(map=self.__socket_map, timeout=timeout)
except Exception:
if not self.__closed:
raise # Unexpected exc
raise # Unexpected exc
logger.debug('acceptor %s loop stopped', self.addr)
__closed = False
def close(self):
if not self.__closed:
self.__closed = True
......
import json
import logging
import os
import random
import threading
import ZODB.POSException
logger = logging.getLogger(__name__)
from ..shortrepr import short_repr
from . import base
from .compat import asyncio, new_event_loop
from .marshal import server_decoder, encoder, reduce_exception
logger = logging.getLogger(__name__)
class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
"""
......@@ -39,6 +40,7 @@ class ServerProtocol(base.Protocol):
)
closed = False
def close(self):
logger.debug("Closing server protocol")
if not self.closed:
......@@ -46,7 +48,8 @@ class ServerProtocol(base.Protocol):
if self.transport is not None:
self.transport.close()
connected = None # for tests
connected = None # for tests
def connection_made(self, transport):
self.connected = True
super(ServerProtocol, self).connection_made(transport)
......@@ -60,7 +63,7 @@ class ServerProtocol(base.Protocol):
self.stop()
def stop(self):
pass # Might be replaced when running a thread per client
pass # Might be replaced when running a thread per client
def finish_connect(self, protocol_version):
if protocol_version == b'ruok':
......@@ -95,7 +98,7 @@ class ServerProtocol(base.Protocol):
return
if message_id == -1:
return # keep-alive
return # keep-alive
if name not in self.methods:
logger.error('Invalid method, %r', name)
......@@ -109,7 +112,7 @@ class ServerProtocol(base.Protocol):
"%s`%r` raised exception:",
'async ' if async_ else '', name)
if async_:
return self.close() # No way to recover/cry for help
return self.close() # No way to recover/cry for help
else:
return self.send_error(message_id, exc)
......@@ -147,16 +150,19 @@ class ServerProtocol(base.Protocol):
def async_threadsafe(self, method, *args):
self.call_soon_threadsafe(self.call_async, method, args)
best_protocol_version = os.environ.get(
'ZEO_SERVER_PROTOCOL',
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket)
cr = loop.create_connection((lambda: protocol), sock=socket)
asyncio.ensure_future(cr, loop=loop)
class Delay(object):
"""Used to delay response to client for synchronous calls.
......@@ -192,6 +198,7 @@ class Delay(object):
def __reduce__(self):
raise TypeError("Can't pickle delays.")
class Result(Delay):
def __init__(self, *args):
......@@ -202,6 +209,7 @@ class Result(Delay):
protocol.send_reply(msgid, reply)
callback()
class MTDelay(Delay):
def __init__(self):
......@@ -266,6 +274,7 @@ class Acceptor(object):
self.event_loop.close()
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -277,6 +286,7 @@ class Acceptor(object):
self.server.close()
f = asyncio.ensure_future(self.server.wait_closed(), loop=loop)
@f.add_done_callback
def server_closed(f):
# stop the loop when the server closes:
......
......@@ -11,7 +11,6 @@ except NameError:
class ConnectionRefusedError(OSError):
pass
import pprint
class Loop(object):
......@@ -19,7 +18,7 @@ class Loop(object):
def __init__(self, addrs=(), debug=True):
self.addrs = addrs
self.get_debug = lambda : debug
self.get_debug = lambda: debug
self.connecting = {}
self.later = []
self.exceptions = []
......@@ -31,7 +30,7 @@ class Loop(object):
func(*args)
def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory()
self.protocol = protocol = protocol_factory()
self.transport = transport = Transport(protocol)
protocol.connection_made(transport)
future.set_result((transport, protocol))
......@@ -45,10 +44,8 @@ class Loop(object):
if not future.cancelled():
future.set_exception(ConnectionRefusedError())
def create_connection(
self, protocol_factory, host=None, port=None, sock=None,
ssl=None, server_hostname=None
):
def create_connection(self, protocol_factory, host=None, port=None,
sock=None, ssl=None, server_hostname=None):
future = asyncio.Future(loop=self)
if sock is None:
addr = host, port
......@@ -83,13 +80,16 @@ class Loop(object):
self.exceptions.append(context)
closed = False
def close(self):
self.closed = True
stopped = False
def stop(self):
self.stopped = True
class Handle(object):
cancelled = False
......@@ -97,6 +97,7 @@ class Handle(object):
def cancel(self):
self.cancelled = True
class Transport(object):
capacity = 1 << 64
......@@ -136,12 +137,14 @@ class Transport(object):
self.protocol.resume_writing()
closed = False
def close(self):
self.closed = True
def get_extra_info(self, name):
return self.extra[name]
class AsyncRPC(object):
"""Adapt an asyncio API to an RPC to help hysterical tests
"""
......@@ -151,6 +154,7 @@ class AsyncRPC(object):
def __getattr__(self, name):
return lambda *a, **kw: self.api.call(name, *a, **kw)
class ClientRunner(object):
def __init__(self, addr, client, cache, storage, read_only, timeout,
......
......@@ -2,17 +2,17 @@ from .._compat import PY3
if PY3:
import asyncio
def to_byte(i):
return bytes([i])
else:
import trollius as asyncio
import trollius as asyncio # NOQA: F401 unused import
def to_byte(b):
return b
from zope.testing import setupstack
from concurrent.futures import Future
import mock
from ZODB.POSException import ReadOnlyError
from ZODB.utils import maxtid, RLock
import collections
......@@ -28,6 +28,7 @@ from .client import ClientRunner, Fallback
from .server import new_connection, best_protocol_version
from .marshal import encoder, decoder
class Base(object):
enc = b'Z'
......@@ -56,6 +57,7 @@ class Base(object):
return self.unsized(data, True)
target = None
def send(self, method, *args, **kw):
target = kw.pop('target', self.target)
called = kw.pop('called', True)
......@@ -77,6 +79,7 @@ class Base(object):
def pop(self, count=None, parse=True):
return self.unsized(self.loop.transport.pop(count), parse)
class ClientTests(Base, setupstack.TestCase, ClientRunner):
maxDiff = None
......@@ -204,7 +207,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, maxtid)
# The data wasn't in the cache, so we made a server call:
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
# Note load_before uses the oid as the message id.
self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
......@@ -224,7 +231,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, maxtid)
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
......@@ -238,7 +249,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(self.pop(), ((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8)))
self.assertEqual(self.pop(),
((b'1'*8, b'_'*8),
False,
'loadBefore',
(b'1'*8, b'_'*8)))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
......@@ -247,6 +262,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# iteratable to tpc_finish_threadsafe.
tids = []
def finished_cb(tid):
tids.append(tid)
......@@ -349,7 +365,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, )))
self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
self.respond(3, (b'e'*8, [b'4'*8]))
self.assertEqual(self.pop(), (4, False, 'get_info', ()))
......@@ -361,7 +378,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# And the cache has been updated:
self.assertEqual(cache.load(b'2'*8),
('2 data', b'a'*8)) # unchanged
('2 data', b'a'*8)) # unchanged
self.assertEqual(cache.load(b'4'*8), None)
# Because we were able to update the cache, we didn't have to
......@@ -384,7 +401,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, )))
self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
# We respond None, indicating that we're too far out of date:
self.respond(3, None)
......@@ -451,10 +469,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.respond(2, 'a'*8)
self.pop()
self.assertFalse(client.connected.done() or transport.data)
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
self.assertTrue(8 < delay < 10)
self.assertEqual(len(loop.later), 1) # first in later is heartbeat
func(*args) # connect again
self.assertEqual(len(loop.later), 1) # first in later is heartbeat
func(*args) # connect again
self.assertFalse(protocol is loop.protocol)
self.assertFalse(transport is loop.transport)
protocol = loop.protocol
......@@ -512,7 +530,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We connect the second address:
loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), self.enc + b'3101')
self.assertEqual(self.unsized(loop.transport.pop(2)),
self.enc + b'3101')
self.assertEqual(self.parse(loop.transport.pop()),
(1, False, 'register', ('TEST', False)))
self.assertTrue(self.is_read_only())
......@@ -613,7 +632,6 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.data_received(sized(self.enc + b'200'))
self.assertTrue(isinstance(error.call_args[0][1], ProtocolError))
def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True)
......@@ -641,7 +659,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# that caused it to fail badly if errors were raised while
# handling data.
wrapper, cache, loop, client, protocol, transport =self.start(
wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True)
wrapper.receiveBlobStart.side_effect = ValueError('test')
......@@ -694,10 +712,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.connection_lost(None)
self.assertTrue(handle.cancelled)
class MsgpackClientTests(ClientTests):
enc = b'M'
seq_type = tuple
class MemoryCache(object):
def __init__(self):
......@@ -709,6 +729,7 @@ class MemoryCache(object):
clear = __init__
closed = False
def close(self):
self.closed = True
......@@ -771,6 +792,7 @@ class ServerTests(Base, setupstack.TestCase):
message_id = 0
target = None
def call(self, meth, *args, **kw):
if kw:
expect = kw.pop('expect', self)
......@@ -835,10 +857,12 @@ class ServerTests(Base, setupstack.TestCase):
self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed)
class MsgpackServerTests(ServerTests):
enc = b'M'
seq_type = tuple
def server_protocol(msgpack,
zeo_storage=None,
protocol_version=None,
......@@ -847,18 +871,17 @@ def server_protocol(msgpack,
if zeo_storage is None:
zeo_storage = mock.Mock()
loop = Loop()
sock = () # anything not None
sock = () # anything not None
new_connection(loop, addr, sock, zeo_storage, msgpack)
if protocol_version:
loop.protocol.data_received(sized(protocol_version))
return loop.protocol
def response(*data):
return sized(self.encode(*data))
def sized(message):
return struct.pack(">I", len(message)) + message
class Logging(object):
def __init__(self, level=logging.ERROR):
......@@ -885,9 +908,11 @@ class ProtocolTests(setupstack.TestCase):
loop = self.loop
protocol, transport = loop.protocol, loop.transport
transport.capacity = 1 # single message
def it(tag):
yield tag
yield tag
protocol._writeit(it(b"0"))
protocol._writeit(it(b"1"))
for b in b"0011":
......
......@@ -86,7 +86,7 @@ ZEC_HEADER_SIZE = 12
# need to write a free block that is almost twice as big. If we die
# in the middle of a store, then we need to split the large free records
# while opening.
max_block_size = (1<<31) - 1
max_block_size = (1 << 31) - 1
# After the header, the file contains a contiguous sequence of blocks. All
......@@ -132,12 +132,13 @@ allocated_record_overhead = 43
# Under PyPy, the available dict specializations perform significantly
# better (faster) than the pure-Python BTree implementation. They may
# use less memory too. And we don't require any of the special BTree features...
# use less memory too. And we don't require any of the special BTree features.
_current_index_type = ZODB.fsIndex.fsIndex if not PYPY else dict
_noncurrent_index_type = BTrees.LOBTree.LOBTree if not PYPY else dict
# ...except at this leaf level
_noncurrent_bucket_type = BTrees.LLBTree.LLBucket
class ClientCache(object):
"""A simple in-memory cache."""
......@@ -193,7 +194,7 @@ class ClientCache(object):
if path:
self._lock_file = zc.lockfile.LockFile(path + '.lock')
if not os.path.exists(path):
# Create a small empty file. We'll make it bigger in _initfile.
# Create a small empty file. We'll make it bigger in _initfile.
self.f = open(path, 'wb+')
self.f.write(magic+z64)
logger.info("created persistent cache file %r", path)
......@@ -209,10 +210,10 @@ class ClientCache(object):
try:
self._initfile(fsize)
except:
except: # NOQA: E722 bare except
self.f.close()
if not path:
raise # unrecoverable temp file error :(
raise # unrecoverable temp file error :(
badpath = path+'.bad'
if os.path.exists(badpath):
logger.critical(
......@@ -271,7 +272,7 @@ class ClientCache(object):
self.current = _current_index_type()
self.noncurrent = _noncurrent_index_type()
l = 0
length = 0
last = ofs = ZEC_HEADER_SIZE
first_free_offset = 0
current = self.current
......@@ -290,7 +291,7 @@ class ClientCache(object):
assert start_tid < end_tid, (ofs, f.tell())
self._set_noncurrent(oid, start_tid, ofs)
assert lver == 0, "Versions aren't supported"
l += 1
length += 1
else:
# free block
if first_free_offset == 0:
......@@ -331,7 +332,7 @@ class ClientCache(object):
break
if fsize < maxsize:
assert ofs==fsize
assert ofs == fsize
# Make sure the OS really saves enough bytes for the file.
seek(self.maxsize - 1)
write(b'x')
......@@ -349,7 +350,7 @@ class ClientCache(object):
assert last and (status in b' f1234')
first_free_offset = last
else:
assert ofs==maxsize
assert ofs == maxsize
if maxsize < fsize:
seek(maxsize)
f.truncate()
......@@ -357,7 +358,7 @@ class ClientCache(object):
# We use the first_free_offset because it is most likely the
# place where we last wrote.
self.currentofs = first_free_offset or ZEC_HEADER_SIZE
self._len = l
self._len = length
def _set_noncurrent(self, oid, tid, ofs):
noncurrent_for_oid = self.noncurrent.get(u64(oid))
......@@ -375,7 +376,6 @@ class ClientCache(object):
except KeyError:
logger.error("Couldn't find non-current %r", (oid, tid))
def clearStats(self):
self._n_adds = self._n_added_bytes = 0
self._n_evicts = self._n_evicted_bytes = 0
......@@ -384,8 +384,7 @@ class ClientCache(object):
def getStats(self):
return (self._n_adds, self._n_added_bytes,
self._n_evicts, self._n_evicted_bytes,
self._n_accesses
)
self._n_accesses)
##
# The number of objects currently in the cache.
......@@ -403,7 +402,7 @@ class ClientCache(object):
sync(f)
f.close()
if hasattr(self,'_lock_file'):
if hasattr(self, '_lock_file'):
self._lock_file.close()
##
......@@ -517,9 +516,9 @@ class ClientCache(object):
if ofsofs < 0:
ofsofs += self.maxsize
if (ofsofs > self.rearrange and
self.maxsize > 10*len(data) and
size > 4):
if ofsofs > self.rearrange and \
self.maxsize > 10*len(data) and \
size > 4:
# The record is far back and might get evicted, but it's
# valuable, so move it forward.
......@@ -619,8 +618,8 @@ class ClientCache(object):
raise ValueError("already have current data for oid")
else:
noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid and (
u64(start_tid) in noncurrent_for_oid):
if noncurrent_for_oid and \
u64(start_tid) in noncurrent_for_oid:
return
size = allocated_record_overhead + len(data)
......@@ -692,7 +691,6 @@ class ClientCache(object):
self.currentofs += size
##
# If `tid` is None,
# forget all knowledge of `oid`. (`tid` can be None only for
......@@ -765,8 +763,7 @@ class ClientCache(object):
for oid, tid in L:
print(oid_repr(oid), oid_repr(tid))
print("dll contents")
L = list(self)
L.sort(lambda x, y: cmp(x.key, y.key))
L = sorted(list(self), key=lambda x: x.key)
for x in L:
end_tid = x.end_tid or z64
print(oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid))
......@@ -779,6 +776,7 @@ class ClientCache(object):
# tracing by setting self._trace to a dummy function, and set
# self._tracefile to None.
_tracefile = None
def _trace(self, *a, **kw):
pass
......@@ -797,6 +795,7 @@ class ClientCache(object):
return
now = time.time
def _trace(code, oid=b"", tid=z64, end_tid=z64, dlen=0):
# The code argument is two hex digits; bits 0 and 7 must be zero.
# The first hex digit shows the operation, the second the outcome.
......@@ -812,7 +811,7 @@ class ClientCache(object):
pack(">iiH8s8s",
int(now()), encoded, len(oid), tid, end_tid) + oid,
)
except:
except: # NOQA: E722 bare except
print(repr(tid), repr(end_tid))
raise
......@@ -826,10 +825,7 @@ class ClientCache(object):
self._tracefile.close()
del self._tracefile
def sync(f):
f.flush()
if hasattr(os, 'fsync'):
def sync(f):
f.flush()
os.fsync(f.fileno())
os.fsync(f.fileno())
......@@ -14,6 +14,7 @@
import zope.interface
class StaleCache(object):
"""A ZEO cache is stale and requires verification.
"""
......@@ -21,6 +22,7 @@ class StaleCache(object):
def __init__(self, storage):
self.storage = storage
class IClientCache(zope.interface.Interface):
"""Client cache interface.
......@@ -86,6 +88,7 @@ class IClientCache(zope.interface.Interface):
"""Clear/empty the cache
"""
class IServeable(zope.interface.Interface):
"""Interface provided by storages that can be served by ZEO
"""
......
......@@ -30,10 +30,7 @@ from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
import asyncore
import socket
import time
import logging
zeo_version = 'unknown'
try:
......@@ -47,6 +44,7 @@ else:
if zeo_dist is not None:
zeo_version = zeo_dist.version
class StorageStats(object):
"""Per-storage usage statistics."""
......
......@@ -33,6 +33,7 @@ diff_names = 'aborts commits conflicts conflicts_resolved loads stores'.split()
per_times = dict(seconds=1.0, minutes=60.0, hours=3600.0, days=86400.0)
def new_metric(metrics, storage_id, name, value):
if storage_id == '1':
label = name
......@@ -43,6 +44,7 @@ def new_metric(metrics, storage_id, name, value):
label = "%s:%s" % (storage_id, name)
metrics.append("%s=%s" % (label, value))
def result(messages, metrics=(), status=None):
if metrics:
messages[0] += '|' + metrics[0]
......@@ -51,12 +53,15 @@ def result(messages, metrics=(), status=None):
print('\n'.join(messages))
return status
def error(message):
return result((message, ), (), 2)
def warn(message):
return result((message, ), (), 1)
def check(addr, output_metrics, status, per):
m = re.match(r'\[(\S+)\]:(\d+)$', addr)
if m:
......@@ -75,7 +80,7 @@ def check(addr, output_metrics, status, per):
return error("Can't connect %s" % err)
s.sendall(b'\x00\x00\x00\x04ruok')
proto = s.recv(struct.unpack(">I", s.recv(4))[0])
proto = s.recv(struct.unpack(">I", s.recv(4))[0]) # NOQA: F841 unused
datas = s.recv(struct.unpack(">I", s.recv(4))[0])
s.close()
data = json.loads(datas.decode("ascii"))
......@@ -94,8 +99,8 @@ def check(addr, output_metrics, status, per):
now = time.time()
if os.path.exists(status):
dt = now - os.stat(status).st_mtime
if dt > 0: # sanity :)
with open(status) as f: # Read previous
if dt > 0: # sanity :)
with open(status) as f: # Read previous
old = json.loads(f.read())
dt /= per_times[per]
for storage_id, sdata in sorted(data.items()):
......@@ -105,7 +110,7 @@ def check(addr, output_metrics, status, per):
for name in diff_names:
v = (sdata[name] - sold[name]) / dt
new_metric(metrics, storage_id, name, v)
with open(status, 'w') as f: # save current
with open(status, 'w') as f: # save current
f.write(json.dumps(data))
for storage_id, sdata in sorted(data.items()):
......@@ -116,6 +121,7 @@ def check(addr, output_metrics, status, per):
messages.append('OK')
return result(messages, metrics, level or None)
def main(args=None):
if args is None:
args = sys.argv[1:]
......@@ -139,5 +145,6 @@ def main(args=None):
return check(
addr, options.output_metrics, options.status_path, options.time_units)
if __name__ == '__main__':
main()
......@@ -46,21 +46,25 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address
def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all()
class ZEOOptionsMixin(object):
storages = None
......@@ -69,14 +73,17 @@ class ZEOOptionsMixin(object):
self.family, self.address = parse_binding_address(arg)
def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*!
from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object):
def __init__(self, name, path):
self._name = name
self.path = path
self.stop = None
def getSectionName(self):
return self._name
if not self.storages:
self.storages = []
name = str(1 + len(self.storages))
......@@ -84,6 +91,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf)
testing_exit_immediately = False
def handle_test(self, *args):
self.testing_exit_immediately = True
......@@ -108,6 +116,7 @@ class ZEOOptionsMixin(object):
None, 'pid-file=')
self.add("ssl", "zeo.ssl")
class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__
......@@ -164,15 +173,15 @@ class ZEOServer(object):
root = logging.getLogger()
root.setLevel(logging.INFO)
fmt = logging.Formatter(
"------\n%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%dT%H:%M:%S")
"------\n%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%dT%H:%M:%S")
handler = logging.StreamHandler()
handler.setFormatter(fmt)
root.addHandler(handler)
def check_socket(self):
if (isinstance(self.options.address, tuple) and
self.options.address[1] is None):
if isinstance(self.options.address, tuple) and \
self.options.address[1] is None:
self.options.address = self.options.address[0], 0
return
......@@ -217,7 +226,7 @@ class ZEOServer(object):
self.setup_win32_signals()
return
if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames()
for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None)
......@@ -237,12 +246,12 @@ class ZEOServer(object):
"will *not* be installed.")
return
SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32.
if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows.
SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self):
......@@ -278,7 +287,8 @@ class ZEOServer(object):
def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"):
if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
level=logging.WARNING)
return
......@@ -286,13 +296,13 @@ class ZEOServer(object):
loggers = [self.options.config_logger]
if os.name == "posix":
for l in loggers:
l.reopen()
for logger in loggers:
logger.reopen()
log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers:
for f in l.handler_factories:
handler = f()
else: # nt - same rotation code as in Zope's Signals/Signals.py
for logger in loggers:
for factory in logger.handler_factories:
handler = factory()
if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate()
log("Log files rotation complete", level=logging.INFO)
......@@ -350,21 +360,21 @@ def create_server(storages, options):
return StorageServer(
options.address,
storages,
read_only = options.read_only,
read_only=options.read_only,
client_conflict_resolution=options.client_conflict_resolution,
msgpack=(options.msgpack if isinstance(options.msgpack, bool)
else os.environ.get('ZEO_MSGPACK')),
invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout,
ssl = options.ssl,
)
invalidation_queue_size=options.invalidation_queue_size,
invalidation_age=options.invalidation_age,
transaction_timeout=options.transaction_timeout,
ssl=options.ssl)
# Signal names
signames = None
def signame(sig):
"""Return a symbolic name for a signal.
......@@ -376,6 +386,7 @@ def signame(sig):
init_signames()
return signames.get(sig) or "signal %d" % sig
def init_signames():
global signames
signames = {}
......@@ -395,11 +406,13 @@ def main(args=None):
s = ZEOServer(options)
s.main()
def run(args):
options = ZEOOptions()
options.realize(args)
s = ZEOServer(options)
s.run()
if __name__ == "__main__":
main()
......@@ -27,6 +27,7 @@ from __future__ import print_function, absolute_import
import bisect
import struct
import random
import re
import sys
import ZEO.cache
......@@ -34,6 +35,7 @@ import argparse
from ZODB.utils import z64
from ..cache import ZEC_HEADER_SIZE
from .cache_stats import add_interval_argument
from .cache_stats import add_tracefile_argument
......@@ -46,7 +48,7 @@ def main(args=None):
if args is None:
args = sys.argv[1:]
# Parse options.
MB = 1<<20
MB = 1 << 20
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--size", "-s",
default=20*MB, dest="cachelimit",
......@@ -115,6 +117,7 @@ def main(args=None):
interval_sim.report()
sim.finish()
class Simulation(object):
"""Base class for simulations.
......@@ -270,7 +273,6 @@ class CircularCacheEntry(object):
self.end_tid = end_tid
self.offset = offset
from ZEO.cache import ZEC_HEADER_SIZE
class CircularCacheSimulation(Simulation):
"""Simulate the ZEO 3.0 cache."""
......@@ -285,8 +287,6 @@ class CircularCacheSimulation(Simulation):
evicts = 0
def __init__(self, cachelimit, rearrange):
from ZEO import cache
Simulation.__init__(self, cachelimit, rearrange)
self.total_evicts = 0 # number of cache evictions
......@@ -296,7 +296,7 @@ class CircularCacheSimulation(Simulation):
# Map offset in file to (size, CircularCacheEntry) pair, or to
# (size, None) if the offset starts a free block.
self.filemap = {ZEC_HEADER_SIZE: (self.cachelimit - ZEC_HEADER_SIZE,
None)}
None)}
# Map key to CircularCacheEntry. A key is an (oid, tid) pair.
self.key2entry = {}
......@@ -322,10 +322,11 @@ class CircularCacheSimulation(Simulation):
self.evicted_hit = self.evicted_miss = 0
evicted_hit = evicted_miss = 0
def load(self, oid, size, tid, code):
if (code == 0x20) or (code == 0x22):
# Trying to load current revision.
if oid in self.current: # else it's a cache miss
if oid in self.current: # else it's a cache miss
self.hits += 1
self.total_hits += 1
......@@ -433,7 +434,8 @@ class CircularCacheSimulation(Simulation):
# Storing current revision.
if oid in self.current: # we already have it in cache
if evhit:
import pdb; pdb.set_trace()
import pdb
pdb.set_trace()
raise ValueError('WTF')
return
self.current[oid] = start_tid
......@@ -442,7 +444,8 @@ class CircularCacheSimulation(Simulation):
self.add(oid, size, start_tid)
return
if evhit:
import pdb; pdb.set_trace()
import pdb
pdb.set_trace()
raise ValueError('WTF')
# Storing non-current revision.
L = self.noncurrent.setdefault(oid, [])
......@@ -514,7 +517,7 @@ class CircularCacheSimulation(Simulation):
self.inuse = round(100.0 * used / total, 1)
self.total_inuse = self.inuse
Simulation.report(self)
#print self.evicted_hit, self.evicted_miss
# print self.evicted_hit, self.evicted_miss
def check(self):
oidcount = 0
......@@ -538,16 +541,18 @@ class CircularCacheSimulation(Simulation):
def roundup(size):
k = MINSIZE
k = MINSIZE # NOQA: F821 undefined name
while k < size:
k += k
return k
def hitrate(loads, hits):
if loads < 1:
return 'n/a'
return "%5.1f%%" % (100.0 * hits / loads)
def duration(secs):
mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60)
......@@ -557,7 +562,10 @@ def duration(secs):
return "%d:%02d" % (mm, ss)
return "%d" % ss
nre = re.compile('([=-]?)(\d+)([.]\d*)?').match
nre = re.compile(r'([=-]?)(\d+)([.]\d*)?').match
def addcommas(n):
sign, s, d = nre(str(n)).group(1, 2, 3)
if d == '.0':
......@@ -571,11 +579,11 @@ def addcommas(n):
return (sign or '') + result + (d or '')
import random
def maybe(f, p=0.5):
if random.random() < p:
f()
if __name__ == "__main__":
sys.exit(main())
......@@ -55,6 +55,7 @@ import gzip
from time import ctime
import six
def add_interval_argument(parser):
def _interval(a):
interval = int(60 * float(a))
......@@ -63,9 +64,11 @@ def add_interval_argument(parser):
elif interval > 3600:
interval = 3600
return interval
parser.add_argument("--interval", "-i",
default=15*60, type=_interval,
help="summarizing interval in minutes (default 15; max 60)")
parser.add_argument(
"--interval", "-i",
default=15*60, type=_interval,
help="summarizing interval in minutes (default 15; max 60)")
def add_tracefile_argument(parser):
......@@ -82,15 +85,17 @@ def add_tracefile_argument(parser):
parser.add_argument("tracefile", type=GzipFileType(),
help="The trace to read; may be gzipped")
def main(args=None):
if args is None:
args = sys.argv[1:]
# Parse options
parser = argparse.ArgumentParser(description="Trace file statistics analyzer",
# Our -h, short for --load-histogram
# conflicts with default for help, so we handle
# manually.
add_help=False)
parser = argparse.ArgumentParser(
description="Trace file statistics analyzer",
# Our -h, short for --load-histogram
# conflicts with default for help, so we handle
# manually.
add_help=False)
verbose_group = parser.add_mutually_exclusive_group()
verbose_group.add_argument('--verbose', '-v',
default=False, action='store_true',
......@@ -99,18 +104,22 @@ def main(args=None):
default=False, action='store_true',
help="Reduce output; don't print summaries")
parser.add_argument("--sizes", '-s',
default=False, action="store_true", dest="print_size_histogram",
default=False, action="store_true",
dest="print_size_histogram",
help="print histogram of object sizes")
parser.add_argument("--no-stats", '-S',
default=True, action="store_false", dest="dostats",
help="don't print statistics")
parser.add_argument("--load-histogram", "-h",
default=False, action="store_true", dest="print_histogram",
default=False, action="store_true",
dest="print_histogram",
help="print histogram of object load frequencies")
parser.add_argument("--check", "-X",
default=False, action="store_true", dest="heuristic",
help=" enable heuristic checking for misaligned records: oids > 2**32"
" will be rejected; this requires the tracefile to be seekable")
help=" enable heuristic checking for misaligned "
"records: oids > 2**32"
" will be rejected; this requires the tracefile "
"to be seekable")
add_interval_argument(parser)
add_tracefile_argument(parser)
......@@ -123,20 +132,20 @@ def main(args=None):
f = options.tracefile
rt0 = time.time()
bycode = {} # map code to count of occurrences
byinterval = {} # map code to count in current interval
records = 0 # number of trace records read
versions = 0 # number of trace records with versions
datarecords = 0 # number of records with dlen set
datasize = 0 # sum of dlen across records with dlen set
oids = {} # map oid to number of times it was loaded
bysize = {} # map data size to number of loads
bysizew = {} # map data size to number of writes
bycode = {} # map code to count of occurrences
byinterval = {} # map code to count in current interval
records = 0 # number of trace records read
versions = 0 # number of trace records with versions
datarecords = 0 # number of records with dlen set
datasize = 0 # sum of dlen across records with dlen set
oids = {} # map oid to number of times it was loaded
bysize = {} # map data size to number of loads
bysizew = {} # map data size to number of writes
total_loads = 0
t0 = None # first timestamp seen
te = None # most recent timestamp seen
h0 = None # timestamp at start of current interval
he = None # timestamp at end of current interval
t0 = None # first timestamp seen
te = None # most recent timestamp seen
h0 = None # timestamp at start of current interval
he = None # timestamp at end of current interval
thisinterval = None # generally te//interval
f_read = f.read
unpack = struct.unpack
......@@ -144,7 +153,8 @@ def main(args=None):
FMT_SIZE = struct.calcsize(FMT)
assert FMT_SIZE == 26
# Read file, gathering statistics, and printing each record if verbose.
print(' '*16, "%7s %7s %7s %7s" % ('loads', 'hits', 'inv(h)', 'writes'), end=' ')
print(' '*16, "%7s %7s %7s %7s" % (
'loads', 'hits', 'inv(h)', 'writes'), end=' ')
print('hitrate')
try:
while 1:
......@@ -187,10 +197,10 @@ def main(args=None):
bycode[code] = bycode.get(code, 0) + 1
byinterval[code] = byinterval.get(code, 0) + 1
if dlen:
if code & 0x70 == 0x20: # All loads
if code & 0x70 == 0x20: # All loads
bysize[dlen] = d = bysize.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1
elif code & 0x70 == 0x50: # All stores
elif code & 0x70 == 0x50: # All stores
bysizew[dlen] = d = bysizew.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1
if options.verbose:
......@@ -205,7 +215,7 @@ def main(args=None):
if code & 0x70 == 0x20:
oids[oid] = oids.get(oid, 0) + 1
total_loads += 1
elif code == 0x00: # restart
elif code == 0x00: # restart
if not options.quiet:
dumpbyinterval(byinterval, h0, he)
byinterval = {}
......@@ -279,6 +289,7 @@ def main(args=None):
dumpbysize(bysizew, "written", "writes")
dumpbysize(bysize, "loaded", "loads")
def dumpbysize(bysize, how, how2):
print()
print("Unique sizes %s: %s" % (how, addcommas(len(bysize))))
......@@ -292,6 +303,7 @@ def dumpbysize(bysize, how, how2):
len(bysize.get(size, "")),
loads))
def dumpbyinterval(byinterval, h0, he):
loads = hits = invals = writes = 0
for code in byinterval:
......@@ -301,7 +313,7 @@ def dumpbyinterval(byinterval, h0, he):
if code in (0x22, 0x26):
hits += n
elif code & 0x40:
writes += byinterval[code]
writes += byinterval[code]
elif code & 0x10:
if code != 0x10:
invals += byinterval[code]
......@@ -315,6 +327,7 @@ def dumpbyinterval(byinterval, h0, he):
ctime(h0)[4:-8], ctime(he)[14:-8],
loads, hits, invals, writes, hr))
def hitrate(bycode):
loads = hits = 0
for code in bycode:
......@@ -328,6 +341,7 @@ def hitrate(bycode):
else:
return 0.0
def histogram(d):
bins = {}
for v in six.itervalues(d):
......@@ -335,15 +349,18 @@ def histogram(d):
L = sorted(bins.items())
return L
def U64(s):
return struct.unpack(">Q", s)[0]
def oid_repr(oid):
if isinstance(oid, six.binary_type) and len(oid) == 8:
return '%16x' % U64(oid)
else:
return repr(oid)
def addcommas(n):
sign, s = '', str(n)
if s[0] == '-':
......@@ -354,6 +371,7 @@ def addcommas(n):
i -= 3
return sign + s
explain = {
# The first hex digit shows the operation, the second the outcome.
# If the second digit is in "02468" then it is a 'miss'.
......
......@@ -3,7 +3,7 @@
"""Parse the BLATHER logging generated by ZEO2.
An example of the log format is:
2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514)
2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514) # NOQA: E501 line too long
"""
from __future__ import print_function
from __future__ import print_function
......@@ -14,7 +14,8 @@ from __future__ import print_function
import re
import time
rx_time = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
rx_time = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None."""
......@@ -26,11 +27,14 @@ def parse_time(line):
time_l = [int(elt) for elt in time_.split(':')]
return int(time.mktime(date_l + time_l + [0, 0, 0]))
rx_meth = re.compile("zrpc:\d+ calling (\w+)\((.*)")
rx_meth = re.compile(r"zrpc:\d+ calling (\w+)\((.*)")
def parse_method(line):
pass
def parse_line(line):
"""Parse a log entry and return time, method info, and client."""
t = parse_time(line)
......@@ -47,6 +51,7 @@ def parse_line(line):
m = meth_name, tuple(meth_args)
return t, m
class TStats(object):
counter = 1
......@@ -61,7 +66,6 @@ class TStats(object):
def report(self):
"""Print a report about the transaction"""
t = time.ctime(self.begin)
if hasattr(self, "vote"):
d_vote = self.vote - self.begin
else:
......@@ -69,10 +73,11 @@ class TStats(object):
if hasattr(self, "finish"):
d_finish = self.finish - self.begin
else:
d_finish = "*"
d_finish = "*"
print(self.fmt % (time.ctime(self.begin), d_vote, d_finish,
self.user, self.url))
class TransactionParser(object):
def __init__(self):
......@@ -122,6 +127,7 @@ class TransactionParser(object):
L.sort()
return [t for (id, t) in L]
if __name__ == "__main__":
import fileinput
......@@ -131,7 +137,7 @@ if __name__ == "__main__":
i += 1
try:
p.parse(line)
except:
except: # NOQA: E722 bare except
print("line", i)
raise
print("Transaction: %d" % len(p.txns))
......
......@@ -12,18 +12,21 @@
#
##############################################################################
from __future__ import print_function
import doctest, re, unittest
import doctest
import re
import unittest
from zope.testing import renormalizing
def test_suite():
return unittest.TestSuite((
doctest.DocFileSuite(
'zeopack.test',
checker=renormalizing.RENormalizing([
(re.compile('usage: Usage: '), 'Usage: '), # Py 2.4
(re.compile('options:'), 'Options:'), # Py 2.4
(re.compile('usage: Usage: '), 'Usage: '), # Py 2.4
(re.compile('options:'), 'Options:'), # Py 2.4
]),
globs={'print_function': print_function},
),
))
......@@ -25,6 +25,7 @@ from ZEO.ClientStorage import ClientStorage
ZERO = '\0'*8
def main():
if len(sys.argv) not in (3, 4):
sys.stderr.write("Usage: timeout.py address delay [storage-name]\n" %
......@@ -68,5 +69,6 @@ def main():
time.sleep(delay)
print("Done.")
if __name__ == "__main__":
main()
......@@ -8,7 +8,6 @@ import time
import traceback
import ZEO.ClientStorage
from six.moves import map
from six.moves import zip
usage = """Usage: %prog [options] [servers]
......@@ -21,7 +20,8 @@ each is of the form:
"""
WAIT = 10 # wait no more than 10 seconds for client to connect
WAIT = 10 # wait no more than 10 seconds for client to connect
def _main(args=None, prog=None):
if args is None:
......@@ -160,10 +160,11 @@ def _main(args=None, prog=None):
continue
cs.pack(packt, wait=True)
cs.close()
except:
except: # NOQA: E722 bare except
traceback.print_exception(*(sys.exc_info()+(99, sys.stderr)))
error("Error packing storage %s in %r" % (name, addr))
def main(*args):
root_logger = logging.getLogger()
old_level = root_logger.getEffectiveLevel()
......@@ -178,6 +179,6 @@ def main(*args):
logging.getLogger().setLevel(old_level)
logging.getLogger().removeHandler(handler)
if __name__ == "__main__":
main()
......@@ -37,7 +37,6 @@ STATEFILE = 'zeoqueue.pck'
PROGRAM = sys.argv[0]
tcre = re.compile(r"""
(?P<ymd>
\d{4}- # year
......@@ -67,7 +66,6 @@ ccre = re.compile(r"""
wcre = re.compile(r'Clients waiting: (?P<num>\d+)')
def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None."""
mo = tcre.match(line)
......@@ -97,7 +95,6 @@ class Txn(object):
return False
class Status(object):
"""Track status of ZEO server by replaying log records.
......@@ -303,7 +300,6 @@ class Status(object):
break
def usage(code, msg=''):
print(__doc__ % globals(), file=sys.stderr)
if msg:
......
......@@ -41,25 +41,25 @@ import time
import getopt
import operator
# ZEO logs measure wall-clock time so for consistency we need to do the same
#from time import clock as now
# from time import clock as now
from time import time as now
from ZODB.FileStorage import FileStorage
#from BDBStorage.BDBFullStorage import BDBFullStorage
#from Standby.primary import PrimaryStorage
#from Standby.config import RS_PORT
# from BDBStorage.BDBFullStorage import BDBFullStorage
# from Standby.primary import PrimaryStorage
# from Standby.config import RS_PORT
from ZODB.Connection import TransactionMetaData
from ZODB.utils import p64
from functools import reduce
datecre = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
methcre = re.compile("ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)")
datecre = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
methcre = re.compile(r"ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)")
class StopParsing(Exception):
pass
def usage(code, msg=''):
print(__doc__)
if msg:
......@@ -67,7 +67,6 @@ def usage(code, msg=''):
sys.exit(code)
def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None."""
mo = datecre.match(line)
......@@ -95,7 +94,6 @@ def parse_line(line):
return t, m, c
class StoreStat(object):
def __init__(self, when, oid, size):
self.when = when
......@@ -104,8 +102,10 @@ class StoreStat(object):
# Crufty
def __getitem__(self, i):
if i == 0: return self.oid
if i == 1: return self.size
if i == 0:
return self.oid
if i == 1:
return self.size
raise IndexError
......@@ -136,10 +136,10 @@ class TxnStat(object):
self._finishtime = when
# Mapping oid -> revid
_revids = {}
class ReplayTxn(TxnStat):
def __init__(self, storage):
self._storage = storage
......@@ -157,7 +157,7 @@ class ReplayTxn(TxnStat):
# BAW: simulate a pickle of the given size
data = 'x' * obj.size
# BAW: ignore versions for now
newrevid = self._storage.store(p64(oid), revid, data, '', t)
newrevid = self._storage.store(p64(oid), revid, data, '', t)
_revids[oid] = newrevid
if self._aborttime:
self._storage.tpc_abort(t)
......@@ -172,7 +172,6 @@ class ReplayTxn(TxnStat):
self._replaydelta = t1 - t0 - origdelta
class ZEOParser(object):
def __init__(self, maxtxns=-1, report=1, storage=None):
self.__txns = []
......@@ -261,7 +260,6 @@ class ZEOParser(object):
print('average faster txn was:', float(sum) / len(faster))
def main():
try:
opts, args = getopt.getopt(
......@@ -294,8 +292,8 @@ def main():
if replay:
storage = FileStorage(storagefile)
#storage = BDBFullStorage(storagefile)
#storage = PrimaryStorage('yyz', storage, RS_PORT)
# storage = BDBFullStorage(storagefile)
# storage = PrimaryStorage('yyz', storage, RS_PORT)
t0 = now()
p = ZEOParser(maxtxns, report, storage)
i = 0
......@@ -308,7 +306,7 @@ def main():
p.parse(line)
except StopParsing:
break
except:
except: # NOQA: E722 bare except
print('input file line:', i)
raise
t1 = now()
......@@ -321,6 +319,5 @@ def main():
print('total time:', t3-t0)
if __name__ == '__main__':
main()
......@@ -169,9 +169,11 @@ from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
import datetime, sys, re, os
import datetime
import os
import re
import sys
from six.moves import map
from six.moves import zip
def time(line):
......@@ -187,9 +189,10 @@ def sub(t1, t2):
return delta.days*86400.0+delta.seconds+delta.microseconds/1000000.0
waitre = re.compile(r'Clients waiting: (\d+)')
idre = re.compile(r' ZSS:\d+/(\d+.\d+.\d+.\d+:\d+) ')
def blocked_times(args):
f, thresh = args
......@@ -217,7 +220,6 @@ def blocked_times(args):
t2 = t1
if not blocking and last_blocking:
last_wait = 0
t2 = time(line)
cid = idre.search(line).group(1)
......@@ -225,11 +227,14 @@ def blocked_times(args):
d = sub(t1, time(line))
if d >= thresh:
print(t1, sub(t1, t2), cid, d)
t1 = t2 = cid = blocking = waiting = last_wait = max_wait = 0
t1 = t2 = cid = blocking = waiting = 0
last_blocking = blocking
connidre = re.compile(r' zrpc-conn:(\d+.\d+.\d+.\d+:\d+) ')
def time_calls(f):
f, thresh = f
if f == '-':
......@@ -255,6 +260,7 @@ def time_calls(f):
print(maxd)
def xopen(f):
if f == '-':
return sys.stdin
......@@ -262,6 +268,7 @@ def xopen(f):
return os.popen(f, 'r')
return open(f)
def time_tpc(f):
f, thresh = f
if f == '-':
......@@ -307,11 +314,14 @@ def time_tpc(f):
t = time(line)
d = sub(t1, t)
if d >= thresh:
print('c', t1, cid, sub(t1, t2), vs, sub(t2, t3), sub(t3, t))
print('c', t1, cid, sub(t1, t2),
vs, sub(t2, t3), sub(t3, t))
del transactions[cid]
newobre = re.compile(r"storea\(.*, '\\x00\\x00\\x00\\x00\\x00")
def time_trans(f):
f, thresh = f
if f == '-':
......@@ -363,8 +373,8 @@ def time_trans(f):
t = time(line)
d = sub(t1, t)
if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \
sub(t0, t1), sub(t1, t2), vs, \
print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs,
sub(t2, t), 'abort')
del transactions[cid]
elif ' calling tpc_finish(' in line:
......@@ -377,11 +387,12 @@ def time_trans(f):
t = time(line)
d = sub(t1, t)
if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \
sub(t0, t1), sub(t1, t2), vs, \
print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs,
sub(t2, t3), sub(t3, t))
del transactions[cid]
def minute(f, slice=16, detail=1, summary=1):
f, = f
......@@ -405,10 +416,9 @@ def minute(f, slice=16, detail=1, summary=1):
for line in f:
line = line.strip()
if (line.find('returns') > 0
or line.find('storea') > 0
or line.find('tpc_abort') > 0
):
if line.find('returns') > 0 or \
line.find('storea') > 0 or \
line.find('tpc_abort') > 0:
client = connidre.search(line).group(1)
m = line[:slice]
if m != mlast:
......@@ -452,12 +462,13 @@ def minute(f, slice=16, detail=1, summary=1):
print('Summary: \t', '\t'.join(('min', '10%', '25%', 'med',
'75%', '90%', 'max', 'mean')))
print("n=%6d\t" % len(cls), '-'*62)
print('Clients: \t', '\t'.join(map(str,stats(cls))))
print('Reads: \t', '\t'.join(map(str,stats(rs))))
print('Stores: \t', '\t'.join(map(str,stats(ss))))
print('Commits: \t', '\t'.join(map(str,stats(cs))))
print('Aborts: \t', '\t'.join(map(str,stats(aborts))))
print('Trans: \t', '\t'.join(map(str,stats(ts))))
print('Clients: \t', '\t'.join(map(str, stats(cls))))
print('Reads: \t', '\t'.join(map(str, stats(rs))))
print('Stores: \t', '\t'.join(map(str, stats(ss))))
print('Commits: \t', '\t'.join(map(str, stats(cs))))
print('Aborts: \t', '\t'.join(map(str, stats(aborts))))
print('Trans: \t', '\t'.join(map(str, stats(ts))))
def stats(s):
s.sort()
......@@ -468,13 +479,14 @@ def stats(s):
ni = n + 1
for p in .1, .25, .5, .75, .90:
lp = ni*p
l = int(lp)
lp_int = int(lp)
if lp < 1 or lp > n:
out.append('-')
elif abs(lp-l) < .00001:
out.append(s[l-1])
elif abs(lp-lp_int) < .00001:
out.append(s[lp_int-1])
else:
out.append(int(s[l-1] + (lp - l) * (s[l] - s[l-1])))
out.append(
int(s[lp_int-1] + (lp - lp_int) * (s[lp_int] - s[lp_int-1])))
mean = 0.0
for v in s:
......@@ -484,24 +496,31 @@ def stats(s):
return out
def minutes(f):
minute(f, 16, detail=0)
def hour(f):
minute(f, 13)
def day(f):
minute(f, 10)
def hours(f):
minute(f, 13, detail=0)
def days(f):
minute(f, 10, detail=0)
new_connection_idre = re.compile(
r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):")
def verify(f):
f, = f
......@@ -527,6 +546,7 @@ def verify(f):
d = sub(t1, time(line))
print(cid, t1, n, d, n and (d*1000.0/n) or '-')
def recovery(f):
f, = f
......@@ -542,16 +562,16 @@ def recovery(f):
n += 1
if line.find('RecoveryServer') < 0:
continue
l = line.find('sending transaction ')
if l > 0 and last.find('sending transaction ') > 0:
trans.append(line[l+20:].strip())
pos = line.find('sending transaction ')
if pos > 0 and last.find('sending transaction ') > 0:
trans.append(line[pos+20:].strip())
else:
if trans:
if len(trans) > 1:
print(" ... %s similar records skipped ..." % (
len(trans) - 1))
print(n, last.strip())
trans=[]
trans = []
print(n, line.strip())
last = line
......@@ -561,6 +581,5 @@ def recovery(f):
print(n, last.strip())
if __name__ == '__main__':
globals()[sys.argv[1]](sys.argv[2:])
......@@ -47,6 +47,7 @@ from ZEO.ClientStorage import ClientStorage
ZEO_VERSION = 2
def setup_logging():
# Set up logging to stderr which will show messages originating
# at severity ERROR or higher.
......@@ -59,6 +60,7 @@ def setup_logging():
handler.setFormatter(fmt)
root.addHandler(handler)
def check_server(addr, storage, write):
t0 = time.time()
if ZEO_VERSION == 2:
......@@ -97,11 +99,13 @@ def check_server(addr, storage, write):
t1 = time.time()
print("Elapsed time: %.2f" % (t1 - t0))
def usage(exit=1):
print(__doc__)
print(" ".join(sys.argv))
sys.exit(exit)
def main():
host = None
port = None
......@@ -123,7 +127,7 @@ def main():
elif o == '--nowrite':
write = 0
elif o == '-1':
ZEO_VERSION = 1
ZEO_VERSION = 1 # NOQA: F841 unused variable
except Exception as err:
s = str(err)
if s:
......@@ -143,6 +147,7 @@ def main():
setup_logging()
check_server(addr, storage, write)
if __name__ == "__main__":
try:
main()
......
......@@ -14,8 +14,9 @@
REPR_LIMIT = 60
def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes."
"""Return an object repr limited to REPR_LIMIT bytes."""
# Some of the objects being repr'd are large strings. A lot of memory
# would be wasted to repr them and then truncate, so they are treated
......
......@@ -17,6 +17,7 @@ from ZODB.Connection import TransactionMetaData
from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle
class TransUndoStorageWithCache(object):
def checkUndoInvalidation(self):
......
......@@ -20,12 +20,12 @@ from persistent.TimeStamp import TimeStamp
from ZODB.Connection import TransactionMetaData
from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
import ZEO.ClientStorage
from ZEO.Exceptions import ClientDisconnected
from ZEO.tests.TestThread import TestThread
ZERO = b'\0'*8
class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for
......@@ -62,6 +62,7 @@ class WorkerThread(TestThread):
self.ready.set()
future.result(9)
class CommitLockTests(object):
NUM_CLIENTS = 5
......@@ -99,7 +100,7 @@ class CommitLockTests(object):
for i in range(self.NUM_CLIENTS):
storage = self._new_storage_client()
txn = TransactionMetaData()
tid = self._get_timestamp()
self._get_timestamp()
t = WorkerThread(self, storage, txn)
self._threads.append(t)
......@@ -118,9 +119,10 @@ class CommitLockTests(object):
def _get_timestamp(self):
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,))
t = TimeStamp(*time.gmtime(t)[:5]+(t % 60,))
return repr(t)
class CommitLockVoteTests(CommitLockTests):
def checkCommitLockVoteFinish(self):
......
......@@ -26,11 +26,10 @@ from ZEO.tests import forker
from ZODB.Connection import TransactionMetaData
from ZODB.DB import DB
from ZODB.POSException import ReadOnlyError, ConflictError
from ZODB.POSException import ReadOnlyError
from ZODB.tests.StorageTestBase import StorageTestBase
from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle
import ZODB.tests.util
import transaction
......@@ -40,6 +39,7 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests')
ZERO = '\0'*8
class TestClientStorage(ClientStorage):
test_connection = False
......@@ -51,6 +51,7 @@ class TestClientStorage(ClientStorage):
self.connection_count_for_tests += 1
self.verify_result = conn.verify_result
class DummyDB(object):
def invalidate(self, *args, **kwargs):
pass
......@@ -93,7 +94,7 @@ class CommonSetupTearDown(StorageTestBase):
self._storage.close()
if hasattr(self._storage, 'cleanup'):
logging.debug("cleanup storage %s" %
self._storage.__name__)
self._storage.__name__)
self._storage.cleanup()
for stop in self._servers:
stop()
......@@ -113,7 +114,7 @@ class CommonSetupTearDown(StorageTestBase):
for dummy in range(5):
try:
os.unlink(path)
except:
except: # NOQA: E722 bare except
time.sleep(0.5)
else:
need_to_delete = False
......@@ -188,7 +189,7 @@ class CommonSetupTearDown(StorageTestBase):
stop = self._servers[index]
if stop is not None:
stop()
self._servers[index] = lambda : None
self._servers[index] = lambda: None
def pollUp(self, timeout=30.0, storage=None):
if storage is None:
......@@ -271,7 +272,6 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ReadOnlyError, self._dostore)
self._storage.close()
def checkDisconnectionError(self):
# Make sure we get a ClientDisconnected when we try to read an
# object when we're not connected to a storage server and the
......@@ -374,7 +374,7 @@ class ConnectionTests(CommonSetupTearDown):
pickle, rev = self._storage.load(oid, '')
newobj = zodb_unpickle(pickle)
self.assertEqual(newobj, obj)
newobj.value = 42 # .value *should* be 42 forever after now, not 13
newobj.value = 42 # .value *should* be 42 forever after now, not 13
self._dostore(oid, data=newobj, revid=rev)
self._storage.close()
......@@ -416,6 +416,7 @@ class ConnectionTests(CommonSetupTearDown):
def checkBadMessage2(self):
# just like a real message, but with an unpicklable argument
global Hack
class Hack(object):
pass
......@@ -505,7 +506,7 @@ class ConnectionTests(CommonSetupTearDown):
r1["a"] = MinPO("a")
transaction.commit()
self.assertEqual(r1._p_state, 0) # up-to-date
self.assertEqual(r1._p_state, 0) # up-to-date
db2 = DB(self.openClientStorage())
r2 = db2.open().root()
......@@ -524,9 +525,9 @@ class ConnectionTests(CommonSetupTearDown):
if r1._p_state == -1:
break
time.sleep(i / 10.0)
self.assertEqual(r1._p_state, -1) # ghost
self.assertEqual(r1._p_state, -1) # ghost
r1.keys() # unghostify
r1.keys() # unghostify
self.assertEqual(r1._p_serial, r2._p_serial)
self.assertEqual(r1["b"].value, "b")
......@@ -551,6 +552,7 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '')
class SSLConnectionTests(ConnectionTests):
def getServerConfig(self, addr, ro_svr):
......@@ -585,13 +587,13 @@ class InvqTests(CommonSetupTearDown):
revid2 = self._dostore(oid2, revid2)
forker.wait_until(
lambda :
lambda:
perstorage.lastTransaction() == self._storage.lastTransaction())
perstorage.load(oid, '')
perstorage.close()
forker.wait_until(lambda : os.path.exists('test-1.zec'))
forker.wait_until(lambda: os.path.exists('test-1.zec'))
revid = self._dostore(oid, revid)
......@@ -617,7 +619,7 @@ class InvqTests(CommonSetupTearDown):
revid = self._dostore(oid, revid)
forker.wait_until(
"Client has seen all of the transactions from the server",
lambda :
lambda:
perstorage.lastTransaction() == self._storage.lastTransaction()
)
perstorage.load(oid, '')
......@@ -635,6 +637,7 @@ class InvqTests(CommonSetupTearDown):
perstorage.close()
class ReconnectionTests(CommonSetupTearDown):
# The setUp() starts a server automatically. In order for its
# state to persist, we set the class variable keep to 1. In
......@@ -798,7 +801,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Start a read-write server
self.startServer(index=1, read_only=0, keep=0)
# After a while, stores should work
for i in range(300): # Try for 30 seconds
for i in range(300): # Try for 30 seconds
try:
self._dostore()
break
......@@ -840,7 +843,7 @@ class ReconnectionTests(CommonSetupTearDown):
revid = self._dostore(oid, revid)
forker.wait_until(
"Client has seen all of the transactions from the server",
lambda :
lambda:
perstorage.lastTransaction() == self._storage.lastTransaction()
)
perstorage.load(oid, '')
......@@ -894,7 +897,6 @@ class ReconnectionTests(CommonSetupTearDown):
# Module ZEO.ClientStorage, line 709, in _update_cache
# KeyError: ...
def checkReconnection(self):
# Check that the client reconnects when a server restarts.
......@@ -952,6 +954,7 @@ class ReconnectionTests(CommonSetupTearDown):
self.assertTrue(did_a_store)
self._storage.close()
class TimeoutTests(CommonSetupTearDown):
timeout = 1
......@@ -967,9 +970,8 @@ class TimeoutTests(CommonSetupTearDown):
# Make sure it's logged as CRITICAL
with open("server.log") as f:
for line in f:
if (('Transaction timeout after' in line) and
('CRITICAL ZEO.StorageServer' in line)
):
if ('Transaction timeout after' in line) and \
('CRITICAL ZEO.StorageServer' in line):
break
else:
self.fail('bad logging')
......@@ -1002,7 +1004,7 @@ class TimeoutTests(CommonSetupTearDown):
t = TransactionMetaData()
old_connection_count = storage.connection_count_for_tests
storage.tpc_begin(t)
revid1 = storage.store(oid, ZERO, zodb_pickle(obj), '', t)
storage.store(oid, ZERO, zodb_pickle(obj), '', t)
storage.tpc_vote(t)
# Now sleep long enough for the storage to time out
time.sleep(3)
......@@ -1021,6 +1023,7 @@ class TimeoutTests(CommonSetupTearDown):
# or the server.
self.assertRaises(KeyError, storage.load, oid, '')
class MSTThread(threading.Thread):
__super_init = threading.Thread.__init__
......@@ -1054,7 +1057,7 @@ class MSTThread(threading.Thread):
# Begin a transaction
t = TransactionMetaData()
for c in clients:
#print("%s.%s.%s begin" % (tname, c.__name, i))
# print("%s.%s.%s begin" % (tname, c.__name, i))
c.tpc_begin(t)
for j in range(testcase.nobj):
......@@ -1063,18 +1066,18 @@ class MSTThread(threading.Thread):
oid = c.new_oid()
c.__oids.append(oid)
data = MinPO("%s.%s.t%d.o%d" % (tname, c.__name, i, j))
#print(data.value)
# print(data.value)
data = zodb_pickle(data)
c.store(oid, ZERO, data, '', t)
# Vote on all servers and handle serials
for c in clients:
#print("%s.%s.%s vote" % (tname, c.__name, i))
# print("%s.%s.%s vote" % (tname, c.__name, i))
c.tpc_vote(t)
# Finish on all servers
for c in clients:
#print("%s.%s.%s finish\n" % (tname, c.__name, i))
# print("%s.%s.%s finish\n" % (tname, c.__name, i))
c.tpc_finish(t)
for c in clients:
......@@ -1090,7 +1093,7 @@ class MSTThread(threading.Thread):
for c in self.clients:
try:
c.close()
except:
except: # NOQA: E722 bare except
pass
......@@ -1101,6 +1104,7 @@ def short_timeout(self):
yield
self._storage._server.timeout = old
# Run IPv6 tests if V6 sockets are supported
try:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
......
......@@ -41,6 +41,7 @@ from ZODB.POSException import ReadConflictError, ConflictError
# thought they added (i.e., the keys for which transaction.commit()
# did not raise any exception).
class FailableThread(TestThread):
# mixin class
......@@ -52,7 +53,7 @@ class FailableThread(TestThread):
def testrun(self):
try:
self._testrun()
except:
except: # NOQA: E722 bare except
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
......@@ -81,12 +82,11 @@ class StressTask(object):
tree[key] = self.threadnum
def commit(self):
cn = self.cn
key = self.startnum
self.tm.get().note(u"add key %s" % key)
try:
self.tm.get().commit()
except ConflictError as msg:
except ConflictError:
self.tm.abort()
else:
if self.sleep:
......@@ -98,15 +98,18 @@ class StressTask(object):
self.tm.get().abort()
self.cn.close()
def _runTasks(rounds, *tasks):
'''run *task* interleaved for *rounds* rounds.'''
def commit(run, actions):
actions.append(':')
for t in run:
t.commit()
del run[:]
r = Random()
r.seed(1064589285) # make it deterministic
r.seed(1064589285) # make it deterministic
run = []
actions = []
try:
......@@ -117,7 +120,7 @@ def _runTasks(rounds, *tasks):
run.append(t)
t.doStep()
actions.append(repr(t.startnum))
commit(run,actions)
commit(run, actions)
# stderr.write(' '.join(actions)+'\n')
finally:
for t in tasks:
......@@ -160,13 +163,14 @@ class StressThread(FailableThread):
self.commitdict[self] = 1
if self.sleep:
time.sleep(self.sleep)
except (ReadConflictError, ConflictError) as msg:
except (ReadConflictError, ConflictError):
tm.abort()
else:
self.added_keys.append(key)
key += self.step
cn.close()
class LargeUpdatesThread(FailableThread):
# A thread that performs a lot of updates. It attempts to modify
......@@ -195,7 +199,7 @@ class LargeUpdatesThread(FailableThread):
# print("%d getting tree abort" % self.threadnum)
transaction.abort()
keys_added = {} # set of keys we commit
keys_added = {} # set of keys we commit
tkeys = []
while not self.stop.isSet():
......@@ -212,7 +216,7 @@ class LargeUpdatesThread(FailableThread):
for key in keys:
try:
tree[key] = self.threadnum
except (ReadConflictError, ConflictError) as msg:
except (ReadConflictError, ConflictError): # as msg:
# print("%d setting key %s" % (self.threadnum, msg))
transaction.abort()
break
......@@ -224,7 +228,7 @@ class LargeUpdatesThread(FailableThread):
self.commitdict[self] = 1
if self.sleep:
time.sleep(self.sleep)
except ConflictError as msg:
except ConflictError: # as msg
# print("%d commit %s" % (self.threadnum, msg))
transaction.abort()
continue
......@@ -234,6 +238,7 @@ class LargeUpdatesThread(FailableThread):
self.added_keys = keys_added.keys()
cn.close()
class InvalidationTests(object):
# Minimum # of seconds the main thread lets the workers run. The
......@@ -261,7 +266,7 @@ class InvalidationTests(object):
transaction.abort()
else:
raise
except:
except: # NOQA: E722 bare except
display(tree)
raise
......
......@@ -21,6 +21,7 @@ from ZODB.Connection import TransactionMetaData
from ..asyncio.testing import AsyncRPC
class IterationTests(object):
def _assertIteratorIdsEmpty(self):
......@@ -44,7 +45,7 @@ class IterationTests(object):
# everything goes away as expected.
gc.enable()
gc.collect()
gc.collect() # sometimes PyPy needs it twice to clear weak refs
gc.collect() # sometimes PyPy needs it twice to clear weak refs
self._storage._iterator_gc()
......@@ -147,7 +148,6 @@ class IterationTests(object):
self._dostore()
six.advance_iterator(self._storage.iterator())
iid = list(self._storage._iterator_ids)[0]
t = TransactionMetaData()
self._storage.tpc_begin(t)
# Show that after disconnecting, the client side GCs the iterators
......@@ -176,12 +176,12 @@ def iterator_sane_after_reconnect():
Start a server:
>>> addr, adminaddr = start_server(
>>> addr, adminaddr = start_server( # NOQA: F821 undefined
... '<filestorage>\npath fs\n</filestorage>', keep=1)
Open a client storage to it and commit a some transactions:
>>> import ZEO, ZODB, transaction
>>> import ZEO, ZODB
>>> client = ZEO.client(addr)
>>> db = ZODB.DB(client)
>>> conn = db.open()
......@@ -196,10 +196,11 @@ Create an iterator:
Restart the storage:
>>> stop_server(adminaddr)
>>> wait_disconnected(client)
>>> _ = start_server('<filestorage>\npath fs\n</filestorage>', addr=addr)
>>> wait_connected(client)
>>> stop_server(adminaddr) # NOQA: F821 undefined
>>> wait_disconnected(client) # NOQA: F821 undefined
>>> _ = start_server( # NOQA: F821 undefined
... '<filestorage>\npath fs\n</filestorage>', addr=addr)
>>> wait_connected(client) # NOQA: F821 undefined
Now, we'll create a second iterator:
......
......@@ -16,6 +16,7 @@ import threading
import sys
import six
class TestThread(threading.Thread):
"""Base class for defining threads that run from unittest.
......@@ -46,12 +47,14 @@ class TestThread(threading.Thread):
def run(self):
try:
self.testrun()
except:
except: # NOQA: E722 blank except
self._exc_info = sys.exc_info()
def cleanup(self, timeout=15):
self.join(timeout)
if self._exc_info:
six.reraise(self._exc_info[0], self._exc_info[1], self._exc_info[2])
six.reraise(self._exc_info[0],
self._exc_info[1],
self._exc_info[2])
if self.is_alive():
self._testcase.fail("Thread did not finish: %s" % self)
......@@ -21,6 +21,7 @@ import ZEO.Exceptions
ZERO = '\0'*8
class BasicThread(threading.Thread):
def __init__(self, storage, doNextEvent, threadStartedEvent):
self.storage = storage
......@@ -123,7 +124,6 @@ class ThreadTests(object):
# Helper for checkMTStores
def mtstorehelper(self):
name = threading.currentThread().getName()
objs = []
for i in range(10):
objs.append(MinPO("X" * 200000))
......
......@@ -41,7 +41,6 @@ from ZEO._compat import Pickler, Unpickler, PY3, BytesIO
from ZEO.Exceptions import AuthError
from .monitor import StorageStats, StatsServer
from .zrpc.connection import ManagedServerConnection, Delay, MTDelay, Result
from .zrpc.server import Dispatcher
from ZODB.Connection import TransactionMetaData
from ZODB.loglevels import BLATHER
from ZODB.POSException import StorageError, StorageTransactionError
......@@ -53,6 +52,7 @@ ResolvedSerial = b'rs'
logger = logging.getLogger('ZEO.StorageServer')
def log(message, level=logging.INFO, label='', exc_info=False):
"""Internal helper to log a message."""
if label:
......@@ -152,7 +152,7 @@ class ZEOStorage(object):
info = self.get_info()
if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: ()
self.undoLog = self.undoInfo = lambda *a, **k: ()
self.getTid = storage.getTid
self.load = storage.load
......@@ -164,7 +164,7 @@ class ZEOStorage(object):
try:
fn = storage.getExtensionMethods
except AttributeError:
pass # no extension methods
pass # no extension methods
else:
d = fn()
self._extensions.update(d)
......@@ -182,14 +182,14 @@ class ZEOStorage(object):
"Falling back to using _transaction attribute, which\n."
"is icky.",
logging.ERROR)
self.tpc_transaction = lambda : storage._transaction
self.tpc_transaction = lambda: storage._transaction
else:
raise
def history(self,tid,size=1):
def history(self, tid, size=1):
# This caters for storages which still accept
# a version parameter.
return self.storage.history(tid,size=size)
return self.storage.history(tid, size=size)
def _check_tid(self, tid, exc=None):
if self.read_only:
......@@ -253,8 +253,7 @@ class ZEOStorage(object):
def get_info(self):
storage = self.storage
supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)()
supportsUndo = (getattr(storage, 'supportsUndo', lambda: False)()
and self.connection.peer_protocol_version >= b'Z310')
# Communicate the backend storage interfaces to the client
......@@ -448,7 +447,7 @@ class ZEOStorage(object):
def _try_to_vote(self, delay=None):
if self.connection is None:
return # We're disconnected
return # We're disconnected
if delay is not None and delay.sent:
# as a consequence of the unlocking strategy, _try_to_vote
# may be called multiple times for delayed
......@@ -473,7 +472,6 @@ class ZEOStorage(object):
if not getattr(self, op)(*args):
break
# Blob support
while self.blob_log and not self.store_failed:
oid, oldserial, data, blobfilename = self.blob_log.pop()
......@@ -547,7 +545,7 @@ class ZEOStorage(object):
def storeBlobEnd(self, oid, serial, data, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
assert self.txnlog is not None # effectively not allowed after undo
fd, tempname = self.blob_tempfile
self.blob_tempfile = None
os.close(fd)
......@@ -555,14 +553,12 @@ class ZEOStorage(object):
def storeBlobShared(self, oid, serial, data, filename, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
assert self.txnlog is not None # effectively not allowed after undo
# Reconstruct the full path from the filename in the OID directory
if (os.path.sep in filename
or not (filename.endswith('.tmp')
or filename[:-1].endswith('.tmp')
)
):
if os.path.sep in filename or \
not (filename.endswith('.tmp')
or filename[:-1].endswith('.tmp')):
logger.critical(
"We're under attack! (bad filename to storeBlobShared, %r)",
filename)
......@@ -590,7 +586,7 @@ class ZEOStorage(object):
(oid_repr(oid), str(err)), BLATHER)
if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client
self.log("%s error: %s, %s" % ((op,)+ sys.exc_info()[:2]),
self.log("%s error: %s, %s" % ((op,) + sys.exc_info()[:2]),
logging.ERROR, exc_info=True)
err = self._marshal_error(err)
# The exception is reported back as newserial for this oid
......@@ -691,7 +687,7 @@ class ZEOStorage(object):
pickler.fast = 1
try:
pickler.dump(error)
except:
except: # NOQA: E722 bare except
msg = "Couldn't pickle storage exception: %s" % repr(error)
self.log(msg, logging.ERROR)
error = StorageServerError(msg)
......@@ -758,6 +754,7 @@ class ZEOStorage(object):
def set_client_label(self, label):
self.log_label = str(label)+' '+_addr_label(self.connection.addr)
class StorageServerDB(object):
def __init__(self, server, storage_id):
......@@ -776,6 +773,7 @@ class StorageServerDB(object):
transform_record_data = untransform_record_data = lambda self, data: data
class StorageServer(object):
"""The server side implementation of ZEO.
......@@ -876,7 +874,6 @@ class StorageServer(object):
log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg))
self._lock = threading.Lock()
self._commit_locks = {}
self._waiting = dict((name, []) for name in storages)
......@@ -942,7 +939,6 @@ class StorageServer(object):
self.invq[name] = list(lastInvalidations(self.invq_bound))
self.invq[name].reverse()
def _setup_auth(self, protocol):
# Can't be done in global scope, because of cyclic references
from .auth import get_module
......@@ -976,7 +972,6 @@ class StorageServer(object):
"does not match storage realm %r"
% (self.database.realm, self.auth_realm))
def new_connection(self, sock, addr):
"""Internal: factory to create a new connection.
......@@ -1050,7 +1045,6 @@ class StorageServer(object):
except DisconnectedError:
pass
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients.
......@@ -1096,7 +1090,6 @@ class StorageServer(object):
# b. A connection is closes while we are iterating. We'll need
# to cactch and ignore Disconnected errors.
if invalidated:
invq = self.invq[storage_id]
if len(invq) >= self.invq_bound:
......@@ -1156,9 +1149,10 @@ class StorageServer(object):
asyncore.loop(timeout, map=self.socket_map)
except Exception:
if not self.__closed:
raise # Unexpected exc
raise # Unexpected exc
__thread = None
def start_thread(self, daemon=True):
self.__thread = thread = threading.Thread(target=self.loop)
thread.setName("StorageServer(%s)" % _addr_label(self.addr))
......@@ -1166,6 +1160,7 @@ class StorageServer(object):
thread.start()
__closed = False
def close(self, join_timeout=1):
"""Close the dispatcher so that there are no new connections.
......@@ -1187,7 +1182,7 @@ class StorageServer(object):
for conn in connections[:]:
try:
conn.connection.close()
except:
except: # NOQA: E722 bare except
pass
for name, storage in six.iteritems(self.storages):
......@@ -1282,7 +1277,6 @@ class StorageServer(object):
except Exception:
logger.exception("Calling unlock callback")
def stop_waiting(self, zeostore):
storage_id = zeostore.storage_id
waiting = self._waiting[storage_id]
......@@ -1307,7 +1301,8 @@ class StorageServer(object):
status = self.stats[storage_id].__dict__.copy()
status['connections'] = len(status['connections'])
status['waiting'] = len(self._waiting[storage_id])
status['timeout-thread-is-alive'] = self.timeouts[storage_id].is_alive()
status['timeout-thread-is-alive'] = \
self.timeouts[storage_id].is_alive()
last_transaction = self.storages[storage_id].lastTransaction()
last_transaction_hex = codecs.encode(last_transaction, 'hex_codec')
if PY3:
......@@ -1320,6 +1315,7 @@ class StorageServer(object):
return dict((storage_id, self.server_status(storage_id))
for storage_id in self.storages)
def _level_for_waiting(waiting):
if len(waiting) > 9:
return logging.CRITICAL
......@@ -1328,6 +1324,7 @@ def _level_for_waiting(waiting):
else:
return logging.DEBUG
class StubTimeoutThread(object):
def begin(self, client):
......@@ -1336,7 +1333,8 @@ class StubTimeoutThread(object):
def end(self, client):
pass
is_alive = lambda self: 'stub'
def is_alive(self):
return 'stub'
class TimeoutThread(threading.Thread):
......@@ -1352,7 +1350,7 @@ class TimeoutThread(threading.Thread):
self._timeout = timeout
self._client = None
self._deadline = None
self._cond = threading.Condition() # Protects _client and _deadline
self._cond = threading.Condition() # Protects _client and _deadline
def begin(self, client):
# Called from the restart code the "main" thread, whenever the
......@@ -1382,14 +1380,14 @@ class TimeoutThread(threading.Thread):
if howlong <= 0:
# Prevent reporting timeout more than once
self._deadline = None
client = self._client # For the howlong <= 0 branch below
client = self._client # For the howlong <= 0 branch below
if howlong <= 0:
client.log("Transaction timeout after %s seconds" %
self._timeout, logging.CRITICAL)
try:
client.connection.call_from_thread(client.connection.close)
except:
except: # NOQA: E722 bare except
client.log("Timeout failure", logging.CRITICAL,
exc_info=sys.exc_info())
self.end(client)
......@@ -1485,6 +1483,7 @@ class ClientStub(object):
self.rpc.callAsyncIterator(store())
class ClientStub308(ClientStub):
def invalidateTransaction(self, tid, args):
......@@ -1494,6 +1493,7 @@ class ClientStub308(ClientStub):
def invalidateVerify(self, oid):
ClientStub.invalidateVerify(self, (oid, ''))
class ZEOStorage308Adapter(object):
def __init__(self, storage):
......@@ -1503,7 +1503,7 @@ class ZEOStorage308Adapter(object):
return self is other or self.storage is other
def getSerial(self, oid):
return self.storage.loadEx(oid)[1] # Z200
return self.storage.loadEx(oid)[1] # Z200
def history(self, oid, version, size=1):
if version:
......@@ -1573,6 +1573,7 @@ class ZEOStorage308Adapter(object):
def __getattr__(self, name):
return getattr(self.storage, name)
def _addr_label(addr):
if isinstance(addr, six.binary_type):
return addr.decode('ascii')
......@@ -1582,6 +1583,7 @@ def _addr_label(addr):
host, port = addr
return str(host) + ":" + str(port)
class CommitLog(object):
def __init__(self):
......@@ -1624,14 +1626,17 @@ class CommitLog(object):
self.file.close()
self.file = None
class ServerEvent(object):
def __init__(self, server, **kw):
self.__dict__.update(kw)
self.server = server
class Serving(ServerEvent):
pass
class Closed(ServerEvent):
pass
......@@ -14,6 +14,7 @@
_auth_modules = {}
def get_module(name):
if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database
......@@ -24,6 +25,7 @@ def get_module(name):
else:
return _auth_modules.get(name)
def register_module(name, storage_class, client, db):
if name in _auth_modules:
raise TypeError("%s is already registred" % name)
......
......@@ -45,6 +45,7 @@ from ..StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError
from ..hash import sha1
def get_random_bytes(n=8):
try:
b = os.urandom(n)
......@@ -53,9 +54,11 @@ def get_random_bytes(n=8):
b = b"".join(L)
return b
def hexdigest(s):
return sha1(s.encode()).hexdigest()
class DigestDatabase(Database):
def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm)
......@@ -69,6 +72,7 @@ class DigestDatabase(Database):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig
def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up
......@@ -77,6 +81,7 @@ def session_key(h_up, nonce):
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage):
def set_database(self, database):
assert isinstance(database, DigestDatabase)
......@@ -124,6 +129,7 @@ class StorageClass(ZEOStorage):
extensions = [auth_get_challenge, auth_response]
class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"]
......
......@@ -22,6 +22,7 @@ from __future__ import print_function
import os
from ..hash import sha1
class Client(object):
# Subclass should override to list the names of methods that
# will be called on the server.
......@@ -32,11 +33,13 @@ class Client(object):
for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L):
"""Sort a list in-place and return it."""
L.sort()
return L
class Database(object):
"""Abstracts a password database.
......@@ -49,6 +52,7 @@ class Database(object):
produced from the password string.
"""
realm = None
def __init__(self, filename, realm=None):
"""Creates a new Database
......
......@@ -3,24 +3,26 @@
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC(object):
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object.
key: key for the keyed hash object.
......@@ -49,8 +51,8 @@ class HMAC(object):
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
# def clear(self):
# raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
......@@ -85,7 +87,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
......
......@@ -47,6 +47,7 @@ else:
if zeo_dist is not None:
zeo_version = zeo_dist.version
class StorageStats(object):
"""Per-storage usage statistics."""
......@@ -113,6 +114,7 @@ class StorageStats(object):
print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr):
......@@ -144,6 +146,7 @@ class StatsClient(asyncore.dispatcher):
if self.closed and not self.buf:
asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient
......
......@@ -49,21 +49,24 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address
def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all()
class ZEOOptionsMixin(object):
storages = None
......@@ -75,14 +78,18 @@ class ZEOOptionsMixin(object):
self.monitor_family, self.monitor_address = parse_binding_address(arg)
def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*!
from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object):
def __init__(self, name, path):
self._name = name
self.path = path
self.stop = None
def getSectionName(self):
return self._name
if not self.storages:
self.storages = []
name = str(1 + len(self.storages))
......@@ -90,6 +97,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf)
testing_exit_immediately = False
def handle_test(self, *args):
self.testing_exit_immediately = True
......@@ -117,6 +125,7 @@ class ZEOOptionsMixin(object):
self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=')
class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__
......@@ -179,8 +188,8 @@ class ZEOServer(object):
root.addHandler(handler)
def check_socket(self):
if (isinstance(self.options.address, tuple) and
self.options.address[1] is None):
if isinstance(self.options.address, tuple) and \
self.options.address[1] is None:
self.options.address = self.options.address[0], 0
return
if self.can_connect(self.options.family, self.options.address):
......@@ -224,7 +233,7 @@ class ZEOServer(object):
self.setup_win32_signals()
return
if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames()
for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None)
......@@ -244,12 +253,12 @@ class ZEOServer(object):
"will *not* be installed.")
return
SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32.
if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows.
SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self):
......@@ -275,20 +284,21 @@ class ZEOServer(object):
def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
level=logging.WARNING)
return
loggers = [self.options.config_logger]
if os.name == "posix":
for l in loggers:
l.reopen()
for logger in loggers:
logger.reopen()
log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers:
for f in l.handler_factories:
else: # nt - same rotation code as in Zope's Signals/Signals.py
for logger in loggers:
for f in logger.handler_factories:
handler = f()
if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate()
......@@ -347,14 +357,14 @@ def create_server(storages, options):
return StorageServer(
options.address,
storages,
read_only = options.read_only,
invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout,
monitor_address = options.monitor_address,
auth_protocol = options.auth_protocol,
auth_database = options.auth_database,
auth_realm = options.auth_realm,
read_only=options.read_only,
invalidation_queue_size=options.invalidation_queue_size,
invalidation_age=options.invalidation_age,
transaction_timeout=options.transaction_timeout,
monitor_address=options.monitor_address,
auth_protocol=options.auth_protocol,
auth_database=options.auth_database,
auth_realm=options.auth_realm,
)
......@@ -362,6 +372,7 @@ def create_server(storages, options):
signames = None
def signame(sig):
"""Return a symbolic name for a signal.
......@@ -373,6 +384,7 @@ def signame(sig):
init_signames()
return signames.get(sig) or "signal %d" % sig
def init_signames():
global signames
signames = {}
......@@ -392,5 +404,6 @@ def main(args=None):
s = ZEOServer(options)
s.main()
if __name__ == "__main__":
main()
......@@ -6,24 +6,26 @@
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC(object):
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object.
key: key for the keyed hash object.
......@@ -56,8 +58,8 @@ class HMAC(object):
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
# def clear(self):
# raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
......@@ -92,7 +94,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
......
......@@ -34,6 +34,7 @@ from six.moves import map
def client_timeout():
return 30.0
def client_loop(map):
read = asyncore.read
write = asyncore.write
......@@ -52,7 +53,7 @@ def client_loop(map):
r, w, e = select.select(r, w, e, client_timeout())
except (select.error, RuntimeError) as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
# which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno != errno.EINTR:
if err_errno == errno.EBADF:
......@@ -114,14 +115,13 @@ def client_loop(map):
continue
_exception(obj)
except:
except: # NOQA: E722 bare except
if map:
try:
logging.getLogger(__name__+'.client_loop').critical(
'A ZEO client loop failed.',
exc_info=sys.exc_info())
except:
except: # NOQA: E722 bare except
pass
for fd, obj in map.items():
......@@ -129,14 +129,14 @@ def client_loop(map):
continue
try:
obj.mgr.client.close()
except:
except: # NOQA: E722 bare except
map.pop(fd, None)
try:
logging.getLogger(__name__+'.client_loop'
).critical(
"Couldn't close a dispatcher.",
exc_info=sys.exc_info())
except:
except: # NOQA: E722 bare except
pass
......@@ -152,11 +152,11 @@ class ConnectionManager(object):
self.tmin = min(tmin, tmax)
self.tmax = tmax
self.cond = threading.Condition(threading.Lock())
self.connection = None # Protected by self.cond
self.connection = None # Protected by self.cond
self.closed = 0
# If thread is not None, then there is a helper thread
# attempting to connect.
self.thread = None # Protected by self.cond
self.thread = None # Protected by self.cond
def new_addrs(self, addrs):
self.addrlist = self._parse_addrs(addrs)
......@@ -189,7 +189,8 @@ class ConnectionManager(object):
for addr in addrs:
addr_type = self._guess_type(addr)
if addr_type is None:
raise ValueError("unknown address in list: %s" % repr(addr))
raise ValueError(
"unknown address in list: %s" % repr(addr))
addrlist.append((addr_type, addr))
return addrlist
......@@ -197,10 +198,10 @@ class ConnectionManager(object):
if isinstance(addr, str):
return socket.AF_UNIX
if (len(addr) == 2
and isinstance(addr[0], str)
and isinstance(addr[1], int)):
return socket.AF_INET # also denotes IPv6
if len(addr) == 2 and \
isinstance(addr[0], str) and \
isinstance(addr[1], int):
return socket.AF_INET # also denotes IPv6
# not anything I know about
return None
......@@ -226,7 +227,7 @@ class ConnectionManager(object):
if obj is not self.trigger:
try:
obj.close()
except:
except: # NOQA: E722 bare except
logging.getLogger(__name__+'.'+self.__class__.__name__
).critical(
"Couldn't close a dispatcher.",
......@@ -237,7 +238,7 @@ class ConnectionManager(object):
try:
self.loop_thread.join(9)
except RuntimeError:
pass # we are the thread :)
pass # we are the thread :)
self.trigger.close()
def attempt_connect(self):
......@@ -304,7 +305,7 @@ class ConnectionManager(object):
self.connection = conn
if preferred:
self.thread = None
self.cond.notifyAll() # Wake up connect(sync=1)
self.cond.notifyAll() # Wake up connect(sync=1)
finally:
self.cond.release()
......@@ -331,6 +332,7 @@ class ConnectionManager(object):
finally:
self.cond.release()
# When trying to do a connect on a non-blocking socket, some outcomes
# are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected
# when an initial connect can't complete immediately. Set _CONNECT_OK
......@@ -342,10 +344,11 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # Windows
# seen this.
_CONNECT_IN_PROGRESS = (errno.WSAEWOULDBLOCK,)
# Win98: WSAEISCONN; Win2K: WSAEINVAL
_CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL)
_CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL)
else: # Unix
_CONNECT_IN_PROGRESS = (errno.EINPROGRESS,)
_CONNECT_OK = (0, errno.EISCONN)
_CONNECT_OK = (0, errno.EISCONN)
class ConnectThread(threading.Thread):
"""Thread that tries to connect to server given one or more addresses.
......@@ -455,7 +458,7 @@ class ConnectThread(threading.Thread):
) in socket.getaddrinfo(host or 'localhost', port,
socket.AF_INET,
socket.SOCK_STREAM
): # prune non-TCP results
): # prune non-TCP results
# for IPv6, drop flowinfo, and restrict addresses
# to [host]:port
yield family, sockaddr[:2]
......@@ -495,7 +498,7 @@ class ConnectThread(threading.Thread):
break
try:
r, w, x = select.select([], connecting, connecting, 1.0)
log("CT: select() %d, %d, %d" % tuple(map(len, (r,w,x))))
log("CT: select() %d, %d, %d" % tuple(map(len, (r, w, x))))
except select.error as msg:
log("CT: select failed; msg=%s" % str(msg),
level=logging.WARNING)
......@@ -509,7 +512,7 @@ class ConnectThread(threading.Thread):
for wrap in w:
wrap.connect_procedure()
if wrap.state == "notified":
del wrappers[wrap] # Don't close this one
del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys():
wrap.close()
return 1
......@@ -526,7 +529,7 @@ class ConnectThread(threading.Thread):
else:
wrap.notify_client()
if wrap.state == "notified":
del wrappers[wrap] # Don't close this one
del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys():
wrap.close()
return -1
......@@ -602,7 +605,7 @@ class ConnectWrapper(object):
to do app-level check of the connection.
"""
self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr)
self.sock = None # The socket is now owned by the connection
self.sock = None # The socket is now owned by the connection
try:
self.preferred = self.client.testConnection(self.conn)
self.state = "tested"
......@@ -610,7 +613,7 @@ class ConnectWrapper(object):
log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr))
self.close()
return
except:
except: # NOQA: E722 bare except
log("CW: error in testConnection (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True)
self.close()
......@@ -629,7 +632,7 @@ class ConnectWrapper(object):
"""
try:
self.client.notifyConnected(self.conn)
except:
except: # NOQA: E722 bare except
log("CW: error in notifyConnected (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True)
self.close()
......
......@@ -26,12 +26,13 @@ from .log import short_repr, log
from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException
REPLY = ".reply" # message name used for replies
REPLY = ".reply" # message name used for replies
exception_type_type = type(Exception)
debug_zrpc = False
class Delay(object):
"""Used to delay response to client for synchronous calls.
......@@ -57,7 +58,9 @@ class Delay(object):
def __repr__(self):
return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self), self.msgid, self.conn, self.sent)
self.__class__.__name__, id(self), self.msgid,
self.conn, self.sent)
class Result(Delay):
......@@ -69,6 +72,7 @@ class Result(Delay):
conn.send_reply(msgid, reply, False)
callback()
class MTDelay(Delay):
def __init__(self):
......@@ -147,6 +151,7 @@ class MTDelay(Delay):
# supply a handshake() method appropriate for their role in protocol
# negotiation.
class Connection(smac.SizedMessageAsyncConnection, object):
"""Dispatcher for RPC on object on both sides of socket.
......@@ -294,7 +299,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.fast_encode = marshal.fast_encode
self.closed = False
self.peer_protocol_version = None # set in recv_handshake()
self.peer_protocol_version = None # set in recv_handshake()
assert tag in b"CS"
self.tag = tag
......@@ -359,7 +364,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.addr)
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=BLATHER, exc_info=False):
self.logger.log(level, self.log_label + message, exc_info=exc_info)
......@@ -441,7 +446,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll()
except:
except: # NOQA: E722 bare except
# Fall back to normal version for better error handling
self.send_reply(msgid, ret)
......@@ -520,10 +525,10 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above
except: # NOQA: E722 bare except; see above
try:
r = short_repr(err_value)
except:
except: # NOQA: E722 bare except
r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
......@@ -656,10 +661,10 @@ class ManagedServerConnection(Connection):
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, ret)
except: # see above
except: # NOQA: E722 bare except; see above
try:
r = short_repr(ret)
except:
except: # NOQA: E722 bare except
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
......@@ -669,6 +674,7 @@ class ManagedServerConnection(Connection):
poll = smac.SizedMessageAsyncConnection.handle_write
def server_loop(map):
while len(map) > 1:
try:
......@@ -680,6 +686,7 @@ def server_loop(map):
for o in tuple(map.values()):
o.close()
class ManagedClientConnection(Connection):
"""Client-side Connection subclass."""
__super_init = Connection.__init__
......@@ -740,7 +747,7 @@ class ManagedClientConnection(Connection):
# are queued for the duration. The client will send its own
# handshake after the server's handshake is seen, in recv_handshake()
# below. It will then send any messages queued while waiting.
assert self.queue_output # the constructor already set this
assert self.queue_output # the constructor already set this
def recv_handshake(self, proto):
# The protocol to use is the older of our and the server's preferred
......@@ -778,11 +785,11 @@ class ManagedClientConnection(Connection):
raise DisconnectedError()
msgid = self.send_call(method, args)
r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
if isinstance(r_args, tuple) and len(r_args) > 1 and \
type(r_args[0]) == exception_type_type and \
issubclass(r_args[0], Exception):
inst = r_args[1]
raise inst # error raised by server
raise inst # error raised by server
else:
return r_args
......@@ -821,11 +828,11 @@ class ManagedClientConnection(Connection):
def _deferred_wait(self, msgid):
r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
if isinstance(r_args, tuple) and \
type(r_args[0]) == exception_type_type and \
issubclass(r_args[0], Exception):
inst = r_args[1]
raise inst # error raised by server
raise inst # error raised by server
else:
return r_args
......
......@@ -14,9 +14,11 @@
from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError):
pass
class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server.
......
......@@ -17,24 +17,29 @@ import logging
from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid()
def new_label():
global _label
_label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label
if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60
def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes."
......
......@@ -19,7 +19,8 @@ from .log import log, short_repr
PY2 = not PY3
def encode(*args): # args: (msgid, flags, name, args)
def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( )
# It's not thread safe if, in the couse of pickling, we call the
......@@ -41,7 +42,6 @@ def encode(*args): # args: (msgid, flags, name, args)
return res
if PY3:
# XXX: Py3: Needs optimization.
fast_encode = encode
......@@ -50,48 +50,57 @@ elif PYPY:
# every time, getvalue() only works once
fast_encode = encode
else:
def fast_encode():
# Only use in cases where you *know* the data contains only basic
# Python objects
pickler = Pickler(1)
pickler.fast = 1
dump = pickler.dump
def fast_encode(*args):
return dump(args, 1)
return fast_encode
fast_encode = fast_encode()
def decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global
try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version
# PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
def server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global
try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version
# PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
_globals = globals()
_silly = ('__doc__',)
......@@ -102,6 +111,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__',
)
def find_global(module, name):
"""Helper for message unpickler"""
try:
......@@ -114,7 +124,8 @@ def find_global(module, name):
except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES)
safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe:
return r
......@@ -124,6 +135,7 @@ def find_global(module, name):
raise ZRPCError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name):
"""Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES:
......
......@@ -13,6 +13,7 @@
##############################################################################
import asyncore
import socket
import time
# _has_dualstack: True if the dual-stack sockets are supported
try:
......@@ -39,6 +40,7 @@ import logging
# Export the main asyncore loop
loop = asyncore.loop
class Dispatcher(asyncore.dispatcher):
"""A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__
......@@ -74,7 +76,7 @@ class Dispatcher(asyncore.dispatcher):
for i in range(25):
try:
self.bind(self.addr)
except Exception as exc:
except Exception:
log("bind failed %s waiting", i)
if i == 24:
raise
......@@ -98,7 +100,6 @@ class Dispatcher(asyncore.dispatcher):
log("accepted failed: %s" % msg)
return
# We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below,
......@@ -111,12 +112,12 @@ class Dispatcher(asyncore.dispatcher):
# closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above.
if addr: # Sometimes None on Mac. See above.
addr = addr[:2]
try:
c = self.factory(sock, addr)
except:
except: # NOQA: E722 bare except
if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()]
logger.exception("Error in handle_accept")
......
......@@ -67,19 +67,20 @@ MAC_BIT = 0x80000000
_close_marker = object()
class SizedMessageAsyncConnection(asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close
__closed = True # Marker indicating that we're closed
__closed = True # Marker indicating that we're closed
socket = None # to outwit Sam's getattr
socket = None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None):
self.addr = addr
# __input_lock protects __inp, __input_len, __state, __msg_size
self.__input_lock = threading.Lock()
self.__inp = None # None, a single String, or a list
self.__inp = None # None, a single String, or a list
self.__input_len = 0
# Instance variables __state, __msg_size and __has_mac work together:
# when __state == 0:
......@@ -168,7 +169,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
d = self.recv(8192)
except socket.error as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
# which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_read_errors:
return
......@@ -190,7 +191,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
else:
self.__inp.append(d)
self.__input_len = input_len
return # keep waiting for more input
return # keep waiting for more input
# load all previous input and d into single string inp
if isinstance(inp, six.binary_type):
......@@ -298,15 +299,15 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
# ensure the above mentioned "output" invariant
output.insert(0, v)
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
# which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_write_errors:
break # we couldn't write anything
break # we couldn't write anything
raise
if n < len(v):
output.append(v[n:])
break # we can't write any more
break # we can't write any more
def handle_close(self):
self.close()
......
......@@ -21,7 +21,7 @@ import socket
import errno
from ZODB.utils import positive_id
from ZEO._compat import thread, get_ident
from ZEO._compat import thread
# Original comments follow; they're hard to follow in the context of
# ZEO's use of triggers. TODO: rewrite from a ZEO perspective.
......@@ -56,6 +56,7 @@ from ZEO._compat import thread, get_ident
# new data onto a channel's outgoing data queue at the same time that
# the main thread is trying to remove some]
class _triggerbase(object):
"""OS-independent base class for OS-dependent trigger class."""
......@@ -127,7 +128,7 @@ class _triggerbase(object):
return
try:
thunk[0](*thunk[1:])
except:
except: # NOQA: E722 bare except
nil, t, v, tbinfo = asyncore.compact_traceback()
print(('exception in trigger thunk:'
' (%s:%s %s)' % (t, v, tbinfo)))
......@@ -135,6 +136,7 @@ class _triggerbase(object):
def __repr__(self):
return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self))
if os.name == 'posix':
class trigger(_triggerbase, asyncore.file_dispatcher):
......@@ -187,39 +189,39 @@ else:
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
w.connect(connect_address)
break # success
except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
w.close()
raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
w.connect(connect_address)
break # success
except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
w.close()
raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
r, addr = a.accept() # r becomes asyncore's (self.)socket
a.close()
......
......@@ -16,7 +16,6 @@ from __future__ import print_function
import random
import sys
import time
......@@ -56,8 +55,8 @@ def encode_format(fmt):
fmt = fmt.replace(*xform)
return fmt
runner = _forker.runner
runner = _forker.runner
stop_runner = _forker.stop_runner
start_zeo_server = _forker.start_zeo_server
......@@ -70,6 +69,7 @@ else:
shutdown_zeo_server = _forker.shutdown_zeo_server
def get_port(ignored=None):
"""Return a port that is not in use.
......@@ -107,6 +107,7 @@ def get_port(ignored=None):
s1.close()
raise RuntimeError("Can't find port")
def can_connect(port):
c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
......@@ -119,6 +120,7 @@ def can_connect(port):
finally:
c.close()
def setUp(test):
ZODB.tests.util.setUp(test)
......@@ -194,9 +196,11 @@ def wait_until(label=None, func=None, timeout=30, onfail=None):
return onfail()
time.sleep(0.01)
def wait_connected(storage):
wait_until("storage is connected", storage.is_connected)
def wait_disconnected(storage):
wait_until("storage is disconnected",
lambda: not storage.is_connected())
......
......@@ -34,6 +34,7 @@ import ZEO.asyncio.tests
import ZEO.StorageServer
import ZODB.MappingStorage
class StorageServer(ZEO.StorageServer.StorageServer):
def __init__(self, addr='test_addr', storages=None, **kw):
......@@ -41,6 +42,7 @@ class StorageServer(ZEO.StorageServer.StorageServer):
storages = {'1': ZODB.MappingStorage.MappingStorage()}
ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw)
def client(server, name='client'):
zs = ZEO.StorageServer.ZEOStorage(server)
protocol = ZEO.asyncio.tests.server_protocol(
......
......@@ -18,7 +18,21 @@ from __future__ import print_function
# FOR A PARTICULAR PURPOSE
#
##############################################################################
usage="""Test speed of a ZODB storage
import asyncore
import getopt
import os
import sys
import time
import persistent
import transaction
import ZODB
from ZODB.POSException import ConflictError
from ZEO.tests import forker
usage = """Test speed of a ZODB storage
Options:
......@@ -48,41 +62,40 @@ Options:
-t n Number of concurrent threads to run.
"""
import asyncore
import sys, os, getopt, time
##sys.path.insert(0, os.getcwd())
import persistent
import transaction
import ZODB
from ZODB.POSException import ConflictError
from ZEO.tests import forker
class P(persistent.Persistent):
pass
fs_name = "zeo-speed.fs"
class ZEOExit(asyncore.file_dispatcher):
"""Used to exit ZEO.StorageServer when run is done"""
def writable(self):
return 0
def readable(self):
return 1
def handle_read(self):
buf = self.recv(4)
assert buf == "done"
self.delete_fs()
os._exit(0)
def handle_close(self):
print("Parent process exited unexpectedly")
self.delete_fs()
os._exit(0)
def delete_fs(self):
os.unlink(fs_name)
os.unlink(fs_name + ".lock")
os.unlink(fs_name + ".tmp")
def work(db, results, nrep, compress, data, detailed, minimize, threadno=None):
for j in range(nrep):
for r in 1, 10, 100, 1000:
......@@ -98,7 +111,7 @@ def work(db, results, nrep, compress, data, detailed, minimize, threadno=None):
if key in rt:
p = rt[key]
else:
rt[key] = p =P()
rt[key] = p = P()
for i in range(r):
v = getattr(p, str(i), P())
if compress is not None:
......@@ -121,46 +134,49 @@ def work(db, results, nrep, compress, data, detailed, minimize, threadno=None):
print("%s\t%s\t%.4f\t%d\t%d" % (j, r, t, conflicts,
threadno))
results[r].append((t, conflicts))
rt=d=p=v=None # release all references
rt = p = v = None # release all references
if minimize:
time.sleep(3)
jar.cacheMinimize()
def main(args):
opts, args = getopt.getopt(args, 'zd:n:Ds:LMt:U')
s = None
compress = None
data=sys.argv[0]
nrep=5
minimize=0
detailed=1
data = sys.argv[0]
nrep = 5
minimize = 0
detailed = 1
cache = None
domain = 'AF_INET'
threads = 1
for o, v in opts:
if o=='-n': nrep = int(v)
elif o=='-d': data = v
elif o=='-s': s = v
elif o=='-z':
if o == '-n':
nrep = int(v)
elif o == '-d':
data = v
elif o == '-s':
s = v
elif o == '-z':
import zlib
compress = zlib.compress
elif o=='-L':
minimize=1
elif o=='-M':
detailed=0
elif o=='-D':
elif o == '-L':
minimize = 1
elif o == '-M':
detailed = 0
elif o == '-D':
global debug
os.environ['STUPID_LOG_FILE']=''
os.environ['STUPID_LOG_SEVERITY']='-999'
os.environ['STUPID_LOG_FILE'] = ''
os.environ['STUPID_LOG_SEVERITY'] = '-999'
debug = 1
elif o == '-C':
cache = 'speed'
cache = 'speed' # NOQA: F841 unused variable
elif o == '-U':
domain = 'AF_UNIX'
elif o == '-t':
threads = int(v)
zeo_pipe = None
if s:
s = __import__(s, globals(), globals(), ('__doc__',))
s = s.Storage
......@@ -169,25 +185,25 @@ def main(args):
s, server, pid = forker.start_zeo("FileStorage",
(fs_name, 1), domain=domain)
data=open(data).read()
db=ZODB.DB(s,
# disable cache deactivation
cache_size=4000,
cache_deactivate_after=6000,)
data = open(data).read()
db = ZODB.DB(s,
# disable cache deactivation
cache_size=4000,
cache_deactivate_after=6000)
print("Beginning work...")
results={1:[], 10:[], 100:[], 1000:[]}
results = {1: [], 10: [], 100: [], 1000: []}
if threads > 1:
import threading
l = []
thread_list = []
for i in range(threads):
t = threading.Thread(target=work,
args=(db, results, nrep, compress, data,
detailed, minimize, i))
l.append(t)
for t in l:
thread_list.append(t)
for t in thread_list:
t.start()
for t in l:
for t in thread_list:
t.join()
else:
......@@ -202,21 +218,24 @@ def main(args):
print("num\tmean\tmin\tmax")
for r in 1, 10, 100, 1000:
times = []
for time, conf in results[r]:
times.append(time)
for time_val, conf in results[r]:
times.append(time_val)
t = mean(times)
print("%d\t%.4f\t%.4f\t%.4f" % (r, t, min(times), max(times)))
def mean(l):
def mean(lst):
tot = 0
for v in l:
for v in lst:
tot = tot + v
return tot / len(l)
return tot / len(lst)
# def compress(s):
# c = zlib.compressobj()
# o = c.compress(s)
# return o + c.flush()
##def compress(s):
## c = zlib.compressobj()
## o = c.compress(s)
## return o + c.flush()
if __name__=='__main__':
if __name__ == '__main__':
main(sys.argv[1:])
......@@ -36,20 +36,22 @@ MAX_DEPTH = 20
MIN_OBJSIZE = 128
MAX_OBJSIZE = 2048
def an_object():
"""Return an object suitable for a PersistentMapping key"""
size = random.randrange(MIN_OBJSIZE, MAX_OBJSIZE)
if os.path.exists("/dev/urandom"):
f = open("/dev/urandom")
buf = f.read(size)
f.close()
fp = open("/dev/urandom")
buf = fp.read(size)
fp.close()
return buf
else:
f = open(MinPO.__file__)
l = list(f.read(size))
f.close()
random.shuffle(l)
return "".join(l)
fp = open(MinPO.__file__)
lst = list(fp.read(size))
fp.close()
random.shuffle(lst)
return "".join(lst)
def setup(cn):
"""Initialize the database with some objects"""
......@@ -63,6 +65,7 @@ def setup(cn):
transaction.commit()
cn.close()
def work(cn):
"""Do some work with a transaction"""
cn.sync()
......@@ -74,11 +77,13 @@ def work(cn):
obj.value = an_object()
transaction.commit()
def main():
# Yuck! Need to cleanup forker so that the API is consistent
# across Unix and Windows, at least if that's possible.
if os.name == "nt":
zaddr, tport, pid = forker.start_zeo_server('MappingStorage', ())
def exitserver():
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -87,6 +92,7 @@ def main():
else:
zaddr = '', random.randrange(20000, 30000)
pid, exitobj = forker.start_zeo_server(MappingStorage(), zaddr)
def exitserver():
exitobj.close()
......@@ -97,6 +103,7 @@ def main():
exitserver()
def start_child(zaddr):
pid = os.fork()
......@@ -107,6 +114,7 @@ def start_child(zaddr):
finally:
os._exit(0)
def _start_child(zaddr):
storage = ClientStorage(zaddr, debug=1, min_disconnect_poll=0.5, wait=1)
db = ZODB.DB(storage, pool_size=NUM_CONNECTIONS)
......@@ -133,5 +141,6 @@ def _start_child(zaddr):
c.__count += 1
work(c)
if __name__ == "__main__":
main()
......@@ -21,6 +21,7 @@ from ZODB.config import storageFromString
from .forker import start_zeo_server
from .threaded import threaded_server_tests
class ZEOConfigTestBase(setupstack.TestCase):
setUp = setupstack.setUpDirectory
......@@ -52,23 +53,21 @@ class ZEOConfigTestBase(setupstack.TestCase):
</clientstorage>
""".format(settings))
def _client_assertions(
self, client, addr,
connected=True,
cache_size=20 * (1<<20),
cache_path=None,
blob_dir=None,
shared_blob_dir=False,
blob_cache_size=None,
blob_cache_size_check=10,
read_only=False,
read_only_fallback=False,
server_sync=False,
wait_timeout=30,
client_label=None,
storage='1',
name=None,
):
def _client_assertions(self, client, addr,
connected=True,
cache_size=20 * (1 << 20),
cache_path=None,
blob_dir=None,
shared_blob_dir=False,
blob_cache_size=None,
blob_cache_size_check=10,
read_only=False,
read_only_fallback=False,
server_sync=False,
wait_timeout=30,
client_label=None,
storage='1',
name=None):
self.assertEqual(client.is_connected(), connected)
self.assertEqual(client._addr, [addr])
self.assertEqual(client._cache.maxsize, cache_size)
......@@ -88,6 +87,7 @@ class ZEOConfigTestBase(setupstack.TestCase):
self.assertEqual(client.__name__,
name if name is not None else str(client._addr))
class ZEOConfigTest(ZEOConfigTestBase):
def test_default_zeo_config(self, **client_settings):
......@@ -101,18 +101,17 @@ class ZEOConfigTest(ZEOConfigTestBase):
def test_client_variations(self):
for name, value in dict(
cache_size=4200,
cache_path='test',
blob_dir='blobs',
blob_cache_size=424242,
read_only=True,
read_only_fallback=True,
server_sync=True,
wait_timeout=33,
client_label='test_client',
name='Test'
).items():
for name, value in dict(cache_size=4200,
cache_path='test',
blob_dir='blobs',
blob_cache_size=424242,
read_only=True,
read_only_fallback=True,
server_sync=True,
wait_timeout=33,
client_label='test_client',
name='Test',
).items():
params = {name: value}
self.test_default_zeo_config(**params)
......@@ -120,6 +119,7 @@ class ZEOConfigTest(ZEOConfigTestBase):
self.test_default_zeo_config(blob_cache_size=424242,
blob_cache_size_check=50)
def test_suite():
suite = unittest.makeSuite(ZEOConfigTest)
suite.layer = threaded_server_tests
......
......@@ -29,10 +29,9 @@ else:
import unittest
import ZODB.tests.util
import ZEO
from . import forker
class FileStorageConfig(object):
def getConfig(self, path, create, read_only):
return """\
......@@ -44,57 +43,56 @@ class FileStorageConfig(object):
create and 'yes' or 'no',
read_only and 'yes' or 'no')
class MappingStorageConfig(object):
def getConfig(self, path, create, read_only):
return """<mappingstorage 1/>"""
class FileStorageConnectionTests(
FileStorageConfig,
ConnectionTests.ConnectionTests,
InvalidationTests.InvalidationTests
):
FileStorageConfig,
ConnectionTests.ConnectionTests,
InvalidationTests.InvalidationTests):
"""FileStorage-specific connection tests."""
class FileStorageReconnectionTests(
FileStorageConfig,
ConnectionTests.ReconnectionTests,
):
FileStorageConfig,
ConnectionTests.ReconnectionTests):
"""FileStorage-specific re-connection tests."""
# Run this at level 1 because MappingStorage can't do reconnection tests
class FileStorageInvqTests(
FileStorageConfig,
ConnectionTests.InvqTests
):
FileStorageConfig,
ConnectionTests.InvqTests):
"""FileStorage-specific invalidation queue tests."""
class FileStorageTimeoutTests(
FileStorageConfig,
ConnectionTests.TimeoutTests
):
FileStorageConfig,
ConnectionTests.TimeoutTests):
pass
class MappingStorageConnectionTests(
MappingStorageConfig,
ConnectionTests.ConnectionTests
):
MappingStorageConfig,
ConnectionTests.ConnectionTests):
"""Mapping storage connection tests."""
# The ReconnectionTests can't work with MappingStorage because it's only an
# in-memory storage and has no persistent state.
class MappingStorageTimeoutTests(
MappingStorageConfig,
ConnectionTests.TimeoutTests
):
MappingStorageConfig,
ConnectionTests.TimeoutTests):
pass
class SSLConnectionTests(
MappingStorageConfig,
ConnectionTests.SSLConnectionTests,
):
MappingStorageConfig,
ConnectionTests.SSLConnectionTests):
pass
......@@ -108,6 +106,7 @@ test_classes = [FileStorageConnectionTests,
if not forker.ZEO4_SERVER:
test_classes.append(SSLConnectionTests)
def invalidations_while_connecting():
r"""
As soon as a client registers with a server, it will recieve
......@@ -122,7 +121,7 @@ This tests tries to provoke this bug by:
- starting a server
>>> addr, _ = start_server()
>>> addr, _ = start_server() # NOQA: F821 undefined name
- opening a client to the server that writes some objects, filling
it's cache at the same time,
......@@ -182,7 +181,9 @@ This tests tries to provoke this bug by:
... db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x'))
... with lock:
... #logging.getLogger('ZEO').debug('Locked %s' % c)
... @wait_until("connected and we have caught up", timeout=199)
... msg = "connected and we have caught up"
...
... @wait_until(msg, timeout=199) # NOQA: F821 undefined var
... def _():
... if (db.storage.is_connected()
... and db.storage.lastTransaction()
......@@ -228,6 +229,7 @@ This tests tries to provoke this bug by:
>>> db2.close()
"""
def test_suite():
suite = unittest.TestSuite()
......
......@@ -14,7 +14,6 @@
import doctest
import unittest
import ZEO.asyncio.testing
class FakeStorageBase(object):
......@@ -22,7 +21,7 @@ class FakeStorageBase(object):
if name in ('getTid', 'history', 'load', 'loadSerial',
'lastTransaction', 'getSize', 'getName', 'supportsUndo',
'tpc_transaction'):
return lambda *a, **k: None
return lambda *a, **k: None
raise AttributeError(name)
def isReadOnly(self):
......@@ -31,11 +30,12 @@ class FakeStorageBase(object):
def __len__(self):
return 4
class FakeStorage(FakeStorageBase):
def record_iternext(self, next=None):
if next == None:
next = '0'
if next is None:
next = '0'
next = str(int(next) + 1)
oid = next
if next == '4':
......@@ -43,6 +43,7 @@ class FakeStorage(FakeStorageBase):
return oid, oid*8, 'data ' + oid, next
class FakeServer(object):
storages = {
'1': FakeStorage(),
......@@ -55,13 +56,17 @@ class FakeServer(object):
client_conflict_resolution = False
class FakeConnection(object):
protocol_version = b'Z4'
addr = 'test'
call_soon_threadsafe = lambda f, *a: f(*a)
def call_soon_threadsafe(f, *a):
return f(*a)
async_ = async_threadsafe = None
def test_server_record_iternext():
"""
......@@ -99,6 +104,7 @@ The storage info also reflects the fact that record_iternext is supported.
"""
def test_client_record_iternext():
"""Test client storage delegation to the network client
......@@ -143,8 +149,10 @@ Now we'll have our way with it's private _server attr:
"""
def test_suite():
return doctest.DocTestSuite()
if __name__ == '__main__':
unittest.main(defaultTest='test_suite')
......@@ -16,15 +16,18 @@ import unittest
from ZEO.TransactionBuffer import TransactionBuffer
def random_string(size):
"""Return a random string of size size."""
l = [chr(random.randrange(256)) for i in range(size)]
return "".join(l)
lst = [chr(random.randrange(256)) for i in range(size)]
return "".join(lst)
def new_store_data():
"""Return arbitrary data to use as argument to store() method."""
return random_string(8), random_string(random.randrange(1000))
def store(tbuf, resolved=False):
data = new_store_data()
tbuf.store(*data)
......@@ -32,6 +35,7 @@ def store(tbuf, resolved=False):
tbuf.server_resolve(data[0])
return data
class TransBufTests(unittest.TestCase):
def checkTypicalUsage(self):
......@@ -54,5 +58,6 @@ class TransBufTests(unittest.TestCase):
self.assertEqual(resolved, data[i][1])
tbuf.close()
def test_suite():
return unittest.makeSuite(TransBufTests, 'check')
......@@ -14,7 +14,6 @@
"""Test suite for ZEO based on ZODB.tests."""
from __future__ import print_function
import multiprocessing
import re
from ZEO.ClientStorage import ClientStorage
from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests
......@@ -41,7 +40,6 @@ import re
import shutil
import signal
import stat
import ssl
import sys
import tempfile
import threading
......@@ -62,11 +60,15 @@ from . import testssl
logger = logging.getLogger('ZEO.tests.testZEO')
class DummyDB(object):
def invalidate(self, *args):
pass
def invalidateCache(*unused):
pass
transform_record_data = untransform_record_data = lambda self, v: v
......@@ -76,7 +78,6 @@ class CreativeGetState(persistent.Persistent):
return super(CreativeGetState, self).__getstate__()
class Test_convenience_functions(unittest.TestCase):
def test_ZEO_client_convenience(self):
......@@ -206,9 +207,8 @@ class MiscZEOTests(object):
for n in range(30):
time.sleep(.1)
data, serial = storage2.load(oid, '')
if (serial == revid2 and
zodb_unpickle(data) == MinPO('second')
):
if serial == revid2 and \
zodb_unpickle(data) == MinPO('second'):
break
else:
raise AssertionError('Invalidation message was not sent!')
......@@ -230,9 +230,10 @@ class MiscZEOTests(object):
self.assertNotEqual(ZODB.utils.z64, storage3.lastTransaction())
storage3.close()
class GenericTestBase(
# Base class for all ZODB tests
StorageTestBase.StorageTestBase):
# Base class for all ZODB tests
StorageTestBase.StorageTestBase):
shared_blob_dir = False
blob_cache_dir = None
......@@ -259,8 +260,9 @@ class GenericTestBase(
)
self._storage.registerDB(DummyDB())
# _new_storage_client opens another ClientStorage to the same storage server
# self._storage is connected to. It is used by both ZEO and ZODB tests.
# _new_storage_client opens another ClientStorage to the same storage
# server self._storage is connected to. It is used by both ZEO and ZODB
# tests.
def _new_storage_client(self):
client = ZEO.ClientStorage.ClientStorage(
self._storage._addr, wait=1, **self._client_options())
......@@ -283,21 +285,20 @@ class GenericTestBase(
stop()
StorageTestBase.StorageTestBase.tearDown(self)
class GenericTests(
GenericTestBase,
# ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage,
PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage,
MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
GenericTestBase,
# ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage,
PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage,
MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests):
"""Combine tests from various origins in one class.
"""
......@@ -347,17 +348,17 @@ class GenericTests(
thread.join(voted and .1 or 9)
return thread
class FullGenericTests(
GenericTests,
Cache.TransUndoStorageWithCache,
ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage,
PackableStorage.PackableUndoStorage,
RevisionStorage.RevisionStorage,
TransactionalUndoStorage.TransactionalUndoStorage,
IteratorStorage.IteratorStorage,
IterationTests.IterationTests,
):
GenericTests,
Cache.TransUndoStorageWithCache,
ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage,
PackableStorage.PackableUndoStorage,
RevisionStorage.RevisionStorage,
TransactionalUndoStorage.TransactionalUndoStorage,
IteratorStorage.IteratorStorage,
IterationTests.IterationTests):
"""Extend GenericTests with tests that MappingStorage can't pass."""
def checkPackUndoLog(self):
......@@ -374,7 +375,7 @@ class FullGenericTests(
except AttributeError:
# ...unless we're on Python 2, which doesn't have the __wrapped__
# attribute.
if bytes is not str: # pragma: no cover Python 3
if bytes is not str: # pragma: no cover Python 3
raise
unbound_func = PackableStorage.PackableUndoStorage.checkPackUndoLog
wrapper_func = unbound_func.__func__
......@@ -457,8 +458,7 @@ class FileStorageTests(FullGenericTests):
self._storage))
# This is communicated using ClientStorage's _info object:
self.assertEqual(self._expected_interfaces,
self._storage._info['interfaces']
)
self._storage._info['interfaces'])
class FileStorageSSLTests(FileStorageTests):
......@@ -492,6 +492,7 @@ class FileStorageHexTests(FileStorageTests):
</hexstorage>
"""
class FileStorageClientHexTests(FileStorageHexTests):
use_extension_bytes = True
......@@ -509,10 +510,10 @@ class FileStorageClientHexTests(FileStorageHexTests):
def _wrap_client(self, client):
return ZODB.tests.hexstorage.HexStorage(client)
class ClientConflictResolutionTests(
GenericTestBase,
ConflictResolution.ConflictResolvingStorage,
):
GenericTestBase,
ConflictResolution.ConflictResolvingStorage):
def getConfig(self):
return '<mappingstorage>\n</mappingstorage>\n'
......@@ -520,7 +521,9 @@ class ClientConflictResolutionTests(
def getZEOConfig(self):
# Using '' can result in binding to :: and cause problems
# connecting to the MTAcceptor on Travis CI
return forker.ZEOConfig(('127.0.0.1', 0), client_conflict_resolution=True)
return forker.ZEOConfig(('127.0.0.1', 0),
client_conflict_resolution=True)
class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage."""
......@@ -538,9 +541,8 @@ class MappingStorageTests(GenericTests):
# to construct our iterator, which we don't, so we disable this test.
pass
class DemoStorageTests(
GenericTests,
):
class DemoStorageTests(GenericTests):
def getConfig(self):
return """
......@@ -557,9 +559,10 @@ class DemoStorageTests(
pass
def checkPackWithMultiDatabaseReferences(self):
pass # DemoStorage pack doesn't do gc
pass # DemoStorage pack doesn't do gc
checkPackAllRevisions = checkPackWithMultiDatabaseReferences
class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
def getConfig(self, path, create, read_only):
......@@ -573,7 +576,6 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
handler = zope.testing.loggingsupport.InstalledHandler(
'ZEO.asyncio.client')
# We no longer implement the event loop, we we no longer know
# how to break it. We'll just stop it instead for now.
self._storage._server.loop.call_soon_threadsafe(
......@@ -581,7 +583,7 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
forker.wait_until(
'disconnected',
lambda : not self._storage.is_connected()
lambda: not self._storage.is_connected()
)
log = str(handler)
......@@ -614,12 +616,15 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
class DummyDB(object):
_invalidatedCache = 0
def invalidateCache(self):
self._invalidatedCache += 1
def invalidate(*a, **k):
pass
transform_record_data = untransform_record_data = \
lambda self, data: data
lambda self, data: data
db = DummyDB()
storage.registerDB(db)
......@@ -660,7 +665,6 @@ class CommonBlobTests(object):
blob_cache_dir = 'blob_cache'
def checkStoreBlob(self):
import transaction
from ZODB.blob import Blob
from ZODB.tests.StorageTestBase import ZERO
from ZODB.tests.StorageTestBase import zodb_pickle
......@@ -681,7 +685,7 @@ class CommonBlobTests(object):
self._storage.storeBlob(oid, ZERO, data, tfname, '', t)
self._storage.tpc_vote(t)
revid = self._storage.tpc_finish(t)
except:
except: # NOQA: E722 bare except
self._storage.tpc_abort(t)
raise
self.assertTrue(not os.path.exists(tfname))
......@@ -703,7 +707,6 @@ class CommonBlobTests(object):
def checkLoadBlob(self):
from ZODB.blob import Blob
from ZODB.tests.StorageTestBase import zodb_pickle, ZERO
import transaction
somedata = b'a' * 10
......@@ -720,7 +723,7 @@ class CommonBlobTests(object):
self._storage.storeBlob(oid, ZERO, data, tfname, '', t)
self._storage.tpc_vote(t)
serial = self._storage.tpc_finish(t)
except:
except: # NOQA: E722 bare except
self._storage.tpc_abort(t)
raise
......@@ -732,7 +735,7 @@ class CommonBlobTests(object):
def checkTemporaryDirectory(self):
self.assertEqual(os.path.join(self.blob_cache_dir, 'tmp'),
self._storage.temporaryDirectory())
self._storage.temporaryDirectory())
def checkTransactionBufferCleanup(self):
oid = self._storage.new_oid()
......@@ -749,7 +752,6 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests):
"""ZEO backed by a BlobStorage-adapted FileStorage."""
def checkStoreAndLoadBlob(self):
import transaction
from ZODB.blob import Blob
from ZODB.tests.StorageTestBase import ZERO
from ZODB.tests.StorageTestBase import zodb_pickle
......@@ -785,7 +787,7 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests):
self._storage.storeBlob(oid, ZERO, data, tfname, '', t)
self._storage.tpc_vote(t)
revid = self._storage.tpc_finish(t)
except:
except: # NOQA: E722 bare except
self._storage.tpc_abort(t)
raise
......@@ -812,11 +814,9 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests):
returns = []
threads = [
threading.Thread(
target=lambda :
target=lambda:
returns.append(self._storage.loadBlob(oid, revid))
)
for i in range(10)
]
) for i in range(10)]
[thread.start() for thread in threads]
[thread.join() for thread in threads]
[self.assertEqual(r, filename) for r in returns]
......@@ -828,18 +828,21 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests):
blob_cache_dir = 'blobs'
shared_blob_dir = True
class FauxConn(object):
addr = 'x'
protocol_version = ZEO.asyncio.server.best_protocol_version
peer_protocol_version = protocol_version
serials = []
def async_(self, method, *args):
if method == 'serialnos':
self.serials.extend(args[0])
call_soon_threadsafe = async_threadsafe = async_
class StorageServerWrapper(object):
def __init__(self, server, storage_id):
......@@ -881,13 +884,14 @@ class StorageServerWrapper(object):
def tpc_abort(self, transaction):
self.server.tpc_abort(id(transaction))
def tpc_finish(self, transaction, func = lambda: None):
def tpc_finish(self, transaction, func=lambda: None):
self.server.tpc_finish(id(transaction)).set_sender(0, self)
return self._result
def multiple_storages_invalidation_queue_is_not_insane():
"""
>>> from ZEO.StorageServer import StorageServer, ZEOStorage
>>> from ZEO.StorageServer import StorageServer
>>> from ZODB.FileStorage import FileStorage
>>> from ZODB.DB import DB
>>> from persistent.mapping import PersistentMapping
......@@ -926,6 +930,7 @@ def multiple_storages_invalidation_queue_is_not_insane():
>>> fs1.close(); fs2.close()
"""
def getInvalidationsAfterServerRestart():
"""
......@@ -969,12 +974,10 @@ If a storage implements the method lastInvalidations, as FileStorage
does, then the storage server will populate its invalidation data
structure using lastTransactions.
>>> tid, oids = s.getInvalidations(last[-10])
>>> tid == last[-1]
True
>>> from ZODB.utils import u64
>>> sorted([int(u64(oid)) for oid in oids])
[0, 92, 93, 94, 95, 96, 97, 98, 99, 100]
......@@ -1023,13 +1026,14 @@ that were only created.
>>> fs.close()
"""
def tpc_finish_error():
r"""Server errors in tpc_finish weren't handled properly.
If there are errors applying changes to the client cache, don't
leave the cache in an inconsistent state.
>>> addr, admin = start_server()
>>> addr, admin = start_server() # NOQA: F821 undefined
>>> client = ZEO.client(addr)
>>> db = ZODB.DB(client)
......@@ -1070,16 +1074,17 @@ def tpc_finish_error():
>>> db.close()
>>> stop_server(admin)
>>> stop_server(admin) # NOQA: F821 undefined
"""
def test_prefetch(self):
"""The client storage prefetch method pre-fetches from the server
>>> count = 999
>>> import ZEO
>>> addr, stop = start_server()
>>> addr, stop = start_server() # NOQA: F821 undefined
>>> conn = ZEO.connection(addr)
>>> root = conn.root()
>>> cls = root.__class__
......@@ -1102,7 +1107,7 @@ def test_prefetch(self):
But it is filled eventually:
>>> from zope.testing.wait import wait
>>> wait(lambda : len(storage._cache) > count)
>>> wait(lambda: len(storage._cache) > count)
>>> loads = storage.server_status()['loads']
......@@ -1117,15 +1122,16 @@ def test_prefetch(self):
>>> conn.close()
"""
def client_has_newer_data_than_server():
"""It is bad if a client has newer data than the server.
>>> db = ZODB.DB('Data.fs')
>>> db.close()
>>> r = shutil.copyfile('Data.fs', 'Data.save')
>>> addr, admin = start_server(keep=1)
>>> addr, admin = start_server(keep=1) # NOQA: F821 undefined
>>> db = ZEO.DB(addr, name='client', max_disconnect_poll=.01)
>>> wait_connected(db.storage)
>>> wait_connected(db.storage) # NOQA: F821 undefined
>>> conn = db.open()
>>> conn.root().x = 1
>>> transaction.commit()
......@@ -1134,7 +1140,7 @@ def client_has_newer_data_than_server():
the new data. Now, we'll stop the server, put back the old data, and
see what happens. :)
>>> stop_server(admin)
>>> stop_server(admin) # NOQA: F821 undefined
>>> r = shutil.copyfile('Data.save', 'Data.fs')
>>> import zope.testing.loggingsupport
......@@ -1142,9 +1148,9 @@ def client_has_newer_data_than_server():
... 'ZEO', level=logging.ERROR)
>>> formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
>>> _, admin = start_server(addr=addr)
>>> _, admin = start_server(addr=addr) # NOQA: F821 undefined
>>> wait_until('got enough errors', lambda:
>>> wait_until('got enough errors', lambda: # NOQA: F821 undefined
... len([x for x in handler.records
... if x.levelname == 'CRITICAL' and
... 'Client cache is out of sync with the server.' in x.msg
......@@ -1154,15 +1160,16 @@ def client_has_newer_data_than_server():
>>> db.close()
>>> handler.uninstall()
>>> stop_server(admin)
>>> stop_server(admin) # NOQA: F821 undefined
"""
def history_over_zeo():
"""
>>> addr, _ = start_server()
>>> addr, _ = start_server() # NOQA: F821 undefined
>>> db = ZEO.DB(addr)
>>> wait_connected(db.storage)
>>> wait_connected(db.storage) # NOQA: F821 undefined
>>> conn = db.open()
>>> conn.root().x = 0
>>> transaction.commit()
......@@ -1172,9 +1179,10 @@ def history_over_zeo():
>>> db.close()
"""
def dont_log_poskeyerrors_on_server():
"""
>>> addr, admin = start_server(log='server.log')
>>> addr, admin = start_server(log='server.log') # NOQA: F821 undefined
>>> cs = ClientStorage(addr)
>>> cs.load(ZODB.utils.p64(1))
Traceback (most recent call last):
......@@ -1182,16 +1190,17 @@ def dont_log_poskeyerrors_on_server():
POSKeyError: 0x01
>>> cs.close()
>>> stop_server(admin)
>>> stop_server(admin) # NOQA: F821 undefined
>>> with open('server.log') as f:
... 'POSKeyError' in f.read()
False
"""
def open_convenience():
"""Often, we just want to open a single connection.
>>> addr, _ = start_server(path='data.fs')
>>> addr, _ = start_server(path='data.fs') # NOQA: F821 undefined
>>> conn = ZEO.connection(addr)
>>> conn.root()
{}
......@@ -1210,9 +1219,10 @@ def open_convenience():
>>> db.close()
"""
def client_asyncore_thread_has_name():
"""
>>> addr, _ = start_server()
>>> addr, _ = start_server() # NOQA: F821 undefined
>>> db = ZEO.DB(addr)
>>> any(t for t in threading.enumerate()
... if ' zeo client networking thread' in t.getName())
......@@ -1220,6 +1230,7 @@ def client_asyncore_thread_has_name():
>>> db.close()
"""
def runzeo_without_configfile():
r"""
>>> with open('runzeo', 'w') as r:
......@@ -1251,11 +1262,12 @@ def runzeo_without_configfile():
>>> proc.stdout.close()
"""
def close_client_storage_w_invalidations():
r"""
Invalidations could cause errors when closing client storages,
>>> addr, _ = start_server()
>>> addr, _ = start_server() # NOQA: F821 undefined
>>> writing = threading.Event()
>>> def mad_write_thread():
... global writing
......@@ -1280,10 +1292,11 @@ Invalidations could cause errors when closing client storages,
>>> thread.join(1)
"""
def convenient_to_pass_port_to_client_and_ZEO_dot_client():
"""Jim hates typing
>>> addr, _ = start_server()
>>> addr, _ = start_server() # NOQA: F821 undefined
>>> client = ZEO.client(addr[1])
>>> client.__name__ == "('127.0.0.1', %s)" % addr[1]
True
......@@ -1291,12 +1304,14 @@ def convenient_to_pass_port_to_client_and_ZEO_dot_client():
>>> client.close()
"""
@forker.skip_if_testing_client_against_zeo4
def test_server_status():
"""
You can get server status using the server_status method.
>>> addr, _ = start_server(zeo_conf=dict(transaction_timeout=1))
>>> addr, _ = start_server( # NOQA: F821 undefined
... zeo_conf=dict(transaction_timeout=1))
>>> db = ZEO.DB(addr)
>>> pprint.pprint(db.storage.server_status(), width=40)
{'aborts': 0,
......@@ -1316,12 +1331,14 @@ def test_server_status():
>>> db.close()
"""
@forker.skip_if_testing_client_against_zeo4
def test_ruok():
"""
You can also get server status using the ruok protocol.
>>> addr, _ = start_server(zeo_conf=dict(transaction_timeout=1))
>>> addr, _ = start_server( # NOQA: F821 undefined
... zeo_conf=dict(transaction_timeout=1))
>>> db = ZEO.DB(addr) # force a transaction :)
>>> import json, socket, struct
>>> s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -1349,6 +1366,7 @@ def test_ruok():
>>> db.close(); s.close()
"""
def client_labels():
"""
When looking at server logs, for servers with lots of clients coming
......@@ -1358,10 +1376,10 @@ log entries with actual clients. It's possible, sort of, but tedious.
You can make this easier by passing a label to the ClientStorage
constructor.
>>> addr, _ = start_server(log='server.log')
>>> addr, _ = start_server(log='server.log') # NOQA: F821 undefined
>>> db = ZEO.DB(addr, client_label='test-label-1')
>>> db.close()
>>> @wait_until
>>> @wait_until # NOQA: F821 undefined
... def check_for_test_label_1():
... with open('server.log') as f:
... for line in f:
......@@ -1382,7 +1400,7 @@ You can specify the client label via a configuration file as well:
... </zodb>
... ''' % addr[1])
>>> db.close()
>>> @wait_until
>>> @wait_until # NOQA: F821 undefined
... def check_for_test_label_2():
... with open('server.log') as f:
... for line in f:
......@@ -1393,6 +1411,7 @@ You can specify the client label via a configuration file as well:
"""
def invalidate_client_cache_entry_on_server_commit_error():
"""
......@@ -1400,7 +1419,7 @@ When the serials returned during commit includes an error, typically a
conflict error, invalidate the cache entry. This is important when
the cache is messed up.
>>> addr, _ = start_server()
>>> addr, _ = start_server() # NOQA: F821 undefined
>>> conn1 = ZEO.connection(addr)
>>> conn1.root.x = conn1.root().__class__()
>>> transaction.commit()
......@@ -1473,6 +1492,8 @@ sys.path[:] = %(path)r
%(src)s
"""
def generate_script(name, src):
with open(name, 'w') as f:
f.write(script_template % dict(
......@@ -1481,10 +1502,12 @@ def generate_script(name, src):
src=src,
))
def read(filename):
with open(filename) as f:
return f.read()
def runzeo_logrotate_on_sigusr2():
"""
>>> from ZEO.tests.forker import get_port
......@@ -1506,10 +1529,10 @@ def runzeo_logrotate_on_sigusr2():
... import ZEO.runzeo
... ZEO.runzeo.main()
... ''')
>>> import subprocess, signal
>>> import subprocess
>>> p = subprocess.Popen([sys.executable, 's', '-Cc'], close_fds=True)
>>> wait_until('started',
... lambda : os.path.exists('l') and ('listening on' in read('l'))
>>> wait_until('started', # NOQA: F821 undefined
... lambda: os.path.exists('l') and ('listening on' in read('l'))
... )
>>> oldlog = read('l')
......@@ -1518,7 +1541,8 @@ def runzeo_logrotate_on_sigusr2():
>>> s = ClientStorage(port)
>>> s.close()
>>> wait_until('See logging', lambda : ('Log files ' in read('l')))
>>> wait_until('See logging', # NOQA: F821 undefined
... lambda: ('Log files ' in read('l')))
>>> read('o') == oldlog # No new data in old log
True
......@@ -1528,10 +1552,11 @@ def runzeo_logrotate_on_sigusr2():
>>> _ = p.wait()
"""
def unix_domain_sockets():
"""Make sure unix domain sockets work
>>> addr, _ = start_server(port='./sock')
>>> addr, _ = start_server(port='./sock') # NOQA: F821 undefined
>>> c = ZEO.connection(addr)
>>> c.root.x = 1
......@@ -1539,6 +1564,7 @@ def unix_domain_sockets():
>>> c.close()
"""
def gracefully_handle_abort_while_storing_many_blobs():
r"""
......@@ -1548,7 +1574,7 @@ def gracefully_handle_abort_while_storing_many_blobs():
>>> handler = logging.StreamHandler(sys.stdout)
>>> logging.getLogger().addHandler(handler)
>>> addr, _ = start_server(blob_dir='blobs')
>>> addr, _ = start_server(blob_dir='blobs') # NOQA: F821 undefined
>>> client = ZEO.client(addr, blob_dir='cblobs')
>>> c = ZODB.connection(client)
>>> c.root.x = ZODB.blob.Blob(b'z'*(1<<20))
......@@ -1578,6 +1604,7 @@ call to the server. we'd get some sort of error here.
"""
def ClientDisconnected_errors_are_TransientErrors():
"""
>>> from ZEO.Exceptions import ClientDisconnected
......@@ -1586,6 +1613,7 @@ def ClientDisconnected_errors_are_TransientErrors():
True
"""
if not os.environ.get('ZEO4_SERVER'):
if os.environ.get('ZEO_MSGPACK'):
def test_runzeo_msgpack_support():
......@@ -1620,11 +1648,13 @@ if WIN:
del runzeo_logrotate_on_sigusr2
del unix_domain_sockets
def work_with_multiprocessing_process(name, addr, q):
conn = ZEO.connection(addr)
q.put((name, conn.root.x))
conn.close()
class MultiprocessingTests(unittest.TestCase):
layer = ZODB.tests.util.MininalTestLayer('work_with_multiprocessing')
......@@ -1634,9 +1664,9 @@ class MultiprocessingTests(unittest.TestCase):
# Gaaa, zope.testing.runner.FakeInputContinueGenerator has no close
if not hasattr(sys.stdin, 'close'):
sys.stdin.close = lambda : None
sys.stdin.close = lambda: None
if not hasattr(sys.stdin, 'fileno'):
sys.stdin.fileno = lambda : -1
sys.stdin.fileno = lambda: -1
self.globs = {}
forker.setUp(self)
......@@ -1651,12 +1681,13 @@ class MultiprocessingTests(unittest.TestCase):
for i in range(3)]
_ = [p.start() for p in processes]
self.assertEqual(sorted(q.get(timeout=300) for p in processes),
[(0, 1), (1, 1), (2, 1)])
[(0, 1), (1, 1), (2, 1)])
_ = [p.join(30) for p in processes]
conn.close()
zope.testing.setupstack.tearDown(self)
@forker.skip_if_testing_client_against_zeo4
def quick_close_doesnt_kill_server():
r"""
......@@ -1664,7 +1695,7 @@ def quick_close_doesnt_kill_server():
Start a server:
>>> from .testssl import server_config, client_ssl
>>> addr, _ = start_server(zeo_conf=server_config)
>>> addr, _ = start_server(zeo_conf=server_config) # NOQA: F821 undefined
Now connect and immediately disconnect. This caused the server to
die in the past:
......@@ -1678,7 +1709,10 @@ def quick_close_doesnt_kill_server():
... s.close()
>>> print("\n\nXXX WARNING: running quick_close_doesnt_kill_server with ssl as hack pending http://bugs.python.org/issue27386\n", file=sys.stderr) # Intentional long line to be annoying till this is fixed
>>> print("\n\nXXX WARNING: running quick_close_doesnt_kill_server "
... "with ssl as hack pending http://bugs.python.org/issue27386\n",
... file=sys.stderr) # Intentional long line to be annoying
... # until this is fixed
Now we should be able to connect as normal:
......@@ -1689,10 +1723,11 @@ def quick_close_doesnt_kill_server():
>>> db.close()
"""
def can_use_empty_string_for_local_host_on_client():
"""We should be able to spell localhost with ''.
>>> (_, port), _ = start_server()
>>> (_, port), _ = start_server() # NOQA: F821 undefined name
>>> conn = ZEO.connection(('', port))
>>> conn.root()
{}
......@@ -1702,6 +1737,7 @@ def can_use_empty_string_for_local_host_on_client():
>>> conn.close()
"""
slow_test_classes = [
BlobAdaptedFileStorageTests, BlobWritableCacheTests,
MappingStorageTests, DemoStorageTests,
......@@ -1713,6 +1749,7 @@ if not forker.ZEO4_SERVER:
quick_test_classes = [FileStorageRecoveryTests, ZRPCConnectionTests]
class ServerManagingClientStorage(ClientStorage):
def __init__(self, name, blob_dir, shared=False, extrafsoptions=''):
......@@ -1743,11 +1780,13 @@ class ServerManagingClientStorage(ClientStorage):
ClientStorage.close(self)
zope.testing.setupstack.tearDown(self)
def create_storage_shared(name, blob_dir):
return ServerManagingClientStorage(name, blob_dir, True)
class ServerManagingClientStorageForIExternalGCTest(
ServerManagingClientStorage):
ServerManagingClientStorage):
def pack(self, t=None, referencesf=None):
ServerManagingClientStorage.pack(self, t, referencesf, wait=True)
......@@ -1756,6 +1795,7 @@ class ServerManagingClientStorageForIExternalGCTest(
self._cache.clear()
ZEO.ClientStorage._check_blob_cache_size(self.blob_dir, 0)
def test_suite():
suite = unittest.TestSuite((
unittest.makeSuite(Test_convenience_functions),
......@@ -1769,7 +1809,8 @@ def test_suite():
'last-transaction'),
(re.compile("ZODB.POSException.ConflictError"), "ConflictError"),
(re.compile("ZODB.POSException.POSKeyError"), "POSKeyError"),
(re.compile("ZEO.Exceptions.ClientStorageError"), "ClientStorageError"),
(re.compile("ZEO.Exceptions.ClientStorageError"),
"ClientStorageError"),
(re.compile(r"\[Errno \d+\]"), '[Errno N]'),
(re.compile(r"loads=\d+\.\d+"), 'loads=42.42'),
# Python 3 drops the u prefix
......@@ -1810,7 +1851,7 @@ def test_suite():
),
)
zeo.addTest(PackableStorage.IExternalGC_suite(
lambda :
lambda:
ServerManagingClientStorageForIExternalGCTest(
'data.fs', 'blobs', extrafsoptions='pack-gc false')
))
......
......@@ -27,6 +27,7 @@ import ZODB.FileStorage
import ZODB.tests.util
import ZODB.utils
def proper_handling_of_blob_conflicts():
r"""
......@@ -108,6 +109,7 @@ The transaction is aborted by the server:
>>> fs.close()
"""
def proper_handling_of_errors_in_restart():
r"""
......@@ -149,6 +151,7 @@ We can start another client and get the storage lock.
>>> server.close()
"""
def errors_in_vote_should_clear_lock():
"""
......@@ -409,6 +412,7 @@ If clients disconnect while waiting, they will be dequeued:
>>> server.close()
"""
def lock_sanity_check():
r"""
On one occasion with 3.10.0a1 in production, we had a case where a
......@@ -492,6 +496,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up:
>>> server.close()
"""
def test_suite():
return unittest.TestSuite((
doctest.DocTestSuite(
......@@ -506,5 +511,6 @@ def test_suite():
),
))
if __name__ == '__main__':
unittest.main(defaultTest='test_suite')
......@@ -27,6 +27,7 @@ from zdaemon.tests.testzdoptions import TestZDOptions
# supplies the empty string.
DEFAULT_BINDING_HOST = ""
class TestZEOOptions(TestZDOptions):
OptionsClass = ZEOOptions
......@@ -59,7 +60,7 @@ class TestZEOOptions(TestZDOptions):
# Hide the base class test_configure
pass
def test_default_help(self): pass # disable silly test w spurious failures
def test_default_help(self): pass # disable silly test w spurious failures
def test_defaults_with_schema(self):
options = self.OptionsClass()
......@@ -106,5 +107,6 @@ def test_suite():
suite.addTest(unittest.makeSuite(cls))
return suite
if __name__ == "__main__":
unittest.main(defaultTest='test_suite')
import unittest
import mock
import os
from ZEO._compat import PY3
from ZEO.runzeo import ZEOServer
......@@ -11,7 +10,8 @@ class TestStorageServer(object):
def __init__(self, fail_create_server):
self.called = []
if fail_create_server: raise RuntimeError()
if fail_create_server:
raise RuntimeError()
def close(self):
self.called.append("close")
......@@ -49,7 +49,8 @@ class TestZEOServer(ZEOServer):
def loop_forever(self):
self.called.append("loop_forever")
if self.fail_loop_forever: raise RuntimeError()
if self.fail_loop_forever:
raise RuntimeError()
def close_server(self):
self.called.append("close_server")
......@@ -87,7 +88,7 @@ class CloseServerTests(unittest.TestCase):
"setup_signals",
"create_server",
"loop_forever",
"close_server", # New
"close_server", # New
"clear_socket",
"remove_pidfile",
])
......@@ -138,6 +139,7 @@ class CloseServerTests(unittest.TestCase):
self.assertEqual(hasattr(zeo, "server"), True)
self.assertEqual(zeo.server, None)
@mock.patch('os.unlink')
class TestZEOServerSocket(unittest.TestCase):
......
......@@ -14,7 +14,7 @@
"""Basic unit tests for a client cache."""
from __future__ import print_function
from ZODB.utils import p64, repr_to_oid
from ZODB.utils import p64, u64, z64, repr_to_oid
import doctest
import os
import re
......@@ -28,8 +28,6 @@ import ZODB.tests.util
import zope.testing.setupstack
import zope.testing.renormalizing
import ZEO.cache
from ZODB.utils import p64, u64, z64
n1 = p64(1)
n2 = p64(2)
......@@ -47,8 +45,8 @@ def hexprint(file):
printable = ""
hex = ""
for character in line:
if (character in string.printable
and not ord(character) in [12,13,9]):
if character in string.printable and \
not ord(character) in [12, 13, 9]:
printable += character
else:
printable += '.'
......@@ -63,8 +61,11 @@ def hexprint(file):
def oid(o):
repr = '%016x' % o
return repr_to_oid(repr)
tid = oid
class CacheTests(ZODB.tests.util.TestCase):
def setUp(self):
......@@ -207,30 +208,30 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertTrue(1 not in cache.noncurrent)
def testVeryLargeCaches(self):
cache = ZEO.cache.ClientCache('cache', size=(1<<32)+(1<<20))
cache = ZEO.cache.ClientCache('cache', size=(1 << 32)+(1 << 20))
cache.store(n1, n2, None, b"x")
cache.close()
cache = ZEO.cache.ClientCache('cache', size=(1<<33)+(1<<20))
cache = ZEO.cache.ClientCache('cache', size=(1 << 33)+(1 << 20))
self.assertEqual(cache.load(n1), (b'x', n2))
cache.close()
def testConversionOfLargeFreeBlocks(self):
with open('cache', 'wb') as f:
f.write(ZEO.cache.magic+
f.write(ZEO.cache.magic +
b'\0'*8 +
b'f'+struct.pack(">I", (1<<32)-12)
b'f'+struct.pack(">I", (1 << 32)-12)
)
f.seek((1<<32)-1)
f.seek((1 << 32)-1)
f.write(b'x')
cache = ZEO.cache.ClientCache('cache', size=1<<32)
cache = ZEO.cache.ClientCache('cache', size=1 << 32)
cache.close()
cache = ZEO.cache.ClientCache('cache', size=1<<32)
cache = ZEO.cache.ClientCache('cache', size=1 << 32)
cache.close()
with open('cache', 'rb') as f:
f.seek(12)
self.assertEqual(f.read(1), b'f')
self.assertEqual(struct.unpack(">I", f.read(4))[0],
ZEO.cache.max_block_size)
ZEO.cache.max_block_size)
if not sys.platform.startswith('linux'):
# On platforms without sparse files, these tests are just way
......@@ -277,7 +278,7 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertEqual(os.path.getsize(
'cache'), ZEO.cache.ZEC_HEADER_SIZE+small*recsize+extra)
self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()),
set(range(small)))
set(range(small)))
for i in range(100, 110):
cache.store(p64(i), n1, None, data)
......@@ -297,7 +298,7 @@ class CacheTests(ZODB.tests.util.TestCase):
'cache', size=ZEO.cache.ZEC_HEADER_SIZE+small*recsize+extra)
self.assertEqual(len(cache), expected_len)
self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()),
expected_oids)
expected_oids)
# Now make it bigger
cache.close()
......@@ -308,8 +309,7 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertEqual(os.path.getsize(
'cache'), ZEO.cache.ZEC_HEADER_SIZE+large*recsize+extra)
self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()),
expected_oids)
expected_oids)
for i in range(200, 305):
cache.store(p64(i), n1, None, data)
......@@ -321,7 +321,7 @@ class CacheTests(ZODB.tests.util.TestCase):
list(range(106, 110)) +
list(range(200, 305)))
self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()),
expected_oids)
expected_oids)
# Make sure we can reopen with same size
cache.close()
......@@ -329,7 +329,7 @@ class CacheTests(ZODB.tests.util.TestCase):
'cache', size=ZEO.cache.ZEC_HEADER_SIZE+large*recsize+extra)
self.assertEqual(len(cache), expected_len)
self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()),
expected_oids)
expected_oids)
# Cleanup
cache.close()
......@@ -356,6 +356,7 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertEqual(cache.loadBefore(oid, n2), (b'first', n1, n2))
self.assertEqual(cache.loadBefore(oid, n3), (b'second', n2, None))
def kill_does_not_cause_cache_corruption():
r"""
......@@ -363,7 +364,7 @@ If we kill a process while a cache is being written to, the cache
isn't corrupted. To see this, we'll write a little script that
writes records to a cache file repeatedly.
>>> import os, random, sys, time
>>> import os, sys
>>> with open('t', 'w') as f:
... _ = f.write('''
... import os, random, sys, time
......@@ -402,6 +403,7 @@ writes records to a cache file repeatedly.
"""
def full_cache_is_valid():
r"""
......@@ -419,6 +421,7 @@ still be used.
>>> cache.close()
"""
def cannot_open_same_cache_file_twice():
r"""
>>> import ZEO.cache
......@@ -432,6 +435,7 @@ LockError: Couldn't lock 'cache.lock'
>>> cache.close()
"""
def broken_non_current():
r"""
......@@ -467,6 +471,7 @@ Couldn't find non-current
# def bad_magic_number(): See rename_bad_cache_file
def cache_trace_analysis():
r"""
Check to make sure the cache analysis scripts work.
......@@ -585,19 +590,19 @@ Check to make sure the cache analysis scripts work.
Jul 11 12:11:43 20 947 0000000000000000 0000000000000000 -
Jul 11 12:11:43 52 947 0000000000000002 0000000000000000 - 602
Jul 11 12:11:44 20 124b 0000000000000000 0000000000000000 -
Jul 11 12:11:44 52 124b 0000000000000002 0000000000000000 - 1418
Jul 11 12:11:44 52 ... 124b 0000000000000002 0000000000000000 - 1418
...
Jul 11 15:14:55 52 10cc 00000000000003e9 0000000000000000 - 1306
Jul 11 15:14:55 52 ... 10cc 00000000000003e9 0000000000000000 - 1306
Jul 11 15:14:56 20 18a7 0000000000000000 0000000000000000 -
Jul 11 15:14:56 52 18a7 00000000000003e9 0000000000000000 - 1610
Jul 11 15:14:57 22 18b5 000000000000031d 0000000000000000 - 1636
Jul 11 15:14:56 52 ... 18a7 00000000000003e9 0000000000000000 - 1610
Jul 11 15:14:57 22 ... 18b5 000000000000031d 0000000000000000 - 1636
Jul 11 15:14:58 20 b8a 0000000000000000 0000000000000000 -
Jul 11 15:14:58 52 b8a 00000000000003e9 0000000000000000 - 838
Jul 11 15:14:59 22 1085 0000000000000357 0000000000000000 - 217
Jul 11 15:00-14 818 291 30 609 35.6%
Jul 11 15:15:00 22 1072 000000000000037e 0000000000000000 - 204
Jul 11 15:15:01 20 16c5 0000000000000000 0000000000000000 -
Jul 11 15:15:01 52 16c5 00000000000003e9 0000000000000000 - 1712
Jul 11 15:15:01 52 ... 16c5 00000000000003e9 0000000000000000 - 1712
Jul 11 15:15-15 2 1 0 1 50.0%
<BLANKLINE>
Read 18,876 trace records (641,776 bytes) in 0.0 seconds
......@@ -1001,6 +1006,7 @@ Cleanup:
"""
def cache_simul_properly_handles_load_miss_after_eviction_and_inval():
r"""
......@@ -1031,6 +1037,7 @@ Now try to do simulation:
"""
def invalidations_with_current_tid_dont_wreck_cache():
"""
>>> cache = ZEO.cache.ClientCache('cache', 1000)
......@@ -1049,6 +1056,7 @@ def invalidations_with_current_tid_dont_wreck_cache():
>>> logging.getLogger().setLevel(old_level)
"""
def rename_bad_cache_file():
"""
An attempt to open a bad cache file will cause it to be dropped and recreated.
......@@ -1098,6 +1106,7 @@ An attempt to open a bad cache file will cause it to be dropped and recreated.
>>> logging.getLogger().setLevel(old_level)
"""
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(CacheTests))
......@@ -1105,10 +1114,9 @@ def test_suite():
doctest.DocTestSuite(
setUp=zope.testing.setupstack.setUpDirectory,
tearDown=zope.testing.setupstack.tearDown,
checker=ZODB.tests.util.checker + \
zope.testing.renormalizing.RENormalizing([
(re.compile(r'31\.3%'), '31.2%'),
]),
checker=(ZODB.tests.util.checker +
zope.testing.renormalizing.RENormalizing([
(re.compile(r'31\.3%'), '31.2%')])),
)
)
return suite
......@@ -11,6 +11,7 @@ import ZEO.StorageServer
from . import forker
from .threaded import threaded_server_tests
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class ClientAuthTests(setupstack.TestCase):
......@@ -54,8 +55,8 @@ class ClientAuthTests(setupstack.TestCase):
stop()
def test_suite():
suite = unittest.makeSuite(ClientAuthTests)
suite.layer = threaded_server_tests
return suite
......@@ -15,11 +15,13 @@ import ZEO
from . import forker
from .utils import StorageServer
class Var(object):
def __eq__(self, other):
self.value = other
return True
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
......@@ -62,7 +64,6 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 2)
# Now, we'll create a server that expects the client to
# resolve conflicts:
......@@ -93,9 +94,9 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
self.assertEqual(
zs.vote(3),
[dict(oid=ob._p_oid,
serials=(tid2, tid1),
data=writer.serialize(ob),
)],
serials=(tid2, tid1),
data=writer.serialize(ob),
)],
)
# Now, it's up to the client to resolve the conflict. It can
......@@ -119,17 +120,18 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
addr, stop = ZEO.server(os.path.join(path, 'data.fs'), threaded=False)
db = ZEO.DB(addr)
with db.transaction() as conn:
conn.root.l = Length(0)
conn.root.len = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
conn2.root.len.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
conn.root.len.change(1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
self.assertEqual(conn2.root.len.value, 2)
db.close(); stop()
db.close()
stop()
# Now, do conflict resolution on the client.
addr2, stop = ZEO.server(
......@@ -140,18 +142,20 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
db = ZEO.DB(addr2)
with db.transaction() as conn:
conn.root.l = Length(0)
conn.root.len = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
conn2.root.len.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
conn.root.len.change(1)
self.assertEqual(conn2.root.l.value, 1)
self.assertEqual(conn2.root.len.value, 1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
self.assertEqual(conn2.root.len.value, 2)
db.close()
stop()
db.close(); stop()
def test_suite():
return unittest.makeSuite(ClientSideConflictResolutionTests)
......@@ -17,7 +17,7 @@ class MarshalTests(unittest.TestCase):
# this is an example (1) of Zope2's arguments for
# undoInfo call. Arguments are encoded by ZEO client
# and decoded by server. The operation must be idempotent.
# (1) https://github.com/zopefoundation/Zope/blob/2.13/src/App/Undo.py#L111
# (1) https://github.com/zopefoundation/Zope/blob/2.13/src/App/Undo.py#L111 # NOQA: E501 line too long
args = (0, 20, {'user_name': Prefix('test')})
# test against repr because Prefix __eq__ operator
# doesn't compare Prefix with Prefix but only
......
import unittest
from zope.testing import setupstack
from .. import server, client
......@@ -13,6 +11,7 @@ else:
server_ping_method = 'ping'
server_zss = 'zeo_storages_by_storage_id'
class SyncTests(setupstack.TestCase):
def instrument(self):
......@@ -22,6 +21,7 @@ class SyncTests(setupstack.TestCase):
[zs] = getattr(server.server, server_zss)['1']
orig_ping = getattr(zs, server_ping_method)
def ping():
self.__ping_calls += 1
return orig_ping()
......
......@@ -15,11 +15,12 @@ from .threaded import threaded_server_tests
here = os.path.dirname(__file__)
server_cert = os.path.join(here, 'server.pem')
server_key = os.path.join(here, 'server_key.pem')
server_key = os.path.join(here, 'server_key.pem')
serverpw_cert = os.path.join(here, 'serverpw.pem')
serverpw_key = os.path.join(here, 'serverpw_key.pem')
serverpw_key = os.path.join(here, 'serverpw_key.pem')
client_cert = os.path.join(here, 'client.pem')
client_key = os.path.join(here, 'client_key.pem')
client_key = os.path.join(here, 'client_key.pem')
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class SSLConfigTest(ZEOConfigTestBase):
......@@ -117,6 +118,7 @@ class SSLConfigTest(ZEOConfigTestBase):
)
stop()
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
@mock.patch(('asyncio' if PY3 else 'trollius') + '.ensure_future')
@mock.patch(('asyncio' if PY3 else 'trollius') + '.set_event_loop')
......@@ -133,14 +135,13 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
server.close()
def assert_context(
self,
server,
factory, context,
cert=(server_cert, server_key, None),
verify_mode=ssl.CERT_REQUIRED,
check_hostname=False,
cafile=None, capath=None,
):
self,
server,
factory, context,
cert=(server_cert, server_key, None),
verify_mode=ssl.CERT_REQUIRED,
check_hostname=False,
cafile=None, capath=None):
factory.assert_called_with(
ssl.Purpose.CLIENT_AUTH if server else ssl.Purpose.SERVER_AUTH,
cafile=cafile, capath=capath)
......@@ -180,72 +181,75 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
authenticate=here,
)
context = server.acceptor.ssl_context
self.assert_context(True,
factory, context, (server_cert, server_key, pwfunc), capath=here)
self.assert_context(True,
factory,
context,
(server_cert, server_key, pwfunc),
capath=here)
server.close()
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_no_ssl(self, ClientStorage, factory, *_):
client = ssl_client()
ssl_client()
self.assertFalse('ssl' in ClientStorage.call_args[1])
self.assertFalse('ssl_server_hostname' in ClientStorage.call_args[1])
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_signed(
self, ClientStorage, factory, *_
):
client = ssl_client(certificate=client_cert, key=client_key)
self, ClientStorage, factory, *_):
ssl_client(certificate=client_cert, key=client_key)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(False,
factory, context, (client_cert, client_key, None),
check_hostname=True)
factory,
context,
(client_cert, client_key, None),
check_hostname=True)
context.load_default_certs.assert_called_with()
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_dir(
self, ClientStorage, factory, *_
):
client = ssl_client(
self, ClientStorage, factory, *_):
ssl_client(
certificate=client_cert, key=client_key, authenticate=here)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(False,
factory, context, (client_cert, client_key, None),
capath=here,
check_hostname=True,
)
factory,
context,
(client_cert, client_key, None),
capath=here,
check_hostname=True)
context.load_default_certs.assert_not_called()
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_file(
self, ClientStorage, factory, *_
):
client = ssl_client(
self, ClientStorage, factory, *_):
ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(False,
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=True,
)
factory,
context,
(client_cert, client_key, None),
cafile=server_cert,
check_hostname=True)
context.load_default_certs.assert_not_called()
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_pw(
self, ClientStorage, factory, *_
):
client = ssl_client(
self, ClientStorage, factory, *_):
ssl_client(
certificate=client_cert, key=client_key,
password_function='ZEO.tests.testssl.pwfunc',
authenticate=server_cert)
......@@ -253,48 +257,51 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(False,
factory, context, (client_cert, client_key, pwfunc),
cafile=server_cert,
check_hostname=True,
)
factory,
context,
(client_cert, client_key, pwfunc),
cafile=server_cert,
check_hostname=True)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_hostname(
self, ClientStorage, factory, *_
):
client = ssl_client(
self, ClientStorage, factory, *_):
ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert,
server_hostname='example.com')
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
'example.com')
self.assert_context(False,
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=True,
)
factory,
context,
(client_cert, client_key, None),
cafile=server_cert,
check_hostname=True)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_check_hostname(
self, ClientStorage, factory, *_
):
client = ssl_client(
self, ClientStorage, factory, *_):
ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert,
check_hostname=False)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(False,
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=False,
)
factory,
context,
(client_cert, client_key, None),
cafile=server_cert,
check_hostname=False)
def args(*a, **kw):
return a, kw
def ssl_conf(**ssl_settings):
if ssl_settings:
ssl_conf = '<ssl>\n' + '\n'.join(
......@@ -306,6 +313,7 @@ def ssl_conf(**ssl_settings):
return ssl_conf
def ssl_client(**ssl_settings):
return storageFromString(
"""%import ZEO
......@@ -317,6 +325,7 @@ def ssl_client(**ssl_settings):
""".format(ssl_conf(**ssl_settings))
)
def create_server(**ssl_settings):
with open('conf', 'w') as f:
f.write(
......@@ -336,7 +345,9 @@ def create_server(**ssl_settings):
s.create_server()
return s.server
pwfunc = lambda : '1234'
def pwfunc():
return '1234'
def test_suite():
......@@ -347,8 +358,8 @@ def test_suite():
suite.layer = threaded_server_tests
return suite
# Helpers for other tests:
# Helpers for other tests:
server_config = """
<zeo>
address 127.0.0.1:0
......@@ -360,6 +371,7 @@ server_config = """
</zeo>
""".format(server_cert, server_key, client_cert)
def client_ssl(cafile=server_key,
client_cert=client_cert,
client_key=client_key,
......@@ -373,11 +385,11 @@ def client_ssl(cafile=server_key,
return context
# See
# https://discuss.pivotal.io/hc/en-us/articles/202653388-How-to-renew-an-expired-Apache-Web-Server-self-signed-certificate-using-the-OpenSSL-tool
# https://discuss.pivotal.io/hc/en-us/articles/202653388-How-to-renew-an-expired-Apache-Web-Server-self-signed-certificate-using-the-OpenSSL-tool # NOQA: E501
# for instructions on updating the server.pem (the certificate) if
# needed. server.pem.csr is the request.
# This should do it:
# openssl x509 -req -days 999999 -in src/ZEO/tests/server.pem.csr -signkey src/ZEO/tests/server_key.pem -out src/ZEO/tests/server.pem
# openssl x509 -req -days 999999 -in src/ZEO/tests/server.pem.csr -signkey src/ZEO/tests/server_key.pem -out src/ZEO/tests/server.pem # NOQA: E501
# If you need to create a new key first:
# openssl genrsa -out server_key.pem 2048
# These two files should then be copied to client_key.pem and client.pem.
......@@ -9,4 +9,3 @@ import ZODB.tests.util
threaded_server_tests = ZODB.tests.util.MininalTestLayer(
'threaded_server_tests')
......@@ -3,6 +3,7 @@
import ZEO.StorageServer
from ..asyncio.server import best_protocol_version
class ServerProtocol(object):
method = ('register', )
......@@ -17,6 +18,7 @@ class ServerProtocol(object):
zs.notify_connected(self)
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -30,6 +32,7 @@ class ServerProtocol(object):
async_threadsafe = async_
class StorageServer(object):
"""Create a client interface to a StorageServer.
......
......@@ -15,6 +15,7 @@
import os
def parentdir(p, n=1):
"""Return the ancestor of p from n levels up."""
d = p
......@@ -25,6 +26,7 @@ def parentdir(p, n=1):
n -= 1
return d
class Environment(object):
"""Determine location of the Data.fs & ZEO_SERVER.pid files.
......
......@@ -3,6 +3,7 @@
import os
import sys
def ssl_config(section, server):
import ssl
......@@ -10,9 +11,9 @@ def ssl_config(section, server):
auth = section.authenticate
if auth:
if os.path.isdir(auth):
capath=auth
capath = auth
elif auth != 'DYNAMIC':
cafile=auth
cafile = auth
context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH if server else ssl.Purpose.SERVER_AUTH,
......@@ -44,12 +45,15 @@ def ssl_config(section, server):
return context, section.server_hostname
def server_ssl(section):
return ssl_config(section, True)
def client_ssl(section):
return ssl_config(section, False)
class ClientStorageConfig(object):
def __init__(self, config):
......@@ -86,6 +90,6 @@ class ClientStorageConfig(object):
name=config.name,
read_only=config.read_only,
read_only_fallback=config.read_only_fallback,
server_sync = config.server_sync,
server_sync=config.server_sync,
wait_timeout=config.wait_timeout,
**options)
......@@ -20,6 +20,7 @@ import os
import ZEO
import zdaemon.zdctl
# Main program
def main(args=None):
options = zdaemon.zdctl.ZDCtlOptions()
......@@ -27,5 +28,6 @@ def main(args=None):
options.schemafile = "zeoctl.xml"
zdaemon.zdctl.main(args, options)
if __name__ == "__main__":
main()
......@@ -38,10 +38,12 @@ extras =
basepython = python3
skip_install = true
deps =
flake8
check-manifest
check-python-versions >= 0.19.1
wheel
commands =
flake8 src setup.py
check-manifest
check-python-versions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment