Commit 765c3b9f authored by Jim Fulton's avatar Jim Fulton Committed by GitHub

Merge pull request #54 from zopefoundation/zeo4-server-support

Zeo4 server support
parents 96659e2c 5ba506e7
......@@ -14,6 +14,15 @@ matrix:
- os: linux
python: 3.5
env: ZEO_MTACCEPTOR=1
- os: linux
python: 2.7
env: ZEO4_SERVER=1
- os: linux
python: 3.4
env: ZEO4_SERVER=1
- os: linux
python: 3.5
env: ZEO4_SERVER=1
install:
- pip install -U setuptools
- python bootstrap.py
......
Changelog
=========
- Fixed bugs in using the ZEO 5 client with ZEO 4 servers.
5.0.0a2 (2016-07-30)
--------------------
......
......@@ -949,8 +949,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def serialnos(self, args):
"""Server callback to pass a list of changed (oid, serial) pairs.
"""
for oid, s in args:
self._tbuf.serial(oid, s)
self._tbuf.serialnos(args)
def info(self, dict):
"""Server callback to update the info dictionary."""
......
......@@ -92,3 +92,17 @@ class TransactionBuffer:
for oid in server_resolved:
if oid not in seen:
yield oid, None, True
# Support ZEO4:
def serialnos(self, args):
for oid in args:
if isinstance(oid, bytes):
self.server_resolved.add(oid)
else:
oid, serial = oid
if isinstance(serial, Exception):
self.exception = serial
elif serial == b'rs':
self.server_resolved.add(oid)
......@@ -313,6 +313,12 @@ class Client(object):
self.protocols = ()
self.disconnected(None)
# Work around odd behavior of ZEO4 server. It may send
# invalidations for transactions later than the result of
# getInvalidations. While we support ZEO 4 servers, we'll
# need to keep an invalidation queue. :(
self.verify_invalidation_queue = []
def new_addrs(self, addrs):
self.addrs = addrs
if self.trying_to_connect():
......@@ -409,6 +415,8 @@ class Client(object):
@future_generator
def verify(self, server_tid):
self.verify_invalidation_queue = [] # See comment in init :(
protocol = self.protocol
if server_tid is None:
server_tid = yield protocol.fut('lastTransaction')
......@@ -465,6 +473,12 @@ class Client(object):
self.cache.setLastTid(server_tid)
self.ready = True
# Gaaaa, ZEO 4 work around. See comment in __init__. :(
for tid, oids in self.verify_invalidation_queue:
if tid > server_tid:
self.invalidateTransaction(tid, oids)
self.verify_invalidation_queue = []
try:
info = yield protocol.fut('get_info')
except Exception as exc:
......@@ -597,15 +611,23 @@ class Client(object):
self.cache.invalidate(oid, tid)
self.client.invalidateTransaction(tid, oids)
self.cache.setLastTid(tid)
else:
self.verify_invalidation_queue.append((tid, oids))
def serialnos(self, serials):
# Method called by ZEO4 storage servers.
# Before delegating, check for errors (likely ConflictErrors)
# and invalidate the oids they're associated with. In the
# past, this was done by the client, but now we control the
# cache and this is our last chance, as the client won't call
# back into us when there's an error.
for oid, serial in serials:
if isinstance(serial, Exception):
for oid in serials:
if isinstance(oid, bytes):
self.cache.invalidate(oid, None)
else:
oid, serial = oid
if isinstance(serial, Exception) or serial == b'rs':
self.cache.invalidate(oid, None)
self.client.serialnos(serials)
......
======================
Copy of ZEO 4 server
======================
This copy was made by first converting the ZEO 4 server code to use
relative imports. The code was tested with ZEO 4 before copying. It
was unchanged aside from the relative imports.
The ZEO 4 server is used for tests if the ZEO4_SERVER environment
variable is set to a non-empty value.
##############################################################################
#
# Copyright (c) 2001, 2002, 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""The StorageServer class and the exception that it may raise.
This server acts as a front-end for one or more real storages, like
file storage or Berkeley storage.
TODO: Need some basic access control-- a declaration of the methods
exported for invocation by the server.
"""
import asyncore
import codecs
import itertools
import logging
import os
import sys
import tempfile
import threading
import time
import transaction
import warnings
from .zrpc.error import DisconnectedError
import ZODB.blob
import ZODB.event
import ZODB.serialize
import ZODB.TimeStamp
import zope.interface
import six
from ZEO._compat import Pickler, Unpickler, PY3, BytesIO
from ZEO.Exceptions import AuthError
from .monitor import StorageStats, StatsServer
from .zrpc.connection import ManagedServerConnection, Delay, MTDelay, Result
from .zrpc.server import Dispatcher
from ZODB.ConflictResolution import ResolvedSerial
from ZODB.loglevels import BLATHER
from ZODB.POSException import StorageError, StorageTransactionError
from ZODB.POSException import TransactionError, ReadOnlyError, ConflictError
from ZODB.serialize import referencesf
from ZODB.utils import oid_repr, p64, u64, z64
logger = logging.getLogger('ZEO.StorageServer')
def log(message, level=logging.INFO, label='', exc_info=False):
"""Internal helper to log a message."""
if label:
message = "(%s) %s" % (label, message)
logger.log(level, message, exc_info=exc_info)
class StorageServerError(StorageError):
"""Error reported when an unpicklable exception is raised."""
class ZEOStorage:
"""Proxy to underlying storage for a single remote client."""
# A list of extension methods. A subclass with extra methods
# should override.
extensions = []
def __init__(self, server, read_only=0, auth_realm=None):
self.server = server
# timeout and stats will be initialized in register()
self.stats = None
self.connection = None
self.client = None
self.storage = None
self.storage_id = "uninitialized"
self.transaction = None
self.read_only = read_only
self.log_label = 'unconnected'
self.locked = False # Don't have storage lock
self.verifying = 0
self.store_failed = 0
self.authenticated = 0
self.auth_realm = auth_realm
self.blob_tempfile = None
# The authentication protocol may define extra methods.
self._extensions = {}
for func in self.extensions:
self._extensions[func.__name__] = None
self._iterators = {}
self._iterator_ids = itertools.count()
# Stores the last item that was handed out for a
# transaction iterator.
self._txn_iterators_last = {}
def _finish_auth(self, authenticated):
if not self.auth_realm:
return 1
self.authenticated = authenticated
return authenticated
def set_database(self, database):
self.database = database
def notifyConnected(self, conn):
self.connection = conn
assert conn.peer_protocol_version is not None
if conn.peer_protocol_version < b'Z309':
self.client = ClientStub308(conn)
conn.register_object(ZEOStorage308Adapter(self))
else:
self.client = ClientStub(conn)
self.log_label = _addr_label(conn.addr)
def notifyDisconnected(self):
# When this storage closes, we must ensure that it aborts
# any pending transaction.
if self.transaction is not None:
self.log("disconnected during %s transaction"
% (self.locked and 'locked' or 'unlocked'))
self.tpc_abort(self.transaction.id)
else:
self.log("disconnected")
self.connection = None
def __repr__(self):
tid = self.transaction and repr(self.transaction.id)
if self.storage:
stid = (self.tpc_transaction() and
repr(self.tpc_transaction().id))
else:
stid = None
name = self.__class__.__name__
return "<%s %X trans=%s s_trans=%s>" % (name, id(self), tid, stid)
def log(self, msg, level=logging.INFO, exc_info=False):
log(msg, level=level, label=self.log_label, exc_info=exc_info)
def setup_delegation(self):
"""Delegate several methods to the storage
"""
# Called from register
storage = self.storage
info = self.get_info()
if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: ()
self.getTid = storage.getTid
self.load = storage.load
self.loadSerial = storage.loadSerial
record_iternext = getattr(storage, 'record_iternext', None)
if record_iternext is not None:
self.record_iternext = record_iternext
try:
fn = storage.getExtensionMethods
except AttributeError:
pass # no extension methods
else:
d = fn()
self._extensions.update(d)
for name in d:
assert not hasattr(self, name)
setattr(self, name, getattr(storage, name))
self.lastTransaction = storage.lastTransaction
try:
self.tpc_transaction = storage.tpc_transaction
except AttributeError:
if hasattr(storage, '_transaction'):
log("Storage %r doesn't have a tpc_transaction method.\n"
"See ZEO.interfaces.IServeable."
"Falling back to using _transaction attribute, which\n."
"is icky.",
logging.ERROR)
self.tpc_transaction = lambda : storage._transaction
else:
raise
def history(self,tid,size=1):
# This caters for storages which still accept
# a version parameter.
return self.storage.history(tid,size=size)
def _check_tid(self, tid, exc=None):
if self.read_only:
raise ReadOnlyError()
if self.transaction is None:
caller = sys._getframe().f_back.f_code.co_name
self.log("no current transaction: %s()" % caller,
level=logging.WARNING)
if exc is not None:
raise exc(None, tid)
else:
return 0
if self.transaction.id != tid:
caller = sys._getframe().f_back.f_code.co_name
self.log("%s(%s) invalid; current transaction = %s" %
(caller, repr(tid), repr(self.transaction.id)),
logging.WARNING)
if exc is not None:
raise exc(self.transaction.id, tid)
else:
return 0
return 1
def getAuthProtocol(self):
"""Return string specifying name of authentication module to use.
The module name should be auth_%s where %s is auth_protocol."""
protocol = self.server.auth_protocol
if not protocol or protocol == 'none':
return None
return protocol
def register(self, storage_id, read_only):
"""Select the storage that this client will use
This method must be the first one called by the client.
For authenticated storages this method will be called by the client
immediately after authentication is finished.
"""
if self.auth_realm and not self.authenticated:
raise AuthError("Client was never authenticated with server!")
if self.storage is not None:
self.log("duplicate register() call")
raise ValueError("duplicate register() call")
storage = self.server.storages.get(storage_id)
if storage is None:
self.log("unknown storage_id: %s" % storage_id)
raise ValueError("unknown storage: %s" % storage_id)
if not read_only and (self.read_only or storage.isReadOnly()):
raise ReadOnlyError()
self.read_only = self.read_only or read_only
self.storage_id = storage_id
self.storage = storage
self.setup_delegation()
self.stats = self.server.register_connection(storage_id, self)
def get_info(self):
storage = self.storage
supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)()
and self.connection.peer_protocol_version >= b'Z310')
# Communicate the backend storage interfaces to the client
storage_provides = zope.interface.providedBy(storage)
interfaces = []
for candidate in storage_provides.__iro__:
interfaces.append((candidate.__module__, candidate.__name__))
return {'length': len(storage),
'size': storage.getSize(),
'name': storage.getName(),
'supportsUndo': supportsUndo,
'extensionMethods': self.getExtensionMethods(),
'supports_record_iternext': hasattr(self, 'record_iternext'),
'interfaces': tuple(interfaces),
}
def get_size_info(self):
return {'length': len(self.storage),
'size': self.storage.getSize(),
}
def getExtensionMethods(self):
return self._extensions
def loadEx(self, oid):
self.stats.loads += 1
return self.storage.load(oid, '')
def loadBefore(self, oid, tid):
self.stats.loads += 1
return self.storage.loadBefore(oid, tid)
def getInvalidations(self, tid):
invtid, invlist = self.server.get_invalidations(self.storage_id, tid)
if invtid is None:
return None
self.log("Return %d invalidations up to tid %s"
% (len(invlist), u64(invtid)))
return invtid, invlist
def verify(self, oid, tid):
try:
t = self.getTid(oid)
except KeyError:
self.client.invalidateVerify(oid)
else:
if tid != t:
self.client.invalidateVerify(oid)
def zeoVerify(self, oid, s):
if not self.verifying:
self.verifying = 1
self.stats.verifying_clients += 1
try:
os = self.getTid(oid)
except KeyError:
self.client.invalidateVerify((oid, ''))
# It's not clear what we should do now. The KeyError
# could be caused by an object uncreation, in which case
# invalidation is right. It could be an application bug
# that left a dangling reference, in which case it's bad.
else:
if s != os:
self.client.invalidateVerify((oid, ''))
def endZeoVerify(self):
if self.verifying:
self.stats.verifying_clients -= 1
self.verifying = 0
self.client.endVerify()
def pack(self, time, wait=1):
# Yes, you can pack a read-only server or storage!
if wait:
return run_in_thread(self._pack_impl, time)
else:
# If the client isn't waiting for a reply, start a thread
# and forget about it.
t = threading.Thread(target=self._pack_impl, args=(time,))
t.setName("zeo storage packing thread")
t.start()
return None
def _pack_impl(self, time):
self.log("pack(time=%s) started..." % repr(time))
self.storage.pack(time, referencesf)
self.log("pack(time=%s) complete" % repr(time))
# Broadcast new size statistics
self.server.invalidate(0, self.storage_id, None,
(), self.get_size_info())
def new_oids(self, n=100):
"""Return a sequence of n new oids, where n defaults to 100"""
n = min(n, 100)
if self.read_only:
raise ReadOnlyError()
if n <= 0:
n = 1
return [self.storage.new_oid() for i in range(n)]
# undoLog and undoInfo are potentially slow methods
def undoInfo(self, first, last, spec):
return run_in_thread(self.storage.undoInfo, first, last, spec)
def undoLog(self, first, last):
return run_in_thread(self.storage.undoLog, first, last)
def tpc_begin(self, id, user, description, ext, tid=None, status=" "):
if self.read_only:
raise ReadOnlyError()
if self.transaction is not None:
if self.transaction.id == id:
self.log("duplicate tpc_begin(%s)" % repr(id))
return
else:
raise StorageTransactionError("Multiple simultaneous tpc_begin"
" requests from one client.")
t = transaction.Transaction()
t.id = id
t.user = user
t.description = description
t._extension = ext
self.serials = []
self.invalidated = []
self.txnlog = CommitLog()
self.blob_log = []
self.tid = tid
self.status = status
self.store_failed = 0
self.stats.active_txns += 1
# Assign the transaction attribute last. This is so we don't
# think we've entered TPC until everything is set. Why?
# Because if we have an error after this, the server will
# think it is in TPC and the client will think it isn't. At
# that point, the client will keep trying to enter TPC and
# server won't let it. Errors *after* the tpc_begin call will
# cause the client to abort the transaction.
# (Also see https://bugs.launchpad.net/zodb/+bug/374737.)
self.transaction = t
def tpc_finish(self, id):
if not self._check_tid(id):
return
assert self.locked, "finished called wo lock"
self.stats.commits += 1
self.storage.tpc_finish(self.transaction, self._invalidate)
# Note that the tid is still current because we still hold the
# commit lock. We'll relinquish it in _clear_transaction.
tid = self.storage.lastTransaction()
# Return the tid, for cache invalidation optimization
return Result(tid, self._clear_transaction)
def _invalidate(self, tid):
if self.invalidated:
self.server.invalidate(self, self.storage_id, tid,
self.invalidated, self.get_size_info())
def tpc_abort(self, tid):
if not self._check_tid(tid):
return
self.stats.aborts += 1
self.storage.tpc_abort(self.transaction)
self._clear_transaction()
def _clear_transaction(self):
# Common code at end of tpc_finish() and tpc_abort()
if self.locked:
self.server.unlock_storage(self)
self.locked = 0
if self.transaction is not None:
self.server.stop_waiting(self)
self.transaction = None
self.stats.active_txns -= 1
if self.txnlog is not None:
self.txnlog.close()
self.txnlog = None
for oid, oldserial, data, blobfilename in self.blob_log:
ZODB.blob.remove_committed(blobfilename)
del self.blob_log
def vote(self, tid):
self._check_tid(tid, exc=StorageTransactionError)
if self.locked or self.server.already_waiting(self):
raise StorageTransactionError(
'Already voting (%s)' % (self.locked and 'locked' or 'waiting')
)
return self._try_to_vote()
def _try_to_vote(self, delay=None):
if self.connection is None:
return # We're disconnected
if delay is not None and delay.sent:
# as a consequence of the unlocking strategy, _try_to_vote
# may be called multiple times for delayed
# transactions. The first call will mark the delay as
# sent. We should skip if the delay was already sent.
return
self.locked, delay = self.server.lock_storage(self, delay)
if self.locked:
try:
self.log(
"Preparing to commit transaction: %d objects, %d bytes"
% (self.txnlog.stores, self.txnlog.size()),
level=BLATHER)
if (self.tid is not None) or (self.status != ' '):
self.storage.tpc_begin(self.transaction,
self.tid, self.status)
else:
self.storage.tpc_begin(self.transaction)
for op, args in self.txnlog:
if not getattr(self, op)(*args):
break
# Blob support
while self.blob_log and not self.store_failed:
oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename)
if not self.store_failed:
# Only call tpc_vote of no store call failed,
# otherwise the serialnos() call will deliver an
# exception that will be handled by the client in
# its tpc_vote() method.
serials = self.storage.tpc_vote(self.transaction)
if serials:
self.serials.extend(serials)
self.client.serialnos(self.serials)
except Exception:
self.storage.tpc_abort(self.transaction)
self._clear_transaction()
if delay is not None:
delay.error(sys.exc_info())
else:
raise
else:
if delay is not None:
delay.reply(None)
else:
return None
else:
return delay
def _unlock_callback(self, delay):
connection = self.connection
if connection is None:
self.server.stop_waiting(self)
else:
connection.call_from_thread(self._try_to_vote, delay)
# The public methods of the ZEO client API do not do the real work.
# They defer work until after the storage lock has been acquired.
# Most of the real implementations are in methods beginning with
# an _.
def deleteObject(self, oid, serial, id):
self._check_tid(id, exc=StorageTransactionError)
self.stats.stores += 1
self.txnlog.delete(oid, serial)
def storea(self, oid, serial, data, id):
self._check_tid(id, exc=StorageTransactionError)
self.stats.stores += 1
self.txnlog.store(oid, serial, data)
def checkCurrentSerialInTransaction(self, oid, serial, id):
self._check_tid(id, exc=StorageTransactionError)
self.txnlog.checkread(oid, serial)
def restorea(self, oid, serial, data, prev_txn, id):
self._check_tid(id, exc=StorageTransactionError)
self.stats.stores += 1
self.txnlog.restore(oid, serial, data, prev_txn)
def storeBlobStart(self):
assert self.blob_tempfile is None
self.blob_tempfile = tempfile.mkstemp(
dir=self.storage.temporaryDirectory())
def storeBlobChunk(self, chunk):
os.write(self.blob_tempfile[0], chunk)
def storeBlobEnd(self, oid, serial, data, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
fd, tempname = self.blob_tempfile
self.blob_tempfile = None
os.close(fd)
self.blob_log.append((oid, serial, data, tempname))
def storeBlobShared(self, oid, serial, data, filename, id):
self._check_tid(id, exc=StorageTransactionError)
assert self.txnlog is not None # effectively not allowed after undo
# Reconstruct the full path from the filename in the OID directory
if (os.path.sep in filename
or not (filename.endswith('.tmp')
or filename[:-1].endswith('.tmp')
)
):
logger.critical(
"We're under attack! (bad filename to storeBlobShared, %r)",
filename)
raise ValueError(filename)
filename = os.path.join(self.storage.fshelper.getPathForOID(oid),
filename)
self.blob_log.append((oid, serial, data, filename))
def sendBlob(self, oid, serial):
self.client.storeBlob(oid, serial, self.storage.loadBlob(oid, serial))
def undo(*a, **k):
raise NotImplementedError
def undoa(self, trans_id, tid):
self._check_tid(tid, exc=StorageTransactionError)
self.txnlog.undo(trans_id)
def _op_error(self, oid, err, op):
self.store_failed = 1
if isinstance(err, ConflictError):
self.stats.conflicts += 1
self.log("conflict error oid=%s msg=%s" %
(oid_repr(oid), str(err)), BLATHER)
if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client
self.log("%s error: %s, %s" % ((op,)+ sys.exc_info()[:2]),
logging.ERROR, exc_info=True)
err = self._marshal_error(err)
# The exception is reported back as newserial for this oid
self.serials.append((oid, err))
def _delete(self, oid, serial):
err = None
try:
self.storage.deleteObject(oid, serial, self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
err = e
self._op_error(oid, err, 'delete')
return err is None
def _checkread(self, oid, serial):
err = None
try:
self.storage.checkCurrentSerialInTransaction(
oid, serial, self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
err = e
self._op_error(oid, err, 'checkCurrentSerialInTransaction')
return err is None
def _store(self, oid, serial, data, blobfile=None):
err = None
try:
if blobfile is None:
newserial = self.storage.store(
oid, serial, data, '', self.transaction)
else:
newserial = self.storage.storeBlob(
oid, serial, data, blobfile, '', self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as error:
self._op_error(oid, error, 'store')
err = error
else:
if serial != b"\0\0\0\0\0\0\0\0":
self.invalidated.append(oid)
if isinstance(newserial, bytes):
newserial = [(oid, newserial)]
for oid, s in newserial or ():
if s == ResolvedSerial:
self.stats.conflicts_resolved += 1
self.log("conflict resolved oid=%s"
% oid_repr(oid), BLATHER)
self.serials.append((oid, s))
return err is None
def _restore(self, oid, serial, data, prev_txn):
err = None
try:
self.storage.restore(oid, serial, data, '', prev_txn,
self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as err:
self._op_error(oid, err, 'restore')
return err is None
def _undo(self, trans_id):
err = None
try:
tid, oids = self.storage.undo(trans_id, self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
err = e
self._op_error(z64, err, 'undo')
else:
self.invalidated.extend(oids)
self.serials.extend((oid, ResolvedSerial) for oid in oids)
return err is None
def _marshal_error(self, error):
# Try to pickle the exception. If it can't be pickled,
# the RPC response would fail, so use something that can be pickled.
if PY3:
pickler = Pickler(BytesIO(), 3)
else:
# The pure-python version requires at least one argument (PyPy)
pickler = Pickler(0)
pickler.fast = 1
try:
pickler.dump(error)
except:
msg = "Couldn't pickle storage exception: %s" % repr(error)
self.log(msg, logging.ERROR)
error = StorageServerError(msg)
return error
# IStorageIteration support
def iterator_start(self, start, stop):
iid = next(self._iterator_ids)
self._iterators[iid] = iter(self.storage.iterator(start, stop))
return iid
def iterator_next(self, iid):
iterator = self._iterators[iid]
try:
info = next(iterator)
except StopIteration:
del self._iterators[iid]
item = None
if iid in self._txn_iterators_last:
del self._txn_iterators_last[iid]
else:
item = (info.tid,
info.status,
info.user,
info.description,
info.extension)
# Keep a reference to the last iterator result to allow starting a
# record iterator off it.
self._txn_iterators_last[iid] = info
return item
def iterator_record_start(self, txn_iid, tid):
record_iid = next(self._iterator_ids)
txn_info = self._txn_iterators_last[txn_iid]
if txn_info.tid != tid:
raise Exception(
'Out-of-order request for record iterator for transaction %r'
% tid)
self._iterators[record_iid] = iter(txn_info)
return record_iid
def iterator_record_next(self, iid):
iterator = self._iterators[iid]
try:
info = next(iterator)
except StopIteration:
del self._iterators[iid]
item = None
else:
item = (info.oid,
info.tid,
info.data,
info.data_txn)
return item
def iterator_gc(self, iids):
for iid in iids:
self._iterators.pop(iid, None)
def server_status(self):
return self.server.server_status(self.storage_id)
def set_client_label(self, label):
self.log_label = str(label)+' '+_addr_label(self.connection.addr)
class StorageServerDB:
def __init__(self, server, storage_id):
self.server = server
self.storage_id = storage_id
self.references = ZODB.serialize.referencesf
def invalidate(self, tid, oids, version=''):
if version:
raise StorageServerError("Versions aren't supported.")
storage_id = self.storage_id
self.server.invalidate(None, storage_id, tid, oids)
def invalidateCache(self):
self.server._invalidateCache(self.storage_id)
transform_record_data = untransform_record_data = lambda self, data: data
class StorageServer:
"""The server side implementation of ZEO.
The StorageServer is the 'manager' for incoming connections. Each
connection is associated with its own ZEOStorage instance (defined
below). The StorageServer may handle multiple storages; each
ZEOStorage instance only handles a single storage.
"""
# Classes we instantiate. A subclass might override.
from .zrpc.server import Dispatcher as DispatcherClass
ZEOStorageClass = ZEOStorage
ManagedServerConnectionClass = ManagedServerConnection
def __init__(self, addr, storages,
read_only=0,
invalidation_queue_size=100,
invalidation_age=None,
transaction_timeout=None,
monitor_address=None,
auth_protocol=None,
auth_database=None,
auth_realm=None,
):
"""StorageServer constructor.
This is typically invoked from the start.py script.
Arguments (the first two are required and positional):
addr -- the address at which the server should listen. This
can be a tuple (host, port) to signify a TCP/IP connection
or a pathname string to signify a Unix domain socket
connection. A hostname may be a DNS name or a dotted IP
address.
storages -- a dictionary giving the storage(s) to handle. The
keys are the storage names, the values are the storage
instances, typically FileStorage or Berkeley storage
instances. By convention, storage names are typically
strings representing small integers starting at '1'.
read_only -- an optional flag saying whether the server should
operate in read-only mode. Defaults to false. Note that
even if the server is operating in writable mode,
individual storages may still be read-only. But if the
server is in read-only mode, no write operations are
allowed, even if the storages are writable. Note that
pack() is considered a read-only operation.
invalidation_queue_size -- The storage server keeps a queue
of the objects modified by the last N transactions, where
N == invalidation_queue_size. This queue is used to
speed client cache verification when a client disconnects
for a short period of time.
invalidation_age --
If the invalidation queue isn't big enough to support a
quick verification, but the last transaction seen by a
client is younger than the invalidation age, then
invalidations will be computed by iterating over
transactions later than the given transaction.
transaction_timeout -- The maximum amount of time to wait for
a transaction to commit after acquiring the storage lock.
If the transaction takes too long, the client connection
will be closed and the transaction aborted.
monitor_address -- The address at which the monitor server
should listen. If specified, a monitor server is started.
The monitor server provides server statistics in a simple
text format.
auth_protocol -- The name of the authentication protocol to use.
Examples are "digest" and "srp".
auth_database -- The name of the password database filename.
It should be in a format compatible with the authentication
protocol used; for instance, "sha" and "srp" require different
formats.
Note that to implement an authentication protocol, a server
and client authentication mechanism must be implemented in a
auth_* module, which should be stored inside the "auth"
subdirectory. This module may also define a DatabaseClass
variable that should indicate what database should be used
by the authenticator.
"""
self.addr = addr
self.storages = storages
msg = ", ".join(
["%s:%s:%s" % (name, storage.isReadOnly() and "RO" or "RW",
storage.getName())
for name, storage in storages.items()])
log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg))
self._lock = threading.Lock()
self._commit_locks = {}
self._waiting = dict((name, []) for name in storages)
self.read_only = read_only
self.auth_protocol = auth_protocol
self.auth_database = auth_database
self.auth_realm = auth_realm
self.database = None
if auth_protocol:
self._setup_auth(auth_protocol)
# A list, by server, of at most invalidation_queue_size invalidations.
# The list is kept in sorted order with the most recent
# invalidation at the front. The list never has more than
# self.invq_bound elements.
self.invq_bound = invalidation_queue_size
self.invq = {}
for name, storage in storages.items():
self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name))
self.invalidation_age = invalidation_age
self.connections = {}
self.socket_map = {}
self.dispatcher = self.DispatcherClass(
addr, factory=self.new_connection, map=self.socket_map)
if len(self.addr) == 2 and self.addr[1] == 0 and self.addr[0]:
self.addr = self.dispatcher.socket.getsockname()
ZODB.event.notify(
Serving(self, address=self.dispatcher.socket.getsockname()))
self.stats = {}
self.timeouts = {}
for name in self.storages.keys():
self.connections[name] = []
self.stats[name] = StorageStats(self.connections[name])
if transaction_timeout is None:
# An object with no-op methods
timeout = StubTimeoutThread()
else:
timeout = TimeoutThread(transaction_timeout)
timeout.setName("TimeoutThread for %s" % name)
timeout.start()
self.timeouts[name] = timeout
if monitor_address:
warnings.warn(
"The monitor server is deprecated. Use the server_status\n"
"ZEO method instead.",
DeprecationWarning)
self.monitor = StatsServer(monitor_address, self.stats)
else:
self.monitor = None
def _setup_invq(self, name, storage):
lastInvalidations = getattr(storage, 'lastInvalidations', None)
if lastInvalidations is None:
# Using None below doesn't look right, but the first
# element in invq is never used. See get_invalidations.
# (If it was used, it would generate an error, which would
# be good. :) Doing this allows clients that were up to
# date when a server was restarted to pick up transactions
# it subsequently missed.
self.invq[name] = [(storage.lastTransaction() or z64, None)]
else:
self.invq[name] = list(lastInvalidations(self.invq_bound))
self.invq[name].reverse()
def _setup_auth(self, protocol):
# Can't be done in global scope, because of cyclic references
from .auth import get_module
name = self.__class__.__name__
module = get_module(protocol)
if not module:
log("%s: no such an auth protocol: %s" % (name, protocol))
return
storage_class, client, db_class = module
if not storage_class or not issubclass(storage_class, ZEOStorage):
log(("%s: %s isn't a valid protocol, must have a StorageClass" %
(name, protocol)))
self.auth_protocol = None
return
self.ZEOStorageClass = storage_class
log("%s: using auth protocol: %s" % (name, protocol))
# We create a Database instance here for use with the authenticator
# modules. Having one instance allows it to be shared between multiple
# storages, avoiding the need to bloat each with a new authenticator
# Database that would contain the same info, and also avoiding any
# possibly synchronization issues between them.
self.database = db_class(self.auth_database)
if self.database.realm != self.auth_realm:
raise ValueError("password database realm %r "
"does not match storage realm %r"
% (self.database.realm, self.auth_realm))
def new_connection(self, sock, addr):
"""Internal: factory to create a new connection.
This is called by the Dispatcher class in ZEO.zrpc.server
whenever accept() returns a socket for a new incoming
connection.
"""
if self.auth_protocol and self.database:
zstorage = self.ZEOStorageClass(self, self.read_only,
auth_realm=self.auth_realm)
zstorage.set_database(self.database)
else:
zstorage = self.ZEOStorageClass(self, self.read_only)
c = self.ManagedServerConnectionClass(sock, addr, zstorage, self)
log("new connection %s: %s" % (addr, repr(c)), logging.DEBUG)
return c
def register_connection(self, storage_id, conn):
"""Internal: register a connection with a particular storage.
This is called by ZEOStorage.register().
The dictionary self.connections maps each storage name to a
list of current connections for that storage; this information
is needed to handle invalidation. This function updates this
dictionary.
Returns the timeout and stats objects for the appropriate storage.
"""
self.connections[storage_id].append(conn)
return self.stats[storage_id]
def _invalidateCache(self, storage_id):
"""We need to invalidate any caches we have.
This basically means telling our clients to
invalidate/revalidate their caches. We do this by closing them
and making them reconnect.
"""
# This method can be called from foreign threads. We have to
# worry about interaction with the main thread.
# 1. We modify self.invq which is read by get_invalidations
# below. This is why get_invalidations makes a copy of
# self.invq.
# 2. We access connections. There are two dangers:
#
# a. We miss a new connection. This is not a problem because
# if a client connects after we get the list of connections,
# then it will have to read the invalidation queue, which
# has already been reset.
#
# b. A connection is closes while we are iterating. This
# doesn't matter, bacause we can call should_close on a closed
# connection.
# Rebuild invq
self._setup_invq(storage_id, self.storages[storage_id])
# Make a copy since we are going to be mutating the
# connections indirectoy by closing them. We don't care about
# later transactions since they will have to validate their
# caches anyway.
for p in self.connections[storage_id][:]:
try:
p.connection.should_close()
p.connection.trigger.pull_trigger()
except DisconnectedError:
pass
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients.
This is called from several ZEOStorage methods.
invalidated is a sequence of oids.
This can do three different things:
- If the invalidated argument is non-empty, it broadcasts
invalidateTransaction() messages to all clients of the given
storage except the current client (the conn argument).
- If the invalidated argument is empty and the info argument
is a non-empty dictionary, it broadcasts info() messages to
all clients of the given storage, including the current
client.
- If both the invalidated argument and the info argument are
non-empty, it broadcasts invalidateTransaction() messages to all
clients except the current, and sends an info() message to
the current client.
"""
# This method can be called from foreign threads. We have to
# worry about interaction with the main thread.
# 1. We modify self.invq which is read by get_invalidations
# below. This is why get_invalidations makes a copy of
# self.invq.
# 2. We access connections. There are two dangers:
#
# a. We miss a new connection. This is not a problem because
# we are called while the storage lock is held. A new
# connection that tries to read data won't read committed
# data without first recieving an invalidation. Also, if a
# client connects after getting the list of connections,
# then it will have to read the invalidation queue, which
# has been updated to reflect the invalidations.
#
# b. A connection is closes while we are iterating. We'll need
# to cactch and ignore Disconnected errors.
if invalidated:
invq = self.invq[storage_id]
if len(invq) >= self.invq_bound:
invq.pop()
invq.insert(0, (tid, invalidated))
for p in self.connections[storage_id]:
try:
if invalidated and p is not conn:
p.client.invalidateTransaction(tid, invalidated)
elif info is not None:
p.client.info(info)
except DisconnectedError:
pass
def get_invalidations(self, storage_id, tid):
"""Return a tid and list of all objects invalidation since tid.
The tid is the most recent transaction id seen by the client.
Returns None if it is unable to provide a complete list
of invalidations for tid. In this case, client should
do full cache verification.
"""
# We make a copy of invq because it might be modified by a
# foreign (other than main thread) calling invalidate above.
invq = self.invq[storage_id][:]
oids = set()
latest_tid = None
if invq and invq[-1][0] <= tid:
# We have needed data in the queue
for _tid, L in invq:
if _tid <= tid:
break
oids.update(L)
latest_tid = invq[0][0]
elif (self.invalidation_age and
(self.invalidation_age >
(time.time()-ZODB.TimeStamp.TimeStamp(tid).timeTime())
)
):
for t in self.storages[storage_id].iterator(p64(u64(tid)+1)):
for r in t:
oids.add(r.oid)
latest_tid = t.tid
elif not invq:
log("invq empty")
else:
log("tid to old for invq %s < %s" % (u64(tid), u64(invq[-1][0])))
return latest_tid, list(oids)
def loop(self, timeout=30):
try:
asyncore.loop(timeout, map=self.socket_map)
except Exception:
if not self.__closed:
raise # Unexpected exc
__thread = None
def start_thread(self, daemon=True):
self.__thread = thread = threading.Thread(target=self.loop)
thread.setName("StorageServer(%s)" % _addr_label(self.addr))
thread.setDaemon(daemon)
thread.start()
__closed = False
def close(self, join_timeout=1):
"""Close the dispatcher so that there are no new connections.
This is only called from the test suite, AFAICT.
"""
if self.__closed:
return
self.__closed = True
# Stop accepting connections
self.dispatcher.close()
if self.monitor is not None:
self.monitor.close()
ZODB.event.notify(Closed(self))
# Close open client connections
for sid, connections in self.connections.items():
for conn in connections[:]:
try:
conn.connection.close()
except:
pass
for name, storage in six.iteritems(self.storages):
logger.info("closing storage %r", name)
storage.close()
if self.__thread is not None:
self.__thread.join(join_timeout)
def close_conn(self, conn):
"""Internal: remove the given connection from self.connections.
This is the inverse of register_connection().
"""
for cl in self.connections.values():
if conn.obj in cl:
cl.remove(conn.obj)
def lock_storage(self, zeostore, delay):
storage_id = zeostore.storage_id
waiting = self._waiting[storage_id]
with self._lock:
if storage_id in self._commit_locks:
# The lock is held by another zeostore
locked = self._commit_locks[storage_id]
assert locked is not zeostore, (storage_id, delay)
if locked.connection is None:
locked.log("Still locked after disconnected. Unlocking.",
logging.CRITICAL)
if locked.transaction:
locked.storage.tpc_abort(locked.transaction)
del self._commit_locks[storage_id]
# yuck: have to manipulate lock to appease with :(
self._lock.release()
try:
return self.lock_storage(zeostore, delay)
finally:
self._lock.acquire()
if delay is None:
# New request, queue it
assert not [i for i in waiting if i[0] is zeostore
], "already waiting"
delay = Delay()
waiting.append((zeostore, delay))
zeostore.log("(%r) queue lock: transactions waiting: %s"
% (storage_id, len(waiting)),
_level_for_waiting(waiting)
)
return False, delay
else:
self._commit_locks[storage_id] = zeostore
self.timeouts[storage_id].begin(zeostore)
self.stats[storage_id].lock_time = time.time()
if delay is not None:
# we were waiting, stop
waiting[:] = [i for i in waiting if i[0] is not zeostore]
zeostore.log("(%r) lock: transactions waiting: %s"
% (storage_id, len(waiting)),
_level_for_waiting(waiting)
)
return True, delay
def unlock_storage(self, zeostore):
storage_id = zeostore.storage_id
waiting = self._waiting[storage_id]
with self._lock:
assert self._commit_locks[storage_id] is zeostore
del self._commit_locks[storage_id]
self.timeouts[storage_id].end(zeostore)
self.stats[storage_id].lock_time = None
callbacks = waiting[:]
if callbacks:
assert not [i for i in waiting if i[0] is zeostore
], "waiting while unlocking"
zeostore.log("(%r) unlock: transactions waiting: %s"
% (storage_id, len(callbacks)),
_level_for_waiting(callbacks)
)
for zeostore, delay in callbacks:
try:
zeostore._unlock_callback(delay)
except (SystemExit, KeyboardInterrupt):
raise
except Exception:
logger.exception("Calling unlock callback")
def stop_waiting(self, zeostore):
storage_id = zeostore.storage_id
waiting = self._waiting[storage_id]
with self._lock:
new_waiting = [i for i in waiting if i[0] is not zeostore]
if len(new_waiting) == len(waiting):
return
waiting[:] = new_waiting
zeostore.log("(%r) dequeue lock: transactions waiting: %s"
% (storage_id, len(waiting)),
_level_for_waiting(waiting)
)
def already_waiting(self, zeostore):
storage_id = zeostore.storage_id
waiting = self._waiting[storage_id]
with self._lock:
return bool([i for i in waiting if i[0] is zeostore])
def server_status(self, storage_id):
status = self.stats[storage_id].__dict__.copy()
status['connections'] = len(status['connections'])
status['waiting'] = len(self._waiting[storage_id])
status['timeout-thread-is-alive'] = self.timeouts[storage_id].isAlive()
last_transaction = self.storages[storage_id].lastTransaction()
last_transaction_hex = codecs.encode(last_transaction, 'hex_codec')
if PY3:
# doctests and maybe clients expect a str, not bytes
last_transaction_hex = str(last_transaction_hex, 'ascii')
status['last-transaction'] = last_transaction_hex
return status
def ruok(self):
return dict((storage_id, self.server_status(storage_id))
for storage_id in self.storages)
def _level_for_waiting(waiting):
if len(waiting) > 9:
return logging.CRITICAL
if len(waiting) > 3:
return logging.WARNING
else:
return logging.DEBUG
class StubTimeoutThread:
def begin(self, client):
pass
def end(self, client):
pass
isAlive = lambda self: 'stub'
class TimeoutThread(threading.Thread):
"""Monitors transaction progress and generates timeouts."""
# There is one TimeoutThread per storage, because there's one
# transaction lock per storage.
def __init__(self, timeout):
threading.Thread.__init__(self)
self.setName("TimeoutThread")
self.setDaemon(1)
self._timeout = timeout
self._client = None
self._deadline = None
self._cond = threading.Condition() # Protects _client and _deadline
def begin(self, client):
# Called from the restart code the "main" thread, whenever the
# storage lock is being acquired. (Serialized by asyncore.)
with self._cond:
assert self._client is None
self._client = client
self._deadline = time.time() + self._timeout
self._cond.notify()
def end(self, client):
# Called from the "main" thread whenever the storage lock is
# being released. (Serialized by asyncore.)
with self._cond:
assert self._client is not None
assert self._client is client
self._client = None
self._deadline = None
def run(self):
# Code running in the thread.
while 1:
with self._cond:
while self._deadline is None:
self._cond.wait()
howlong = self._deadline - time.time()
if howlong <= 0:
# Prevent reporting timeout more than once
self._deadline = None
client = self._client # For the howlong <= 0 branch below
if howlong <= 0:
client.log("Transaction timeout after %s seconds" %
self._timeout, logging.CRITICAL)
try:
client.connection.call_from_thread(client.connection.close)
except:
client.log("Timeout failure", logging.CRITICAL,
exc_info=sys.exc_info())
self.end(client)
else:
time.sleep(howlong)
def run_in_thread(method, *args):
t = SlowMethodThread(method, args)
t.start()
return t.delay
class SlowMethodThread(threading.Thread):
"""Thread to run potentially slow storage methods.
Clients can use the delay attribute to access the MTDelay object
used to send a zrpc response at the right time.
"""
# Some storage methods can take a long time to complete. If we
# run these methods via a standard asyncore read handler, they
# will block all other server activity until they complete. To
# avoid blocking, we spawn a separate thread, return an MTDelay()
# object, and have the thread reply() when it finishes.
def __init__(self, method, args):
threading.Thread.__init__(self)
self.setName("SlowMethodThread for %s" % method.__name__)
self._method = method
self._args = args
self.delay = MTDelay()
def run(self):
try:
result = self._method(*self._args)
except (SystemExit, KeyboardInterrupt):
raise
except Exception:
self.delay.error(sys.exc_info())
else:
self.delay.reply(result)
class ClientStub:
def __init__(self, rpc):
self.rpc = rpc
def beginVerify(self):
self.rpc.callAsync('beginVerify')
def invalidateVerify(self, args):
self.rpc.callAsync('invalidateVerify', args)
def endVerify(self):
self.rpc.callAsync('endVerify')
def invalidateTransaction(self, tid, args):
# Note that this method is *always* called from a different
# thread than self.rpc's async thread. It is the only method
# for which this is true and requires special consideration!
# callAsyncNoSend is important here because:
# - callAsyncNoPoll isn't appropriate because
# the network thread may not wake up for a long time,
# delaying invalidations for too long. (This is demonstrateed
# by a test failure.)
# - callAsync isn't appropriate because (on the server) it tries
# to write to the socket. If self.rpc's network thread also
# tries to write at the ame time, we can run into problems
# because handle_write isn't thread safe.
self.rpc.callAsyncNoSend('invalidateTransaction', tid, args)
def serialnos(self, arg):
self.rpc.callAsyncNoPoll('serialnos', arg)
def info(self, arg):
self.rpc.callAsyncNoPoll('info', arg)
def storeBlob(self, oid, serial, blobfilename):
def store():
yield ('receiveBlobStart', (oid, serial))
f = open(blobfilename, 'rb')
while 1:
chunk = f.read(59000)
if not chunk:
break
yield ('receiveBlobChunk', (oid, serial, chunk, ))
f.close()
yield ('receiveBlobStop', (oid, serial))
self.rpc.callAsyncIterator(store())
class ClientStub308(ClientStub):
def invalidateTransaction(self, tid, args):
ClientStub.invalidateTransaction(
self, tid, [(arg, '') for arg in args])
def invalidateVerify(self, oid):
ClientStub.invalidateVerify(self, (oid, ''))
class ZEOStorage308Adapter:
def __init__(self, storage):
self.storage = storage
def __eq__(self, other):
return self is other or self.storage is other
def getSerial(self, oid):
return self.storage.loadEx(oid)[1] # Z200
def history(self, oid, version, size=1):
if version:
raise ValueError("Versions aren't supported.")
return self.storage.history(oid, size=size)
def getInvalidations(self, tid):
result = self.storage.getInvalidations(tid)
if result is not None:
result = result[0], [(oid, '') for oid in result[1]]
return result
def verify(self, oid, version, tid):
if version:
raise StorageServerError("Versions aren't supported.")
return self.storage.verify(oid, tid)
def loadEx(self, oid, version=''):
if version:
raise StorageServerError("Versions aren't supported.")
data, serial = self.storage.loadEx(oid)
return data, serial, ''
def storea(self, oid, serial, data, version, id):
if version:
raise StorageServerError("Versions aren't supported.")
self.storage.storea(oid, serial, data, id)
def storeBlobEnd(self, oid, serial, data, version, id):
if version:
raise StorageServerError("Versions aren't supported.")
self.storage.storeBlobEnd(oid, serial, data, id)
def storeBlobShared(self, oid, serial, data, filename, version, id):
if version:
raise StorageServerError("Versions aren't supported.")
self.storage.storeBlobShared(oid, serial, data, filename, id)
def getInfo(self):
result = self.storage.getInfo()
result['supportsVersions'] = False
return result
def zeoVerify(self, oid, s, sv=None):
if sv:
raise StorageServerError("Versions aren't supported.")
self.storage.zeoVerify(oid, s)
def modifiedInVersion(self, oid):
return ''
def versions(self):
return ()
def versionEmpty(self, version):
return True
def commitVersion(self, *a, **k):
raise NotImplementedError
abortVersion = commitVersion
def zeoLoad(self, oid): # Z200
p, s = self.storage.loadEx(oid)
return p, s, '', None, None
def __getattr__(self, name):
return getattr(self.storage, name)
def _addr_label(addr):
if isinstance(addr, six.binary_type):
return addr.decode('ascii')
if isinstance(addr, six.string_types):
return addr
else:
host, port = addr
return str(host) + ":" + str(port)
class CommitLog:
def __init__(self):
self.file = tempfile.TemporaryFile(suffix=".comit-log")
self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1
self.stores = 0
def size(self):
return self.file.tell()
def delete(self, oid, serial):
self.pickler.dump(('_delete', (oid, serial)))
self.stores += 1
def checkread(self, oid, serial):
self.pickler.dump(('_checkread', (oid, serial)))
self.stores += 1
def store(self, oid, serial, data):
self.pickler.dump(('_store', (oid, serial, data)))
self.stores += 1
def restore(self, oid, serial, data, prev_txn):
self.pickler.dump(('_restore', (oid, serial, data, prev_txn)))
self.stores += 1
def undo(self, transaction_id):
self.pickler.dump(('_undo', (transaction_id, )))
self.stores += 1
def __iter__(self):
self.file.seek(0)
unpickler = Unpickler(self.file)
for i in range(self.stores):
yield unpickler.load()
def close(self):
if self.file:
self.file.close()
self.file = None
class ServerEvent:
def __init__(self, server, **kw):
self.__dict__.update(kw)
self.server = server
class Serving(ServerEvent):
pass
class Closed(ServerEvent):
pass
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
_auth_modules = {}
def get_module(name):
if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database
return StorageClass, SHAClient, Database
elif name == 'digest':
from .auth_digest import StorageClass, DigestClient, DigestDatabase
return StorageClass, DigestClient, DigestDatabase
else:
return _auth_modules.get(name)
def register_module(name, storage_class, client, db):
if name in _auth_modules:
raise TypeError("%s is already registred" % name)
_auth_modules[name] = storage_class, client, db
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Digest authentication for ZEO
This authentication mechanism follows the design of HTTP digest
authentication (RFC 2069). It is a simple challenge-response protocol
that does not send passwords in the clear, but does not offer strong
security. The RFC discusses many of the limitations of this kind of
protocol.
Guard the password database as if it contained plaintext passwords.
It stores the hash of a username and password. This does not expose
the plaintext password, but it is sensitive nonetheless. An attacker
with the hash can impersonate the real user. This is a limitation of
the simple digest scheme.
HTTP is a stateless protocol, and ZEO is a stateful protocol. The
security requirements are quite different as a result. The HTTP
protocol uses a nonce as a challenge. The ZEO protocol requires a
separate session key that is used for message authentication. We
generate a second nonce for this purpose; the hash of nonce and
user/realm/password is used as the session key.
TODO: I'm not sure if this is a sound approach; SRP would be preferred.
"""
import os
import random
import struct
import time
from .base import Database, Client
from ..StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError
from ..hash import sha1
def get_random_bytes(n=8):
try:
b = os.urandom(n)
except NotImplementedError:
L = [chr(random.randint(0, 255)) for i in range(n)]
b = b"".join(L)
return b
def hexdigest(s):
return sha1(s.encode()).hexdigest()
class DigestDatabase(Database):
def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm)
# Initialize a key used to build the nonce for a challenge.
# We need one key for the lifetime of the server, so it
# is convenient to store in on the database.
self.noncekey = get_random_bytes(8)
def _store_password(self, username, password):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig
def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up
# directly because it would never change over time. Instead
# use the hash plus part of h_up.
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage):
def set_database(self, database):
assert isinstance(database, DigestDatabase)
self.database = database
self.noncekey = database.noncekey
def _get_time(self):
# Return a string representing the current time.
t = int(time.time())
return struct.pack("i", t)
def _get_nonce(self):
# RFC 2069 recommends a nonce of the form
# H(client-IP ":" time-stamp ":" private-key)
dig = sha1()
dig.update(str(self.connection.addr).encode('latin-1'))
dig.update(self._get_time())
dig.update(self.noncekey)
return dig.hexdigest()
def auth_get_challenge(self):
"""Return realm, challenge, and nonce."""
self._challenge = self._get_nonce()
self._key_nonce = self._get_nonce()
return self.auth_realm, self._challenge, self._key_nonce
def auth_response(self, resp):
# verify client response
user, challenge, response = resp
# Since zrpc is a stateful protocol, we just store the nonce
# we sent to the client. It will need to generate a new
# nonce for a new connection anyway.
if self._challenge != challenge:
raise ValueError("invalid challenge")
# lookup user in database
h_up = self.database.get_password(user)
# regeneration resp from user, password, and nonce
check = hexdigest("%s:%s" % (h_up, challenge))
if check == response:
self.connection.setSessionKey(session_key(h_up, self._key_nonce))
return self._finish_auth(check == response)
extensions = [auth_get_challenge, auth_response]
class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"]
def start(self, username, realm, password):
_realm, challenge, nonce = self.stub.auth_get_challenge()
if _realm != realm:
raise AuthError("expected realm %r, got realm %r"
% (_realm, realm))
h_up = hexdigest("%s:%s:%s" % (username, realm, password))
resp_dig = hexdigest("%s:%s" % (h_up, challenge))
result = self.stub.auth_response((username, challenge, resp_dig))
if result:
return session_key(h_up, nonce)
else:
return None
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Base classes for defining an authentication protocol.
Database -- abstract base class for password database
Client -- abstract base class for authentication client
"""
from __future__ import print_function
from __future__ import print_function
import os
from ..hash import sha1
class Client:
# Subclass should override to list the names of methods that
# will be called on the server.
extensions = []
def __init__(self, stub):
self.stub = stub
for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L):
"""Sort a list in-place and return it."""
L.sort()
return L
class Database:
"""Abstracts a password database.
This class is used both in the authentication process (via
get_password()) and by client scripts that manage the password
database file.
The password file is a simple, colon-separated text file mapping
usernames to password hashes. The hashes are SHA hex digests
produced from the password string.
"""
realm = None
def __init__(self, filename, realm=None):
"""Creates a new Database
filename: a string containing the full pathname of
the password database file. Must be readable by the user
running ZEO. Must be writeable by any client script that
accesses the database.
realm: the realm name (a string)
"""
self._users = {}
self.filename = filename
self.load()
if realm:
if self.realm and self.realm != realm:
raise ValueError("Specified realm %r differs from database "
"realm %r" % (realm or '', self.realm))
else:
self.realm = realm
def save(self, fd=None):
filename = self.filename
needs_closed = False
if not fd:
fd = open(filename, 'w')
needs_closed = True
try:
if self.realm:
print("realm", self.realm, file=fd)
for username in sorted(self._users.keys()):
print("%s: %s" % (username, self._users[username]), file=fd)
finally:
if needs_closed:
fd.close()
def load(self):
filename = self.filename
if not filename:
return
if not os.path.exists(filename):
return
with open(filename) as fd:
L = fd.readlines()
if not L:
return
if L[0].startswith("realm "):
line = L.pop(0).strip()
self.realm = line[len("realm "):]
for line in L:
username, hash = line.strip().split(":", 1)
self._users[username] = hash.strip()
def _store_password(self, username, password):
self._users[username] = self.hash(password)
def get_password(self, username):
"""Returns password hash for specified username.
Callers must check for LookupError, which is raised in
the case of a non-existent user specified."""
if username not in self._users:
raise LookupError("No such user: %s" % username)
return self._users[username]
def hash(self, s):
return sha1(s.encode()).hexdigest()
def add_user(self, username, password):
if username in self._users:
raise LookupError("User %s already exists" % username)
self._store_password(username, password)
def del_user(self, username):
if username not in self._users:
raise LookupError("No such user: %s" % username)
del self._users[username]
def change_password(self, username, password):
if username not in self._users:
raise LookupError("No such user: %s" % username)
self._store_password(username, password)
"""HMAC (Keyed-Hashing for Message Authentication) Python module.
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC:
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
"""Create a new HMAC object.
key: key for the keyed hash object.
msg: Initial input for the hash, if provided.
digestmod: A module supporting PEP 247. Defaults to the md5 module.
"""
if digestmod is None:
import md5
digestmod = md5
self.digestmod = digestmod
self.outer = digestmod.new()
self.inner = digestmod.new()
self.digest_size = digestmod.digest_size
blocksize = 64
ipad = "\x36" * blocksize
opad = "\x5C" * blocksize
if len(key) > blocksize:
key = digestmod.new(key).digest()
key = key + chr(0) * (blocksize - len(key))
self.outer.update(_strxor(key, opad))
self.inner.update(_strxor(key, ipad))
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
"""
self.inner.update(msg)
def copy(self):
"""Return a separate copy of this hashing object.
An update to this copy won't affect the original object.
"""
other = HMAC("")
other.digestmod = self.digestmod
other.inner = self.inner.copy()
other.outer = self.outer.copy()
return other
def digest(self):
"""Return the hash value of this hashing object.
This returns a string containing 8-bit data. The object is
not altered in any way by this function; you can continue
updating the object after calling this function.
"""
h = self.outer.copy()
h.update(self.inner.digest())
return h.digest()
def hexdigest(self):
"""Like digest(), but returns a string of hexadecimal digits instead.
"""
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
msg: if available, will immediately be hashed into the object's starting
state.
You can now feed arbitrary strings into the object using its update()
method, and can ask for the hash value at any time by calling its digest()
method.
"""
return HMAC(key, msg, digestmod)
<component>
<sectiontype name="zeo">
<description>
The content of a ZEO section describe operational parameters
of a ZEO server except for the storage(s) to be served.
</description>
<key name="address" datatype="socket-binding-address"
required="yes">
<description>
The address at which the server should listen. This can be in
the form 'host:port' to signify a TCP/IP connection or a
pathname string to signify a Unix domain socket connection (at
least one '/' is required). A hostname may be a DNS name or a
dotted IP address. If the hostname is omitted, the platform's
default behavior is used when binding the listening socket (''
is passed to socket.bind() as the hostname portion of the
address).
</description>
</key>
<key name="read-only" datatype="boolean"
required="no"
default="false">
<description>
Flag indicating whether the server should operate in read-only
mode. Defaults to false. Note that even if the server is
operating in writable mode, individual storages may still be
read-only. But if the server is in read-only mode, no write
operations are allowed, even if the storages are writable. Note
that pack() is considered a read-only operation.
</description>
</key>
<key name="invalidation-queue-size" datatype="integer"
required="no"
default="100">
<description>
The storage server keeps a queue of the objects modified by the
last N transactions, where N == invalidation_queue_size. This
queue is used to speed client cache verification when a client
disconnects for a short period of time.
</description>
</key>
<key name="invalidation-age" datatype="float" required="no">
<description>
The maximum age of a client for which quick-verification
invalidations will be provided by iterating over the served
storage. This option should only be used if the served storage
supports efficient iteration from a starting point near the
end of the transaction history (e.g. end of file).
</description>
</key>
<key name="monitor-address" datatype="socket-binding-address"
required="no">
<description>
The address at which the monitor server should listen. If
specified, a monitor server is started. The monitor server
provides server statistics in a simple text format. This can
be in the form 'host:port' to signify a TCP/IP connection or a
pathname string to signify a Unix domain socket connection (at
least one '/' is required). A hostname may be a DNS name or a
dotted IP address. If the hostname is omitted, the platform's
default behavior is used when binding the listening socket (''
is passed to socket.bind() as the hostname portion of the
address).
</description>
</key>
<key name="transaction-timeout" datatype="integer"
required="no">
<description>
The maximum amount of time to wait for a transaction to commit
after acquiring the storage lock, specified in seconds. If the
transaction takes too long, the client connection will be closed
and the transaction aborted.
</description>
</key>
<key name="authentication-protocol" required="no">
<description>
The name of the protocol used for authentication. The
only protocol provided with ZEO is "digest," but extensions
may provide other protocols.
</description>
</key>
<key name="authentication-database" required="no">
<description>
The path of the database containing authentication credentials.
</description>
</key>
<key name="authentication-realm" required="no">
<description>
The authentication realm of the server. Some authentication
schemes use a realm to identify the logical set of usernames
that are accepted by this server.
</description>
</key>
<key name="pid-filename" datatype="existing-dirpath"
required="no">
<description>
The full path to the file in which to write the ZEO server's Process ID
at startup. If omitted, $INSTANCE/var/ZEO.pid is used.
</description>
<metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault>
</key>
<!-- DM 2006-06-12: added option -->
<key name="drop-cache-rather-verify" datatype="boolean"
required="no" default="false">
<description>
indicates that the cache should be dropped rather than
verified when the verification optimization is not
available (e.g. when the ZEO server restarted).
</description>
</key>
</sectiontype>
</component>
##############################################################################
#
# Copyright (c) 2008 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""In Python 2.6, the "sha" and "md5" modules have been deprecated
in favor of using hashlib for both. This class allows for compatibility
between versions."""
try:
import hashlib
sha1 = hashlib.sha1
new = sha1
except ImportError:
import sha
sha1 = sha.new
new = sha1
digest_size = sha.digest_size
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Monitor behavior of ZEO server and record statistics.
"""
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
import asyncore
import socket
import time
import logging
zeo_version = 'unknown'
try:
import pkg_resources
except ImportError:
pass
else:
zeo_dist = pkg_resources.working_set.find(
pkg_resources.Requirement.parse('ZODB3')
)
if zeo_dist is not None:
zeo_version = zeo_dist.version
class StorageStats:
"""Per-storage usage statistics."""
def __init__(self, connections=None):
self.connections = connections
self.loads = 0
self.stores = 0
self.commits = 0
self.aborts = 0
self.active_txns = 0
self.verifying_clients = 0
self.lock_time = None
self.conflicts = 0
self.conflicts_resolved = 0
self.start = time.ctime()
@property
def clients(self):
return len(self.connections)
def parse(self, s):
# parse the dump format
lines = s.split("\n")
for line in lines:
field, value = line.split(":", 1)
if field == "Server started":
self.start = value
elif field == "Clients":
# Hack because we use this both on the server and on
# the client where there are no connections.
self.connections = [0] * int(value)
elif field == "Clients verifying":
self.verifying_clients = int(value)
elif field == "Active transactions":
self.active_txns = int(value)
elif field == "Commit lock held for":
# This assumes
self.lock_time = time.time() - int(value)
elif field == "Commits":
self.commits = int(value)
elif field == "Aborts":
self.aborts = int(value)
elif field == "Loads":
self.loads = int(value)
elif field == "Stores":
self.stores = int(value)
elif field == "Conflicts":
self.conflicts = int(value)
elif field == "Conflicts resolved":
self.conflicts_resolved = int(value)
def dump(self, f):
print("Server started:", self.start, file=f)
print("Clients:", self.clients, file=f)
print("Clients verifying:", self.verifying_clients, file=f)
print("Active transactions:", self.active_txns, file=f)
if self.lock_time:
howlong = time.time() - self.lock_time
print("Commit lock held for:", int(howlong), file=f)
print("Commits:", self.commits, file=f)
print("Aborts:", self.aborts, file=f)
print("Loads:", self.loads, file=f)
print("Stores:", self.stores, file=f)
print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr):
asyncore.dispatcher.__init__(self, sock)
self.buf = []
self.closed = 0
def close(self):
self.closed = 1
# The socket is closed after all the data is written.
# See handle_write().
def write(self, s):
self.buf.append(s)
def writable(self):
return len(self.buf)
def readable(self):
return 0
def handle_write(self):
s = "".join(self.buf)
self.buf = []
n = self.socket.send(s.encode('ascii'))
if n < len(s):
self.buf.append(s[:n])
if self.closed and not self.buf:
asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient
def __init__(self, addr, stats):
asyncore.dispatcher.__init__(self)
self.addr = addr
self.stats = stats
if type(self.addr) == tuple:
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr()
logger = logging.getLogger('ZEO.monitor')
logger.info("listening on %s", repr(self.addr))
self.bind(self.addr)
self.listen(5)
def writable(self):
return 0
def readable(self):
return 1
def handle_accept(self):
try:
sock, addr = self.accept()
except socket.error:
return
f = self.StatsConnectionClass(sock, addr)
self.dump(f)
f.close()
def dump(self, f):
print("ZEO monitor server version %s" % zeo_version, file=f)
print(time.ctime(), file=f)
print(file=f)
L = sorted(self.stats.keys())
for k in L:
stats = self.stats[k]
print("Storage:", k, file=f)
stats.dump(f)
print(file=f)
##############################################################################
#
# Copyright (c) 2001, 2002, 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Start the ZEO storage server.
Usage: %s [-C URL] [-a ADDRESS] [-f FILENAME] [-h]
Options:
-C/--configuration URL -- configuration file or URL
-a/--address ADDRESS -- server address of the form PORT, HOST:PORT, or PATH
(a PATH must contain at least one "/")
-f/--filename FILENAME -- filename for FileStorage
-t/--timeout TIMEOUT -- transaction timeout in seconds (default no timeout)
-h/--help -- print this usage message and exit
-m/--monitor ADDRESS -- address of monitor server ([HOST:]PORT or PATH)
--pid-file PATH -- relative path to output file containing this process's pid;
default $(INSTANCE_HOME)/var/ZEO.pid but only if envar
INSTANCE_HOME is defined
Unless -C is specified, -a and -f are required.
"""
from __future__ import print_function
from __future__ import print_function
# The code here is designed to be reused by other, similar servers.
# For the forseeable future, it must work under Python 2.1 as well as
# 2.2 and above.
import asyncore
import os
import sys
import signal
import socket
import logging
import ZConfig.datatypes
from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address
def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all()
class ZEOOptionsMixin:
storages = None
def handle_address(self, arg):
self.family, self.address = parse_binding_address(arg)
def handle_monitor_address(self, arg):
self.monitor_family, self.monitor_address = parse_binding_address(arg)
def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig:
def __init__(self, name, path):
self._name = name
self.path = path
self.stop = None
def getSectionName(self):
return self._name
if not self.storages:
self.storages = []
name = str(1 + len(self.storages))
conf = FileStorage(FSConfig(name, arg))
self.storages.append(conf)
testing_exit_immediately = False
def handle_test(self, *args):
self.testing_exit_immediately = True
def add_zeo_options(self):
self.add(None, None, None, "test", self.handle_test)
self.add(None, None, "a:", "address=", self.handle_address)
self.add(None, None, "f:", "filename=", self.handle_filename)
self.add("family", "zeo.address.family")
self.add("address", "zeo.address.address",
required="no server address specified; use -a or -C")
self.add("read_only", "zeo.read_only", default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100)
self.add("invalidation_age", "zeo.invalidation_age")
self.add("transaction_timeout", "zeo.transaction_timeout",
"t:", "timeout=", float)
self.add("monitor_address", "zeo.monitor_address.address",
"m:", "monitor=", self.handle_monitor_address)
self.add('auth_protocol', 'zeo.authentication_protocol',
None, 'auth-protocol=', default=None)
self.add('auth_database', 'zeo.authentication_database',
None, 'auth-database=')
self.add('auth_realm', 'zeo.authentication_realm',
None, 'auth-realm=')
self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=')
class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__
logsectionname = "eventlog"
schemadir = os.path.dirname(__file__)
def __init__(self):
ZDOptions.__init__(self)
self.add_zeo_options()
self.add("storages", "storages",
required="no storages specified; use -f or -C")
def realize(self, *a, **k):
ZDOptions.realize(self, *a, **k)
nunnamed = [s for s in self.storages if s.name is None]
if nunnamed:
if len(nunnamed) > 1:
return self.usage("No more than one storage may be unnamed.")
if [s for s in self.storages if s.name == '1']:
return self.usage(
"Can't have an unnamed storage and a storage named 1.")
for s in self.storages:
if s.name is None:
s.name = '1'
break
class ZEOServer:
def __init__(self, options):
self.options = options
def main(self):
self.setup_default_logging()
self.check_socket()
self.clear_socket()
self.make_pidfile()
try:
self.open_storages()
self.setup_signals()
self.create_server()
self.loop_forever()
finally:
self.server.close()
self.clear_socket()
self.remove_pidfile()
def setup_default_logging(self):
if self.options.config_logger is not None:
return
# No log file is configured; default to stderr.
root = logging.getLogger()
root.setLevel(logging.INFO)
fmt = logging.Formatter(
"------\n%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%dT%H:%M:%S")
handler = logging.StreamHandler()
handler.setFormatter(fmt)
root.addHandler(handler)
def check_socket(self):
if (isinstance(self.options.address, tuple) and
self.options.address[1] is None):
self.options.address = self.options.address[0], 0
return
if self.can_connect(self.options.family, self.options.address):
self.options.usage("address %s already in use" %
repr(self.options.address))
def can_connect(self, family, address):
s = socket.socket(family, socket.SOCK_STREAM)
try:
s.connect(address)
except socket.error:
return 0
else:
s.close()
return 1
def clear_socket(self):
if isinstance(self.options.address, type("")):
try:
os.unlink(self.options.address)
except os.error:
pass
def open_storages(self):
self.storages = {}
for opener in self.options.storages:
log("opening storage %r using %s"
% (opener.name, opener.__class__.__name__))
self.storages[opener.name] = opener.open()
def setup_signals(self):
"""Set up signal handlers.
The signal handler for SIGFOO is a method handle_sigfoo().
If no handler method is defined for a signal, the signal
action is not changed from its initial value. The handler
method is called without additional arguments.
"""
if os.name != "posix":
if os.name == "nt":
self.setup_win32_signals()
return
if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames()
for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None)
if method is not None:
def wrapper(sig_dummy, frame_dummy, method=method):
method()
signal.signal(sig, wrapper)
def setup_win32_signals(self):
# Borrow the Zope Signals package win32 support, if available.
# Signals does a check/log for the availability of pywin32.
try:
import Signals.Signals
except ImportError:
logger.debug("Signals package not found. "
"Windows-specific signal handler "
"will *not* be installed.")
return
SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self):
self.server = create_server(self.storages, self.options)
def loop_forever(self):
if self.options.testing_exit_immediately:
print("testing exit immediately")
else:
self.server.loop()
def handle_sigterm(self):
log("terminated by SIGTERM")
sys.exit(0)
def handle_sigint(self):
log("terminated by SIGINT")
sys.exit(0)
def handle_sighup(self):
log("restarted by SIGHUP")
sys.exit(1)
def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
level=logging.WARNING)
return
loggers = [self.options.config_logger]
if os.name == "posix":
for l in loggers:
l.reopen()
log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers:
for f in l.handler_factories:
handler = f()
if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate()
log("Log files rotation complete", level=logging.INFO)
def _get_pidfile(self):
pidfile = self.options.pid_file
# 'pidfile' is marked as not required.
if not pidfile:
# Try to find a reasonable location if the pidfile is not
# set. If we are running in a Zope environment, we can
# safely assume INSTANCE_HOME.
instance_home = os.environ.get("INSTANCE_HOME")
if not instance_home:
# If all our attempts failed, just log a message and
# proceed.
logger.debug("'pidfile' option not set, and 'INSTANCE_HOME' "
"environment variable could not be found. "
"Cannot guess pidfile location.")
return
self.options.pid_file = os.path.join(instance_home,
"var", "ZEO.pid")
def make_pidfile(self):
if not self.options.read_only:
self._get_pidfile()
pidfile = self.options.pid_file
if pidfile is None:
return
pid = os.getpid()
try:
if os.path.exists(pidfile):
os.unlink(pidfile)
f = open(pidfile, 'w')
print(pid, file=f)
f.close()
log("created PID file '%s'" % pidfile)
except IOError:
logger.error("PID file '%s' cannot be opened" % pidfile)
def remove_pidfile(self):
if not self.options.read_only:
pidfile = self.options.pid_file
if pidfile is None:
return
try:
if os.path.exists(pidfile):
os.unlink(pidfile)
log("removed PID file '%s'" % pidfile)
except IOError:
logger.error("PID file '%s' could not be removed" % pidfile)
def create_server(storages, options):
from .StorageServer import StorageServer
return StorageServer(
options.address,
storages,
read_only = options.read_only,
invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout,
monitor_address = options.monitor_address,
auth_protocol = options.auth_protocol,
auth_database = options.auth_database,
auth_realm = options.auth_realm,
)
# Signal names
signames = None
def signame(sig):
"""Return a symbolic name for a signal.
Return "signal NNN" if there is no corresponding SIG name in the
signal module.
"""
if signames is None:
init_signames()
return signames.get(sig) or "signal %d" % sig
def init_signames():
global signames
signames = {}
for name, sig in signal.__dict__.items():
k_startswith = getattr(name, "startswith", None)
if k_startswith is None:
continue
if k_startswith("SIG") and not k_startswith("SIG_"):
signames[sig] = name
# Main program
def main(args=None):
options = ZEOOptions()
options.realize(args)
s = ZEOServer(options)
s.main()
if __name__ == "__main__":
main()
<schema>
<!-- note that zeoctl.xml is a closely related schema which should
match this schema, but should require the "runner" section -->
<description>
This schema describes the configuration of the ZEO storage server
process.
</description>
<!-- Use the storage types defined by ZODB. -->
<import package="ZODB"/>
<!-- Use the ZEO server information structure. -->
<import package="ZEO.tests.ZEO4"/>
<import package="ZConfig.components.logger"/>
<!-- runner control -->
<import package="zdaemon"/>
<section type="zeo" name="*" required="yes" attribute="zeo" />
<section type="runner" name="*" required="no" attribute="runner" />
<multisection name="*" type="ZODB.storage"
attribute="storages"
required="yes">
<description>
One or more storages that are provided by the ZEO server. The
section names are used as the storage names, and must be unique
within each ZEO storage server. Traditionally, these names
represent small integers starting at '1'.
</description>
</multisection>
<section name="*" type="eventlog" attribute="eventlog" required="no" />
</schema>
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
# zrpc is a package with the following modules
# client -- manages connection creation to remote server
# connection -- object dispatcher
# log -- logging helper
# error -- exceptions raised by zrpc
# marshal -- internal, handles basic protocol issues
# server -- manages incoming connections from remote clients
# smac -- sized message async connections
# trigger -- medusa's trigger
# zrpc is not an advertised subpackage of ZEO; its interfaces are internal
# This file is a slightly modified copy of Python 2.3's Lib/hmac.py.
# This file is under the Python Software Foundation (PSF) license.
"""HMAC (Keyed-Hashing for Message Authentication) Python module.
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC:
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
"""Create a new HMAC object.
key: key for the keyed hash object.
msg: Initial input for the hash, if provided.
digestmod: A module supporting PEP 247. Defaults to the md5 module.
"""
if digestmod is None:
import md5
digestmod = md5
self.digestmod = digestmod
self.outer = digestmod.new()
self.inner = digestmod.new()
# Python 2.1 and 2.2 differ about the correct spelling
try:
self.digest_size = digestmod.digestsize
except AttributeError:
self.digest_size = digestmod.digest_size
blocksize = 64
ipad = "\x36" * blocksize
opad = "\x5C" * blocksize
if len(key) > blocksize:
key = digestmod.new(key).digest()
key = key + chr(0) * (blocksize - len(key))
self.outer.update(_strxor(key, opad))
self.inner.update(_strxor(key, ipad))
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
"""
self.inner.update(msg)
def copy(self):
"""Return a separate copy of this hashing object.
An update to this copy won't affect the original object.
"""
other = HMAC("")
other.digestmod = self.digestmod
other.inner = self.inner.copy()
other.outer = self.outer.copy()
return other
def digest(self):
"""Return the hash value of this hashing object.
This returns a string containing 8-bit data. The object is
not altered in any way by this function; you can continue
updating the object after calling this function.
"""
h = self.outer.copy()
h.update(self.inner.digest())
return h.digest()
def hexdigest(self):
"""Like digest(), but returns a string of hexadecimal digits instead.
"""
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
msg: if available, will immediately be hashed into the object's starting
state.
You can now feed arbitrary strings into the object using its update()
method, and can ask for the hash value at any time by calling its digest()
method.
"""
return HMAC(key, msg, digestmod)
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import asyncore
import errno
import logging
import select
import socket
import sys
import threading
import time
from . import trigger
from .connection import ManagedClientConnection
from .log import log
from .error import DisconnectedError
from ZODB.POSException import ReadOnlyError
from ZODB.loglevels import BLATHER
from six.moves import map
def client_timeout():
return 30.0
def client_loop(map):
read = asyncore.read
write = asyncore.write
_exception = asyncore._exception
while map:
try:
# The next two lines intentionally don't use
# iterators. Other threads can close dispatchers, causeing
# the socket map to shrink.
r = e = map.keys()
w = [fd for (fd, obj) in map.items() if obj.writable()]
try:
r, w, e = select.select(r, w, e, client_timeout())
except (select.error, RuntimeError) as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno != errno.EINTR:
if err_errno == errno.EBADF:
# If a connection is closed while we are
# calling select on it, we can get a bad
# file-descriptor error. We'll check for this
# case by looking for entries in r and w that
# are not in the socket map.
if [fd for fd in r if fd not in map]:
continue
if [fd for fd in w if fd not in map]:
continue
# Hm, on Mac OS X, we could get a run time
# error and end up here, but retrying select
# would work. Let's try:
select.select(r, w, e, 0)
# we survived, keep going :)
continue
raise
else:
continue
if not map:
break
if not (r or w or e):
# The line intentionally doesn't use iterators. Other
# threads can close dispatchers, causeing the socket
# map to shrink.
for obj in map.values():
if isinstance(obj, ManagedClientConnection):
# Send a heartbeat message as a reply to a
# non-existent message id.
try:
obj.send_reply(-1, None)
except DisconnectedError:
pass
continue
for fd in r:
obj = map.get(fd)
if obj is None:
continue
read(obj)
for fd in w:
obj = map.get(fd)
if obj is None:
continue
write(obj)
for fd in e:
obj = map.get(fd)
if obj is None:
continue
_exception(obj)
except:
if map:
try:
logging.getLogger(__name__+'.client_loop').critical(
'A ZEO client loop failed.',
exc_info=sys.exc_info())
except:
pass
for fd, obj in map.items():
if not hasattr(obj, 'mgr'):
continue
try:
obj.mgr.client.close()
except:
map.pop(fd, None)
try:
logging.getLogger(__name__+'.client_loop'
).critical(
"Couldn't close a dispatcher.",
exc_info=sys.exc_info())
except:
pass
class ConnectionManager(object):
"""Keeps a connection up over time"""
sync_wait = 30
def __init__(self, addrs, client, tmin=1, tmax=180):
self.client = client
self._start_asyncore_loop()
self.addrlist = self._parse_addrs(addrs)
self.tmin = min(tmin, tmax)
self.tmax = tmax
self.cond = threading.Condition(threading.Lock())
self.connection = None # Protected by self.cond
self.closed = 0
# If thread is not None, then there is a helper thread
# attempting to connect.
self.thread = None # Protected by self.cond
def new_addrs(self, addrs):
self.addrlist = self._parse_addrs(addrs)
def _start_asyncore_loop(self):
self.map = {}
self.trigger = trigger.trigger(self.map)
self.loop_thread = threading.Thread(
name="%s zeo client networking thread" % self.client.__name__,
target=client_loop, args=(self.map,))
self.loop_thread.setDaemon(True)
self.loop_thread.start()
def __repr__(self):
return "<%s for %s>" % (self.__class__.__name__, self.addrlist)
def _parse_addrs(self, addrs):
# Return a list of (addr_type, addr) pairs.
# For backwards compatibility (and simplicity?) the
# constructor accepts a single address in the addrs argument --
# a string for a Unix domain socket or a 2-tuple with a
# hostname and port. It can also accept a list of such addresses.
addr_type = self._guess_type(addrs)
if addr_type is not None:
return [(addr_type, addrs)]
else:
addrlist = []
for addr in addrs:
addr_type = self._guess_type(addr)
if addr_type is None:
raise ValueError("unknown address in list: %s" % repr(addr))
addrlist.append((addr_type, addr))
return addrlist
def _guess_type(self, addr):
if isinstance(addr, str):
return socket.AF_UNIX
if (len(addr) == 2
and isinstance(addr[0], str)
and isinstance(addr[1], int)):
return socket.AF_INET # also denotes IPv6
# not anything I know about
return None
def close(self):
"""Prevent ConnectionManager from opening new connections"""
self.closed = 1
self.cond.acquire()
try:
t = self.thread
self.thread = None
finally:
self.cond.release()
if t is not None:
log("CM.close(): stopping and joining thread")
t.stop()
t.join(30)
if t.isAlive():
log("CM.close(): self.thread.join() timed out",
level=logging.WARNING)
for fd, obj in list(self.map.items()):
if obj is not self.trigger:
try:
obj.close()
except:
logging.getLogger(__name__+'.'+self.__class__.__name__
).critical(
"Couldn't close a dispatcher.",
exc_info=sys.exc_info())
self.map.clear()
self.trigger.pull_trigger()
try:
self.loop_thread.join(9)
except RuntimeError:
pass # we are the thread :)
self.trigger.close()
def attempt_connect(self):
"""Attempt a connection to the server without blocking too long.
There isn't a crisp definition for too long. When a
ClientStorage is created, it attempts to connect to the
server. If the server isn't immediately available, it can
operate from the cache. This method will start the background
connection thread and wait a little while to see if it
finishes quickly.
"""
# Will a single attempt take too long?
# Answer: it depends -- normally, you'll connect or get a
# connection refused error very quickly. Packet-eating
# firewalls and other mishaps may cause the connect to take a
# long time to time out though. It's also possible that you
# connect quickly to a slow server, and the attempt includes
# at least one roundtrip to the server (the register() call).
# But that's as fast as you can expect it to be.
self.connect()
self.cond.acquire()
try:
t = self.thread
conn = self.connection
finally:
self.cond.release()
if t is not None and conn is None:
event = t.one_attempt
event.wait()
self.cond.acquire()
try:
conn = self.connection
finally:
self.cond.release()
return conn is not None
def connect(self, sync=0):
self.cond.acquire()
try:
if self.connection is not None:
return
t = self.thread
if t is None:
log("CM.connect(): starting ConnectThread")
self.thread = t = ConnectThread(self, self.client)
t.setDaemon(1)
t.start()
if sync:
while self.connection is None and t.isAlive():
self.cond.wait(self.sync_wait)
if self.connection is None:
log("CM.connect(sync=1): still waiting...")
assert self.connection is not None
finally:
self.cond.release()
def connect_done(self, conn, preferred):
# Called by ConnectWrapper.notify_client() after notifying the client
log("CM.connect_done(preferred=%s)" % preferred)
self.cond.acquire()
try:
self.connection = conn
if preferred:
self.thread = None
self.cond.notifyAll() # Wake up connect(sync=1)
finally:
self.cond.release()
def close_conn(self, conn):
# Called by the connection when it is closed
self.cond.acquire()
try:
if conn is not self.connection:
# Closing a non-current connection
log("CM.close_conn() non-current", level=BLATHER)
return
log("CM.close_conn()")
self.connection = None
finally:
self.cond.release()
self.client.notifyDisconnected()
if not self.closed:
self.connect()
def is_connected(self):
self.cond.acquire()
try:
return self.connection is not None
finally:
self.cond.release()
# When trying to do a connect on a non-blocking socket, some outcomes
# are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected
# when an initial connect can't complete immediately. Set _CONNECT_OK
# to the errno value(s) expected if the connect succeeds *or* if it's
# already connected (our code can attempt redundant connects).
if hasattr(errno, "WSAEWOULDBLOCK"): # Windows
# Caution: The official Winsock docs claim that WSAEALREADY should be
# treated as yet another "in progress" indicator, but we've never
# seen this.
_CONNECT_IN_PROGRESS = (errno.WSAEWOULDBLOCK,)
# Win98: WSAEISCONN; Win2K: WSAEINVAL
_CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL)
else: # Unix
_CONNECT_IN_PROGRESS = (errno.EINPROGRESS,)
_CONNECT_OK = (0, errno.EISCONN)
class ConnectThread(threading.Thread):
"""Thread that tries to connect to server given one or more addresses.
The thread is passed a ConnectionManager and the manager's client
as arguments. It calls testConnection() on the client when a
socket connects; that should return 1 or 0 indicating whether this
is a preferred or a fallback connection. It may also raise an
exception, in which case the connection is abandoned.
The thread will continue to run, attempting connections, until a
preferred connection is seen and successfully handed over to the
manager and client.
As soon as testConnection() finds a preferred connection, or after
all sockets have been tried and at least one fallback connection
has been seen, notifyConnected(connection) is called on the client
and connect_done() on the manager. If this was a preferred
connection, the thread then exits; otherwise, it keeps trying
until it gets a preferred connection, and then reconnects the
client using that connection.
"""
__super_init = threading.Thread.__init__
# We don't expect clients to call any methods of this Thread other
# than close() and those defined by the Thread API.
def __init__(self, mgr, client):
self.__super_init(name="Connect(%s)" % mgr.addrlist)
self.mgr = mgr
self.client = client
self.stopped = 0
self.one_attempt = threading.Event()
# A ConnectThread keeps track of whether it has finished a
# call to try_connecting(). This allows the ConnectionManager
# to make an attempt to connect right away, but not block for
# too long if the server isn't immediately available.
def stop(self):
self.stopped = 1
def run(self):
delay = self.mgr.tmin
success = 0
# Don't wait too long the first time.
# TODO: make timeout configurable?
attempt_timeout = 5
while not self.stopped:
success = self.try_connecting(attempt_timeout)
if not self.one_attempt.isSet():
self.one_attempt.set()
attempt_timeout = 75
if success > 0:
break
time.sleep(delay)
if self.mgr.is_connected():
log("CT: still trying to replace fallback connection",
level=logging.INFO)
delay = min(delay*2, self.mgr.tmax)
log("CT: exiting thread: %s" % self.getName())
def try_connecting(self, timeout):
"""Try connecting to all self.mgr.addrlist addresses.
Return 1 if a preferred connection was found; 0 if no
connection was found; and -1 if a fallback connection was
found.
If no connection is found within timeout seconds, return 0.
"""
log("CT: attempting to connect on %d sockets" % len(self.mgr.addrlist))
deadline = time.time() + timeout
wrappers = self._create_wrappers()
for wrap in wrappers.keys():
if wrap.state == "notified":
return 1
try:
if time.time() > deadline:
return 0
r = self._connect_wrappers(wrappers, deadline)
if r is not None:
return r
if time.time() > deadline:
return 0
r = self._fallback_wrappers(wrappers, deadline)
if r is not None:
return r
# Alas, no luck.
assert not wrappers
finally:
for wrap in wrappers.keys():
wrap.close()
del wrappers
return 0
def _expand_addrlist(self):
for domain, addr in self.mgr.addrlist:
# AF_INET really means either IPv4 or IPv6, possibly
# indirected by DNS. By design, DNS lookup is deferred
# until connections get established, so that DNS
# reconfiguration can affect failover
if domain == socket.AF_INET:
host, port = addr
for (family, socktype, proto, cannoname, sockaddr
) in socket.getaddrinfo(host or 'localhost', port,
socket.AF_INET,
socket.SOCK_STREAM
): # prune non-TCP results
# for IPv6, drop flowinfo, and restrict addresses
# to [host]:port
yield family, sockaddr[:2]
else:
yield domain, addr
def _create_wrappers(self):
# Create socket wrappers
wrappers = {} # keys are active wrappers
for domain, addr in self._expand_addrlist():
wrap = ConnectWrapper(domain, addr, self.mgr, self.client)
wrap.connect_procedure()
if wrap.state == "notified":
for w in wrappers.keys():
w.close()
return {wrap: wrap}
if wrap.state != "closed":
wrappers[wrap] = wrap
return wrappers
def _connect_wrappers(self, wrappers, deadline):
# Next wait until they all actually connect (or fail)
# The deadline is necessary, because we'd wait forever if a
# sockets never connects or fails.
while wrappers:
if self.stopped:
for wrap in wrappers.keys():
wrap.close()
return 0
# Select connecting wrappers
connecting = [wrap
for wrap in wrappers.keys()
if wrap.state == "connecting"]
if not connecting:
break
if time.time() > deadline:
break
try:
r, w, x = select.select([], connecting, connecting, 1.0)
log("CT: select() %d, %d, %d" % tuple(map(len, (r,w,x))))
except select.error as msg:
log("CT: select failed; msg=%s" % str(msg),
level=logging.WARNING)
continue
# Exceptable wrappers are in trouble; close these suckers
for wrap in x:
log("CT: closing troubled socket %s" % str(wrap.addr))
del wrappers[wrap]
wrap.close()
# Writable sockets are connected
for wrap in w:
wrap.connect_procedure()
if wrap.state == "notified":
del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys():
wrap.close()
return 1
if wrap.state == "closed":
del wrappers[wrap]
def _fallback_wrappers(self, wrappers, deadline):
# If we've got wrappers left at this point, they're fallback
# connections. Try notifying them until one succeeds.
for wrap in list(wrappers.keys()):
assert wrap.state == "tested" and wrap.preferred == 0
if self.mgr.is_connected():
wrap.close()
else:
wrap.notify_client()
if wrap.state == "notified":
del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys():
wrap.close()
return -1
assert wrap.state == "closed"
del wrappers[wrap]
# TODO: should check deadline
class ConnectWrapper:
"""An object that handles the connection procedure for one socket.
This is a little state machine with states:
closed
opened
connecting
connected
tested
notified
"""
def __init__(self, domain, addr, mgr, client):
"""Store arguments and create non-blocking socket."""
self.domain = domain
self.addr = addr
self.mgr = mgr
self.client = client
# These attributes are part of the interface
self.state = "closed"
self.sock = None
self.conn = None
self.preferred = 0
log("CW: attempt to connect to %s" % repr(addr))
try:
self.sock = socket.socket(domain, socket.SOCK_STREAM)
except socket.error as err:
log("CW: can't create socket, domain=%s: %s" % (domain, err),
level=logging.ERROR)
self.close()
return
self.sock.setblocking(0)
self.state = "opened"
def connect_procedure(self):
"""Call sock.connect_ex(addr) and interpret result."""
if self.state in ("opened", "connecting"):
try:
err = self.sock.connect_ex(self.addr)
except socket.error as msg:
log("CW: connect_ex(%r) failed: %s" % (self.addr, msg),
level=logging.ERROR)
self.close()
return
log("CW: connect_ex(%s) returned %s" %
(self.addr, errno.errorcode.get(err) or str(err)))
if err in _CONNECT_IN_PROGRESS:
self.state = "connecting"
return
if err not in _CONNECT_OK:
log("CW: error connecting to %s: %s" %
(self.addr, errno.errorcode.get(err) or str(err)),
level=logging.WARNING)
self.close()
return
self.state = "connected"
if self.state == "connected":
self.test_connection()
def test_connection(self):
"""Establish and test a connection at the zrpc level.
Call the client's testConnection(), giving the client a chance
to do app-level check of the connection.
"""
self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr)
self.sock = None # The socket is now owned by the connection
try:
self.preferred = self.client.testConnection(self.conn)
self.state = "tested"
except ReadOnlyError:
log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr))
self.close()
return
except:
log("CW: error in testConnection (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True)
self.close()
return
if self.preferred:
self.notify_client()
def notify_client(self):
"""Call the client's notifyConnected().
If this succeeds, call the manager's connect_done().
If the client is already connected, we assume it's a fallback
connection, and the new connection must be a preferred
connection. The client will close the old connection.
"""
try:
self.client.notifyConnected(self.conn)
except:
log("CW: error in notifyConnected (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True)
self.close()
return
self.state = "notified"
self.mgr.connect_done(self.conn, self.preferred)
def close(self):
"""Close the socket and reset everything."""
self.state = "closed"
self.mgr = self.client = None
self.preferred = 0
if self.conn is not None:
# Closing the ZRPC connection will eventually close the
# socket, somewhere in asyncore. Guido asks: Why do we care?
self.conn.close()
self.conn = None
if self.sock is not None:
self.sock.close()
self.sock = None
def fileno(self):
return self.sock.fileno()
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import asyncore
import errno
import json
import sys
import threading
import logging
from . import marshal
from . import trigger
from . import smac
from .error import ZRPCError, DisconnectedError
from .log import short_repr, log
from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException
REPLY = ".reply" # message name used for replies
exception_type_type = type(Exception)
debug_zrpc = False
class Delay:
"""Used to delay response to client for synchronous calls.
When a synchronous call is made and the original handler returns
without handling the call, it returns a Delay object that prevents
the mainloop from sending a response.
"""
msgid = conn = sent = None
def set_sender(self, msgid, conn):
self.msgid = msgid
self.conn = conn
def reply(self, obj):
self.sent = 'reply'
self.conn.send_reply(self.msgid, obj)
def error(self, exc_info):
self.sent = 'error'
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.conn.return_error(self.msgid, *exc_info[:2])
def __repr__(self):
return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self), self.msgid, self.conn, self.sent)
class Result(Delay):
def __init__(self, *args):
self.args = args
def set_sender(self, msgid, conn):
reply, callback = self.args
conn.send_reply(msgid, reply, False)
callback()
class MTDelay(Delay):
def __init__(self):
self.ready = threading.Event()
def set_sender(self, *args):
Delay.set_sender(self, *args)
self.ready.set()
def reply(self, obj):
self.ready.wait()
self.conn.call_from_thread(self.conn.send_reply, self.msgid, obj)
def error(self, exc_info):
self.ready.wait()
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.conn.call_from_thread(Delay.error, self, exc_info)
# PROTOCOL NEGOTIATION
#
# The code implementing protocol version 2.0.0 (which is deployed
# in the field and cannot be changed) *only* talks to peers that
# send a handshake indicating protocol version 2.0.0. In that
# version, both the client and the server immediately send out
# their protocol handshake when a connection is established,
# without waiting for their peer, and disconnect when a different
# handshake is receive.
#
# The new protocol uses this to enable new clients to talk to
# 2.0.0 servers. In the new protocol:
#
# The server sends its protocol handshake to the client at once.
#
# The client waits until it receives the server's protocol handshake
# before sending its own handshake. The client sends the lower of its
# own protocol version and the server protocol version, allowing it to
# talk to servers using later protocol versions (2.0.2 and higher) as
# well: the effective protocol used will be the lower of the client
# and server protocol. However, this changed in ZODB 3.3.1 (and
# should have changed in ZODB 3.3) because an older server doesn't
# support MVCC methods required by 3.3 clients.
#
# [Ugly details: In order to treat the first received message (protocol
# handshake) differently than all later messages, both client and server
# start by patching their message_input() method to refer to their
# recv_handshake() method instead. In addition, the client has to arrange
# to queue (delay) outgoing messages until it receives the server's
# handshake, so that the first message the client sends to the server is
# the client's handshake. This multiply-special treatment of the first
# message is delicate, and several asyncore and thread subtleties were
# handled unsafely before ZODB 3.2.6.
# ]
#
# The ZEO modules ClientStorage and ServerStub have backwards
# compatibility code for dealing with the previous version of the
# protocol. The client accepts the old version of some messages,
# and will not send new messages when talking to an old server.
#
# As long as the client hasn't sent its handshake, it can't send
# anything else; output messages are queued during this time.
# (Output can happen because the connection testing machinery can
# start sending requests before the handshake is received.)
#
# UPGRADING FROM ZEO 2.0.0 TO NEWER VERSIONS:
#
# Because a new client can talk to an old server, but not vice
# versa, all clients should be upgraded before upgrading any
# servers. Protocol upgrades beyond 2.0.1 will not have this
# restriction, because clients using protocol 2.0.1 or later can
# talk to both older and newer servers.
#
# No compatibility with protocol version 1 is provided.
# Connection is abstract (it must be derived from). ManagedServerConnection
# and ManagedClientConnection are the concrete subclasses. They need to
# supply a handshake() method appropriate for their role in protocol
# negotiation.
class Connection(smac.SizedMessageAsyncConnection, object):
"""Dispatcher for RPC on object on both sides of socket.
The connection supports synchronous calls, which expect a return,
and asynchronous calls, which do not.
It uses the Marshaller class to handle encoding and decoding of
method calls and arguments. Marshaller uses pickle to encode
arbitrary Python objects. The code here doesn't ever see the wire
format.
A Connection is designed for use in a multithreaded application,
where a synchronous call must block until a response is ready.
A socket connection between a client and a server allows either
side to invoke methods on the other side. The processes on each
end of the socket use a Connection object to manage communication.
The Connection deals with decoded RPC messages. They are
represented as four-tuples containing: msgid, flags, method name,
and a tuple of method arguments.
The msgid starts at zero and is incremented by one each time a
method call message is sent. Each side of the connection has a
separate msgid state.
When one side of the connection (the client) calls a method, it
sends a message with a new msgid. The other side (the server),
replies with a message that has the same msgid, the string
".reply" (the global variable REPLY) as the method name, and the
actual return value in the args position. Note that each side of
the Connection can initiate a call, in which case it will be the
client for that particular call.
The protocol also supports asynchronous calls. The client does
not wait for a return value for an asynchronous call.
If a method call raises an Exception, the exception is propagated
back to the client via the REPLY message. The client side will
raise any exception it receives instead of returning the value to
the caller.
"""
__super_init = smac.SizedMessageAsyncConnection.__init__
__super_close = smac.SizedMessageAsyncConnection.close
__super_setSessionKey = smac.SizedMessageAsyncConnection.setSessionKey
# Protocol history:
#
# Z200 -- Original ZEO 2.0 protocol
#
# Z201 -- Added invalidateTransaction() to client.
# Renamed several client methods.
# Added several sever methods:
# lastTransaction()
# getAuthProtocol() and scheme-specific authentication methods
# getExtensionMethods().
# getInvalidations().
#
# Z303 -- named after the ZODB release 3.3
# Added methods for MVCC:
# loadBefore()
# A Z303 client cannot talk to a Z201 server, because the latter
# doesn't support MVCC. A Z201 client can talk to a Z303 server,
# but because (at least) the type of the root object changed
# from ZODB.PersistentMapping to persistent.mapping, the older
# client can't actually make progress if a Z303 client created,
# or ever modified, the root.
#
# Z308 -- named after the ZODB release 3.8
# Added blob-support server methods:
# sendBlob
# storeBlobStart
# storeBlobChunk
# storeBlobEnd
# storeBlobShared
# Added blob-support client methods:
# receiveBlobStart
# receiveBlobChunk
# receiveBlobStop
#
# Z309 -- named after the ZODB release 3.9
# New server methods:
# restorea, iterator_start, iterator_next,
# iterator_record_start, iterator_record_next,
# iterator_gc
#
# Z310 -- named after the ZODB release 3.10
# New server methods:
# undoa
# Doesn't support undo for older clients.
# Undone oid info returned by vote.
#
# Z3101 -- checkCurrentSerialInTransaction
#
# Z4 -- checkCurrentSerialInTransaction
# No-longer call load.
# Protocol variables:
# Our preferred protocol.
current_protocol = b"Z4"
# If we're a client, an exhaustive list of the server protocols we
# can accept.
servers_we_can_talk_to = [b"Z308", b"Z309", b"Z310", b"Z3101",
current_protocol]
# If we're a server, an exhaustive list of the client protocols we
# can accept.
clients_we_can_talk_to = [
b"Z200", b"Z201", b"Z303", b"Z308", b"Z309", b"Z310", b"Z3101",
current_protocol]
# This is pretty excruciating. Details:
#
# 3.3 server 3.2 client
# server sends Z303 to client
# client computes min(Z303, Z201) == Z201 as the protocol to use
# client sends Z201 to server
# OK, because Z201 is in the server's clients_we_can_talk_to
#
# 3.2 server 3.3 client
# server sends Z201 to client
# client computes min(Z303, Z201) == Z201 as the protocol to use
# Z201 isn't in the client's servers_we_can_talk_to, so client
# raises exception
#
# 3.3 server 3.3 client
# server sends Z303 to client
# client computes min(Z303, Z303) == Z303 as the protocol to use
# Z303 is in the client's servers_we_can_talk_to, so client
# sends Z303 to server
# OK, because Z303 is in the server's clients_we_can_talk_to
# Exception types that should not be logged:
unlogged_exception_types = ()
# Client constructor passes b'C' for tag, server constructor b'S'. This
# is used in log messages, and to determine whether we can speak with
# our peer.
def __init__(self, sock, addr, obj, tag, map=None):
self.obj = None
self.decode = marshal.decode
self.encode = marshal.encode
self.fast_encode = marshal.fast_encode
self.closed = False
self.peer_protocol_version = None # set in recv_handshake()
assert tag in b"CS"
self.tag = tag
self.logger = logging.getLogger('ZEO.zrpc.Connection(%r)' % tag)
if isinstance(addr, tuple):
self.log_label = "(%s:%d) " % addr
else:
self.log_label = "(%s) " % addr
# Supply our own socket map, so that we don't get registered with
# the asyncore socket map just yet. The initial protocol messages
# are treated very specially, and we dare not get invoked by asyncore
# before that special-case setup is complete. Some of that setup
# occurs near the end of this constructor, and the rest is done by
# a concrete subclass's handshake() method. Unfortunately, because
# we ultimately derive from asyncore.dispatcher, it's not possible
# to invoke the superclass constructor without asyncore stuffing
# us into _some_ socket map.
ourmap = {}
self.__super_init(sock, addr, map=ourmap)
# The singleton dict is used in synchronous mode when a method
# needs to call into asyncore to try to force some I/O to occur.
# The singleton dict is a socket map containing only this object.
self._singleton = {self._fileno: self}
# waiting_for_reply is used internally to indicate whether
# a call is in progress. setting a session key is deferred
# until after the call returns.
self.waiting_for_reply = False
self.delay_sesskey = None
self.register_object(obj)
# The first message we see is a protocol handshake. message_input()
# is temporarily replaced by recv_handshake() to treat that message
# specially. revc_handshake() does "del self.message_input", which
# uncovers the normal message_input() method thereafter.
self.message_input = self.recv_handshake
# Server and client need to do different things for protocol
# negotiation, and handshake() is implemented differently in each.
self.handshake()
# Now it's safe to register with asyncore's socket map; it was not
# safe before message_input was replaced, or before handshake() was
# invoked.
# Obscure: in Python 2.4, the base asyncore.dispatcher class grew
# a ._map attribute, which is used instead of asyncore's global
# socket map when ._map isn't None. Because we passed `ourmap` to
# the base class constructor above, in 2.4 asyncore believes we want
# to use `ourmap` instead of the global socket map -- but we don't.
# So we have to replace our ._map with the global socket map, and
# update the global socket map with `ourmap`. Replacing our ._map
# isn't necessary before Python 2.4, but doesn't hurt then (it just
# gives us an unused attribute in 2.3); updating the global socket
# map is necessary regardless of Python version.
if map is None:
map = asyncore.socket_map
self._map = map
map.update(ourmap)
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.addr)
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=BLATHER, exc_info=False):
self.logger.log(level, self.log_label + message, exc_info=exc_info)
def close(self):
self.mgr.close_conn(self)
if self.closed:
return
self._singleton.clear()
self.closed = True
self.__super_close()
self.trigger.pull_trigger()
def register_object(self, obj):
"""Register obj as the true object to invoke methods on."""
self.obj = obj
# Subclass must implement. handshake() is called by the constructor,
# near its end, but before self is added to asyncore's socket map.
# When a connection is created the first message sent is a 4-byte
# protocol version. This allows the protocol to evolve over time, and
# lets servers handle clients using multiple versions of the protocol.
# In general, the server's handshake() just needs to send the server's
# preferred protocol; the client's also needs to queue (delay) outgoing
# messages until it sees the handshake from the server.
def handshake(self):
raise NotImplementedError
# Replaces message_input() for the first message received. Records the
# protocol sent by the peer in `peer_protocol_version`, restores the
# normal message_input() method, and raises an exception if the peer's
# protocol is unacceptable. That's all the server needs to do. The
# client needs to do additional work in response to the server's
# handshake, and extends this method.
def recv_handshake(self, proto):
# Extended by ManagedClientConnection.
del self.message_input # uncover normal-case message_input()
self.peer_protocol_version = proto
if self.tag == b'C':
good_protos = self.servers_we_can_talk_to
else:
assert self.tag == b'S'
good_protos = self.clients_we_can_talk_to
if proto in good_protos:
self.log("received handshake %r" % proto, level=logging.INFO)
else:
self.log("bad handshake %s" % short_repr(proto),
level=logging.ERROR)
raise ZRPCError("bad handshake %r" % proto)
def message_input(self, message):
"""Decode an incoming message and dispatch it"""
# If something goes wrong during decoding, the marshaller
# will raise an exception. The exception will ultimately
# result in asycnore calling handle_error(), which will
# close the connection.
msgid, async, name, args = self.decode(message)
if debug_zrpc:
self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name,
short_repr(args)),
level=TRACE)
if name == 'loadEx':
# Special case and inline the heck out of load case:
try:
ret = self.obj.loadEx(*args)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as msg:
if not isinstance(msg, self.unlogged_exception_types):
self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True)
self.return_error(msgid, *sys.exc_info()[:2])
else:
try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll()
except:
# Fall back to normal version for better error handling
self.send_reply(msgid, ret)
elif name == REPLY:
assert not async
self.handle_reply(msgid, args)
else:
self.handle_request(msgid, async, name, args)
def handle_request(self, msgid, async, name, args):
obj = self.obj
if name.startswith('_') or not hasattr(obj, name):
if obj is None:
if debug_zrpc:
self.log("no object calling %s%s"
% (name, short_repr(args)),
level=logging.DEBUG)
return
msg = "Invalid method name: %s on %s" % (name, repr(obj))
raise ZRPCError(msg)
if debug_zrpc:
self.log("calling %s%s" % (name, short_repr(args)),
level=logging.DEBUG)
meth = getattr(obj, name)
try:
self.waiting_for_reply = True
try:
ret = meth(*args)
finally:
self.waiting_for_reply = False
except (SystemExit, KeyboardInterrupt):
raise
except Exception as msg:
if not isinstance(msg, self.unlogged_exception_types):
self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True)
error = sys.exc_info()[:2]
if async:
self.log("Asynchronous call raised exception: %s" % self,
level=logging.ERROR, exc_info=True)
else:
self.return_error(msgid, *error)
return
if async:
if ret is not None:
raise ZRPCError("async method %s returned value %s" %
(name, short_repr(ret)))
else:
if debug_zrpc:
self.log("%s returns %s" % (name, short_repr(ret)),
logging.DEBUG)
if isinstance(ret, Delay):
ret.set_sender(msgid, self)
else:
self.send_reply(msgid, ret, not self.delay_sesskey)
if self.delay_sesskey:
self.__super_setSessionKey(self.delay_sesskey)
self.delay_sesskey = None
def return_error(self, msgid, err_type, err_value):
# Note that, ideally, this should be defined soley for
# servers, but a test arranges to get it called on
# a client. Too much trouble to fix it now. :/
if not isinstance(err_value, Exception):
err_value = err_type, err_value
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above
try:
r = short_repr(err_value)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
self.poll()
def handle_error(self):
if sys.exc_info()[0] == SystemExit:
raise sys.exc_info()
self.log("Error caught in asyncore",
level=logging.ERROR, exc_info=True)
self.close()
def setSessionKey(self, key):
if self.waiting_for_reply:
self.delay_sesskey = key
else:
self.__super_setSessionKey(key)
def send_call(self, method, args, async=False):
# send a message and return its msgid
if async:
msgid = 0
else:
msgid = self._new_msgid()
if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, async, method),
level=TRACE)
buf = self.encode(msgid, async, method, args)
self.message_output(buf)
return msgid
def callAsync(self, method, *args):
if self.closed:
raise DisconnectedError()
self.send_call(method, args, 1)
self.poll()
def callAsyncNoPoll(self, method, *args):
# Like CallAsync but doesn't poll. This exists so that we can
# send invalidations atomically to all clients without
# allowing any client to sneak in a load request.
if self.closed:
raise DisconnectedError()
self.send_call(method, args, 1)
def callAsyncNoSend(self, method, *args):
# Like CallAsync but doesn't poll. This exists so that we can
# send invalidations atomically to all clients without
# allowing any client to sneak in a load request.
if self.closed:
raise DisconnectedError()
self.send_call(method, args, 1)
self.call_from_thread()
def callAsyncIterator(self, iterator):
"""Queue a sequence of calls using an iterator
The calls will not be interleaved with other calls from the same
client.
"""
self.message_output(self.encode(0, 1, method, args)
for method, args in iterator)
def handle_reply(self, msgid, ret):
assert msgid == -1 and ret is None
def poll(self):
"""Invoke asyncore mainloop to get pending message out."""
if debug_zrpc:
self.log("poll()", level=TRACE)
self.trigger.pull_trigger()
# import cProfile, time
class ManagedServerConnection(Connection):
"""Server-side Connection subclass."""
# Exception types that should not be logged:
unlogged_exception_types = (ZODB.POSException.POSKeyError, )
def __init__(self, sock, addr, obj, mgr):
self.mgr = mgr
map = {}
Connection.__init__(self, sock, addr, obj, b'S', map=map)
self.decode = marshal.server_decode
self.trigger = trigger.trigger(map)
self.call_from_thread = self.trigger.pull_trigger
t = threading.Thread(target=server_loop, args=(map,))
t.setName("ManagedServerConnection thread")
t.setDaemon(True)
t.start()
# self.profile = cProfile.Profile()
# def message_input(self, message):
# self.profile.enable()
# try:
# Connection.message_input(self, message)
# finally:
# self.profile.disable()
def handshake(self):
# Send the server's preferred protocol to the client.
self.message_output(self.current_protocol)
def recv_handshake(self, proto):
if proto == b'ruok':
self.message_output(json.dumps(self.mgr.ruok()).encode("ascii"))
self.poll()
Connection.close(self)
else:
Connection.recv_handshake(self, proto)
self.obj.notifyConnected(self)
def close(self):
self.obj.notifyDisconnected()
Connection.close(self)
# self.profile.dump_stats(str(time.time())+'.stats')
def send_reply(self, msgid, ret, immediately=True):
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, ret)
except: # see above
try:
r = short_repr(ret)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
if immediately:
self.poll()
poll = smac.SizedMessageAsyncConnection.handle_write
def server_loop(map):
while len(map) > 1:
try:
asyncore.poll(30.0, map)
except Exception as v:
if v.args[0] != errno.EBADF:
raise
for o in tuple(map.values()):
o.close()
class ManagedClientConnection(Connection):
"""Client-side Connection subclass."""
__super_init = Connection.__init__
base_message_output = Connection.message_output
def __init__(self, sock, addr, mgr):
self.mgr = mgr
# We can't use the base smac's message_output directly because the
# client needs to queue outgoing messages until it's seen the
# initial protocol handshake from the server. So we have our own
# message_ouput() method, and support for initial queueing. This is
# a delicate design, requiring an output mutex to be wholly
# thread-safe.
# Caution: we must set this up before calling the base class
# constructor, because the latter registers us with asyncore;
# we need to guarantee that we'll queue outgoing messages before
# asyncore learns about us.
self.output_lock = threading.Lock()
self.queue_output = True
self.queued_messages = []
# msgid_lock guards access to msgid
self.msgid = 0
self.msgid_lock = threading.Lock()
# replies_cond is used to block when a synchronous call is
# waiting for a response
self.replies_cond = threading.Condition()
self.replies = {}
self.__super_init(sock, addr, None, tag=b'C', map=mgr.map)
self.trigger = mgr.trigger
self.call_from_thread = self.trigger.pull_trigger
self.call_from_thread()
def close(self):
Connection.close(self)
self.replies_cond.acquire()
self.replies_cond.notifyAll()
self.replies_cond.release()
# Our message_ouput() queues messages until recv_handshake() gets the
# protocol handshake from the server.
def message_output(self, message):
self.output_lock.acquire()
try:
if self.queue_output:
self.queued_messages.append(message)
else:
assert not self.queued_messages
self.base_message_output(message)
finally:
self.output_lock.release()
def handshake(self):
# The client waits to see the server's handshake. Outgoing messages
# are queued for the duration. The client will send its own
# handshake after the server's handshake is seen, in recv_handshake()
# below. It will then send any messages queued while waiting.
assert self.queue_output # the constructor already set this
def recv_handshake(self, proto):
# The protocol to use is the older of our and the server's preferred
# protocols.
proto = min(proto, self.current_protocol)
# Restore the normal message_input method, and raise an exception
# if the protocol version is too old.
Connection.recv_handshake(self, proto)
# Tell the server the protocol in use, then send any messages that
# were queued while waiting to hear the server's protocol, and stop
# queueing messages.
self.output_lock.acquire()
try:
self.base_message_output(proto)
for message in self.queued_messages:
self.base_message_output(message)
self.queued_messages = []
self.queue_output = False
finally:
self.output_lock.release()
def _new_msgid(self):
self.msgid_lock.acquire()
try:
msgid = self.msgid
self.msgid = self.msgid + 1
return msgid
finally:
self.msgid_lock.release()
def call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if debug_zrpc:
self.log("wait(%d)" % msgid, level=TRACE)
self.trigger.pull_trigger()
self.replies_cond.acquire()
try:
while 1:
if self.closed:
raise DisconnectedError()
reply = self.replies.get(msgid, self)
if reply is not self:
del self.replies[msgid]
if debug_zrpc:
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=TRACE)
return reply
self.replies_cond.wait()
finally:
self.replies_cond.release()
# For testing purposes, it is useful to begin a synchronous call
# but not block waiting for its response.
def _deferred_call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
self.trigger.pull_trigger()
return msgid
def _deferred_wait(self, msgid):
r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def handle_reply(self, msgid, args):
if debug_zrpc:
self.log("recv reply: %s, %s"
% (msgid, short_repr(args)), level=TRACE)
self.replies_cond.acquire()
try:
self.replies[msgid] = args
self.replies_cond.notifyAll()
finally:
self.replies_cond.release()
def send_reply(self, msgid, ret):
# Whimper. Used to send heartbeat
assert msgid == -1 and ret is None
self.message_output(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.')
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError):
pass
class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server.
The error occurred because a problem in the low-level RPC connection,
or because the connection was closed.
"""
# This subclass is raised when zrpc catches the error.
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import os
import threading
import logging
from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid()
def new_label():
global _label
_label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label
if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60
def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes."
# Some of the objects being repr'd are large strings. A lot of memory
# would be wasted to repr them and then truncate, so they are treated
# specially in this function.
# Also handle short repr of a tuple containing a long string.
# This strategy works well for arguments to StorageServer methods.
# The oid is usually first and will get included in its entirety.
# The pickle is near the beginning, too, and you can often fit the
# module name in the pickle.
if isinstance(obj, str):
if len(obj) > REPR_LIMIT:
r = repr(obj[:REPR_LIMIT])
else:
r = repr(obj)
if len(r) > REPR_LIMIT:
r = r[:REPR_LIMIT-4] + '...' + r[-1]
return r
elif isinstance(obj, (list, tuple)):
elts = []
size = 0
for elt in obj:
r = short_repr(elt)
elts.append(r)
size += len(r)
if size > REPR_LIMIT:
break
if isinstance(obj, tuple):
r = "(%s)" % (", ".join(elts))
else:
r = "[%s]" % (", ".join(elts))
else:
r = repr(obj)
if len(r) > REPR_LIMIT:
return r[:REPR_LIMIT] + '...'
else:
return r
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import logging
from ZEO._compat import Unpickler, Pickler, BytesIO, PY3, PYPY
from .error import ZRPCError
from .log import log, short_repr
def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( )
# It's not thread safe if, in the couse of pickling, we call the
# Python interpeter, which releases the GIL.
# Note that args may contain very large binary pickles already; for
# this reason, it's important to use proto 1 (or higher) pickles here
# too. For a long time, this used proto 0 pickles, and that can
# bloat our pickle to 4x the size (due to high-bit and control bytes
# being represented by \xij escapes in proto 0).
# Undocumented: cPickle.Pickler accepts a lone protocol argument;
# pickle.py does not.
if PY3:
# XXX: Py3: Needs optimization.
f = BytesIO()
pickler = Pickler(f, 3)
pickler.fast = 1
pickler.dump(args)
res = f.getvalue()
return res
else:
pickler = Pickler(1)
pickler.fast = 1
# Only CPython's cPickle supports dumping
# and returning in one operation:
# return pickler.dump(args, 1)
# For PyPy we must return the value; fortunately this
# works the same on CPython and is no more expensive
pickler.dump(args)
return pickler.getvalue()
if PY3:
# XXX: Py3: Needs optimization.
fast_encode = encode
elif PYPY:
# can't use the python-2 branch, need a new pickler
# every time, getvalue() only works once
fast_encode = encode
else:
def fast_encode():
# Only use in cases where you *know* the data contains only basic
# Python objects
pickler = Pickler(1)
pickler.fast = 1
dump = pickler.dump
def fast_encode(*args):
return dump(args, 1)
return fast_encode
fast_encode = fast_encode()
def decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global
try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
def server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global
try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
_globals = globals()
_silly = ('__doc__',)
exception_type_type = type(Exception)
def find_global(module, name):
"""Helper for message unpickler"""
try:
m = __import__(module, _globals, _globals, _silly)
except ImportError as msg:
raise ZRPCError("import error %s: %s" % (module, msg))
try:
r = getattr(m, name)
except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0)
if safe:
return r
# TODO: is there a better way to do this?
if type(r) == exception_type_type and issubclass(r, Exception):
return r
raise ZRPCError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name):
"""Helper for message unpickler"""
try:
if module != 'ZopeUndo.Prefix':
raise ImportError
m = __import__(module, _globals, _globals, _silly)
except ImportError as msg:
raise ZRPCError("import error %s: %s" % (module, msg))
try:
r = getattr(m, name)
except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name))
return r
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import asyncore
import socket
# _has_dualstack: True if the dual-stack sockets are supported
try:
# Check whether IPv6 sockets can be created
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
except (socket.error, AttributeError):
_has_dualstack = False
else:
# Check whether enabling dualstack (disabling v6only) works
try:
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
except (socket.error, AttributeError):
_has_dualstack = False
else:
_has_dualstack = True
s.close()
del s
from .connection import Connection
from .log import log
from .log import logger
import logging
# Export the main asyncore loop
loop = asyncore.loop
class Dispatcher(asyncore.dispatcher):
"""A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__
def __init__(self, addr, factory=Connection, map=None):
self.__super_init(map=map)
self.addr = addr
self.factory = factory
self._open_socket()
def _open_socket(self):
if type(self.addr) == tuple:
if self.addr[0] == '' and _has_dualstack:
# Wildcard listen on all interfaces, both IPv4 and
# IPv6 if possible
self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt(
socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
elif ':' in self.addr[0]:
self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
if _has_dualstack:
# On Linux, IPV6_V6ONLY is off by default.
# If the user explicitly asked for IPv6, don't bind to IPv4
self.socket.setsockopt(
socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
else:
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr()
log("listening on %s" % str(self.addr), logging.INFO)
for i in range(25):
try:
self.bind(self.addr)
except Exception as exc:
log("bind failed %s waiting", i)
if i == 24:
raise
else:
time.sleep(5)
else:
break
self.listen(5)
def writable(self):
return 0
def readable(self):
return 1
def handle_accept(self):
try:
sock, addr = self.accept()
except socket.error as msg:
log("accepted failed: %s" % msg)
return
# We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below,
# quick_close_doesnt_kill_server, causes addr to be None and
# we'd have to write a test for the non-None case, which is
# *even* harder to provoke. :/ So we'll leave things as they
# are for now.
# It might be better to check whether the socket has been
# closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above.
addr = addr[:2]
try:
c = self.factory(sock, addr)
except:
if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()]
logger.exception("Error in handle_accept")
else:
log("connect from %s: %s" % (repr(addr), c))
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Sized Message Async Connections.
This class extends the basic asyncore layer with a record-marking
layer. The message_output() method accepts an arbitrary sized string
as its argument. It sends over the wire the length of the string
encoded using struct.pack('>I') and the string itself. The receiver
passes the original string to message_input().
This layer also supports an optional message authentication code
(MAC). If a session key is present, it uses HMAC-SHA-1 to generate a
20-byte MAC. If a MAC is present, the high-order bit of the length
is set to 1 and the MAC immediately follows the length.
"""
import asyncore
import errno
import six
try:
import hmac
except ImportError:
from . import _hmac as hmac
import socket
import struct
import threading
from .log import log
from .error import DisconnectedError
from .. import hash as ZEO_hash
# Use the dictionary to make sure we get the minimum number of errno
# entries. We expect that EWOULDBLOCK == EAGAIN on most systems --
# or that only one is actually used.
tmp_dict = {errno.EWOULDBLOCK: 0,
errno.EAGAIN: 0,
errno.EINTR: 0,
}
expected_socket_read_errors = tuple(tmp_dict.keys())
tmp_dict = {errno.EAGAIN: 0,
errno.EWOULDBLOCK: 0,
errno.ENOBUFS: 0,
errno.EINTR: 0,
}
expected_socket_write_errors = tuple(tmp_dict.keys())
del tmp_dict
# We chose 60000 as the socket limit by looking at the largest strings
# that we could pass to send() without blocking.
SEND_SIZE = 60000
MAC_BIT = 0x80000000
_close_marker = object()
class SizedMessageAsyncConnection(asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close
__closed = True # Marker indicating that we're closed
socket = None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None):
self.addr = addr
# __input_lock protects __inp, __input_len, __state, __msg_size
self.__input_lock = threading.Lock()
self.__inp = None # None, a single String, or a list
self.__input_len = 0
# Instance variables __state, __msg_size and __has_mac work together:
# when __state == 0:
# __msg_size == 4, and the next thing read is a message size;
# __has_mac is set according to the MAC_BIT in the header
# when __state == 1:
# __msg_size is variable, and the next thing read is a message.
# __has_mac indicates if we're in MAC mode or not (and
# therefore, if we need to check the mac header)
# The next thing read is always of length __msg_size.
# The state alternates between 0 and 1.
self.__state = 0
self.__has_mac = 0
self.__msg_size = 4
self.__output_messages = []
self.__output = []
self.__closed = False
# Each side of the connection sends and receives messages. A
# MAC is generated for each message and depends on each
# previous MAC; the state of the MAC generator depends on the
# history of operations it has performed. So the MACs must be
# generated in the same order they are verified.
# Each side is guaranteed to receive messages in the order
# they are sent, but there is no ordering constraint between
# message sends and receives. If the two sides are A and B
# and message An indicates the nth message sent by A, then
# A1 A2 B1 B2 and A1 B1 B2 A2 are both legitimate total
# orderings of the messages.
# As a result, there must be seperate MAC generators for each
# side of the connection. If not, the generator state would
# be different after A1 A2 B1 B2 than it would be after
# A1 B1 B2 A2; if the generator state was different, the MAC
# could not be verified.
self.__hmac_send = None
self.__hmac_recv = None
self.__super_init(sock, map)
# asyncore overwrites addr with the getpeername result
# restore our value
self.addr = addr
def setSessionKey(self, sesskey):
log("set session key %r" % sesskey)
# Low-level construction is now delayed until data are sent.
# This is to allow use of iterators that generate messages
# only when we're ready to do I/O so that we can effeciently
# transmit large files. Because we delay messages, we also
# have to delay setting the session key to retain proper
# ordering.
# The low-level output queue supports strings, a special close
# marker, and iterators. It doesn't support callbacks. We
# can create a allback by providing an iterator that doesn't
# yield anything.
# The hack fucntion below is a callback in iterator's
# clothing. :) It never yields anything, but is a generator
# and thus iterator, because it contains a yield statement.
def hack():
self.__hmac_send = hmac.HMAC(sesskey, digestmod=ZEO_hash)
self.__hmac_recv = hmac.HMAC(sesskey, digestmod=ZEO_hash)
if False:
yield b''
self.message_output(hack())
def get_addr(self):
return self.addr
# TODO: avoid expensive getattr calls? Can't remember exactly what
# this comment was supposed to mean, but it has something to do
# with the way asyncore uses getattr and uses if sock:
def __nonzero__(self):
return 1
def handle_read(self):
self.__input_lock.acquire()
try:
# Use a single __inp buffer and integer indexes to make this fast.
try:
d = self.recv(8192)
except socket.error as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_read_errors:
return
raise
if not d:
return
input_len = self.__input_len + len(d)
msg_size = self.__msg_size
state = self.__state
has_mac = self.__has_mac
inp = self.__inp
if msg_size > input_len:
if inp is None:
self.__inp = d
elif isinstance(self.__inp, six.binary_type):
self.__inp = [self.__inp, d]
else:
self.__inp.append(d)
self.__input_len = input_len
return # keep waiting for more input
# load all previous input and d into single string inp
if isinstance(inp, six.binary_type):
inp = inp + d
elif inp is None:
inp = d
else:
inp.append(d)
inp = b"".join(inp)
offset = 0
while (offset + msg_size) <= input_len:
msg = inp[offset:offset + msg_size]
offset = offset + msg_size
if not state:
msg_size = struct.unpack(">I", msg)[0]
has_mac = msg_size & MAC_BIT
if has_mac:
msg_size ^= MAC_BIT
msg_size += 20
elif self.__hmac_send:
raise ValueError("Received message without MAC")
state = 1
else:
msg_size = 4
state = 0
# Obscure: We call message_input() with __input_lock
# held!!! And message_input() may end up calling
# message_output(), which has its own lock. But
# message_output() cannot call message_input(), so
# the locking order is always consistent, which
# prevents deadlock. Also, message_input() may
# take a long time, because it can cause an
# incoming call to be handled. During all this
# time, the __input_lock is held. That's a good
# thing, because it serializes incoming calls.
if has_mac:
mac = msg[:20]
msg = msg[20:]
if self.__hmac_recv:
self.__hmac_recv.update(msg)
_mac = self.__hmac_recv.digest()
if mac != _mac:
raise ValueError("MAC failed: %r != %r"
% (_mac, mac))
else:
log("Received MAC but no session key set")
elif self.__hmac_send:
raise ValueError("Received message without MAC")
self.message_input(msg)
self.__state = state
self.__has_mac = has_mac
self.__msg_size = msg_size
self.__inp = inp[offset:]
self.__input_len = input_len - offset
finally:
self.__input_lock.release()
def readable(self):
return True
def writable(self):
return bool(self.__output_messages or self.__output)
def should_close(self):
self.__output_messages.append(_close_marker)
def handle_write(self):
output = self.__output
messages = self.__output_messages
while output or messages:
# Process queued messages until we have enough output
size = sum((len(s) for s in output))
while (size <= SEND_SIZE) and messages:
message = messages[0]
if isinstance(message, six.binary_type):
size += self.__message_output(messages.pop(0), output)
elif isinstance(message, six.text_type):
# XXX This can silently lead to data loss and client hangs
# if asserts aren't enabled. Encountered this under Python3
# and 'ruok' protocol
assert False, "Got a unicode message: %s" % repr(message)
elif message is _close_marker:
del messages[:]
del output[:]
return self.close()
else:
try:
message = six.advance_iterator(message)
except StopIteration:
messages.pop(0)
else:
assert(isinstance(message, six.binary_type))
size += self.__message_output(message, output)
v = b"".join(output)
del output[:]
try:
n = self.send(v)
except socket.error as err:
# Fix for https://bugs.launchpad.net/zodb/+bug/182833
# ensure the above mentioned "output" invariant
output.insert(0, v)
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_write_errors:
break # we couldn't write anything
raise
if n < len(v):
output.append(v[n:])
break # we can't write any more
def handle_close(self):
self.close()
def message_output(self, message):
if self.__closed:
raise DisconnectedError(
"This action is temporarily unavailable.<p>")
self.__output_messages.append(message)
def __message_output(self, message, output):
# do two separate appends to avoid copying the message string
size = 4
if self.__hmac_send:
output.append(struct.pack(">I", len(message) | MAC_BIT))
self.__hmac_send.update(message)
output.append(self.__hmac_send.digest())
size += 20
else:
output.append(struct.pack(">I", len(message)))
if len(message) <= SEND_SIZE:
output.append(message)
else:
for i in range(0, len(message), SEND_SIZE):
output.append(message[i:i+SEND_SIZE])
return size + len(message)
def close(self):
if not self.__closed:
self.__closed = True
self.__super_close()
from __future__ import print_function
##############################################################################
#
# Copyright (c) 2001-2005 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
from __future__ import with_statement
import asyncore
import os
import socket
import errno
from ZODB.utils import positive_id
from ZEO._compat import thread, get_ident
# Original comments follow; they're hard to follow in the context of
# ZEO's use of triggers. TODO: rewrite from a ZEO perspective.
# Wake up a call to select() running in the main thread.
#
# This is useful in a context where you are using Medusa's I/O
# subsystem to deliver data, but the data is generated by another
# thread. Normally, if Medusa is in the middle of a call to
# select(), new output data generated by another thread will have
# to sit until the call to select() either times out or returns.
# If the trigger is 'pulled' by another thread, it should immediately
# generate a READ event on the trigger object, which will force the
# select() invocation to return.
#
# A common use for this facility: letting Medusa manage I/O for a
# large number of connections; but routing each request through a
# thread chosen from a fixed-size thread pool. When a thread is
# acquired, a transaction is performed, but output data is
# accumulated into buffers that will be emptied more efficiently
# by Medusa. [picture a server that can process database queries
# rapidly, but doesn't want to tie up threads waiting to send data
# to low-bandwidth connections]
#
# The other major feature provided by this class is the ability to
# move work back into the main thread: if you call pull_trigger()
# with a thunk argument, when select() wakes up and receives the
# event it will call your thunk from within that thread. The main
# purpose of this is to remove the need to wrap thread locks around
# Medusa's data structures, which normally do not need them. [To see
# why this is true, imagine this scenario: A thread tries to push some
# new data onto a channel's outgoing data queue at the same time that
# the main thread is trying to remove some]
class _triggerbase(object):
"""OS-independent base class for OS-dependent trigger class."""
kind = None # subclass must set to "pipe" or "loopback"; used by repr
def __init__(self):
self._closed = False
# `lock` protects the `thunks` list from being traversed and
# appended to simultaneously.
self.lock = thread.allocate_lock()
# List of no-argument callbacks to invoke when the trigger is
# pulled. These run in the thread running the asyncore mainloop,
# regardless of which thread pulls the trigger.
self.thunks = []
def readable(self):
return 1
def writable(self):
return 0
def handle_connect(self):
pass
def handle_close(self):
self.close()
# Override the asyncore close() method, because it doesn't know about
# (so can't close) all the gimmicks we have open. Subclass must
# supply a _close() method to do platform-specific closing work. _close()
# will be called iff we're not already closed.
def close(self):
if not self._closed:
self._closed = True
self.del_channel()
self._close() # subclass does OS-specific stuff
def _close(self): # see close() above; subclass must supply
raise NotImplementedError
def pull_trigger(self, *thunk):
if thunk:
with self.lock:
self.thunks.append(thunk)
try:
self._physical_pull()
except Exception:
if not self._closed:
raise
# Subclass must supply _physical_pull, which does whatever the OS
# needs to do to provoke the "write" end of the trigger.
def _physical_pull(self):
raise NotImplementedError
def handle_read(self):
try:
self.recv(8192)
except socket.error:
return
while 1:
with self.lock:
if self.thunks:
thunk = self.thunks.pop(0)
else:
return
try:
thunk[0](*thunk[1:])
except:
nil, t, v, tbinfo = asyncore.compact_traceback()
print(('exception in trigger thunk:'
' (%s:%s %s)' % (t, v, tbinfo)))
def __repr__(self):
return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self))
if os.name == 'posix':
class trigger(_triggerbase, asyncore.file_dispatcher):
kind = "pipe"
def __init__(self, map=None):
_triggerbase.__init__(self)
r, self.trigger = os.pipe()
asyncore.file_dispatcher.__init__(self, r, map)
if self.socket.fd != r:
# Starting in Python 2.6, the descriptor passed to
# file_dispatcher gets duped and assigned to
# self.socket.fd. This breals the instantiation semantics and
# is a bug imo. I dount it will get fixed, but maybe
# it will. Who knows. For that reason, we test for the
# fd changing rather than just checking the Python version.
os.close(r)
def _close(self):
os.close(self.trigger)
asyncore.file_dispatcher.close(self)
def _physical_pull(self):
os.write(self.trigger, b'x')
else:
# Windows version; uses just sockets, because a pipe isn't select'able
# on Windows.
class BindError(Exception):
pass
class trigger(_triggerbase, asyncore.dispatcher):
kind = "loopback"
def __init__(self, map=None):
_triggerbase.__init__(self)
# Get a pair of connected sockets. The trigger is the 'w'
# end of the pair, which is connected to 'r'. 'r' is put
# in the asyncore socket map. "pulling the trigger" then
# means writing something on w, which will wake up r.
w = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up asyncore's
# select() ASAP.
w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
w.connect(connect_address)
break # success
except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
w.close()
raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
r, addr = a.accept() # r becomes asyncore's (self.)socket
a.close()
self.trigger = w
asyncore.dispatcher.__init__(self, r, map)
def _close(self):
# self.socket is r, and self.trigger is w, from __init__
self.socket.close()
self.trigger.close()
def _physical_pull(self):
self.trigger.send('x')
......@@ -19,11 +19,13 @@ import random
import sys
import time
import errno
import multiprocessing
import socket
import subprocess
import logging
import tempfile
import six
from six.moves.queue import Empty
import ZODB.tests.util
import zope.testing.setupstack
from ZEO._compat import StringIO
......@@ -32,6 +34,13 @@ logger = logging.getLogger('ZEO.tests.forker')
DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG')
ZEO4_SERVER = os.environ.get('ZEO4_SERVER')
skip_if_testing_client_against_zeo4 = (
(lambda func: None)
if ZEO4_SERVER else
(lambda func: func)
)
class ZEOConfig:
"""Class to generate ZEO configuration file. """
......@@ -104,17 +113,24 @@ def runner(config, qin, qout, timeout=None,
))
try:
import ZEO.runzeo, threading
from six.moves.queue import Empty
import threading
options = ZEO.runzeo.ZEOOptions()
if ZEO4_SERVER:
from .ZEO4 import runzeo
else:
from .. import runzeo
options = runzeo.ZEOOptions()
options.realize(['-C', config])
server = ZEO.runzeo.ZEOServer(options)
server = runzeo.ZEOServer(options)
globals()[(name if name else 'last') + '_server'] = server
server.open_storages()
server.clear_socket()
server.create_server()
logger.debug('SERVER CREATED')
if ZEO4_SERVER:
qout.put(server.server.addr)
else:
qout.put(server.server.acceptor.addr)
logger.debug('ADDRESS SENT')
thread = threading.Thread(
......@@ -125,7 +141,7 @@ def runner(config, qin, qout, timeout=None,
thread.start()
try:
qin.get(timeout=timeout)
qin.get(timeout=timeout) # wait for shutdown
except Empty:
pass
server.server.close()
......@@ -140,10 +156,6 @@ def runner(config, qin, qout, timeout=None,
pass
qout.put(thread.is_alive())
qin.get(timeout=11) # ack
if hasattr(qout, 'close'):
qout.close()
qout.cancel_join_thread()
except Exception:
logger.exception("In server thread")
......@@ -156,7 +168,6 @@ def runner(config, qin, qout, timeout=None,
def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
qin.put('stop')
dirty = qout.get(timeout=stop_timeout)
qin.put('ack')
if dirty:
print("WARNING SERVER DIDN'T STOP CLEANLY", file=sys.stderr)
......@@ -171,9 +182,6 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
thread.join(stop_timeout)
os.remove(config)
if hasattr(qin, 'close'):
qin.close()
qin.cancel_join_thread()
def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
path='Data.fs', protocol=None, blob_dir=None,
......@@ -222,7 +230,7 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
from six.moves.queue import Queue
else:
from multiprocessing import Process as Thread
from multiprocessing import Queue
Queue = ThreadlessQueue
qin = Queue()
qout = Queue()
......@@ -405,3 +413,17 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG):
def whine(*message):
print(*message, file=sys.stderr)
sys.stderr.flush()
class ThreadlessQueue(object):
def __init__(self):
self.cin, self.cout = multiprocessing.Pipe(False)
def put(self, v):
self.cout.send(v)
def get(self, timeout=None):
if self.cin.poll(timeout):
return self.cin.recv()
else:
raise Empty()
......@@ -27,9 +27,12 @@ if os.environ.get('USE_ZOPE_TESTING_DOCTEST'):
else:
import doctest
import unittest
import ZEO.tests.forker
import ZODB.tests.util
import ZEO
from . import forker
class FileStorageConfig:
def getConfig(self, path, create, read_only):
return """\
......@@ -79,12 +82,6 @@ class MappingStorageConnectionTests(
):
"""Mapping storage connection tests."""
class SSLConnectionTests(
MappingStorageConfig,
ConnectionTests.SSLConnectionTests,
):
pass
# The ReconnectionTests can't work with MappingStorage because it's only an
# in-memory storage and has no persistent state.
......@@ -107,8 +104,9 @@ test_classes = [FileStorageConnectionTests,
FileStorageTimeoutTests,
MappingStorageConnectionTests,
MappingStorageTimeoutTests,
SSLConnectionTests,
]
if not forker.ZEO4_SERVER:
test_classes.append(SSLConnectionTests)
def invalidations_while_connecting():
r"""
......@@ -129,7 +127,7 @@ This tests tries to provoke this bug by:
- opening a client to the server that writes some objects, filling
it's cache at the same time,
>>> import ZODB.tests.MinPO, transaction
>>> import ZEO, ZODB.tests.MinPO, transaction
>>> db = ZEO.DB(addr, client='x')
>>> conn = db.open()
>>> nobs = 1000
......@@ -205,9 +203,9 @@ This tests tries to provoke this bug by:
... record = handler.records.pop(0)
... print(record.name, record.levelname, end=' ')
... print(handler.format(record))
... if bad:
... with open('server-%s.log' % addr[1]) as f:
... print(f.read())
... #if bad:
... # with open('server.log') as f:
... # print(f.read())
... #else:
... # logging.getLogger('ZEO').debug('GOOD %s' % c)
... db.close()
......@@ -236,7 +234,7 @@ def test_suite():
sub = unittest.makeSuite(klass, 'check')
suite.addTest(sub)
suite.addTest(doctest.DocTestSuite(
setUp=ZEO.tests.forker.setUp, tearDown=setupstack.tearDown,
setUp=forker.setUp, tearDown=setupstack.tearDown,
))
suite.layer = ZODB.tests.util.MininalTestLayer('ZEO Connection Tests')
return suite
......@@ -340,7 +340,6 @@ class FileStorageTests(FullGenericTests):
self._storage._info['interfaces']
)
class FileStorageSSLTests(FileStorageTests):
def getZEOConfig(self):
......@@ -1123,6 +1122,7 @@ def convenient_to_pass_port_to_client_and_ZEO_dot_client():
>>> client.close()
"""
@forker.skip_if_testing_client_against_zeo4
def test_server_status():
"""
You can get server status using the server_status method.
......@@ -1147,6 +1147,7 @@ def test_server_status():
>>> db.close()
"""
@forker.skip_if_testing_client_against_zeo4
def test_ruok():
"""
You can also get server status using the ruok protocol.
......@@ -1453,6 +1454,7 @@ class MultiprocessingTests(unittest.TestCase):
conn.close()
zope.testing.setupstack.tearDown(self)
@forker.skip_if_testing_client_against_zeo4
def quick_close_doesnt_kill_server():
r"""
......@@ -1500,9 +1502,11 @@ def can_use_empty_string_for_local_host_on_client():
slow_test_classes = [
BlobAdaptedFileStorageTests, BlobWritableCacheTests,
MappingStorageTests, DemoStorageTests,
FileStorageTests, FileStorageSSLTests,
FileStorageTests,
FileStorageHexTests, FileStorageClientHexTests,
]
if not forker.ZEO4_SERVER:
slow_test_classes.append(FileStorageSSLTests)
quick_test_classes = [FileStorageRecoveryTests, ZRPCConnectionTests]
......@@ -1582,6 +1586,8 @@ def test_suite():
"ClientDisconnected"),
)),
))
if not forker.ZEO4_SERVER:
# ZEO 4 doesn't support client-side conflict resolution
zeo.addTest(unittest.makeSuite(ClientConflictResolutionTests, 'check'))
zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc')
suite.addTest(zeo)
......@@ -1592,7 +1598,16 @@ def test_suite():
'zdoptions.test',
'drop_cache_rather_than_verify.txt', 'client-config.test',
'protocols.test', 'zeo_blob_cache.test', 'invalidation-age.txt',
'dynamic_server_ports.test', '../nagios.rst',
'../nagios.rst',
setUp=forker.setUp, tearDown=zope.testing.setupstack.tearDown,
checker=renormalizing.RENormalizing(patterns),
globs={'print_function': print_function},
),
)
if not forker.ZEO4_SERVER:
zeo.addTest(
doctest.DocFileSuite(
'dynamic_server_ports.test',
setUp=forker.setUp, tearDown=zope.testing.setupstack.tearDown,
checker=renormalizing.RENormalizing(patterns),
globs={'print_function': print_function},
......
......@@ -8,6 +8,9 @@ import unittest
import ZEO.StorageServer
from . import forker
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class ClientAuthTests(setupstack.TestCase):
def setUp(self):
......@@ -50,7 +53,6 @@ class ClientAuthTests(setupstack.TestCase):
stop()
def test_suite():
return unittest.makeSuite(ClientAuthTests)
......@@ -9,6 +9,7 @@ from ZODB.broken import find_global
import ZEO
from . import forker
from .utils import StorageServer
class Var(object):
......@@ -16,10 +17,10 @@ class Var(object):
self.value = other
return True
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
def test_server_side(self):
# First, verify default conflict resolution.
server = StorageServer(self, DemoStorage())
zs = server.zs
......
......@@ -10,6 +10,7 @@ from ..Exceptions import ClientDisconnected
from .. import runzeo
from .testConfig import ZEOConfigTestBase
from . import forker
here = os.path.dirname(__file__)
server_cert = os.path.join(here, 'server.pem')
......@@ -19,6 +20,7 @@ serverpw_key = os.path.join(here, 'serverpw_key.pem')
client_cert = os.path.join(here, 'client.pem')
client_key = os.path.join(here, 'client_key.pem')
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class SSLConfigTest(ZEOConfigTestBase):
def test_ssl_basic(self):
......@@ -114,6 +116,7 @@ class SSLConfigTest(ZEOConfigTestBase):
)
stop()
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
@mock.patch(('asyncio' if PY3 else 'trollius') + '.async')
@mock.patch(('asyncio' if PY3 else 'trollius') + '.set_event_loop')
@mock.patch(('asyncio' if PY3 else 'trollius') + '.new_event_loop')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment