Commit 9e154d03 authored by Jeremy Hylton's avatar Jeremy Hylton

Undo merge of ZEO-ZRPC-Dev branch into the trunk (for now).

parent 48bcb3a7
...@@ -144,20 +144,18 @@ file 0 and file 1. ...@@ -144,20 +144,18 @@ file 0 and file 1.
""" """
__version__ = "$Revision: 1.19 $"[11:-2] __version__ = "$Revision: 1.20 $"[11:-2]
import os, tempfile import os, tempfile
from struct import pack, unpack from struct import pack, unpack
from thread import allocate_lock from thread import allocate_lock
import sys
import zLOG import zLOG
def log(msg, level=zLOG.INFO):
zLOG.LOG("ZEC", level, msg)
magic='ZEC0' magic='ZEC0'
def LOG(msg, level=zLOG.BLATHER):
zLOG.LOG("ZEC", level, msg)
class ClientCache: class ClientCache:
def __init__(self, storage='', size=20000000, client=None, var=None): def __init__(self, storage='', size=20000000, client=None, var=None):
...@@ -213,14 +211,16 @@ class ClientCache: ...@@ -213,14 +211,16 @@ class ClientCache:
f[0].write(magic) f[0].write(magic)
current=0 current=0
log("cache opened. current = %s" % current)
self._limit=size/2 self._limit=size/2
self._current=current self._current=current
def close(self):
try:
self._f[self._current].close()
except (os.error, ValueError):
pass
def open(self): def open(self):
# XXX open is overloaded to perform two tasks for
# optimization reasons
self._acquire() self._acquire()
try: try:
self._index=index={} self._index=index={}
...@@ -235,19 +235,6 @@ class ClientCache: ...@@ -235,19 +235,6 @@ class ClientCache:
return serial.items() return serial.items()
finally: self._release() finally: self._release()
def close(self):
for f in self._f:
if f is not None:
f.close()
def verify(self, verifyFunc):
"""Call the verifyFunc on every object in the cache.
verifyFunc(oid, serialno, version)
"""
for oid, (s, vs) in self.open():
verifyFunc(oid, s, vs)
def invalidate(self, oid, version): def invalidate(self, oid, version):
self._acquire() self._acquire()
try: try:
...@@ -386,6 +373,8 @@ class ClientCache: ...@@ -386,6 +373,8 @@ class ClientCache:
self._f[current]=open(self._p[current],'w+b') self._f[current]=open(self._p[current],'w+b')
else: else:
# Temporary cache file: # Temporary cache file:
if self._f[current] is not None:
self._f[current].close()
self._f[current] = tempfile.TemporaryFile(suffix='.zec') self._f[current] = tempfile.TemporaryFile(suffix='.zec')
self._f[current].write(magic) self._f[current].write(magic)
self._pos=pos=4 self._pos=pos=4
...@@ -394,57 +383,55 @@ class ClientCache: ...@@ -394,57 +383,55 @@ class ClientCache:
def store(self, oid, p, s, version, pv, sv): def store(self, oid, p, s, version, pv, sv):
self._acquire() self._acquire()
try: try: self._store(oid, p, s, version, pv, sv)
self._store(oid, p, s, version, pv, sv) finally: self._release()
finally:
self._release()
def _store(self, oid, p, s, version, pv, sv): def _store(self, oid, p, s, version, pv, sv):
if not s: if not s:
p = '' p=''
s = '\0\0\0\0\0\0\0\0' s='\0\0\0\0\0\0\0\0'
tlen = 31 + len(p) tlen=31+len(p)
if version: if version:
tlen = tlen + len(version) + 12 + len(pv) tlen=tlen+len(version)+12+len(pv)
vlen = len(version) vlen=len(version)
else: else:
vlen = 0 vlen=0
stlen = pack(">I", tlen) pos=self._pos
# accumulate various data to write into a list current=self._current
l = [oid, 'v', stlen, pack(">HI", vlen, len(p)), s] f=self._f[current]
if p: f.seek(pos)
l.append(p) stlen=pack(">I",tlen)
write=f.write
write(oid+'v'+stlen+pack(">HI", vlen, len(p))+s)
if p: write(p)
if version: if version:
l.extend([version, write(version)
pack(">I", len(pv)), write(pack(">I", len(pv)))
pv, sv]) write(pv)
l.append(stlen) write(sv)
f = self._f[self._current]
f.seek(self._pos) write(stlen)
f.write("".join(l))
if self._current:
self._index[oid] = - self._pos
else:
self._index[oid] = self._pos
self._pos += tlen if current: self._index[oid]=-pos
else: self._index[oid]=pos
self._pos=pos+tlen
def read_index(index, serial, f, current): def read_index(index, serial, f, current):
LOG("read_index(%s)" % f.name)
seek=f.seek seek=f.seek
read=f.read read=f.read
pos=4 pos=4
seek(0,2)
size=f.tell()
while 1: while 1:
f.seek(pos) seek(pos)
h=read(27) h=read(27)
if len(h)==27 and h[8] in 'vni': if len(h)==27 and h[8] in 'vni':
tlen, vlen, dlen = unpack(">iHi", h[9:19]) tlen, vlen, dlen = unpack(">iHi", h[9:19])
else: tlen=-1 else:
break
if tlen <= 0 or vlen < 0 or dlen < 0 or vlen+dlen > tlen: if tlen <= 0 or vlen < 0 or dlen < 0 or vlen+dlen > tlen:
break break
...@@ -479,3 +466,15 @@ def read_index(index, serial, f, current): ...@@ -479,3 +466,15 @@ def read_index(index, serial, f, current):
except: pass except: pass
return pos return pos
def main(files):
for file in files:
print file
index = {}
serial = {}
read_index(index, serial, open(file), 0)
print index.keys()
if __name__ == "__main__":
import sys
main(sys.argv[1:])
...@@ -83,167 +83,178 @@ ...@@ -83,167 +83,178 @@
# #
############################################################################## ##############################################################################
"""Network ZODB storage client """Network ZODB storage client
XXX support multiple outstanding requests up until the vote?
XXX is_connected() vis ClientDisconnected error
""" """
__version__='$Revision: 1.36 $'[11:-2]
import cPickle
import os
import socket
import string
import struct
import sys
import tempfile
import thread
import threading
import time
from types import TupleType, StringType
from struct import pack, unpack
import ExtensionClass, Sync, ThreadLock __version__='$Revision: 1.37 $'[11:-2]
import ClientCache
import zrpc2 import struct, time, os, socket, string, Sync, zrpc, ClientCache
import ServerStub import tempfile, Invalidator, ExtensionClass, thread
from TransactionBuffer import TransactionBuffer import ThreadedAsync
from ZODB import POSException now=time.time
from struct import pack, unpack
from ZODB import POSException, BaseStorage
from ZODB.TimeStamp import TimeStamp from ZODB.TimeStamp import TimeStamp
from zLOG import LOG, PROBLEM, INFO, BLATHER from zLOG import LOG, PROBLEM, INFO
from Exceptions import Disconnected
def log2(type, msg, subsys="ClientStorage %d" % os.getpid()): try: from ZODB.ConflictResolution import ResolvedSerial
LOG(subsys, type, msg) except: ResolvedSerial='rs'
try: TupleType=type(())
from ZODB.ConflictResolution import ResolvedSerial
except ImportError:
ResolvedSerial = 'rs'
class ClientStorageError(POSException.StorageError): class ClientStorageError(POSException.StorageError):
"""An error occured in the ZEO Client Storage""" """An error occured in the ZEO Client Storage"""
class UnrecognizedResult(ClientStorageError): class UnrecognizedResult(ClientStorageError):
"""A server call returned an unrecognized result""" """A server call returned an unrecognized result
"""
class ClientDisconnected(ClientStorageError, Disconnected):
"""The database storage is disconnected from the storage."""
def get_timestamp(prev_ts): class ClientDisconnected(ClientStorageError):
t = time.time() """The database storage is disconnected from the storage.
t = apply(TimeStamp, (time.gmtime(t)[:5] + (t % 60,))) """
t = t.laterThan(prev_ts)
return t
class DisconnectedServerStub:
"""Raise ClientDisconnected on all attribute access."""
def __getattr__(self, attr):
raise ClientDisconnected()
disconnected_stub = DisconnectedServerStub() class ClientStorage(ExtensionClass.Base, BaseStorage.BaseStorage):
class ClientStorage: _connected=_async=0
__begin='tpc_begin_sync'
def __init__(self, addr, storage='1', cache_size=20000000, def __init__(self, connection, storage='1', cache_size=20000000,
name='', client='', debug=0, var=None, name='', client='', debug=0, var=None,
min_disconnect_poll=5, max_disconnect_poll=300, min_disconnect_poll=5, max_disconnect_poll=300,
wait_for_server_on_startup=0, read_only=0): wait_for_server_on_startup=1):
self._server = disconnected_stub
self._is_read_only = read_only
self._storage = storage
self._info = {'length': 0, 'size': 0, 'name': 'ZEO Client',
'supportsUndo':0, 'supportsVersions': 0}
self._tbuf = TransactionBuffer() # Decide whether to use non-temporary files
self._db = None client=client or os.environ.get('ZEO_CLIENT','')
self._oids = []
# XXX It's confusing to have _serial, _serials, and _seriald.
self._serials = []
self._seriald = {}
self._basic_init(name or str(addr)) self._connection=connection
self._storage=storage
self._debug=debug
self._wait_for_server_on_startup=wait_for_server_on_startup
# Decide whether to use non-temporary files self._info={'length': 0, 'size': 0, 'name': 'ZEO Client',
client = client or os.environ.get('ZEO_CLIENT', '') 'supportsUndo':0, 'supportsVersions': 0,
self._cache = ClientCache.ClientCache(storage, cache_size, }
client=client, var=var)
self._cache.open() # XXX
self._rpc_mgr = zrpc2.ConnectionManager(addr, self, self._call=zrpc.asyncRPC(connection, debug=debug,
#debug=debug,
tmin=min_disconnect_poll, tmin=min_disconnect_poll,
tmax=max_disconnect_poll) tmax=max_disconnect_poll)
# XXX What if we can only get a read-only connection and we name = name or str(connection)
# want a read-write connection? Looks like the current code
# will block forever.
if wait_for_server_on_startup: self.closed = 0
self._rpc_mgr.connect(sync=1) self._tfile=tempfile.TemporaryFile()
else: self._oids=[]
if not self._rpc_mgr.attempt_connect(): self._serials=[]
self._rpc_mgr.connect() self._seriald={}
def _basic_init(self, name): ClientStorage.inheritedAttribute('__init__')(self, name)
"""Handle initialization activites of BaseStorage"""
self.__name__ = name self.__lock_acquire=self._lock_acquire
# A ClientStorage only allows one client to commit at a time. self._cache=ClientCache.ClientCache(
# A client enters the commit state by finding tpc_tid set to storage, cache_size, client=client, var=var)
# None and updating it to the new transaction's id. The
# tpc_tid variable is protected by tpc_cond.
self.tpc_cond = threading.Condition()
self._transaction = None
# Prevent multiple new_oid calls from going out. The _oids
# variable should only be modified while holding the
# oid_cond.
self.oid_cond = threading.Condition()
commit_lock = thread.allocate_lock() ThreadedAsync.register_loop_callback(self.becomeAsync)
self._commit_lock_acquire = commit_lock.acquire
self._commit_lock_release = commit_lock.release
t = time.time() # IMPORTANT: Note that we aren't fully "there" yet.
t = self._ts = apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,))) # In particular, we don't actually connect to the server
self._serial = `t` # until we have a controlling database set with registerDB
self._oid='\0\0\0\0\0\0\0\0' # below.
def registerDB(self, db, limit): def registerDB(self, db, limit):
"""Register that the storage is controlled by the given DB.""" """Register that the storage is controlled by the given DB.
log2(INFO, "registerDB(%s, %s)" % (repr(db), repr(limit))) """
self._db = db
def is_connected(self): # Among other things, we know that our data methods won't get
if self._server is disconnected_stub: # called until after this call.
return 0
else: self.invalidator = Invalidator.Invalidator(db.invalidate,
return 1 self._cache.invalidate)
def out_of_band_hook(
code, args,
get_hook={
'b': (self.invalidator.begin, 0),
'i': (self.invalidator.invalidate, 1),
'e': (self.invalidator.end, 0),
'I': (self.invalidator.Invalidate, 1),
'U': (self._commit_lock_release, 0),
's': (self._serials.append, 1),
'S': (self._info.update, 1),
}.get):
hook = get_hook(code, None)
if hook is None: return
hook, flag = hook
if flag: hook(args)
else: hook()
self._call.setOutOfBand(out_of_band_hook)
def notifyConnected(self, c): # Now that we have our callback system in place, we can
log2(INFO, "Connected to storage") # try to connect
stub = ServerStub.StorageServer(c)
self._oids = [] self._startup()
# XXX Why is this synchronous? If it were async, verification def _startup(self):
# would start faster.
stub.register(str(self._storage), self._is_read_only)
self.verify_cache(stub)
# Don't make the server available to clients until after if not self._call.connect(not self._wait_for_server_on_startup):
# validating the cache
self._server = stub # If we can't connect right away, go ahead and open the cache
# and start a separate thread to try and reconnect.
LOG("ClientStorage", PROBLEM, "Failed to connect to storage")
self._cache.open()
thread.start_new_thread(self._call.connect,(0,))
# If the connect succeeds then this work will be done by
# notifyConnected
def notifyConnected(self, s):
LOG("ClientStorage", INFO, "Connected to storage")
self._lock_acquire()
try:
# We let the connection keep coming up now that
# we have the storage lock. This way, we know no calls
# will be made while in the process of coming up.
self._call.finishConnect(s)
if self.closed:
return
self._connected=1
self._oids=[]
# we do synchronous commits until we are sure that
# we have and are ready for a main loop.
# Hm. This is a little silly. If self._async, then
# we will really never do a synchronous commit.
# See below.
self.__begin='tpc_begin_sync'
self._call.message_output(str(self._storage))
### This seems silly. We should get the info asynchronously.
# self._info.update(self._call('get_info'))
cached=self._cache.open()
### This is a little expensive for large caches
if cached:
self._call.sendMessage('beginZeoVerify')
for oid, (s, vs) in cached:
self._call.sendMessage('zeoVerify', oid, s, vs)
self._call.sendMessage('endZeoVerify')
finally: self._lock_release()
if self._async:
import asyncore
self.becomeAsync(asyncore.socket_map)
def verify_cache(self, server):
server.beginZeoVerify()
self._cache.verify(server.zeoVerify)
server.endZeoVerify()
### Is there a race condition between notifyConnected and ### Is there a race condition between notifyConnected and
### notifyDisconnected? In Particular, what if we get ### notifyDisconnected? In Particular, what if we get
...@@ -257,345 +268,363 @@ class ClientStorage: ...@@ -257,345 +268,363 @@ class ClientStorage:
### in the middle of notifyDisconnected, because *it's* ### in the middle of notifyDisconnected, because *it's*
### responsible for starting the thread that makes the connection. ### responsible for starting the thread that makes the connection.
def notifyDisconnected(self): def notifyDisconnected(self, ignored):
log2(PROBLEM, "Disconnected from storage") LOG("ClientStorage", PROBLEM, "Disconnected from storage")
self._server = disconnected_stub self._connected=0
if self._transaction: self._transaction=None
self._transaction = None thread.start_new_thread(self._call.connect,(0,))
self.tpc_cond.notifyAll() if self._transaction is not None:
self.tpc_cond.release()
def __len__(self):
return self._info['length']
def getName(self):
return "%s (%s)" % (self.__name__, "XXX")
def getSize(self):
return self._info['size']
def supportsUndo(self):
return self._info['supportsUndo']
def supportsVersions(self):
return self._info['supportsVersions']
def supportsTransactionalUndo(self):
try: try:
return self._info['supportsTransactionalUndo'] self._commit_lock_release()
except KeyError: except:
return 0 pass
def isReadOnly(self):
return self._is_read_only
def _check_trans(self, trans, exc=None):
if self._transaction is not trans:
if exc is None:
return 0
else:
raise exc(self._transaction, trans)
return 1
def _check_tid(self, tid, exc=None):
# XXX Is all this locking unnecessary? The only way to
# begin a transaction is to call tpc_begin(). If we assume
# clients are single-threaded and well-behaved, i.e. they call
# tpc_begin() first, then there appears to be no need for
# locking. If _check_tid() is called and self.tpc_tid != tid,
# then there is no way it can be come equal during the call.
# Thus, there should be no race.
if self.tpc_tid != tid:
if exc is None:
return 0
else:
raise exc(self.tpc_tid, tid)
return 1
# XXX But I'm not sure
self.tpc_cond.acquire() def becomeAsync(self, map):
self._lock_acquire()
try: try:
if self.tpc_tid != tid: self._async=1
if exc is None: if self._connected:
return 0 self._call.setLoop(map, getWakeup())
else: self.__begin='tpc_begin'
raise exc(self.tpc_tid, tid) finally: self._lock_release()
return 1
finally: def __len__(self): return self._info['length']
self.tpc_cond.release()
def abortVersion(self, src, transaction): def abortVersion(self, src, transaction):
if self._is_read_only: if transaction is not self._transaction:
raise POSException.ReadOnlyError() raise POSException.StorageTransactionError(self, transaction)
self._check_trans(transaction, self._lock_acquire()
POSException.StorageTransactionError) try:
oids = self._server.abortVersion(src, self._serial) oids=self._call('abortVersion', src, self._serial)
vlen = pack(">H", len(src))
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, src) self._tfile.write("i%s%s%s" % (oid, vlen, src))
return oids return oids
finally: self._lock_release()
def close(self): def close(self):
self._rpc_mgr.close() self._lock_acquire()
if self._cache is not None: try:
LOG("ClientStorage", INFO, "close")
self._call.closeIntensionally()
try:
self._tfile.close()
except os.error:
# On Windows, this can fail if it is called more than
# once, because it tries to delete the file each
# time.
pass
self._cache.close() self._cache.close()
if self.invalidator is not None:
self.invalidator.close()
self.invalidator = None
self.closed = 1
finally: self._lock_release()
def commitVersion(self, src, dest, transaction): def commitVersion(self, src, dest, transaction):
if self._is_read_only: if transaction is not self._transaction:
raise POSException.ReadOnlyError() raise POSException.StorageTransactionError(self, transaction)
self._check_trans(transaction, self._lock_acquire()
POSException.StorageTransactionError) try:
oids = self._server.commitVersion(src, dest, self._serial) oids=self._call('commitVersion', src, dest, self._serial)
if dest: if dest:
vlen = pack(">H", len(src))
# just invalidate our version data # just invalidate our version data
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, src) self._tfile.write("i%s%s%s" % (oid, vlen, src))
else: else:
vlen = pack(">H", len(dest))
# dest is '', so invalidate version and non-version # dest is '', so invalidate version and non-version
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, dest) self._tfile.write("i%s%s%s" % (oid, vlen, dest))
return oids return oids
finally: self._lock_release()
def getName(self):
return "%s (%s)" % (
self.__name__,
self._connected and 'connected' or 'disconnected')
def getSize(self): return self._info['size']
def history(self, oid, version, length=1): def history(self, oid, version, length=1):
return self._server.history(oid, version, length) self._lock_acquire()
try: return self._call('history', oid, version, length)
finally: self._lock_release()
def loadSerial(self, oid, serial): def loadSerial(self, oid, serial):
return self._server.loadSerial(oid, serial) self._lock_acquire()
try: return self._call('loadSerial', oid, serial)
finally: self._lock_release()
def load(self, oid, version, _stuff=None): def load(self, oid, version, _stuff=None):
p = self._cache.load(oid, version) self._lock_acquire()
if p: try:
return p cache=self._cache
if self._server is None: p = cache.load(oid, version)
raise ClientDisconnected() if p: return p
p, s, v, pv, sv = self._server.zeoLoad(oid) p, s, v, pv, sv = self._call('zeoLoad', oid)
self._cache.checkSize(0) cache.checkSize(0)
self._cache.store(oid, p, s, v, pv, sv) cache.store(oid, p, s, v, pv, sv)
if v and version and v == version: if not v or not version or version != v:
return pv, sv if s: return p, s
else:
if s:
return p, s
raise KeyError, oid # no non-version data for this raise KeyError, oid # no non-version data for this
return pv, sv
finally: self._lock_release()
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
v = self._cache.modifiedInVersion(oid) self._lock_acquire()
if v is not None: try:
return v v=self._cache.modifiedInVersion(oid)
return self._server.modifiedInVersion(oid) if v is not None: return v
return self._call('modifiedInVersion', oid)
finally: self._lock_release()
def new_oid(self, last=None): def new_oid(self, last=None):
if self._is_read_only: self._lock_acquire()
raise POSException.ReadOnlyError() try:
# We want to avoid a situation where multiple oid requests are oids=self._oids
# made at the same time. if not oids:
self.oid_cond.acquire() oids[:]=self._call('new_oids')
if not self._oids: oids.reverse()
self._oids = self._server.new_oids()
self._oids.reverse() return oids.pop()
self.oid_cond.notifyAll() finally: self._lock_release()
oid = self._oids.pop()
self.oid_cond.release()
return oid
def pack(self, t=None, rf=None, wait=0, days=0): def pack(self, t=None, rf=None, wait=0, days=0):
if self._is_read_only:
raise POSException.ReadOnlyError()
# Note that we ignore the rf argument. The server # Note that we ignore the rf argument. The server
# will provide it's own implementation. # will provide it's own implementation.
if t is None: if t is None: t=time.time()
t = time.time() t=t-(days*86400)
t = t - (days * 86400) self._lock_acquire()
return self._server.pack(t, wait) try: return self._call('pack', t, wait)
finally: self._lock_release()
def store(self, oid, serial, data, version, transaction):
if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction)
self._lock_acquire()
try:
serial=self._call.sendMessage('storea', oid, serial,
data, version, self._serial)
write=self._tfile.write
buf = string.join(("s", oid,
pack(">HI", len(version), len(data)),
version, data), "")
write(buf)
def _check_serials(self):
if self._serials: if self._serials:
l = len(self._serials) s=self._serials
r = self._serials[:l] l=len(s)
del self._serials[:l] r=s[:l]
for oid, s in r: del s[:l]
if isinstance(s, Exception): d=self._seriald
raise s for oid, s in r: d[oid]=s
self._seriald[oid] = s
return r return r
def store(self, oid, serial, data, version, transaction): return serial
if self._is_read_only:
raise POSException.ReadOnlyError() finally: self._lock_release()
self._check_trans(transaction, POSException.StorageTransactionError)
self._server.storea(oid, serial, data, version, self._serial)
self._tbuf.store(oid, version, data)
return self._check_serials()
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
self._lock_acquire()
try:
if transaction is not self._transaction: if transaction is not self._transaction:
return return
self._server.vote(self._serial) self._call('vote', self._serial)
return self._check_serials()
if self._serials:
s=self._serials
l=len(s)
r=s[:l]
del s[:l]
d=self._seriald
for oid, s in r: d[oid]=s
return r
finally: self._lock_release()
def supportsUndo(self):
return self._info['supportsUndo']
def supportsVersions(self):
return self._info['supportsVersions']
def supportsTransactionalUndo(self):
try:
return self._info['supportsTransactionalUndo']
except KeyError:
return 0
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
if transaction is not self._transaction: self._lock_acquire()
return try:
self._server.tpc_abort(self._serial) if transaction is not self._transaction: return
self._tbuf.clear() self._call('tpc_abort', self._serial)
self._transaction=None
self._tfile.seek(0)
self._seriald.clear() self._seriald.clear()
del self._serials[:] del self._serials[:]
self._transaction = None self._commit_lock_release()
self.tpc_cond.notify() finally: self._lock_release()
self.tpc_cond.release()
def tpc_begin(self, transaction): def tpc_begin(self, transaction):
self.tpc_cond.acquire() self._lock_acquire()
while self._transaction is not None: try:
if self._transaction == transaction: if self._transaction is transaction: return
self.tpc_cond.release()
return user=transaction.user
self.tpc_cond.wait() desc=transaction.description
ext=transaction._extension
if self._server is None: while 1:
self.tpc_cond.release() self._lock_release()
raise ClientDisconnected() self._commit_lock_acquire()
self._lock_acquire()
self._ts = get_timestamp(self._ts) # We've got the local commit lock. Now get
id = `self._ts` # a (tentative) transaction time stamp.
self._transaction = transaction t=time.time()
t=apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,)))
self._ts=t=t.laterThan(self._ts)
id=`t`
try: try:
r = self._server.tpc_begin(id, if not self._connected:
transaction.user, raise ClientDisconnected(
transaction.description, "This action is temporarily unavailable.<p>")
transaction._extension) r=self._call(self.__begin, id, user, desc, ext)
except: except:
# If _server is None, then the client disconnected during # XXX can't seem to guarantee that the lock is held here.
# the tpc_begin() and notifyDisconnected() will have self._commit_lock_release()
# released the lock.
if self._server is not disconnected_stub:
self.tpc_cond.release()
raise raise
self._serial = id if r is None: break
# We have *BOTH* the local and distributed commit
# lock, now we can actually get ready to get started.
self._serial=id
self._tfile.seek(0)
self._seriald.clear() self._seriald.clear()
del self._serials[:] del self._serials[:]
def tpc_finish(self, transaction, f=None): self._transaction=transaction
if transaction is not self._transaction:
return
if f is not None: # XXX what is f()?
f()
self._server.tpc_finish(self._serial) finally: self._lock_release()
r = self._check_serials() def tpc_finish(self, transaction, f=None):
assert r is None or len(r) == 0, "unhandled serialnos: %s" % r self._lock_acquire()
try:
self._update_cache() if transaction is not self._transaction: return
if f is not None: f()
self._transaction = None self._call('tpc_finish', self._serial,
self.tpc_cond.notify() transaction.user,
self.tpc_cond.release() transaction.description,
transaction._extension)
def _update_cache(self): seriald=self._seriald
# Iterate over the objects in the transaction buffer and if self._serials:
# update or invalidate the cache. s=self._serials
self._cache.checkSize(self._tbuf.get_size()) l=len(s)
self._tbuf.begin_iterate() r=s[:l]
while 1: del s[:l]
try: for oid, s in r: seriald[oid]=s
t = self._tbuf.next()
except ValueError, msg: tfile=self._tfile
seek=tfile.seek
read=tfile.read
cache=self._cache
size=tfile.tell()
cache.checkSize(size)
seek(0)
i=0
while i < size:
opcode=read(1)
if opcode == "s":
oid=read(8)
s=seriald[oid]
h=read(6)
vlen, dlen = unpack(">HI", h)
if vlen: v=read(vlen)
else: v=''
p=read(dlen)
if len(p) != dlen:
raise ClientStorageError, ( raise ClientStorageError, (
"Unexpected error reading temporary file in " "Unexpected end of file in client storage "
"client storage: %s" % msg) "temporary file."
if t is None: )
break if s==ResolvedSerial:
oid, v, p = t
if p is None: # an invalidation
s = None
else:
s = self._seriald[oid]
if s == ResolvedSerial or s is None:
self._cache.invalidate(oid, v) self._cache.invalidate(oid, v)
else: else:
self._cache.update(oid, s, v, p) self._cache.update(oid, s, v, p)
self._tbuf.clear() i=i+15+vlen+dlen
elif opcode == "i":
oid=read(8)
h=read(2)
vlen=unpack(">H", h)[0]
v=read(vlen)
self._cache.invalidate(oid, v)
i=i+11+vlen
seek(0)
self._transaction=None
self._commit_lock_release()
finally: self._lock_release()
def transactionalUndo(self, trans_id, trans): def transactionalUndo(self, trans_id, trans):
if self._is_read_only: self._lock_acquire()
raise POSException.ReadOnlyError() try:
self._check_trans(trans, POSException.StorageTransactionError) if trans is not self._transaction:
oids = self._server.transactionalUndo(trans_id, self._serial) raise POSException.StorageTransactionError(self, transaction)
oids = self._call('transactionalUndo', trans_id, self._serial)
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, '') # write invalidation records with no version
self._tfile.write("i%s\000\000" % oid)
return oids return oids
finally: self._lock_release()
def undo(self, transaction_id): def undo(self, transaction_id):
if self._is_read_only: self._lock_acquire()
raise POSException.ReadOnlyError() try:
# XXX what are the sync issues here? oids=self._call('undo', transaction_id)
oids = self._server.undo(transaction_id) cinvalidate=self._cache.invalidate
for oid in oids: for oid in oids:
self._cache.invalidate(oid, '') cinvalidate(oid,'')
return oids return oids
finally: self._lock_release()
def undoInfo(self, first=0, last=-20, specification=None): def undoInfo(self, first=0, last=-20, specification=None):
return self._server.undoInfo(first, last, specification) self._lock_acquire()
try:
return self._call('undoInfo', first, last, specification)
finally: self._lock_release()
def undoLog(self, first, last, filter=None): def undoLog(self, first, last, filter=None):
if filter is not None: if filter is not None: return ()
return () # XXX can't pass a filter to server
return self._server.undoLog(first, last) # Eek! self._lock_acquire()
try: return self._call('undoLog', first, last) # Eek!
finally: self._lock_release()
def versionEmpty(self, version): def versionEmpty(self, version):
return self._server.versionEmpty(version) self._lock_acquire()
try: return self._call('versionEmpty', version)
finally: self._lock_release()
def versions(self, max=None): def versions(self, max=None):
return self._server.versions(max) self._lock_acquire()
try: return self._call('versions', max)
# below are methods invoked by the StorageServer finally: self._lock_release()
def serialno(self, arg):
self._serials.append(arg)
def info(self, dict):
self._info.update(dict)
def begin(self): def sync(self): self._call.sync()
self._tfile = tempfile.TemporaryFile()
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
def invalidate(self, args):
if self._pickler is None:
return
self._pickler.dump(args)
def end(self):
if self._pickler is None:
return
self._pickler.dump((0,0))
## self._pickler.dump = None
self._tfile.seek(0)
unpick = cPickle.Unpickler(self._tfile)
self._tfile = None
while 1:
oid, version = unpick.load()
if not oid:
break
self._cache.invalidate(oid, version=version)
self._db.invalidate(oid, version=version)
def Invalidate(self, args):
# XXX _db could be None
for oid, version in args:
self._cache.invalidate(oid, version=version)
try:
self._db.invalidate(oid, version=version)
except AttributeError, msg:
log2(PROBLEM,
"Invalidate(%s, %s) failed for _db: %s" % (repr(oid),
repr(version),
msg))
def getWakeup(_w=[]):
if _w: return _w[0]
import trigger
t=trigger.trigger().pull_trigger
_w.append(t)
return t
"""Stub for interface exported by ClientStorage"""
class ClientStorage:
def __init__(self, rpc):
self.rpc = rpc
def beginVerify(self):
self.rpc.callAsync('begin')
# XXX what's the difference between these two?
def invalidate(self, args):
self.rpc.callAsync('invalidate', args)
def Invalidate(self, args):
self.rpc.callAsync('Invalidate', args)
def endVerify(self):
self.rpc.callAsync('end')
def serialno(self, arg):
self.rpc.callAsync('serialno', arg)
def info(self, arg):
self.rpc.callAsync('info', arg)
"""Exceptions for ZEO."""
class Disconnected(Exception):
"""Exception raised when a ZEO client is disconnected from the
ZEO server."""
"""Stub for interface exposed by StorageServer"""
class StorageServer:
def __init__(self, rpc):
self.rpc = rpc
def register(self, storage_name, read_only):
self.rpc.call('register', storage_name, read_only)
def get_info(self):
return self.rpc.call('get_info')
def get_size_info(self):
return self.rpc.call('get_size_info')
def beginZeoVerify(self):
self.rpc.callAsync('beginZeoVerify')
def zeoVerify(self, oid, s, sv):
self.rpc.callAsync('zeoVerify', oid, s, sv)
def endZeoVerify(self):
self.rpc.callAsync('endZeoVerify')
def new_oids(self, n=None):
if n is None:
return self.rpc.call('new_oids')
else:
return self.rpc.call('new_oids', n)
def pack(self, t, wait=None):
if wait is None:
self.rpc.call('pack', t)
else:
self.rpc.call('pack', t, wait)
def zeoLoad(self, oid):
return self.rpc.call('zeoLoad', oid)
def storea(self, oid, serial, data, version, id):
self.rpc.callAsync('storea', oid, serial, data, version, id)
def tpc_begin(self, id, user, descr, ext):
return self.rpc.call('tpc_begin', id, user, descr, ext)
def vote(self, trans_id):
return self.rpc.call('vote', trans_id)
def tpc_finish(self, id):
return self.rpc.call('tpc_finish', id)
def tpc_abort(self, id):
self.rpc.callAsync('tpc_abort', id)
def abortVersion(self, src, id):
return self.rpc.call('abortVersion', src, id)
def commitVersion(self, src, dest, id):
return self.rpc.call('commitVersion', src, dest, id)
def history(self, oid, version, length=None):
if length is not None:
return self.rpc.call('history', oid, version)
else:
return self.rpc.call('history', oid, version, length)
def load(self, oid, version):
return self.rpc.call('load', oid, version)
def loadSerial(self, oid, serial):
return self.rpc.call('loadSerial', oid, serial)
def modifiedInVersion(self, oid):
return self.rpc.call('modifiedInVersion', oid)
def new_oid(self, last=None):
if last is None:
return self.rpc.call('new_oid')
else:
return self.rpc.call('new_oid', last)
def store(self, oid, serial, data, version, trans):
return self.rpc.call('store', oid, serial, data, version, trans)
def transactionalUndo(self, trans_id, trans):
return self.rpc.call('transactionalUndo', trans_id, trans)
def undo(self, trans_id):
return self.rpc.call('undo', trans_id)
def undoLog(self, first, last):
# XXX filter not allowed across RPC
return self.rpc.call('undoLog', first, last)
def undoInfo(self, first, last, spec):
return self.rpc.call('undoInfo', first, last, spec)
def versionEmpty(self, vers):
return self.rpc.call('versionEmpty', vers)
def versions(self, max=None):
if max is None:
return self.rpc.call('versions')
else:
return self.rpc.call('versions', max)
############################################################################## #############################################################################
# #
# Zope Public License (ZPL) Version 1.0 # Zope Public License (ZPL) Version 1.0
# ------------------------------------- # -------------------------------------
...@@ -82,394 +82,527 @@ ...@@ -82,394 +82,527 @@
# attributions are listed in the accompanying credits file. # attributions are listed in the accompanying credits file.
# #
############################################################################## ##############################################################################
"""Network ZODB storage server
This server acts as a front-end for one or more real storages, like
file storage or Berkeley storage.
XXX Need some basic access control-- a declaration of the methods __version__ = "$Revision: 1.34 $"[11:-2]
exported for invocation by the server.
"""
import asyncore import asyncore, socket, string, sys, os
from smac import SizedMessageAsyncConnection
from ZODB import POSException
import cPickle import cPickle
import os from cPickle import Unpickler
import sys from ZODB.POSException import TransactionError, UndoError, VersionCommitError
import threading
import types
import ClientStub
import zrpc2
import zLOG
from zrpc2 import Dispatcher, Handler, ManagedServerConnection, Delay
from ZODB.POSException import StorageError, StorageTransactionError, \
TransactionError, ReadOnlyError
from ZODB.referencesf import referencesf
from ZODB.Transaction import Transaction from ZODB.Transaction import Transaction
import traceback
from zLOG import LOG, INFO, ERROR, TRACE, BLATHER
from ZODB.referencesf import referencesf
from thread import start_new_thread
from cStringIO import StringIO
from ZEO import trigger
from ZEO import asyncwrap
from types import StringType
class StorageServerError(POSException.StorageError): pass
max_blather=120
def blather(*args):
accum = []
total_len = 0
for arg in args:
if not isinstance(arg, StringType):
arg = str(arg)
accum.append(arg)
total_len = total_len + len(arg)
if total_len >= max_blather:
break
m = string.join(accum)
if len(m) > max_blather: m = m[:max_blather] + ' ...'
LOG('ZEO Server', TRACE, m)
# We create a special fast pickler! This allows us # We create a special fast pickler! This allows us
# to create slightly more efficient pickles and # to create slightly more efficient pickles and
# to create them a tad faster. # to create them a tad faster.
pickler = cPickle.Pickler() pickler=cPickle.Pickler()
pickler.fast = 1 # Don't use the memo pickler.fast=1 # Don't use the memo
dump = pickler.dump dump=pickler.dump
def log(message, level=zLOG.INFO, label="ZEO Server:%s" % os.getpid(), class StorageServer(asyncore.dispatcher):
error=None):
zLOG.LOG(label, level, message, error=error)
class StorageServerError(StorageError): def __init__(self, connection, storages):
pass
self.__storages=storages
for n, s in storages.items():
init_storage(s)
self.__connections={}
self.__get_connections=self.__connections.get
self._pack_trigger = trigger.trigger()
asyncore.dispatcher.__init__(self)
class StorageServer: if type(connection) is type(''):
def __init__(self, addr, storages, read_only=0): self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
# XXX should read_only be a per-storage option? not yet... try: os.unlink(connection)
self.addr = addr except: pass
self.storages = storages
self.read_only = read_only
self.connections = {}
for name, store in storages.items():
fixup_storage(store)
self.dispatcher = Dispatcher(addr, factory=self.newConnection,
reuse_addr=1)
def newConnection(self, sock, addr, nil):
c = ManagedServerConnection(sock, addr, None, self)
c.register_object(StorageProxy(self, c))
return c
def register(self, storage_id, proxy):
"""Register a connection's use with a particular storage.
This information is needed to handle invalidation.
"""
l = self.connections.get(storage_id)
if l is None:
l = self.connections[storage_id] = []
# intialize waiting list
self.storages[storage_id]._StorageProxy__waiting = []
l.append(proxy)
def invalidate(self, conn, storage_id, invalidated=(), info=0):
for p in self.connections[storage_id]:
if invalidated and p is not conn:
p.client.Invalidate(invalidated)
else: else:
p.client.info(info) self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.set_reuse_addr()
def close_server(self):
# Close the dispatcher so that there are no new connections. LOG('ZEO Server', INFO, 'Listening on %s' % repr(connection))
self.dispatcher.close() self.bind(connection)
for storage in self.storages.values(): self.listen(5)
storage.close()
# Force the asyncore mainloop to exit by hackery, i.e. close def register_connection(self, connection, storage_id):
# every socket in the map. loop() will return when the map is storage=self.__storages.get(storage_id, None)
# empty. if storage is None:
for s in asyncore.socket_map.values(): LOG('ZEO Server', ERROR, "Unknown storage_id: %s" % storage_id)
connection.close()
return None, None
connections=self.__get_connections(storage_id, None)
if connections is None:
self.__connections[storage_id]=connections=[]
connections.append(connection)
return storage, storage_id
def unregister_connection(self, connection, storage_id):
connections=self.__get_connections(storage_id, None)
if connections:
n=[]
for c in connections:
if c is not connection:
n.append(c)
self.__connections[storage_id]=n
def invalidate(self, connection, storage_id, invalidated=(), info=0,
dump=dump):
for c in self.__connections[storage_id]:
if invalidated and c is not connection:
c.message_output('I'+dump(invalidated, 1))
if info:
c.message_output('S'+dump(info, 1))
def writable(self): return 0
def handle_read(self): pass
def readable(self): return 1
def handle_connect (self): pass
def handle_accept(self):
try: try:
s.close() sock, addr = self.accept()
except socket.error:
sys.stderr.write('warning: accept failed\n')
else:
ZEOConnection(self, sock, addr)
def log_info(self, message, type='info'):
if type=='error': type=ERROR
else: type=INFO
LOG('ZEO Server', type, message)
log=log_info
storage_methods={}
for n in (
'get_info', 'abortVersion', 'commitVersion',
'history', 'load', 'loadSerial',
'modifiedInVersion', 'new_oid', 'new_oids', 'pack', 'store',
'storea', 'tpc_abort', 'tpc_begin', 'tpc_begin_sync',
'tpc_finish', 'undo', 'undoLog', 'undoInfo', 'versionEmpty', 'versions',
'transactionalUndo',
'vote', 'zeoLoad', 'zeoVerify', 'beginZeoVerify', 'endZeoVerify',
):
storage_methods[n]=1
storage_method=storage_methods.has_key
def find_global(module, name,
global_dict=globals(), silly=('__doc__',)):
try: m=__import__(module, global_dict, global_dict, silly)
except: except:
pass raise StorageServerError, (
"Couldn\'t import global module %s" % module)
def close(self, conn): try: r=getattr(m, name)
# XXX who calls this? except:
# called when conn is closed raise StorageServerError, (
# way too inefficient "Couldn\'t find global %s in module %s" % (name, module))
removed = 0
for sid, cl in self.connections.items(): safe=getattr(r, '__no_side_effects__', 0)
if conn.obj in cl: if safe: return r
cl.remove(conn.obj)
removed = 1 raise StorageServerError, 'Unsafe global, %s.%s' % (module, name)
class StorageProxy(Handler): _noreturn=[]
def __init__(self, server, conn): class ZEOConnection(SizedMessageAsyncConnection):
self.server = server
self.client = ClientStub.ClientStorage(conn) _transaction=None
self.__storage = None __storage=__storage_id=None
self.__invalidated = []
self._transaction = None def __init__(self, server, sock, addr):
self.__server=server
def __repr__(self): self.__invalidated=[]
tid = self._transaction and repr(self._transaction.id) self.__closed=None
if self.__storage: if __debug__: debug='ZEO Server'
stid = self.__storage._transaction and \ else: debug=0
repr(self.__storage._transaction.id) SizedMessageAsyncConnection.__init__(self, sock, addr, debug=debug)
LOG('ZEO Server', INFO, 'Connect %s %s' % (id(self), `addr`))
def close(self):
t=self._transaction
if (t is not None and self.__storage is not None and
self.__storage._transaction is t):
self.tpc_abort(t.id)
else: else:
stid = None self._transaction=None
return "<StorageProxy %X trans=%s s_trans=%s>" % (id(self), tid, self.__invalidated=[]
stid)
self.__server.unregister_connection(self, self.__storage_id)
def _log(self, msg, level=zLOG.INFO, error=None, pid=os.getpid()): self.__closed=1
zLOG.LOG("ZEO Server %s %X" % (pid, id(self)), SizedMessageAsyncConnection.close(self)
level, msg, error=error) LOG('ZEO Server', INFO, 'Close %s' % id(self))
def setup_delegation(self): def message_input(self, message,
"""Delegate several methods to the storage""" dump=dump, Unpickler=Unpickler, StringIO=StringIO,
self.undoInfo = self.__storage.undoInfo None=None):
self.undoLog = self.__storage.undoLog if __debug__:
self.versionEmpty = self.__storage.versionEmpty if len(message) > max_blather:
self.versions = self.__storage.versions tmp = `message[:max_blather]`
self.history = self.__storage.history
self.load = self.__storage.load
self.loadSerial = self.__storage.loadSerial
def _check_tid(self, tid, exc=None):
caller = sys._getframe().f_back.f_code.co_name
if self._transaction is None:
self._log("no current transaction: %s()" % caller,
zLOG.PROBLEM)
if exc is not None:
raise exc(None, tid)
else: else:
return 0 tmp = `message`
if self._transaction.id != tid: blather('message_input', id(self), tmp)
self._log("%s(%s) invalid; current transaction = %s" % \
(caller, repr(tid), repr(self._transaction.id)), if self.__storage is None:
zLOG.PROBLEM) # This is the first communication from the client
if exc is not None: self.__storage, self.__storage_id = (
raise exc(self._transaction.id, tid) self.__server.register_connection(self, message))
# Send info back asynchronously, so client need not ask
self.message_output('S'+dump(self.get_info(), 1))
return
try:
# Unpickle carefully.
unpickler=Unpickler(StringIO(message))
unpickler.find_global=find_global
args=unpickler.load()
name, args = args[0], args[1:]
if __debug__:
apply(blather,
("call", id(self), ":", name,) + args)
if not storage_method(name):
raise 'Invalid Method Name', name
if hasattr(self, name):
r=apply(getattr(self, name), args)
else: else:
return 0 r=apply(getattr(self.__storage, name), args)
return 1 if r is _noreturn: return
except (UndoError, VersionCommitError):
# These are normal usage errors. No need to leg them
self.return_error(sys.exc_info()[0], sys.exc_info()[1])
return
except:
LOG('ZEO Server', ERROR, 'error', error=sys.exc_info())
self.return_error(sys.exc_info()[0], sys.exc_info()[1])
return
def register(self, storage_id, read_only): if __debug__:
"""Select the storage that this client will use blather("%s R: %s" % (id(self), `r`))
This method must be the first one called by the client. r=dump(r,1)
""" self.message_output('R'+r)
storage = self.server.storages.get(storage_id)
if storage is None: def return_error(self, err_type, err_value, type=type, dump=dump):
self._log("unknown storage_id: %s" % storage_id) if type(err_value) is not type(self):
raise ValueError, "unknown storage: %s" % storage_id err_value = err_type, err_value
if __debug__:
blather("%s E: %s" % (id(self), `err_value`))
try: r=dump(err_value, 1)
except:
# Ugh, must be an unpicklable exception
r=StorageServerError("Couldn't pickle error %s" % `r`)
dump('',1) # clear pickler
r=dump(r,1)
if not read_only and (self.server.read_only or storage.isReadOnly()): self.message_output('E'+r)
raise ReadOnlyError()
self.__storage_id = storage_id
self.__storage = storage
self.setup_delegation()
self.server.register(storage_id, self)
self._log("registered storage %s: %s" % (storage_id, storage))
def get_info(self): def get_info(self):
return {'length': len(self.__storage), storage=self.__storage
'size': self.__storage.getSize(), info = {
'name': self.__storage.getName(), 'length': len(storage),
'supportsUndo': self.__storage.supportsUndo(), 'size': storage.getSize(),
'supportsVersions': self.__storage.supportsVersions(), 'name': storage.getName(),
'supportsTransactionalUndo':
self.__storage.supportsTransactionalUndo(),
} }
for feature in ('supportsUndo',
'supportsVersions',
'supportsTransactionalUndo',):
if hasattr(storage, feature):
info[feature] = getattr(storage, feature)()
else:
info[feature] = 0
return info
def get_size_info(self): def get_size_info(self):
return {'length': len(self.__storage), storage=self.__storage
'size': self.__storage.getSize(), return {
'length': len(storage),
'size': storage.getSize(),
} }
def zeoLoad(self, oid): def zeoLoad(self, oid):
v = self.__storage.modifiedInVersion(oid) storage=self.__storage
if v: v=storage.modifiedInVersion(oid)
pv, sv = self.__storage.load(oid, v) if v: pv, sv = storage.load(oid, v)
else: else: pv=sv=None
pv = sv = None
try: try:
p, s = self.__storage.load(oid, '') p, s = storage.load(oid,'')
except KeyError: except KeyError:
if sv: if sv:
# Created in version, no non-version data # Created in version, no non-version data
p = s = None p=s=None
else: else:
raise raise
return p, s, v, pv, sv return p, s, v, pv, sv
def beginZeoVerify(self):
self.client.beginVerify()
def zeoVerify(self, oid, s, sv): def beginZeoVerify(self):
try: self.message_output('bN.')
p, os, v, pv, osv = self.zeoLoad(oid) return _noreturn
except: # except what?
return None def zeoVerify(self, oid, s, sv,
dump=dump):
try: p, os, v, pv, osv = self.zeoLoad(oid)
except: return _noreturn
p=pv=None # free the pickles
if os != s: if os != s:
self.client.invalidate((oid, '')) self.message_output('i'+dump((oid, ''),1))
elif osv != sv: elif osv != sv:
self.client.invalidate((oid, v)) self.message_output('i'+dump((oid, v),1))
return _noreturn
def endZeoVerify(self): def endZeoVerify(self):
self.client.endVerify() self.message_output('eN.')
return _noreturn
def new_oids(self, n=100):
new_oid=self.__storage.new_oid
if n < 0: n=1
r=range(n)
for i in r: r[i]=new_oid()
return r
def pack(self, t, wait=0): def pack(self, t, wait=0):
t = threading.Thread(target=self._pack, args=(t, wait)) start_new_thread(self._pack, (t,wait))
t.start() if wait: return _noreturn
def _pack(self, t, wait=0): def _pack(self, t, wait=0):
try: try:
LOG('ZEO Server', BLATHER, 'pack begin')
self.__storage.pack(t, referencesf) self.__storage.pack(t, referencesf)
LOG('ZEO Server', BLATHER, 'pack end')
except: except:
self._log('ZEO Server', zLOG.ERROR, LOG('ZEO Server', ERROR,
'Pack failed for %s' % self.__storage_id, 'Pack failed for %s' % self.__storage_id,
error=sys.exc_info()) error=sys.exc_info())
if wait: if wait:
raise self.return_error(sys.exc_info()[0], sys.exc_info()[1])
self.__server._pack_trigger.pull_trigger()
else:
if wait:
self.message_output('RN.')
self.__server._pack_trigger.pull_trigger()
else: else:
if not wait:
# Broadcast new size statistics # Broadcast new size statistics
self.server.invalidate(0, self.__storage_id, (), self.__server.invalidate(0, self.__storage_id, (),
self.get_size_info()) self.get_size_info())
def abortVersion(self, src, id): def abortVersion(self, src, id):
self._check_tid(id, exc=StorageTransactionError) t=self._transaction
oids = self.__storage.abortVersion(src, self._transaction) if t is None or id != t.id:
for oid in oids: raise POSException.StorageTransactionError(self, id)
self.__invalidated.append((oid, src)) oids=self.__storage.abortVersion(src, t)
a=self.__invalidated.append
for oid in oids: a((oid,src))
return oids return oids
def commitVersion(self, src, dest, id): def commitVersion(self, src, dest, id):
self._check_tid(id, exc=StorageTransactionError) t=self._transaction
oids = self.__storage.commitVersion(src, dest, self._transaction) if t is None or id != t.id:
raise POSException.StorageTransactionError(self, id)
oids=self.__storage.commitVersion(src, dest, t)
a=self.__invalidated.append
for oid in oids: for oid in oids:
self.__invalidated.append((oid, dest)) a((oid,dest))
if dest: if dest: a((oid,src))
self.__invalidated.append((oid, src))
return oids return oids
def storea(self, oid, serial, data, version, id): def storea(self, oid, serial, data, version, id,
self._check_tid(id, exc=StorageTransactionError) dump=dump):
try: try:
# XXX does this stmt need to be in the try/except? t=self._transaction
if t is None or id != t.id:
raise POSException.StorageTransactionError(self, id)
newserial = self.__storage.store(oid, serial, data, version, newserial=self.__storage.store(oid, serial, data, version, t)
self._transaction)
except TransactionError, v: except TransactionError, v:
# This is a normal transaction error such as a conflict error # This is a normal transaction errorm such as a conflict error
# or a version lock or conflict error. It doesn't need to be # or a version lock or conflict error. It doen't need to be
# logged. # logged.
self._log("transaction error: %s" % repr(v)) newserial=v
newserial = v
except: except:
# all errors need to be serialized to prevent unexpected # all errors need to be serialized to prevent unexpected
# returns, which would screw up the return handling. # returns, which would screw up the return handling.
# IOW, Anything that ends up here is evil enough to be logged. # IOW, Anything that ends up here is evil enough to be logged.
error = sys.exc_info() LOG('ZEO Server', ERROR, 'store error', error=sys.exc_info())
self._log('store error: %s: %s' % (error[0], error[1]), newserial=sys.exc_info()[1]
zLOG.ERROR, error=error)
newserial = sys.exc_info()[1]
else: else:
if serial != '\0\0\0\0\0\0\0\0': if serial != '\0\0\0\0\0\0\0\0':
self.__invalidated.append((oid, version)) self.__invalidated.append((oid, version))
try: try: r=dump((oid,newserial), 1)
nil = dump(newserial, 1)
except: except:
self._log("couldn't pickle newserial: %s" % repr(newserial), # We got a pickling error, must be because the
zLOG.ERROR) # newserial is an unpicklable exception.
dump('', 1) # clear pickler r=StorageServerError("Couldn't pickle exception %s" % `newserial`)
r = StorageServerError("Couldn't pickle exception %s" % \ dump('',1) # clear pickler
`newserial`) r=dump((oid, r),1)
newserial = r
self.client.serialno((oid, newserial)) self.message_output('s'+r)
return _noreturn
def vote(self, id): def vote(self, id):
self._check_tid(id, exc=StorageTransactionError) t=self._transaction
self.__storage.tpc_vote(self._transaction) if t is None or id != t.id:
raise POSException.StorageTransactionError(self, id)
return self.__storage.tpc_vote(t)
def transactionalUndo(self, trans_id, id): def transactionalUndo(self, trans_id, id):
self._check_tid(id, exc=StorageTransactionError) t=self._transaction
if t is None or id != t.id:
raise POSException.StorageTransactionError(self, id)
return self.__storage.transactionalUndo(trans_id, self._transaction) return self.__storage.transactionalUndo(trans_id, self._transaction)
def undo(self, transaction_id): def undo(self, transaction_id):
oids = self.__storage.undo(transaction_id) oids=self.__storage.undo(transaction_id)
if oids: if oids:
self.server.invalidate(self, self.__storage_id, self.__server.invalidate(
map(lambda oid: (oid, None, ''), oids)) self, self.__storage_id, map(lambda oid: (oid,None), oids)
)
return oids return oids
return () return ()
# When multiple clients are using a single storage, there are several def tpc_abort(self, id):
# different _transaction attributes to keep track of. Each t=self._transaction
# StorageProxy object has a single _transaction that refers to its if t is None or id != t.id: return
# current transaction. The storage (self.__storage) has another r=self.__storage.tpc_abort(t)
# _transaction that is used for the *real* transaction.
storage=self.__storage
try: waiting=storage.__waiting
except: waiting=storage.__waiting=[]
while waiting:
f, args = waiting.pop(0)
if apply(f,args): break
self._transaction=None
self.__invalidated=[]
# The real trick comes with the __waiting queue for a storage. def unlock(self):
# When a StorageProxy pulls a new transaction from the queue, it if self.__closed: return
# must inform the new transaction's proxy. (The two proxies may self.message_output('UN.')
# be the same.) The new transaction's proxy sets its _transaction
# and continues from there.
def tpc_begin(self, id, user, description, ext): def tpc_begin(self, id, user, description, ext):
if self._transaction is not None: t=self._transaction
if self._transaction.id == id: if t is not None:
self._log("duplicate tpc_begin(%s)" % repr(id)) if id == t.id: return
return
else: else:
raise StorageTransactionError("Multiple simultaneous tpc_begin" raise StorageServerError(
" requests from one client.") "Multiple simultaneous tpc_begin requests from the same "
"client."
t = Transaction() )
t.id = id storage=self.__storage
t.user = user if storage._transaction is not None:
t.description = description try: waiting=storage.__waiting
t._extension = ext except: waiting=storage.__waiting=[]
waiting.append((self.unlock, ()))
if self.__storage._transaction is not None: return 1 # Return a flag indicating a lock condition.
d = zrpc2.Delay()
self.__storage.__waiting.append((d, self, t)) self._transaction=t=Transaction()
return d t.id=id
t.user=user
self._transaction = t t.description=description
self.__storage.tpc_begin(t) t._extension=ext
self.__invalidated = [] storage.tpc_begin(t)
self.__invalidated=[]
def tpc_finish(self, id):
if not self._check_tid(id): def tpc_begin_sync(self, id, user, description, ext):
return if self.__closed: return
t=self._transaction
if t is not None and id == t.id: return
storage=self.__storage
if storage._transaction is None:
self.try_again_sync(id, user, description, ext)
else:
try: waiting=storage.__waiting
except: waiting=storage.__waiting=[]
waiting.append((self.try_again_sync, (id, user, description, ext)))
return _noreturn
def try_again_sync(self, id, user, description, ext):
storage=self.__storage
if storage._transaction is None:
self._transaction=t=Transaction()
t.id=id
t.user=user
t.description=description
storage.tpc_begin(t)
self.__invalidated=[]
self.message_output('RN.')
return 1
def tpc_finish(self, id, user, description, ext):
t=self._transaction
if id != t.id: return
r = self.__storage.tpc_finish(self._transaction) storage=self.__storage
assert self.__storage._transaction is None r=storage.tpc_finish(t)
try: waiting=storage.__waiting
except: waiting=storage.__waiting=[]
while waiting:
f, args = waiting.pop(0)
if apply(f,args): break
self._transaction=None
if self.__invalidated: if self.__invalidated:
self.server.invalidate(self, self.__storage_id, self.__server.invalidate(self, self.__storage_id,
self.__invalidated, self.__invalidated,
self.get_size_info()) self.get_size_info())
self.__invalidated=[]
if not self._handle_waiting(): def init_storage(storage):
self._transaction = None if not hasattr(storage,'tpc_vote'): storage.tpc_vote=lambda *args: None
self.__invalidated = []
def tpc_abort(self, id): if __name__=='__main__':
if not self._check_tid(id): import ZODB.FileStorage
return name, port = sys.argv[1:3]
r = self.__storage.tpc_abort(self._transaction) blather(name, port)
assert self.__storage._transaction is None try:
port='', int(port)
if not self._handle_waiting(): except:
self._transaction = None pass
self.__invalidated = []
def _restart_delayed_transaction(self, delay, trans):
self._transaction = trans
self.__storage.tpc_begin(trans)
self.__invalidated = []
assert self._transaction.id == self.__storage._transaction.id
delay.reply(None)
def _handle_waiting(self):
if self.__storage.__waiting:
delay, proxy, trans = self.__storage.__waiting.pop(0)
proxy._restart_delayed_transaction(delay, trans)
if self is proxy:
return 1
def new_oids(self, n=100): d = {'1': ZODB.FileStorage.FileStorage(name)}
"""Return a sequence of n new oids, where n defaults to 100""" StorageServer(port, d)
if n < 0: asyncwrap.loop()
n = 1
return [self.__storage.new_oid() for i in range(n)]
def fixup_storage(storage):
# backwards compatibility hack
if not hasattr(storage,'tpc_vote'):
storage.tpc_vote = lambda *args: None
"""A TransactionBuffer store transaction updates until commit or abort.
A transaction may generate enough data that it is not practical to
always hold pending updates in memory. Instead, a TransactionBuffer
is used to store the data until a commit or abort.
"""
# XXX Figure out what a sensible storage format is
# XXX A faster implementation might store trans data in memory until
# it reaches a certain size.
import tempfile
import cPickle
class TransactionBuffer:
def __init__(self):
self.file = tempfile.TemporaryFile()
self.count = 0
self.size = 0
# It's safe to use a fast pickler because the only objects
# stored are builtin types -- strings or None.
self.pickler = cPickle.Pickler(self.file, 1)
self.pickler.fast = 1
def store(self, oid, version, data):
"""Store oid, version, data for later retrieval"""
self.pickler.dump((oid, version, data))
self.count += 1
# Estimate per-record cache size
self.size = self.size + len(data) + (27 + 12)
if version:
self.size = self.size + len(version) + 4
def invalidate(self, oid, version):
self.pickler.dump((oid, version, None))
self.count += 1
def clear(self):
"""Mark the buffer as empty"""
self.file.seek(0)
self.count = 0
self.size = 0
# XXX unchecked constraints:
# 1. can't call store() after begin_iterate()
# 2. must call clear() after iteration finishes
def begin_iterate(self):
"""Move the file pointer in advance of iteration"""
self.file.flush()
self.file.seek(0)
self.unpickler = cPickle.Unpickler(self.file)
def next(self):
"""Return next tuple of data or None if EOF"""
if self.count == 0:
del self.unpickler
return None
oid_ver_data = self.unpickler.load()
self.count -= 1
return oid_ver_data
def get_size(self):
"""Return size of data stored in buffer (just a hint)."""
return self.size
...@@ -85,14 +85,11 @@ ...@@ -85,14 +85,11 @@
"""Sized message async connections """Sized message async connections
""" """
__version__ = "$Revision: 1.12 $"[11:-2] __version__ = "$Revision: 1.13 $"[11:-2]
import asyncore, struct
from Exceptions import Disconnected
from zLOG import LOG, TRACE, ERROR, INFO, BLATHER
from types import StringType
import asyncore, string, struct, zLOG, sys, Acquisition
import socket, errno import socket, errno
from zLOG import LOG, TRACE, ERROR, INFO
# Use the dictionary to make sure we get the minimum number of errno # Use the dictionary to make sure we get the minimum number of errno
# entries. We expect that EWOULDBLOCK == EAGAIN on most systems -- # entries. We expect that EWOULDBLOCK == EAGAIN on most systems --
...@@ -112,101 +109,81 @@ tmp_dict = {errno.EAGAIN: 0, ...@@ -112,101 +109,81 @@ tmp_dict = {errno.EAGAIN: 0,
expected_socket_write_errors = tuple(tmp_dict.keys()) expected_socket_write_errors = tuple(tmp_dict.keys())
del tmp_dict del tmp_dict
class SizedMessageAsyncConnection(asyncore.dispatcher): class SizedMessageAsyncConnection(Acquisition.Explicit, asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close
__closed = 1 # Marker indicating that we're closed
socket = None # to outwit Sam's getattr __append=None # Marker indicating that we're closed
READ_SIZE = 8096 socket=None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None, debug=None): def __init__(self, sock, addr, map=None, debug=None):
self.__super_init(sock, map) SizedMessageAsyncConnection.inheritedAttribute(
self.addr = addr '__init__')(self, sock, map)
self.addr=addr
if debug is not None: if debug is not None:
self._debug = debug self._debug=debug
elif not hasattr(self, '_debug'): elif not hasattr(self, '_debug'):
self._debug = __debug__ and 'smac' self._debug=__debug__ and 'smac'
self.__state = None self.__state=None
self.__inp = None # None, a single String, or a list self.__inp=None
self.__input_len = 0 self.__inpl=0
self.__msg_size = 4 self.__l=4
self.__output = [] self.__output=output=[]
self.__closed = None self.__append=output.append
self.__pop=output.pop
# XXX avoid expensive getattr calls?
def __nonzero__(self): def handle_read(self,
return 1 join=string.join, StringType=type(''), _type=type,
_None=None):
def handle_read(self):
# Use a single __inp buffer and integer indexes to make this
# fast.
try: try:
d=self.recv(8096) d=self.recv(8096)
except socket.error, err: except socket.error, err:
if err[0] in expected_socket_read_errors: if err[0] in expected_socket_read_errors:
return return
raise raise
if not d: if not d: return
return
input_len = self.__input_len + len(d)
msg_size = self.__msg_size
state = self.__state
inp = self.__inp inp=self.__inp
if msg_size > input_len: if inp is _None:
if inp is None: inp=d
self.__inp = d elif _type(inp) is StringType:
elif type(self.__inp) is StringType: inp=[inp,d]
self.__inp = [self.__inp, d]
else:
self.__inp.append(d)
self.__input_len = input_len
return # keep waiting for more input
# load all previous input and d into single string inp
if isinstance(inp, StringType):
inp = inp + d
elif inp is None:
inp = d
else: else:
inp.append(d) inp.append(d)
inp = "".join(inp)
offset = 0 inpl=self.__inpl+len(d)
while (offset + msg_size) <= input_len: l=self.__l
msg = inp[offset:offset + msg_size]
offset = offset + msg_size while 1:
if state is None:
if l <= inpl:
# Woo hoo, we have enough data
if _type(inp) is not StringType: inp=join(inp,'')
d=inp[:l]
inp=inp[l:]
inpl=inpl-l
if self.__state is _None:
# waiting for message # waiting for message
msg_size = struct.unpack(">i", msg)[0] l=struct.unpack(">i",d)[0]
state = 1 self.__state=1
else: else:
msg_size = 4 l=4
state = None self.__state=_None
self.message_input(msg) self.message_input(d)
else:
self.__state = state break # not enough data
self.__msg_size = msg_size
self.__inp = inp[offset:]
self.__input_len = input_len - offset
def readable(self): self.__l=l
return 1 self.__inp=inp
self.__inpl=inpl
def writable(self): def readable(self): return 1
if len(self.__output) == 0: def writable(self): return not not self.__output
return 0
else:
return 1
def handle_write(self): def handle_write(self):
output = self.__output output=self.__output
while output: while output:
v = output[0] v=output[0]
try: try:
n=self.send(v) n=self.send(v)
except socket.error, err: except socket.error, err:
...@@ -214,33 +191,42 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -214,33 +191,42 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
break # we couldn't write anything break # we couldn't write anything
raise raise
if n < len(v): if n < len(v):
output[0] = v[n:] output[0]=v[n:]
break # we can't write any more break # we can't write any more
else: else:
del output[0] del output[0]
#break # waaa
def handle_close(self): def handle_close(self):
self.close() self.close()
def message_output(self, message): def message_output(self, message,
if __debug__: pack=struct.pack, len=len):
if self._debug: if self._debug:
if len(message) > 40: if len(message) > 40: m=message[:40]+' ...'
m = message[:40]+' ...' else: m=message
else:
m = message
LOG(self._debug, TRACE, 'message_output %s' % `m`) LOG(self._debug, TRACE, 'message_output %s' % `m`)
if self.__closed is not None: append=self.__append
raise Disconnected, ( if append is None:
"This action is temporarily unavailable." raise Disconnected("This action is temporarily unavailable.<p>")
"<p>"
) append(pack(">i",len(message))+message)
# do two separate appends to avoid copying the message string
self.__output.append(struct.pack(">i", len(message))) def log_info(self, message, type='info'):
self.__output.append(message) if type=='error': type=ERROR
else: type=INFO
LOG('ZEO', type, message)
log=log_info
def close(self): def close(self):
if self.__closed is None: if self.__append is not None:
self.__closed = 1 self.__append=None
self.__super_close() SizedMessageAsyncConnection.inheritedAttribute('close')(self)
class Disconnected(Exception):
"""The client has become disconnected from the server
"""
...@@ -86,13 +86,10 @@ ...@@ -86,13 +86,10 @@
"""Start the server storage. """Start the server storage.
""" """
__version__ = "$Revision: 1.27 $"[11:-2] __version__ = "$Revision: 1.28 $"[11:-2]
import sys, os, getopt, string import sys, os, getopt, string
import StorageServer
import asyncore
def directory(p, n=1): def directory(p, n=1):
d=p d=p
while n: while n:
...@@ -118,11 +115,9 @@ def get_storage(m, n, cache={}): ...@@ -118,11 +115,9 @@ def get_storage(m, n, cache={}):
def main(argv): def main(argv):
me=argv[0] me=argv[0]
sys.path[:]==filter(None, sys.path)
sys.path.insert(0, directory(me, 2)) sys.path.insert(0, directory(me, 2))
# XXX hack for profiling support
global unix, storages, zeo_pid, asyncore
args=[] args=[]
last='' last=''
for a in argv[1:]: for a in argv[1:]:
...@@ -135,13 +130,25 @@ def main(argv): ...@@ -135,13 +130,25 @@ def main(argv):
args.append(a) args.append(a)
last=a last=a
INSTANCE_HOME=os.environ.get('INSTANCE_HOME', directory(me, 4)) if os.environ.has_key('INSTANCE_HOME'):
INSTANCE_HOME=os.environ['INSTANCE_HOME']
elif os.path.isdir(os.path.join(directory(me, 4),'var')):
INSTANCE_HOME=directory(me, 4)
else:
INSTANCE_HOME=os.getcwd()
if os.path.isdir(os.path.join(INSTANCE_HOME, 'var')):
var=os.path.join(INSTANCE_HOME, 'var')
else:
var=INSTANCE_HOME
zeo_pid=os.environ.get('ZEO_SERVER_PID', zeo_pid=os.environ.get('ZEO_SERVER_PID',
os.path.join(INSTANCE_HOME, 'var', 'ZEO_SERVER.pid') os.path.join(var, 'ZEO_SERVER.pid')
) )
fs=os.path.join(INSTANCE_HOME, 'var', 'Data.fs') opts, args = getopt.getopt(args, 'p:Ddh:U:sS:u:')
fs=os.path.join(var, 'Data.fs')
usage="""%s [options] [filename] usage="""%s [options] [filename]
...@@ -149,14 +156,17 @@ def main(argv): ...@@ -149,14 +156,17 @@ def main(argv):
-D -- Run in debug mode -D -- Run in debug mode
-d -- Generate detailed debug logging without running
in the foreground.
-U -- Unix-domain socket file to listen on -U -- Unix-domain socket file to listen on
-u username or uid number -u username or uid number
The username to run the ZEO server as. You may want to run The username to run the ZEO server as. You may want to run
the ZEO server as 'nobody' or some other user with limited the ZEO server as 'nobody' or some other user with limited
resouces. The only works under Unix, and if ZServer is resouces. The only works under Unix, and if the storage
started by root. server is started by root.
-p port -- port to listen on -p port -- port to listen on
...@@ -179,42 +189,23 @@ def main(argv): ...@@ -179,42 +189,23 @@ def main(argv):
attr_name -- This is the name to which the storage object attr_name -- This is the name to which the storage object
is assigned in the module. is assigned in the module.
-P file -- Run under profile and dump output to file. Implies the
-s flag.
if no file name is specified, then %s is used. if no file name is specified, then %s is used.
""" % (me, fs) """ % (me, fs)
try:
opts, args = getopt.getopt(args, 'p:Dh:U:sS:u:P:')
except getopt.error, msg:
print usage
print msg
sys.exit(1)
port=None port=None
debug=0 debug=detailed=0
host='' host=''
unix=None unix=None
Z=1 Z=1
UID='nobody' UID='nobody'
prof = None
for o, v in opts: for o, v in opts:
if o=='-p': port=string.atoi(v) if o=='-p': port=string.atoi(v)
elif o=='-h': host=v elif o=='-h': host=v
elif o=='-U': unix=v elif o=='-U': unix=v
elif o=='-u': UID=v elif o=='-u': UID=v
elif o=='-D': debug=1 elif o=='-D': debug=1
elif o=='-d': detailed=1
elif o=='-s': Z=0 elif o=='-s': Z=0
elif o=='-P': prof = v
if prof:
Z = 0
try:
from ZServer.medusa import asyncore
sys.modules['asyncore']=asyncore
except: pass
if port is None and unix is None: if port is None and unix is None:
print usage print usage
...@@ -228,9 +219,10 @@ def main(argv): ...@@ -228,9 +219,10 @@ def main(argv):
sys.exit(1) sys.exit(1)
fs=args[0] fs=args[0]
__builtins__.__debug__=debug
if debug: os.environ['Z_DEBUG_MODE']='1' if debug: os.environ['Z_DEBUG_MODE']='1'
if detailed: os.environ['STUPID_LOG_SEVERITY']='-99999'
from zLOG import LOG, INFO, ERROR from zLOG import LOG, INFO, ERROR
# Try to set uid to "-u" -provided uid. # Try to set uid to "-u" -provided uid.
...@@ -271,6 +263,10 @@ def main(argv): ...@@ -271,6 +263,10 @@ def main(argv):
import zdaemon import zdaemon
zdaemon.run(sys.argv, '') zdaemon.run(sys.argv, '')
try:
import ZEO.StorageServer, asyncore
storages={} storages={}
for o, v in opts: for o, v in opts:
if o=='-S': if o=='-S':
...@@ -297,9 +293,10 @@ def main(argv): ...@@ -297,9 +293,10 @@ def main(argv):
signal.signal(signal.SIGINT, signal.signal(signal.SIGINT,
lambda sig, frame, s=storages: shutdown(s, 0) lambda sig, frame, s=storages: shutdown(s, 0)
) )
signal.signal(signal.SIGHUP, rotate_logs_handler) try: signal.signal(signal.SIGHUP, rotate_logs_handler)
except: pass
finally: pass except: pass
items=storages.items() items=storages.items()
items.sort() items.sort()
...@@ -308,25 +305,40 @@ def main(argv): ...@@ -308,25 +305,40 @@ def main(argv):
if not unix: unix=host, port if not unix: unix=host, port
if prof: ZEO.StorageServer.StorageServer(unix, storages)
cmds = \
"StorageServer.StorageServer(unix, storages);" \ try: ppid, pid = os.getppid(), os.getpid()
'open(zeo_pid,"w").write("%s %s" % (os.getppid(), os.getpid()));' \ except: pass # getpid not supported
"asyncore.loop()" else: open(zeo_pid,'w').write("%s %s" % (ppid, pid))
import profile
profile.run(cmds, prof) except:
else: # Log startup exception and tell zdaemon not to restart us.
StorageServer.StorageServer(unix, storages) info=sys.exc_info()
open(zeo_pid,'w').write("%s %s" % (os.getppid(), os.getpid())) try:
import zLOG
zLOG.LOG("z2", zLOG.PANIC, "Startup exception",
error=info)
except:
pass
import traceback
apply(traceback.print_exception, info)
sys.exit(0)
asyncore.loop() asyncore.loop()
def rotate_logs(): def rotate_logs():
import zLOG import zLOG
if hasattr(zLOG.log_write, 'reinitialize'): if hasattr(zLOG.log_write, 'reinitialize'):
zLOG.log_write.reinitialize() zLOG.log_write.reinitialize()
else: else:
# Hm, lets at least try to take care of the stupid logger: # Hm, lets at least try to take care of the stupid logger:
zLOG._stupid_dest=None if hasattr(zLOG, '_set_stupid_dest'):
zLOG._set_stupid_dest(None)
else:
zLOG._stupid_dest = None
def rotate_logs_handler(signum, frame): def rotate_logs_handler(signum, frame):
rotate_logs() rotate_logs()
...@@ -347,7 +359,7 @@ def shutdown(storages, die=1): ...@@ -347,7 +359,7 @@ def shutdown(storages, die=1):
for storage in storages.values(): for storage in storages.values():
try: storage.close() try: storage.close()
finally: pass except: pass
try: try:
from zLOG import LOG, INFO from zLOG import LOG, INFO
......
# Copyright (c) 2001 Zope Corporation and Contributors. All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 1.1 (ZPL). A copy of the ZPL should accompany this
# distribution. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL
# EXPRESS OR IMPLIED WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST
# INFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
"""Library for forking storage server and connecting client storage""" """Library for forking storage server and connecting client storage"""
import asyncore import asyncore
import os import os
import profile
import random import random
import socket import socket
import sys import sys
import traceback
import types import types
import ZEO.ClientStorage, ZEO.StorageServer import ZEO.ClientStorage, ZEO.StorageServer
# Change value of PROFILE to enable server-side profiling
PROFILE = 0 PROFILE = 0
if PROFILE:
import hotshot
def get_port(): def get_port():
"""Return a port that is not in use. """Return a port that is not in use.
...@@ -78,11 +66,9 @@ else: ...@@ -78,11 +66,9 @@ else:
buf = self.recv(4) buf = self.recv(4)
if buf: if buf:
assert buf == "done" assert buf == "done"
server.close_server()
asyncore.socket_map.clear() asyncore.socket_map.clear()
def handle_close(self): def handle_close(self):
server.close_server()
asyncore.socket_map.clear() asyncore.socket_map.clear()
class ZEOClientExit: class ZEOClientExit:
...@@ -91,27 +77,20 @@ else: ...@@ -91,27 +77,20 @@ else:
self.pipe = pipe self.pipe = pipe
def close(self): def close(self):
try:
os.write(self.pipe, "done") os.write(self.pipe, "done")
os.close(self.pipe) os.close(self.pipe)
except os.error:
pass
def start_zeo_server(storage, addr): def start_zeo_server(storage, addr):
rd, wr = os.pipe() rd, wr = os.pipe()
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
try:
if PROFILE: if PROFILE:
p = hotshot.Profile("stats.s.%d" % os.getpid()) p = profile.Profile()
p.runctx("run_server(storage, addr, rd, wr)", p.runctx("run_server(storage, addr, rd, wr)", globals(),
globals(), locals()) locals())
p.close() p.dump_stats("stats.s.%d" % os.getpid())
else: else:
run_server(storage, addr, rd, wr) run_server(storage, addr, rd, wr)
except:
print "Exception in ZEO server process"
traceback.print_exc()
os._exit(0) os._exit(0)
else: else:
os.close(rd) os.close(rd)
...@@ -119,11 +98,11 @@ else: ...@@ -119,11 +98,11 @@ else:
def run_server(storage, addr, rd, wr): def run_server(storage, addr, rd, wr):
# in the child, run the storage server # in the child, run the storage server
global server
os.close(wr) os.close(wr)
ZEOServerExit(rd) ZEOServerExit(rd)
server = ZEO.StorageServer.StorageServer(addr, {'1':storage}) serv = ZEO.StorageServer.StorageServer(addr, {'1':storage})
asyncore.loop() asyncore.loop()
os.close(rd)
storage.close() storage.close()
if isinstance(addr, types.StringType): if isinstance(addr, types.StringType):
os.unlink(addr) os.unlink(addr)
...@@ -149,7 +128,6 @@ else: ...@@ -149,7 +128,6 @@ else:
s = ZEO.ClientStorage.ClientStorage(addr, storage_id, s = ZEO.ClientStorage.ClientStorage(addr, storage_id,
debug=1, client=cache, debug=1, client=cache,
cache_size=cache_size, cache_size=cache_size,
min_disconnect_poll=0.5, min_disconnect_poll=0.5)
wait_for_server_on_startup=1)
return s, exit, pid return s, exit, pid
import random
import unittest
from ZEO.TransactionBuffer import TransactionBuffer
def random_string(size):
"""Return a random string of size size."""
l = [chr(random.randrange(256)) for i in range(size)]
return "".join(l)
def new_store_data():
"""Return arbitrary data to use as argument to store() method."""
return random_string(8), '', random_string(random.randrange(1000))
def new_invalidate_data():
"""Return arbitrary data to use as argument to invalidate() method."""
return random_string(8), ''
class TransBufTests(unittest.TestCase):
def checkTypicalUsage(self):
tbuf = TransactionBuffer()
tbuf.store(*new_store_data())
tbuf.invalidate(*new_invalidate_data())
tbuf.begin_iterate()
while 1:
o = tbuf.next()
if o is None:
break
tbuf.clear()
def doUpdates(self, tbuf):
data = []
for i in range(10):
d = new_store_data()
tbuf.store(*d)
data.append(d)
d = new_invalidate_data()
tbuf.invalidate(*d)
data.append(d)
tbuf.begin_iterate()
for i in range(len(data)):
x = tbuf.next()
if x[2] is None:
# the tbuf add a dummy None to invalidates
x = x[:2]
self.assertEqual(x, data[i])
def checkOrderPreserved(self):
tbuf = TransactionBuffer()
self.doUpdates(tbuf)
def checkReusable(self):
tbuf = TransactionBuffer()
self.doUpdates(tbuf)
tbuf.clear()
self.doUpdates(tbuf)
tbuf.clear()
self.doUpdates(tbuf)
def test_suite():
return unittest.makeSuite(TransBufTests, 'check')
...@@ -85,14 +85,11 @@ ...@@ -85,14 +85,11 @@
"""Sized message async connections """Sized message async connections
""" """
__version__ = "$Revision: 1.12 $"[11:-2] __version__ = "$Revision: 1.13 $"[11:-2]
import asyncore, struct
from Exceptions import Disconnected
from zLOG import LOG, TRACE, ERROR, INFO, BLATHER
from types import StringType
import asyncore, string, struct, zLOG, sys, Acquisition
import socket, errno import socket, errno
from zLOG import LOG, TRACE, ERROR, INFO
# Use the dictionary to make sure we get the minimum number of errno # Use the dictionary to make sure we get the minimum number of errno
# entries. We expect that EWOULDBLOCK == EAGAIN on most systems -- # entries. We expect that EWOULDBLOCK == EAGAIN on most systems --
...@@ -112,101 +109,81 @@ tmp_dict = {errno.EAGAIN: 0, ...@@ -112,101 +109,81 @@ tmp_dict = {errno.EAGAIN: 0,
expected_socket_write_errors = tuple(tmp_dict.keys()) expected_socket_write_errors = tuple(tmp_dict.keys())
del tmp_dict del tmp_dict
class SizedMessageAsyncConnection(asyncore.dispatcher): class SizedMessageAsyncConnection(Acquisition.Explicit, asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close
__closed = 1 # Marker indicating that we're closed
socket = None # to outwit Sam's getattr __append=None # Marker indicating that we're closed
READ_SIZE = 8096 socket=None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None, debug=None): def __init__(self, sock, addr, map=None, debug=None):
self.__super_init(sock, map) SizedMessageAsyncConnection.inheritedAttribute(
self.addr = addr '__init__')(self, sock, map)
self.addr=addr
if debug is not None: if debug is not None:
self._debug = debug self._debug=debug
elif not hasattr(self, '_debug'): elif not hasattr(self, '_debug'):
self._debug = __debug__ and 'smac' self._debug=__debug__ and 'smac'
self.__state = None self.__state=None
self.__inp = None # None, a single String, or a list self.__inp=None
self.__input_len = 0 self.__inpl=0
self.__msg_size = 4 self.__l=4
self.__output = [] self.__output=output=[]
self.__closed = None self.__append=output.append
self.__pop=output.pop
# XXX avoid expensive getattr calls?
def __nonzero__(self): def handle_read(self,
return 1 join=string.join, StringType=type(''), _type=type,
_None=None):
def handle_read(self):
# Use a single __inp buffer and integer indexes to make this
# fast.
try: try:
d=self.recv(8096) d=self.recv(8096)
except socket.error, err: except socket.error, err:
if err[0] in expected_socket_read_errors: if err[0] in expected_socket_read_errors:
return return
raise raise
if not d: if not d: return
return
input_len = self.__input_len + len(d)
msg_size = self.__msg_size
state = self.__state
inp = self.__inp inp=self.__inp
if msg_size > input_len: if inp is _None:
if inp is None: inp=d
self.__inp = d elif _type(inp) is StringType:
elif type(self.__inp) is StringType: inp=[inp,d]
self.__inp = [self.__inp, d]
else:
self.__inp.append(d)
self.__input_len = input_len
return # keep waiting for more input
# load all previous input and d into single string inp
if isinstance(inp, StringType):
inp = inp + d
elif inp is None:
inp = d
else: else:
inp.append(d) inp.append(d)
inp = "".join(inp)
offset = 0 inpl=self.__inpl+len(d)
while (offset + msg_size) <= input_len: l=self.__l
msg = inp[offset:offset + msg_size]
offset = offset + msg_size while 1:
if state is None:
if l <= inpl:
# Woo hoo, we have enough data
if _type(inp) is not StringType: inp=join(inp,'')
d=inp[:l]
inp=inp[l:]
inpl=inpl-l
if self.__state is _None:
# waiting for message # waiting for message
msg_size = struct.unpack(">i", msg)[0] l=struct.unpack(">i",d)[0]
state = 1 self.__state=1
else: else:
msg_size = 4 l=4
state = None self.__state=_None
self.message_input(msg) self.message_input(d)
else:
self.__state = state break # not enough data
self.__msg_size = msg_size
self.__inp = inp[offset:]
self.__input_len = input_len - offset
def readable(self): self.__l=l
return 1 self.__inp=inp
self.__inpl=inpl
def writable(self): def readable(self): return 1
if len(self.__output) == 0: def writable(self): return not not self.__output
return 0
else:
return 1
def handle_write(self): def handle_write(self):
output = self.__output output=self.__output
while output: while output:
v = output[0] v=output[0]
try: try:
n=self.send(v) n=self.send(v)
except socket.error, err: except socket.error, err:
...@@ -214,33 +191,42 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -214,33 +191,42 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
break # we couldn't write anything break # we couldn't write anything
raise raise
if n < len(v): if n < len(v):
output[0] = v[n:] output[0]=v[n:]
break # we can't write any more break # we can't write any more
else: else:
del output[0] del output[0]
#break # waaa
def handle_close(self): def handle_close(self):
self.close() self.close()
def message_output(self, message): def message_output(self, message,
if __debug__: pack=struct.pack, len=len):
if self._debug: if self._debug:
if len(message) > 40: if len(message) > 40: m=message[:40]+' ...'
m = message[:40]+' ...' else: m=message
else:
m = message
LOG(self._debug, TRACE, 'message_output %s' % `m`) LOG(self._debug, TRACE, 'message_output %s' % `m`)
if self.__closed is not None: append=self.__append
raise Disconnected, ( if append is None:
"This action is temporarily unavailable." raise Disconnected("This action is temporarily unavailable.<p>")
"<p>"
) append(pack(">i",len(message))+message)
# do two separate appends to avoid copying the message string
self.__output.append(struct.pack(">i", len(message))) def log_info(self, message, type='info'):
self.__output.append(message) if type=='error': type=ERROR
else: type=INFO
LOG('ZEO', type, message)
log=log_info
def close(self): def close(self):
if self.__closed is None: if self.__append is not None:
self.__closed = 1 self.__append=None
self.__super_close() SizedMessageAsyncConnection.inheritedAttribute('close')(self)
class Disconnected(Exception):
"""The client has become disconnected from the server
"""
"""RPC protocol for ZEO based on asyncore
The basic protocol is as:
a pickled tuple containing: msgid, flags, method, args
msgid is an integer.
flags is an integer.
The only currently defined flag is ASYNC (0x1), which means
the client does not expect a reply.
method is a string specifying the method to invoke.
For a reply, the method is ".reply".
args is a tuple of the argument to pass to method.
XXX need to specify a version number that describes the protocol.
allow for future revision.
XXX support multiple outstanding calls
XXX factor out common pattern of deciding what protocol to use based
on whether address is tuple or string
"""
import asyncore
import errno
import cPickle
import os
import select
import socket
import sys
import threading
import thread
import time
import traceback
import types
from cStringIO import StringIO
from ZODB import POSException
from ZEO import smac, trigger
from Exceptions import Disconnected
import zLOG
import ThreadedAsync
from Exceptions import Disconnected
REPLY = ".reply" # message name used for replies
ASYNC = 1
_label = "zrpc:%s" % os.getpid()
def new_label():
global _label
_label = "zrpc:%s" % os.getpid()
def log(message, level=zLOG.BLATHER, label=None, error=None):
zLOG.LOG(label or _label, level, message, error=error)
class ZRPCError(POSException.StorageError):
pass
class DecodingError(ZRPCError):
"""A ZRPC message could not be decoded."""
class DisconnectedError(ZRPCError, Disconnected):
"""The database storage is disconnected from the storage server."""
# Export the mainloop function from asycnore to zrpc clients
loop = asyncore.loop
def connect(addr, client=None):
if type(addr) == types.TupleType:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
else:
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.connect(addr)
c = Connection(s, addr, client)
return c
class Marshaller:
"""Marshal requests and replies to second across network"""
# It's okay to share a single Pickler as long as it's in fast
# mode, which means that it doesn't have a memo.
pickler = cPickle.Pickler()
pickler.fast = 1
pickle = pickler.dump
errors = (cPickle.UnpickleableError,
cPickle.UnpicklingError,
cPickle.PickleError,
cPickle.PicklingError)
def encode(self, msgid, flags, name, args):
"""Returns an encoded message"""
return self.pickle((msgid, flags, name, args), 1)
def decode(self, msg):
"""Decodes msg and returns its parts"""
unpickler = cPickle.Unpickler(StringIO(msg))
unpickler.find_global = find_global
try:
return unpickler.load() # msgid, flags, name, args
except (cPickle.UnpicklingError, IndexError), err_msg:
log("can't decode %s" % repr(msg), level=zLOG.ERROR)
raise DecodingError(msg)
class Delay:
"""Used to delay response to client for synchronous calls
When a synchronous call is made and the original handler returns
without handling the call, it returns a Delay object that prevents
the mainloop from sending a response.
"""
def set_sender(self, msgid, send_reply):
self.msgid = msgid
self.send_reply = send_reply
def reply(self, obj):
self.send_reply(self.msgid, obj)
class Connection(smac.SizedMessageAsyncConnection):
"""Dispatcher for RPC on object
The connection supports synchronous calls, which expect a return,
and asynchronous calls that do not.
It uses the Marshaller class to handle encoding and decoding of
method calls are arguments.
A Connection is designed for use in a multithreaded application,
where a synchronous call must block until a response is ready.
The current design only allows a single synchronous call to be
outstanding.
"""
__super_init = smac.SizedMessageAsyncConnection.__init__
__super_close = smac.SizedMessageAsyncConnection.close
__super_writable = smac.SizedMessageAsyncConnection.writable
def __init__(self, sock, addr, obj=None):
self.msgid = 0
self.obj = obj
self.marshal = Marshaller()
self.closed = 0
self.async = 0
# The reply lock is used to block when a synchronous call is
# waiting for a response
self.__super_init(sock, addr)
self._map = {self._fileno: self}
self._prepare_async()
self.__call_lock = thread.allocate_lock()
self.__reply_lock = thread.allocate_lock()
self.__reply_lock.acquire()
if isinstance(obj, Handler):
self.set_caller = 1
else:
self.set_caller = 0
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.addr)
def close(self):
if self.closed:
return
self.closed = 1
self.__super_close()
def register_object(self, obj):
"""Register obj as the true object to invoke methods on"""
self.obj = obj
def message_input(self, message):
"""Decoding an incoming message and dispatch it"""
# XXX Not sure what to do with errors that reach this level.
# Need to catch ZRPCErrors in handle_reply() and
# handle_request() so that they get back to the client.
try:
msgid, flags, name, args = self.marshal.decode(message)
except DecodingError, msg:
return self.return_error(None, None, sys.exc_info()[0],
sys.exc_info()[1])
if __debug__:
log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
repr(args)[:40]),
level=zLOG.DEBUG)
if name == REPLY:
self.handle_reply(msgid, flags, args)
else:
self.handle_request(msgid, flags, name, args)
def handle_reply(self, msgid, flags, args):
if __debug__:
log("recv reply: %s, %s, %s" % (msgid, flags, str(args)[:40]),
level=zLOG.DEBUG)
self.__reply = msgid, flags, args
self.__reply_lock.release() # will fail if lock is unlocked
def handle_request(self, msgid, flags, name, args):
if __debug__:
log("call %s%s on %s" % (name, repr(args)[:40], repr(self.obj)),
zLOG.DEBUG)
if not self.check_method(name):
raise ZRPCError("Invalid method name: %s on %s" % (name,
`self.obj`))
meth = getattr(self.obj, name)
try:
if self.set_caller:
self.obj.set_caller(self)
try:
ret = meth(*args)
finally:
self.obj.clear_caller()
else:
ret = meth(*args)
except (POSException.UndoError,
POSException.VersionCommitError), msg:
error = sys.exc_info()
log("%s() raised exception: %s" % (name, msg), zLOG.ERROR, error)
return self.return_error(msgid, flags, error[0], error[1])
except Exception, msg:
error = sys.exc_info()
log("%s() raised exception: %s" % (name, msg), zLOG.ERROR, error)
return self.return_error(msgid, flags, error[0], error[1])
if flags & ASYNC:
if ret is not None:
log("async method %s returned value %s" % (name, repr(ret)),
zLOG.ERROR)
raise ZRPCError("async method returned value")
else:
if __debug__:
log("%s return %s" % (name, repr(ret)[:40]), zLOG.DEBUG)
if isinstance(ret, Delay):
ret.set_sender(msgid, self.send_reply)
else:
self.send_reply(msgid, ret)
def handle_error(self):
self.log_error()
self.close()
def log_error(self, msg="No error message supplied"):
error = sys.exc_info()
log(msg, zLOG.ERROR, error=error)
del error
def check_method(self, name):
# XXX minimal security check should go here: Is name exported?
return hasattr(self.obj, name)
def send_reply(self, msgid, ret):
msg = self.marshal.encode(msgid, 0, REPLY, ret)
self.message_output(msg)
def return_error(self, msgid, flags, err_type, err_value):
if flags is None:
self.log_error("Exception raised during decoding")
return
if flags & ASYNC:
self.log_error("Asynchronous call raised exception: %s" % self)
return
if type(err_value) is not types.InstanceType:
err_value = err_type, err_value
try:
msg = self.marshal.encode(msgid, 0, REPLY, (err_type, err_value))
except self.marshal.errors:
err = ZRPCError("Couldn't pickle error %s" % `err_value`)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
self._do_io()
# The next two methods are used by clients to invoke methods on
# remote objects
# XXX Should revise design to allow multiple outstanding
# synchronous calls
def call(self, method, *args):
self.__call_lock.acquire()
try:
return self._call(method, args)
finally:
self.__call_lock.release()
def _call(self, method, args):
if self.closed:
raise DisconnectedError("This action is temporarily unavailable")
msgid = self.msgid
self.msgid = self.msgid + 1
if __debug__:
log("send msg: %d, 0, %s, ..." % (msgid, method))
self.message_output(self.marshal.encode(msgid, 0, method, args))
self.__reply = None
# lock is currently held
self._do_io(wait=1)
# lock is held again...
r_msgid, r_flags, r_args = self.__reply
self.__reply_lock.acquire()
assert r_msgid == msgid, "%s != %s: %s" % (r_msgid, msgid, r_args)
if type(r_args) == types.TupleType \
and type(r_args[0]) == types.ClassType \
and issubclass(r_args[0], Exception):
raise r_args[1] # error raised by server
return r_args
def callAsync(self, method, *args):
self.__call_lock.acquire()
try:
self._callAsync(method, args)
finally:
self.__call_lock.release()
def _callAsync(self, method, args):
if self.closed:
raise DisconnectedError("This action is temporarily unavailable")
msgid = self.msgid
self.msgid += 1
if __debug__:
log("send msg: %d, %d, %s, ..." % (msgid, ASYNC, method))
self.message_output(self.marshal.encode(msgid, ASYNC, method, args))
self._do_io()
# handle IO, possibly in async mode
def sync(self):
pass # XXX what is this supposed to do?
def _prepare_async(self):
self._async = 0
ThreadedAsync.register_loop_callback(self.set_async)
# XXX If we are not in async mode, this will cause dead
# Connections to be leaked.
def set_async(self, map):
# XXX do we need a lock around this? I'm not sure there is
# any harm to a race with _do_io().
self._async = 1
self.trigger = trigger.trigger()
def is_async(self):
return self._async
def _do_io(self, wait=0): # XXX need better name
# XXX invariant? lock must be held when calling with wait==1
# otherwise, in non-async mode, there will be no poll
if __debug__:
log("_do_io(wait=%d), async=%d" % (wait, self.is_async()),
level=zLOG.DEBUG)
if self.is_async():
self.trigger.pull_trigger()
if wait:
self.__reply_lock.acquire()
# wait until reply...
self.__reply_lock.release()
else:
if wait:
# do loop only if lock is already acquired
while not self.__reply_lock.acquire(0):
asyncore.poll(10.0, self._map)
if self.closed:
raise Disconnected()
self.__reply_lock.release()
else:
asyncore.poll(0.0, self._map)
# XXX it seems that we need to release before returning if
# called with wait==1. perhaps the caller need not acquire
# upon return...
class ServerConnection(Connection):
# XXX this is a hack
def _do_io(self, wait=0):
"""If this is a server, there is no explicit IO to do"""
pass
class ConnectionManager:
"""Keeps a connection up over time"""
# XXX requires that obj implement notifyConnected and
# notifyDisconnected. make this optional?
def __init__(self, addr, obj=None, debug=1, tmin=1, tmax=180):
self.set_addr(addr)
self.obj = obj
self.tmin = tmin
self.tmax = tmax
self.debug = debug
self.connected = 0
self.connection = None
# If _thread is not None, then there is a helper thread
# attempting to connect. _thread is protected by _connect_lock.
self._thread = None
self._connect_lock = threading.Lock()
self.trigger = None
self.async = 0
self.closed = 0
ThreadedAsync.register_loop_callback(self.set_async)
def __repr__(self):
return "<%s for %s>" % (self.__class__.__name__, self.addr)
def set_addr(self, addr):
"Set one or more addresses to use for server."
# For backwards compatibility (and simplicity?) the
# constructor accepts a single address in the addr argument --
# a string for a Unix domain socket or a 2-tuple with a
# hostname and port. It can also accept a list of such addresses.
addr_type = self._guess_type(addr)
if addr_type is not None:
self.addr = [(addr_type, addr)]
else:
self.addr = []
for a in addr:
addr_type = self._guess_type(a)
if addr_type is None:
raise ValueError, "unknown address in list: %s" % repr(a)
self.addr.append((addr_type, a))
def _guess_type(self, addr):
if isinstance(addr, types.StringType):
return socket.AF_UNIX
if (len(addr) == 2
and isinstance(addr[0], types.StringType)
and isinstance(addr[1], types.IntType)):
return socket.AF_INET
# not anything I know about
return None
def close(self):
"""Prevent ConnectionManager from opening new connections"""
self.closed = 1
self._connect_lock.acquire()
try:
if self._thread is not None:
self._thread.join()
finally:
self._connect_lock.release()
if self.connection:
self.connection.close()
def register_object(self, obj):
self.obj = obj
def set_async(self, map):
# XXX need each connection started with async==0 to have a callback
self.async = 1 # XXX needs to be set on the Connection
self.trigger = trigger.trigger()
def connect(self, sync=0):
if self.connected == 1:
return
self._connect_lock.acquire()
try:
if self._thread is None:
zLOG.LOG(_label, zLOG.BLATHER,
"starting thread to connect to server")
self._thread = threading.Thread(target=self.__m_connect)
self._thread.start()
if sync:
try:
self._thread.join()
except AttributeError:
# probably means the thread exited quickly
pass
finally:
self._connect_lock.release()
def attempt_connect(self):
# XXX will _attempt_connects() take too long? think select().
self._attempt_connects()
return self.connected
def notify_closed(self, conn):
self.connected = 0
self.connection = None
self.obj.notifyDisconnected()
if not self.closed:
self.connect()
class Connected(Exception):
def __init__(self, sock):
self.sock = sock
def __m_connect(self):
# a new __connect that handles multiple addresses
try:
delay = self.tmin
while not (self.closed or self._attempt_connects()):
time.sleep(delay)
delay *= 2
if delay > self.tmax:
delay = self.tmax
finally:
self._thread = None
def _attempt_connects(self):
"Return true if any connect attempt succeeds."
sockets = {}
zLOG.LOG(_label, zLOG.BLATHER,
"attempting connection on %d sockets" % len(self.addr))
try:
for domain, addr in self.addr:
if __debug__:
zLOG.LOG(_label, zLOG.DEBUG,
"attempt connection to %s" % repr(addr))
s = socket.socket(domain, socket.SOCK_STREAM)
s.setblocking(0)
# XXX can still block for a while if addr requires DNS
e = self._connect_ex(s, addr)
if e is not None:
sockets[s] = addr
# next wait until the actually connect
while sockets:
if self.closed:
for s in sockets.keys():
s.close()
return 0
try:
r, w, x = select.select([], sockets.keys(), [], 1.0)
except select.error:
continue
for s in w:
e = self._connect_ex(s, sockets[s])
if e is None:
del sockets[s]
except self.Connected, container:
s = container.sock
del sockets[s]
# close all the other sockets
for s in sockets.keys():
s.close()
return 1
return 0
def _connect_ex(self, s, addr):
"""Call s.connect_ex(addr) and return true if loop should continue.
We have to handle several possible return values from
connect_ex(). If the socket is connected and the initial ZEO
setup works, we're done. Report success by raising an
exception. Yes, the is odd, but we need to bail out of the
select() loop in the caller and an exception is a principled
way to do the abort.
If the socket sonnects and the initial ZEO setup fails or the
connect_ex() returns an error, we close the socket and ignore it.
If connect_ex() returns EINPROGRESS, we need to try again later.
"""
e = s.connect_ex(addr)
if e == errno.EINPROGRESS:
return 1
elif e == 0:
c = self._test_connection(s, addr)
zLOG.LOG(_label, zLOG.DEBUG, "connected to %s" % repr(addr))
if c:
self.connected = 1
raise self.Connected(s)
else:
if __debug__:
zLOG.LOG(_label, zLOG.DEBUG,
"error connecting to %s: %s" % (addr,
errno.errorcode[e]))
s.close()
def _test_connection(self, s, addr):
c = ManagedConnection(s, addr, self.obj, self)
try:
self.obj.notifyConnected(c)
self.connection = c
return 1
except:
# XXX zLOG the error
c.close()
return 0
class ManagedServerConnection(ServerConnection):
"""A connection that notifies its ConnectionManager of closing"""
__super_init = Connection.__init__
__super_close = Connection.close
def __init__(self, sock, addr, obj, mgr):
self.__mgr = mgr
self.__super_init(sock, addr, obj)
def close(self):
self.__super_close()
self.__mgr.close(self)
class ManagedConnection(Connection):
"""A connection that notifies its ConnectionManager of closing.
A managed connection also defers the ThreadedAsync work to its
manager.
"""
__super_init = Connection.__init__
__super_close = Connection.close
def __init__(self, sock, addr, obj, mgr):
self.__mgr = mgr
if self.__mgr.async:
self.__async = 1
self.trigger = self.__mgr.trigger
else:
self.__async = None
self.__super_init(sock, addr, obj)
def _prepare_async(self):
# Don't do the register_loop_callback that the superclass does
pass
def is_async(self):
if self.__async:
return 1
async = self.__mgr.async
if async:
self.__async = 1
self.trigger = self.__mgr.trigger
return async
def close(self):
self.__super_close()
self.__mgr.notify_closed(self)
class Dispatcher(asyncore.dispatcher):
"""A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__
reuse_addr = 1
def __init__(self, addr, obj=None, factory=Connection, reuse_addr=None):
self.__super_init()
self.addr = addr
self.obj = obj
self.factory = factory
self.clients = []
if reuse_addr is not None:
self.reuse_addr = reuse_addr
self._open_socket()
def _open_socket(self):
if type(self.addr) == types.TupleType:
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr()
self.bind(self.addr)
self.listen(5)
def writable(self):
return 0
def readable(self):
return 1
def handle_accept(self):
try:
sock, addr = self.accept()
except socket.error, msg:
log("accepted failed: %s" % msg)
return
c = self.factory(sock, addr, self.obj)
log("connect from %s: %s" % (repr(addr), c))
self.clients.append(c)
class Handler:
"""Base class used to handle RPC caller discovery"""
def set_caller(self, addr):
self.__caller = addr
def get_caller(self):
return self.__caller
def clear_caller(self):
self.__caller = None
_globals = globals()
_silly = ('__doc__',)
def find_global(module, name):
"""Helper for message unpickler"""
try:
m = __import__(module, _globals, _globals, _silly)
except ImportError, msg:
raise ZRPCError("import error %s: %s" % (module, msg))
try:
r = getattr(m, name)
except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0)
if safe:
return r
if type(r) == types.ClassType and issubclass(r, Exception):
return r
raise ZRPCError("Unsafe global: %s.%s" % (module, name))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment