Commit 723ddd40 authored by Jim Fulton's avatar Jim Fulton

Refactored cache verification to fix threading bugs during connection.

Changed connections to work with unset (None) clients.  Messages
aren't forwarded until the client is set.  This is to prevent sending
spurious invalidation messages until a client is ready to recieve them.
parent 1d61dc56
...@@ -491,6 +491,7 @@ class ClientStorage(object): ...@@ -491,6 +491,7 @@ class ClientStorage(object):
# If we are upgrading from a read-only fallback connection, # If we are upgrading from a read-only fallback connection,
# we must close the old connection to prevent it from being # we must close the old connection to prevent it from being
# used while the cache is verified against the new connection. # used while the cache is verified against the new connection.
self._connection.register_object(None) # Don't call me!
self._connection.close() self._connection.close()
self._connection = None self._connection = None
self._ready.clear() self._ready.clear()
...@@ -558,54 +559,6 @@ class ClientStorage(object): ...@@ -558,54 +559,6 @@ class ClientStorage(object):
else: else:
return '%s:%s' % (self._storage, self._server_addr) return '%s:%s' % (self._storage, self._server_addr)
def verify_cache(self, server):
"""Internal routine called to verify the cache.
The return value (indicating which path we took) is used by
the test suite.
"""
# If verify_cache() finishes the cache verification process,
# it should set self._server. If it goes through full cache
# verification, then endVerify() should self._server.
last_inval_tid = self._cache.getLastTid()
if last_inval_tid is not None:
ltid = server.lastTransaction()
if ltid == last_inval_tid:
log2("No verification necessary (last_inval_tid up-to-date)")
self._server = server
self._ready.set()
return "no verification"
# log some hints about last transaction
log2("last inval tid: %r %s\n"
% (last_inval_tid, tid2time(last_inval_tid)))
log2("last transaction: %r %s" %
(ltid, ltid and tid2time(ltid)))
pair = server.getInvalidations(last_inval_tid)
if pair is not None:
log2("Recovering %d invalidations" % len(pair[1]))
self.invalidateTransaction(*pair)
self._server = server
self._ready.set()
return "quick verification"
log2("Verifying cache")
# setup tempfile to hold zeoVerify results
self._tfile = tempfile.TemporaryFile(suffix=".inv")
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
# TODO: should batch these operations for efficiency; would need
# to acquire lock ...
for oid, tid, version in self._cache.contents():
server.verify(oid, version, tid)
self._pending_server = server
server.endZeoVerify()
return "full verification"
### Is there a race condition between notifyConnected and ### Is there a race condition between notifyConnected and
### notifyDisconnected? In Particular, what if we get ### notifyDisconnected? In Particular, what if we get
### notifyDisconnected in the middle of notifyConnected? ### notifyDisconnected in the middle of notifyConnected?
...@@ -1162,7 +1115,7 @@ class ClientStorage(object): ...@@ -1162,7 +1115,7 @@ class ClientStorage(object):
return return
for oid, version, data in self._tbuf: for oid, version, data in self._tbuf:
self._cache.invalidate(oid, version, tid) self._cache.invalidate(oid, version, tid, False)
# If data is None, we just invalidate. # If data is None, we just invalidate.
if data is not None: if data is not None:
s = self._seriald[oid] s = self._seriald[oid]
...@@ -1224,8 +1177,6 @@ class ClientStorage(object): ...@@ -1224,8 +1177,6 @@ class ClientStorage(object):
"""Storage API: return a sequence of versions in the storage.""" """Storage API: return a sequence of versions in the storage."""
return self._server.versions(max) return self._server.versions(max)
# Below are methods invoked by the StorageServer
def serialnos(self, args): def serialnos(self, args):
"""Server callback to pass a list of changed (oid, serial) pairs.""" """Server callback to pass a list of changed (oid, serial) pairs."""
self._serials.extend(args) self._serials.extend(args)
...@@ -1234,6 +1185,57 @@ class ClientStorage(object): ...@@ -1234,6 +1185,57 @@ class ClientStorage(object):
"""Server callback to update the info dictionary.""" """Server callback to update the info dictionary."""
self._info.update(dict) self._info.update(dict)
def verify_cache(self, server):
"""Internal routine called to verify the cache.
The return value (indicating which path we took) is used by
the test suite.
"""
self._pending_server = server
# setup tempfile to hold zeoVerify results and interim
# invalidation results
self._tfile = tempfile.TemporaryFile(suffix=".inv")
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
# allow incoming invalidations:
self._connection.register_object(self)
# If verify_cache() finishes the cache verification process,
# it should set self._server. If it goes through full cache
# verification, then endVerify() should self._server.
last_inval_tid = self._cache.getLastTid()
if last_inval_tid is not None:
ltid = server.lastTransaction()
if ltid == last_inval_tid:
log2("No verification necessary (last_inval_tid up-to-date)")
self.finish_verification()
return "no verification"
# log some hints about last transaction
log2("last inval tid: %r %s\n"
% (last_inval_tid, tid2time(last_inval_tid)))
log2("last transaction: %r %s" %
(ltid, ltid and tid2time(ltid)))
pair = server.getInvalidations(last_inval_tid)
if pair is not None:
log2("Recovering %d invalidations" % len(pair[1]))
self.finish_verification(pair)
return "quick verification"
log2("Verifying cache")
# TODO: should batch these operations for efficiency; would need
# to acquire lock ...
for oid, tid, version in self._cache.contents():
server.verify(oid, version, tid)
server.endZeoVerify()
return "full verification"
def invalidateVerify(self, args): def invalidateVerify(self, args):
"""Server callback to invalidate an (oid, version) pair. """Server callback to invalidate an (oid, version) pair.
...@@ -1245,67 +1247,92 @@ class ClientStorage(object): ...@@ -1245,67 +1247,92 @@ class ClientStorage(object):
# This should never happen. TODO: assert it doesn't, or log # This should never happen. TODO: assert it doesn't, or log
# if it does. # if it does.
return return
self._pickler.dump(args) oid, version = args
self._pickler.dump((oid, version, None))
def _process_invalidations(self, invs): def endVerify(self):
# Invalidations are sent by the ZEO server as a sequence of """Server callback to signal end of cache validation."""
# oid, version pairs. The DB's invalidate() method expects a
# dictionary of oids. log2("endVerify finishing")
self.finish_verification()
log2("endVerify finished")
def finish_verification(self, catch_up=None):
self._lock.acquire() self._lock.acquire()
try: try:
# versions maps version names to dictionary of invalidations if catch_up:
versions = {} # process catch-up invalidations
for oid, version, tid in invs: tid, invalidations = catch_up
if oid == self._load_oid: self._process_invalidations(
self._load_status = 0 (oid, version, tid)
self._cache.invalidate(oid, version, tid) for oid, version in invalidations
oids = versions.get((version, tid)) )
if not oids:
versions[(version, tid)] = [oid]
else:
oids.append(oid)
if self._db is not None:
for (version, tid), d in versions.items():
self._db.invalidate(tid, d, version=version)
finally:
self._lock.release()
def endVerify(self):
"""Server callback to signal end of cache validation."""
if self._pickler is None: if self._pickler is None:
return return
# write end-of-data marker # write end-of-data marker
self._pickler.dump((None, None)) self._pickler.dump((None, None, None))
self._pickler = None self._pickler = None
self._tfile.seek(0) self._tfile.seek(0)
f = self._tfile unpickler = cPickle.Unpickler(self._tfile)
min_tid = self._cache.getLastTid()
def InvalidationLogIterator():
while 1:
oid, version, tid = unpickler.load()
if oid is None:
break
if ((tid is None)
or (min_tid is None)
or (tid > min_tid)
):
yield oid, version, tid
self._process_invalidations(InvalidationLogIterator())
self._tfile.close()
self._tfile = None self._tfile = None
self._process_invalidations(InvalidationLogIterator(f)) finally:
f.close() self._lock.release()
log2("endVerify finishing")
self._server = self._pending_server self._server = self._pending_server
self._ready.set() self._ready.set()
self._pending_conn = None self._pending_server = None
log2("endVerify finished")
def invalidateTransaction(self, tid, args): def invalidateTransaction(self, tid, args):
"""Invalidate objects modified by tid.""" """Server callback: Invalidate objects modified by tid."""
self._lock.acquire() self._lock.acquire()
try: try:
self._cache.setLastTid(tid)
finally:
self._lock.release()
if self._pickler is not None: if self._pickler is not None:
log2("Transactional invalidation during cache verification", log2("Transactional invalidation during cache verification",
level=BLATHER) level=BLATHER)
for t in args: for oid, version in args:
self._pickler.dump(t) self._pickler.dump((oid, version, tid))
return return
self._process_invalidations([(oid, version, tid) self._process_invalidations([(oid, version, tid)
for oid, version in args]) for oid, version in args])
finally:
self._lock.release()
def _process_invalidations(self, invs):
# Invalidations are sent by the ZEO server as a sequence of
# oid, version, tid triples. The DB's invalidate() method expects a
# dictionary of oids.
# versions maps version names to dictionary of invalidations
versions = {}
for oid, version, tid in invs:
if oid == self._load_oid:
self._load_status = 0
self._cache.invalidate(oid, version, tid)
oids = versions.get((version, tid))
if not oids:
versions[(version, tid)] = [oid]
else:
oids.append(oid)
if self._db is not None:
for (version, tid), d in versions.items():
self._db.invalidate(tid, d, version=version)
# The following are for compatibility with protocol version 2.0.0 # The following are for compatibility with protocol version 2.0.0
...@@ -1315,11 +1342,3 @@ class ClientStorage(object): ...@@ -1315,11 +1342,3 @@ class ClientStorage(object):
invalidate = invalidateVerify invalidate = invalidateVerify
end = endVerify end = endVerify
Invalidate = invalidateTrans Invalidate = invalidateTrans
def InvalidationLogIterator(fileobj):
unpickler = cPickle.Unpickler(fileobj)
while 1:
oid, version = unpickler.load()
if oid is None:
break
yield oid, version, None
Invalidations while connecting
==============================
As soon as a client registers with a server, it will recieve
invalidations from the server. The client must be careful to queue
these invalidations until it is ready to deal with them. At the time
of the writing of this test, clients weren't careful enogh about
queing invalidations. This led to cache corruption in the form of
both low-level file corruption as well as out-of-date records marked
as current.
This tests tries to provoke this bug by:
- starting a server
>>> import ZEO.tests.testZEO, ZEO.tests.forker
>>> addr = 'localhost', ZEO.tests.testZEO.get_port()
>>> zconf = ZEO.tests.forker.ZEOConfig(addr)
>>> sconf = '<filestorage 1>\npath Data.fs\n</filestorage>\n'
>>> _, adminaddr, pid, conf_path = ZEO.tests.forker.start_zeo_server(
... sconf, zconf, addr[1])
- opening a client to the server that writes some objects, filling
it's cache at the same time,
>>> import ZEO.ClientStorage, ZODB.tests.MinPO, transaction
>>> db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x'))
>>> conn = db.open()
>>> nobs = 1000
>>> for i in range(nobs):
... conn.root()[i] = ZODB.tests.MinPO.MinPO(0)
>>> transaction.commit()
- disconnecting the first client (closing it with a persistent cache),
>>> db.close()
- starting a second client that writes objects more or less
constantly,
>>> import random, threading
>>> stop = False
>>> db2 = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr))
>>> tm = transaction.TransactionManager()
>>> conn2 = db2.open(transaction_manager=tm)
>>> random = random.Random(0)
>>> lock = threading.Lock()
>>> def run():
... while 1:
... i = random.randint(0, nobs-1)
... if stop:
... return
... lock.acquire()
... try:
... conn2.root()[i].value += 1
... tm.commit()
... finally:
... lock.release()
... time.sleep(0)
>>> thread = threading.Thread(target=run)
>>> thread.start()
- restarting the first client, and
- testing for cache validity.
>>> import zope.testing.loggingsupport, logging
>>> handler = zope.testing.loggingsupport.InstalledHandler(
... 'ZEO', level=logging.ERROR)
>>> import time
>>> for c in range(10):
... time.sleep(.1)
... db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x'))
... _ = lock.acquire()
... try:
... time.sleep(.1)
... assert (db._storage.lastTransaction()
... == db._storage._server.lastTransaction()), (
... db._storage.lastTransaction(),
... db._storage._server.lastTransactiion())
... conn = db.open()
... for i in range(1000):
... if conn.root()[i].value != conn2.root()[i].value:
... print 'bad', c, i, conn.root()[i].value,
... print conn2.root()[i].value
... finally:
... _ = lock.release()
... db.close()
>>> stop = True
>>> thread.join(10)
>>> thread.isAlive()
False
>>> for record in handler.records:
... print record.name, record.levelname
... print handler.format(record)
>>> handler.uninstall()
>>> db.close()
>>> db2.close()
...@@ -21,7 +21,7 @@ platform-dependent scaffolding. ...@@ -21,7 +21,7 @@ platform-dependent scaffolding.
import unittest import unittest
# Import the actual test class # Import the actual test class
from ZEO.tests import ConnectionTests, InvalidationTests from ZEO.tests import ConnectionTests, InvalidationTests
from zope.testing import doctest, setupstack
class FileStorageConfig: class FileStorageConfig:
def getConfig(self, path, create, read_only): def getConfig(self, path, create, read_only):
...@@ -135,6 +135,10 @@ def test_suite(): ...@@ -135,6 +135,10 @@ def test_suite():
for klass in test_classes: for klass in test_classes:
sub = unittest.makeSuite(klass, 'check') sub = unittest.makeSuite(klass, 'check')
suite.addTest(sub) suite.addTest(sub)
suite.addTest(doctest.DocFileSuite(
'invalidations_while_connecting.test',
setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown,
))
return suite return suite
......
...@@ -447,8 +447,7 @@ class ConnectWrapper: ...@@ -447,8 +447,7 @@ class ConnectWrapper:
Call the client's testConnection(), giving the client a chance Call the client's testConnection(), giving the client a chance
to do app-level check of the connection. to do app-level check of the connection.
""" """
self.conn = ManagedClientConnection(self.sock, self.addr, self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr)
self.client, self.mgr)
self.sock = None # The socket is now owned by the connection self.sock = None # The socket is now owned by the connection
try: try:
self.preferred = self.client.testConnection(self.conn) self.preferred = self.client.testConnection(self.conn)
......
...@@ -555,14 +555,23 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -555,14 +555,23 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.replies_cond.release() self.replies_cond.release()
def handle_request(self, msgid, flags, name, args): def handle_request(self, msgid, flags, name, args):
if not self.check_method(name): obj = self.obj
msg = "Invalid method name: %s on %s" % (name, repr(self.obj))
if name.startswith('_') or not hasattr(obj, name):
if obj is None:
if __debug__:
self.log("no object calling %s%s"
% (name, short_repr(args)),
level=logging.DEBUG)
return
msg = "Invalid method name: %s on %s" % (name, repr(obj))
raise ZRPCError(msg) raise ZRPCError(msg)
if __debug__: if __debug__:
self.log("calling %s%s" % (name, short_repr(args)), self.log("calling %s%s" % (name, short_repr(args)),
level=logging.DEBUG) level=logging.DEBUG)
meth = getattr(self.obj, name) meth = getattr(obj, name)
try: try:
self.waiting_for_reply = True self.waiting_for_reply = True
try: try:
...@@ -601,12 +610,6 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -601,12 +610,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
level=logging.ERROR, exc_info=True) level=logging.ERROR, exc_info=True)
self.close() self.close()
def check_method(self, name):
# TODO: This is hardly "secure".
if name.startswith('_'):
return None
return hasattr(self.obj, name)
def send_reply(self, msgid, ret): def send_reply(self, msgid, ret):
# encode() can pass on a wide variety of exceptions from cPickle. # encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case # While a bare `except` is generally poor practice, in this case
...@@ -897,7 +900,7 @@ class ManagedClientConnection(Connection): ...@@ -897,7 +900,7 @@ class ManagedClientConnection(Connection):
__super_close = Connection.close __super_close = Connection.close
base_message_output = Connection.message_output base_message_output = Connection.message_output
def __init__(self, sock, addr, obj, mgr): def __init__(self, sock, addr, mgr):
self.mgr = mgr self.mgr = mgr
# We can't use the base smac's message_output directly because the # We can't use the base smac's message_output directly because the
...@@ -914,7 +917,7 @@ class ManagedClientConnection(Connection): ...@@ -914,7 +917,7 @@ class ManagedClientConnection(Connection):
self.queue_output = True self.queue_output = True
self.queued_messages = [] self.queued_messages = []
self.__super_init(sock, addr, obj, tag='C', map=client_map) self.__super_init(sock, addr, None, tag='C', map=client_map)
self.thr_async = True self.thr_async = True
self.trigger = client_trigger self.trigger = client_trigger
client_trigger.pull_trigger() client_trigger.pull_trigger()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment