Commit 3d02129b authored by Jim Fulton's avatar Jim Fulton

Merge remote-tracking branch 'origin/asyncio' into simplify-server-commit-lock-management

Conflicts:
	src/ZEO/StorageServer.py
parents f4bcea77 815f39d1
Changelog Changelog
========= =========
- Fixed: SSL clients of servers with signed certs didn't load default
certs and were unable to connect.
5.0.0a0 (2016-07-08) 5.0.0a0 (2016-07-08)
-------------------- --------------------
......
...@@ -490,7 +490,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -490,7 +490,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
return self._call('record_iternext', next) return self._call('record_iternext', next)
def getTid(self, oid): def getTid(self, oid):
# XXX deprecated: used by storage server for full cache verification. # XXX deprecated: but ZODB tests use this. They shouldn't
return self._call('getTid', oid) return self._call('getTid', oid)
def loadSerial(self, oid, serial): def loadSerial(self, oid, serial):
...@@ -504,8 +504,15 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -504,8 +504,15 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
return result[:2] return result[:2]
def loadBefore(self, oid, tid): def loadBefore(self, oid, tid):
result = self._cache.loadBefore(oid, tid)
if result:
return result
return self._server.load_before(oid, tid) return self._server.load_before(oid, tid)
def prefetch(self, oids, tid):
self._server.prefetch(oids, tid)
def new_oid(self): def new_oid(self):
"""Storage API: return a new object identifier. """Storage API: return a new object identifier.
""" """
......
...@@ -76,10 +76,6 @@ registered_methods = set(( 'get_info', 'lastTransaction', ...@@ -76,10 +76,6 @@ registered_methods = set(( 'get_info', 'lastTransaction',
class ZEOStorage: class ZEOStorage:
"""Proxy to underlying storage for a single remote client.""" """Proxy to underlying storage for a single remote client."""
# A list of extension methods. A subclass with extra methods
# should override.
extensions = []
connected = connection = stats = storage = storage_id = transaction = None connected = connection = stats = storage = storage_id = transaction = None
blob_tempfile = None blob_tempfile = None
log_label = 'unconnected' log_label = 'unconnected'
...@@ -91,10 +87,6 @@ class ZEOStorage: ...@@ -91,10 +87,6 @@ class ZEOStorage:
self.client_conflict_resolution = server.client_conflict_resolution self.client_conflict_resolution = server.client_conflict_resolution
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.read_only = read_only self.read_only = read_only
# The authentication protocol may define extra methods.
self._extensions = {}
for func in self.extensions:
self._extensions[func.__name__] = None
self._iterators = {} self._iterators = {}
self._iterator_ids = itertools.count() self._iterator_ids = itertools.count()
# Stores the last item that was handed out for a # Stores the last item that was handed out for a
...@@ -149,23 +141,13 @@ class ZEOStorage: ...@@ -149,23 +141,13 @@ class ZEOStorage:
if not info['supportsUndo']: if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: () self.undoLog = self.undoInfo = lambda *a,**k: ()
# XXX deprecated: but ZODB tests use getTid. They shouldn't
self.getTid = storage.getTid self.getTid = storage.getTid
self.load = storage.load
self.loadSerial = storage.loadSerial self.loadSerial = storage.loadSerial
record_iternext = getattr(storage, 'record_iternext', None) record_iternext = getattr(storage, 'record_iternext', None)
if record_iternext is not None: if record_iternext is not None:
self.record_iternext = record_iternext self.record_iternext = record_iternext
try:
fn = storage.getExtensionMethods
except AttributeError:
pass # no extension methods
else:
d = fn()
self._extensions.update(d)
for name in d:
assert not hasattr(self, name)
setattr(self, name, getattr(storage, name))
self.lastTransaction = storage.lastTransaction self.lastTransaction = storage.lastTransaction
try: try:
...@@ -252,7 +234,6 @@ class ZEOStorage: ...@@ -252,7 +234,6 @@ class ZEOStorage:
'size': storage.getSize(), 'size': storage.getSize(),
'name': storage.getName(), 'name': storage.getName(),
'supportsUndo': supportsUndo, 'supportsUndo': supportsUndo,
'extensionMethods': self.getExtensionMethods(),
'supports_record_iternext': hasattr(self, 'record_iternext'), 'supports_record_iternext': hasattr(self, 'record_iternext'),
'interfaces': tuple(interfaces), 'interfaces': tuple(interfaces),
} }
...@@ -262,13 +243,6 @@ class ZEOStorage: ...@@ -262,13 +243,6 @@ class ZEOStorage:
'size': self.storage.getSize(), 'size': self.storage.getSize(),
} }
def getExtensionMethods(self):
return self._extensions
def loadEx(self, oid):
self.stats.loads += 1
return self.storage.load(oid, '')
def loadBefore(self, oid, tid): def loadBefore(self, oid, tid):
self.stats.loads += 1 self.stats.loads += 1
return self.storage.loadBefore(oid, tid) return self.storage.loadBefore(oid, tid)
...@@ -737,6 +711,7 @@ class StorageServer: ...@@ -737,6 +711,7 @@ class StorageServer:
self._lock = Lock() self._lock = Lock()
self.ssl = ssl # For dev convenience
self.read_only = read_only self.read_only = read_only
self.database = None self.database = None
......
...@@ -8,6 +8,7 @@ else: ...@@ -8,6 +8,7 @@ else:
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
import concurrent.futures import concurrent.futures
import functools
import logging import logging
import random import random
import threading import threading
...@@ -27,6 +28,37 @@ Fallback = object() ...@@ -27,6 +28,37 @@ Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests local_random = random.Random() # use separate generator to facilitate tests
def future_generator(func):
"""Decorates a generator that generates futures
"""
@functools.wraps(func)
def call_generator(*args, **kw):
gen = func(*args, **kw)
try:
f = next(gen)
except StopIteration:
gen.close()
else:
def store(gen, future):
@future.add_done_callback
def _(future):
try:
try:
result = future.result()
except Exception as exc:
f = gen.throw(exc)
else:
f = gen.send(result)
except StopIteration:
gen.close()
else:
store(gen, f)
store(gen, f)
return call_generator
class Protocol(base.Protocol): class Protocol(base.Protocol):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -127,15 +159,9 @@ class Protocol(base.Protocol): ...@@ -127,15 +159,9 @@ class Protocol(base.Protocol):
self.closed = True self.closed = True
self.client.disconnected(self) self.client.disconnected(self)
@future_generator
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
# The future implementation we use differs from
# We use a promise model rather than coroutines here because
# for the most part, this class is reactive and coroutines
# aren't a good model of it's activities. During
# initialization, however, we use promises to provide an
# imperative flow.
# The promise(/future) implementation we use differs from
# asyncio.Future in that callbacks are called immediately, # asyncio.Future in that callbacks are called immediately,
# rather than using the loops call_soon. We want to avoid a # rather than using the loops call_soon. We want to avoid a
# race between invalidations and cache initialization. In # race between invalidations and cache initialization. In
...@@ -150,56 +176,29 @@ class Protocol(base.Protocol): ...@@ -150,56 +176,29 @@ class Protocol(base.Protocol):
self.client.register_failed( self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version)) self, ZEO.Exceptions.ProtocolError(protocol_version))
return return
self._write(self.protocol_version) self._write(self.protocol_version)
register = self.promise( try:
try:
server_tid = yield self.fut(
'register', self.storage_key, 'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False, self.read_only if self.read_only is not Fallback else False,
) )
if self.read_only is not Fallback: except ZODB.POSException.ReadOnlyError:
# Get lastTransaction in flight right away to make
# successful connection quicker, but only if we're not
# doing read-only fallback. If we might need to retry, we
# can't send lastTransaction because if the registration
# fails, it will be seen as an invalid message and the
# connection will close. :( It would be a lot better of
# registere returned the last transaction (and info while
# it's at it).
lastTransaction = self.promise('lastTransaction')
else:
lastTransaction = None # to make python happy
@register
def registered(_):
if self.read_only is Fallback: if self.read_only is Fallback:
self.read_only = False
r_lastTransaction = self.promise('lastTransaction')
else:
r_lastTransaction = lastTransaction
self.client.registered(self, r_lastTransaction)
@register.catch
def register_failed(exc):
if (isinstance(exc, ZODB.POSException.ReadOnlyError) and
self.read_only is Fallback):
# We tried a write connection, degrade to a read-only one
self.read_only = True self.read_only = True
logger.info("%s write connection failed. Trying read-only", server_tid = yield self.fut(
self) 'register', self.storage_key, True)
register = self.promise('register', self.storage_key, True)
# get lastTransaction in flight.
lastTransaction = self.promise('lastTransaction')
@register
def registered(_):
self.client.registered(self, lastTransaction)
@register.catch
def register_failed(exc):
self.client.register_failed(self, exc)
else: else:
raise
else:
if self.read_only is Fallback:
self.read_only = False
except Exception as exc:
self.client.register_failed(self, exc) self.client.register_failed(self, exc)
else:
self.client.registered(self, server_tid)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
...@@ -237,8 +236,19 @@ class Protocol(base.Protocol): ...@@ -237,8 +236,19 @@ class Protocol(base.Protocol):
self._write(self.encode(self.message_id, False, method, args)) self._write(self.encode(self.message_id, False, method, args))
return future return future
def promise(self, method, *args): def fut(self, method, *args):
return self.call(Promise(), method, args) return self.call(Fut(), method, args)
def load_before(self, oid, tid):
# Special-case loadBefore, so we collapse outstanding requests
message_id = (oid, tid)
future = self.futures.get(message_id)
if future is None:
future = asyncio.Future(loop=self.loop)
self.futures[message_id] = future
self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid)))
return future
# Methods called by the server. # Methods called by the server.
# WARNING WARNING we can't call methods that call back to us # WARNING WARNING we can't call methods that call back to us
...@@ -362,18 +372,18 @@ class Client(object): ...@@ -362,18 +372,18 @@ class Client(object):
for addr in self.addrs for addr in self.addrs
] ]
def registered(self, protocol, last_transaction_promise): def registered(self, protocol, server_tid):
if self.protocol is None: if self.protocol is None:
self.protocol = protocol self.protocol = protocol
if not (self.read_only is Fallback and protocol.read_only): if not (self.read_only is Fallback and protocol.read_only):
# We're happy with this protocol. Tell the others to # We're happy with this protocol. Tell the others to
# stop trying. # stop trying.
self._clear_protocols(protocol) self._clear_protocols(protocol)
self.verify(last_transaction_promise) self.verify(server_tid)
elif (self.read_only is Fallback and not protocol.read_only and elif (self.read_only is Fallback and not protocol.read_only and
self.protocol.read_only): self.protocol.read_only):
self.upgrade(protocol) self.upgrade(protocol)
self.verify(last_transaction_promise) self.verify(server_tid)
else: else:
protocol.close() # too late, we went home with another protocol.close() # too late, we went home with another
...@@ -391,11 +401,14 @@ class Client(object): ...@@ -391,11 +401,14 @@ class Client(object):
self.try_connecting) self.try_connecting)
verify_result = None # for tests verify_result = None # for tests
def verify(self, last_transaction_promise):
@future_generator
def verify(self, server_tid):
protocol = self.protocol protocol = self.protocol
if server_tid is None:
server_tid = yield protocol.fut('lastTransaction')
@last_transaction_promise try:
def finish_verify(server_tid):
cache = self.cache cache = self.cache
if cache: if cache:
cache_tid = cache.getLastTid() cache_tid = cache.getLastTid()
...@@ -404,7 +417,6 @@ class Client(object): ...@@ -404,7 +417,6 @@ class Client(object):
logger.error("Non-empty cache w/o tid -- clearing") logger.error("Non-empty cache w/o tid -- clearing")
cache.clear() cache.clear()
self.client.invalidateCache() self.client.invalidateCache()
self.finished_verify(server_tid)
elif cache_tid > server_tid: elif cache_tid > server_tid:
self.verify_result = "Cache newer than server" self.verify_result = "Cache newer than server"
logger.critical( logger.critical(
...@@ -413,17 +425,14 @@ class Client(object): ...@@ -413,17 +425,14 @@ class Client(object):
server_tid, cache_tid, protocol) server_tid, cache_tid, protocol)
elif cache_tid == server_tid: elif cache_tid == server_tid:
self.verify_result = "Cache up to date" self.verify_result = "Cache up to date"
self.finished_verify(server_tid)
else: else:
@protocol.promise('getInvalidations', cache_tid) vdata = yield protocol.fut('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata: if vdata:
self.verify_result = "quick verification" self.verify_result = "quick verification"
tid, oids = vdata server_tid, oids = vdata
for oid in oids: for oid in oids:
cache.invalidate(oid, None) cache.invalidate(oid, None)
self.client.invalidateTransaction(tid, oids) self.client.invalidateTransaction(server_tid, oids)
return tid
else: else:
# cache is too old # cache is too old
self.verify_result = "cache too old, clearing" self.verify_result = "cache too old, clearing"
...@@ -438,37 +447,33 @@ class Client(object): ...@@ -438,37 +447,33 @@ class Client(object):
) )
self.cache.clear() self.cache.clear()
self.client.invalidateCache() self.client.invalidateCache()
return server_tid
verify_invalidations(
self.finished_verify,
self.connected.set_exception,
)
else: else:
self.verify_result = "empty cache" self.verify_result = "empty cache"
self.finished_verify(server_tid)
@finish_verify.catch except Exception as exc:
def verify_failed(exc):
del self.protocol del self.protocol
self.register_failed(protocol, exc) self.register_failed(protocol, exc)
else:
def finished_verify(self, server_tid):
# The cache is validated and the last tid we got from the server. # The cache is validated and the last tid we got from the server.
# Set ready so we apply any invalidations that follow. # Set ready so we apply any invalidations that follow.
# We've been ignoring them up to this point. # We've been ignoring them up to this point.
self.cache.setLastTid(server_tid) self.cache.setLastTid(server_tid)
self.ready = True self.ready = True
@self.protocol.promise('get_info') try:
def got_info(info): info = yield protocol.fut('get_info')
self.client.notify_connected(self, info) except Exception as exc:
self.connected.set_result(None) # This is weird. We were connected and verified our cache, but
# Now we errored getting info.
@got_info.catch # XXX Need a test fpr this. The lone before is what we
def failed_info(exc): # had, but it's wrong.
self.register_failed(self, exc) self.register_failed(self, exc)
else:
self.client.notify_connected(self, info)
self.connected.set_result(None)
def get_peername(self): def get_peername(self):
return self.protocol.get_peername() return self.protocol.get_peername()
...@@ -514,27 +519,49 @@ class Client(object): ...@@ -514,27 +519,49 @@ class Client(object):
# Special methods because they update the cache. # Special methods because they update the cache.
@future_generator
def load_before_threadsafe(self, future, oid, tid): def load_before_threadsafe(self, future, oid, tid):
data = self.cache.loadBefore(oid, tid) data = self.cache.loadBefore(oid, tid)
if data is not None: if data is not None:
future.set_result(data) future.set_result(data)
elif self.ready: elif self.ready:
@self.protocol.promise('loadBefore', oid, tid) try:
def load_before(data): data = yield self.protocol.load_before(oid, tid)
except Exception as exc:
future.set_exception(exc)
else:
future.set_result(data) future.set_result(data)
if data: if data:
data, start, end = data data, start, end = data
self.cache.store(oid, start, end, data) self.cache.store(oid, start, end, data)
load_before.catch(future.set_exception)
else: else:
self._when_ready(self.load_before_threadsafe, future, oid, tid) self._when_ready(self.load_before_threadsafe, future, oid, tid)
@future_generator
def _prefetch(self, oid, tid):
try:
data = yield self.protocol.load_before(oid, tid)
if data:
data, start, end = data
self.cache.store(oid, start, end, data)
except Exception:
logger.exception("prefetch %r %r" % (oid, tid))
def prefetch(self, future, oids, tid):
if self.ready:
for oid in oids:
if self.cache.loadBefore(oid, tid) is None:
self._prefetch(oid, tid)
future.set_result(None)
else:
future.set_exception(ClientDisconnected())
@future_generator
def tpc_finish_threadsafe(self, future, tid, updates, f): def tpc_finish_threadsafe(self, future, tid, updates, f):
if self.ready: if self.ready:
@self.protocol.promise('tpc_finish', tid)
def committed(tid):
try: try:
tid = yield self.protocol.fut('tpc_finish', tid)
cache = self.cache cache = self.cache
for oid, data, resolved in updates: for oid, data, resolved in updates:
cache.invalidate(oid, tid) cache.invalidate(oid, tid)
...@@ -552,8 +579,6 @@ class Client(object): ...@@ -552,8 +579,6 @@ class Client(object):
else: else:
f(tid) f(tid)
future.set_result(tid) future.set_result(tid)
committed.catch(future.set_exception)
else: else:
future.set_exception(ClientDisconnected()) future.set_exception(ClientDisconnected())
...@@ -648,6 +673,9 @@ class ClientRunner(object): ...@@ -648,6 +673,9 @@ class ClientRunner(object):
def async_iter(self, it): def async_iter(self, it):
return self.__call(self.client.call_async_iter_threadsafe, it) return self.__call(self.client.call_async_iter_threadsafe, it)
def prefetch(self, oids, tid):
return self.__call(self.client.prefetch, oids, tid)
def load_before(self, oid, tid): def load_before(self, oid, tid):
return self.__call(self.client.load_before_threadsafe, oid, tid) return self.__call(self.client.load_before_threadsafe, oid, tid)
...@@ -754,95 +782,24 @@ class ClientThread(ClientRunner): ...@@ -754,95 +782,24 @@ class ClientThread(ClientRunner):
if self.exception: if self.exception:
raise self.exception raise self.exception
class Promise(object): class Fut(object):
"""Lightweight future with a partial promise API. """Lightweight future that calls it's callback immediately rather than soon
These are lighweight because they call callbacks synchronously
rather than through an event loop, and because they ony support
single callbacks.
""" """
# Note that we can know that they are completed after callbacks def add_done_callback(self, cb):
# are set up because they're used to make network requests. self.cb = cb
# Requests are made by writing to a transport. Because we're used
# in a single-threaded protocol, we can't get a response and be
# completed if the callbacks are set in the same code that
# created the promise, which they are.
next = success_callback = error_callback = cancelled = None
def __call__(self, success_callback = None, error_callback = None):
"""Set the promises success and error handlers and beget a new promise
The promise returned provides for promise chaining, providing
a sane imperative flow. Let's call this the "next" promise.
Any results or exceptions generated by the promise or it's
callbacks are passed on to the next promise.
When the promise completes successfully, if a success callback
isn't set, then the next promise is completed with the
successfull result. If a success callback is provided, it's
called. If the call succeeds, and the result is a promise,
them the result is called with the next promise's set_result
and set_exception methods, chaining the result and next
promise. If the result isn't a promise, then the next promise
is completed with it by calling set_result. If the success
callback fails, then it's exception is passed to
next.set_exception.
If the promise completes with an error and the error callback
isn't set, then the exception is passed to the next promises
set_exception. If an error handler is provided, it's called
and if it doesn't error, then the original exception is passed
to the next promise's set_exception. If there error handler
errors, then that exception is passed to the next promise's
set_exception.
"""
self.next = self.__class__()
self.success_callback = success_callback
self.error_callback = error_callback
return self.next
def cancel(self):
self.set_exception(concurrent.futures.CancelledError)
def catch(self, error_callback):
self.error_callback = error_callback
exc = None
def set_exception(self, exc): def set_exception(self, exc):
self._notify(None, exc) self.exc = exc
self.cb(self)
def set_result(self, result): def set_result(self, result):
self._notify(result, None) self._result = result
self.cb(self)
def _notify(self, result, exc): def result(self):
next = self.next if self.exc:
if exc is not None: raise self.exc
if self.error_callback is not None:
try:
result = self.error_callback(exc)
except Exception:
logger.exception("Exception handling error %s", exc)
if next is not None:
next.set_exception(exc)
else:
if next is not None:
next.set_result(result)
elif next is not None:
next.set_exception(exc)
else:
if self.success_callback is not None:
try:
result = self.success_callback(result)
except Exception as exc:
logger.exception("Exception in success callback")
if next is not None:
next.set_exception(exc)
else:
if next is not None:
if isinstance(result, Promise):
result(next.set_result, next.set_exception)
else: else:
next.set_result(result) return self._result
elif next is not None:
next.set_result(result)
...@@ -12,7 +12,6 @@ from ZODB.POSException import ReadOnlyError ...@@ -12,7 +12,6 @@ from ZODB.POSException import ReadOnlyError
import collections import collections
import logging import logging
import pdb
import struct import struct
import unittest import unittest
...@@ -71,6 +70,12 @@ class Base(object): ...@@ -71,6 +70,12 @@ class Base(object):
class ClientTests(Base, setupstack.TestCase, ClientRunner): class ClientTests(Base, setupstack.TestCase, ClientRunner):
maxDiff = None
def tearDown(self):
self.client.close()
super(ClientTests, self)
def start(self, def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None, addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
read_only=False, read_only=False,
...@@ -95,12 +100,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -95,12 +100,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
if finish_start: if finish_start:
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.pop(2, False), b'Z3101') self.assertEqual(self.pop(2, False), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop(4)
self.assertEqual(self.pop(), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
self.respond(3, dict(length=42)) self.respond(3, dict(length=42))
...@@ -134,12 +136,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -134,12 +136,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# The client sends back a handshake, and registers the # The client sends back a handshake, and registers the
# storage, and requests the last transaction. # storage, and requests the last transaction.
self.assertEqual(self.pop(2, False), b'Z5') self.assertEqual(self.pop(2, False), b'Z5')
self.assertEqual(self.pop(), self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# Actually, the client isn't connected until it initializes it's cache: # The client isn't connected until it initializes it's cache:
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
# If we try to make calls while the client is *initially* # If we try to make calls while the client is *initially*
...@@ -162,9 +161,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -162,9 +161,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# The wrapper object (ClientStorage) hasn't been notified: # The wrapper object (ClientStorage) hasn't been notified:
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
# Let's respond to those first 2 calls: # Let's respond to the register call:
self.respond(1, None) self.respond(1, None)
# The client requests the last transaction:
self.assertEqual(self.pop(), (2, False, 'lastTransaction', ()))
# We respond
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
# After verification, the client requests info: # After verification, the client requests info:
...@@ -191,9 +194,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -191,9 +194,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# Loading objects gets special handling to leverage the cache. # Loading objects gets special handling to leverage the cache.
loaded = self.load_before(b'1'*8, m64) loaded = self.load_before(b'1'*8, m64)
# The data wasn't in the cache, so we make a server call: # The data wasn't in the cache, so we made a server call:
self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, m64))) self.assertEqual(self.pop(),
self.respond(5, (b'data', b'a'*8, None)) ((b'1'*8, m64), False, 'loadBefore', (b'1'*8, m64)))
# Note load_before uses the oid as the message id.
self.respond((b'1'*8, m64), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None)) self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
# If we make another request, it will be satisfied from the cache: # If we make another request, it will be satisfied from the cache:
...@@ -206,9 +211,16 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -206,9 +211,16 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# Now, if we try to load current again, we'll make a server request. # Now, if we try to load current again, we'll make a server request.
loaded = self.load_before(b'1'*8, m64) loaded = self.load_before(b'1'*8, m64)
self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, m64)))
self.respond(6, (b'data2', b'b'*8, None)) # Note that if we make another request for the same object,
# the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, m64)
self.assertEqual(self.pop(),
((b'1'*8, m64), False, 'loadBefore', (b'1'*8, m64)))
self.respond((b'1'*8, m64), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None)) self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
# Loading non-current data may also be satisfied from cache # Loading non-current data may also be satisfied from cache
loaded = self.load_before(b'1'*8, b'b'*8) loaded = self.load_before(b'1'*8, b'b'*8)
...@@ -219,9 +231,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -219,9 +231,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(self.pop(), self.assertEqual(
(7, False, 'loadBefore', (b'1'*8, b'_'*8))) self.pop(),
self.respond(7, (b'data0', b'^'*8, b'_'*8)) ((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8)))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache # When committing transactions, we need to update the cache
...@@ -244,8 +257,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -244,8 +257,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.load(b'4'*8)) cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8)) self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(self.pop(), self.assertEqual(self.pop(),
(8, False, 'tpc_finish', (b'd'*8,))) (5, False, 'tpc_finish', (b'd'*8,)))
self.respond(8, b'e'*8) self.respond(5, b'e'*8)
self.assertEqual(committed.result(), b'e'*8) self.assertEqual(committed.result(), b'e'*8)
self.assertEqual(cache.load(b'1'*8), None) self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8)) self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
...@@ -257,8 +270,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -257,8 +270,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, m64) loaded = self.load_before(b'1'*8, m64)
f1 = self.call('foo', 1, 2) f1 = self.call('foo', 1, 2)
self.assertFalse(loaded.done() or f1.done()) self.assertFalse(loaded.done() or f1.done())
self.assertEqual(self.pop(), [(9, False, 'loadBefore', (b'1'*8, m64)), self.assertEqual(self.pop(),
(10, False, 'foo', (1, 2))], [((b'1'*8, m64), False, 'loadBefore', (b'1'*8, m64)),
(6, False, 'foo', (1, 2))],
) )
exc = TypeError(43) exc = TypeError(43)
...@@ -286,15 +300,14 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -286,15 +300,14 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# protocol: # protocol:
protocol.data_received(sized(b'Z310')) protocol.data_received(sized(b'Z310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310') self.assertEqual(self.unsized(transport.pop(2)), b'Z310')
self.assertEqual(self.pop(), self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
self.respond(1, None)
self.respond(2, b'e'*8) # If the register response is a tid, then the client won't
self.assertEqual(self.pop(), (3, False, 'get_info', ())) # request lastTransaction
self.respond(3, dict(length=42)) self.respond(1, b'e'*8)
self.assertEqual(self.pop(), (2, False, 'get_info', ()))
self.respond(2, dict(length=42))
# Because the server tid matches the cache tid, we're done connecting # Because the server tid matches the cache tid, we're done connecting
wrapper.notify_connected.assert_called_with(client, {'length': 42}) wrapper.notify_connected.assert_called_with(client, {'length': 42})
...@@ -323,12 +336,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -323,12 +336,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, b'e'*8) self.respond(2, b'e'*8)
self.pop(4)
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
...@@ -361,12 +371,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -361,12 +371,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, b'e'*8) self.respond(2, b'e'*8)
self.pop(4)
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
...@@ -433,12 +440,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -433,12 +440,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.setLastTid('b'*8) cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop()
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
self.assert_(8 < delay < 10) self.assert_(8 < delay < 10)
...@@ -450,12 +454,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -450,12 +454,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, 'b'*8) self.respond(2, 'b'*8)
self.pop(4)
self.assertEqual(self.pop(), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
self.respond(3, dict(length=42)) self.respond(3, dict(length=42))
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
...@@ -481,12 +482,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -481,12 +482,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
# The client tries for a read-only connection: # The client tries for a read-only connection:
self.assertEqual(self.pop(), self.assertEqual(self.pop(), (2, False, 'register', ('TEST', True)))
[(2, False, 'register', ('TEST', True)),
(3, False, 'lastTransaction', ()),
])
# We respond with successfully: # We respond with successfully:
self.respond(2, None) self.respond(2, None)
self.pop(2)
self.respond(3, 'b'*8) self.respond(3, 'b'*8)
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
...@@ -513,12 +512,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -513,12 +512,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We respond and the writable connection succeeds: # We respond and the writable connection succeeds:
self.respond(1, None) self.respond(1, None)
self.assertFalse(self.is_read_only())
# at this point, a lastTransaction request is emitted: # at this point, a lastTransaction request is emitted:
self.assertEqual(self.parse(loop.transport.pop()), self.assertEqual(self.parse(loop.transport.pop()),
(2, False, 'lastTransaction', ())) (2, False, 'lastTransaction', ()))
self.assertFalse(self.is_read_only())
# Now, the original protocol is closed, and the client is # Now, the original protocol is closed, and the client is
# no-longer ready: # no-longer ready:
...@@ -542,11 +541,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -542,11 +541,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport = self.start() wrapper, cache, loop, client, protocol, transport = self.start()
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.pop(4)
self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
self.respond(2, b'a'*8) self.respond(2, b'a'*8)
self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False) self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False)
...@@ -563,11 +559,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -563,11 +559,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.pop(4)
self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
self.respond(2, b'c'*8) self.respond(2, b'c'*8)
self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False) self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False)
...@@ -720,7 +713,9 @@ class MemoryCache(object): ...@@ -720,7 +713,9 @@ class MemoryCache(object):
def store(self, oid, start_tid, end_tid, data): def store(self, oid, start_tid, end_tid, data):
assert start_tid is not None assert start_tid is not None
revisions = self.data[oid] revisions = self.data[oid]
revisions.append((start_tid, end_tid, data)) data = (start_tid, end_tid, data)
if not revisions or data != revisions[-1]:
revisions.append(data)
revisions.sort() revisions.sort()
def loadBefore(self, oid, tid): def loadBefore(self, oid, tid):
......
...@@ -90,14 +90,6 @@ class IServeable(zope.interface.Interface): ...@@ -90,14 +90,6 @@ class IServeable(zope.interface.Interface):
"""Interface provided by storages that can be served by ZEO """Interface provided by storages that can be served by ZEO
""" """
def getTid(oid):
"""The last transaction to change an object
Return the transaction id of the last transaction that committed a
change to an object with the given object id.
"""
def tpc_transaction(): def tpc_transaction():
"""The current transaction being committed. """The current transaction being committed.
......
...@@ -110,6 +110,7 @@ def runner(config, qin, qout, timeout=None, ...@@ -110,6 +110,7 @@ def runner(config, qin, qout, timeout=None,
options = ZEO.runzeo.ZEOOptions() options = ZEO.runzeo.ZEOOptions()
options.realize(['-C', config]) options.realize(['-C', config])
server = ZEO.runzeo.ZEOServer(options) server = ZEO.runzeo.ZEOServer(options)
globals()[(name if name else 'last') + '_server'] = server
server.open_storages() server.open_storages()
server.clear_socket() server.clear_socket()
server.create_server() server.create_server()
......
...@@ -492,6 +492,50 @@ ZEOStorage as closed and see if trying to get a lock cleans it up: ...@@ -492,6 +492,50 @@ ZEOStorage as closed and see if trying to get a lock cleans it up:
>>> logging.getLogger('ZEO').removeHandler(handler) >>> logging.getLogger('ZEO').removeHandler(handler)
""" """
def test_prefetch(self):
"""The client storage prefetch method pre-fetches from the server
>>> count = 999
>>> import ZEO
>>> addr, stop = ZEO.server()
>>> conn = ZEO.connection(addr)
>>> root = conn.root()
>>> cls = root.__class__
>>> for i in range(count):
... root[i] = cls()
>>> conn.transaction_manager.commit()
>>> oids = [root[i]._p_oid for i in range(count)]
>>> conn.close()
>>> conn = ZEO.connection(addr)
>>> storage = conn.db().storage
>>> len(storage._cache)
1
>>> storage.prefetch(oids, conn._storage._start)
The prefetch returns before the cache is filled:
>>> len(storage._cache) < count
True
But it is filled eventually:
>>> from zope.testing.wait import wait
>>> wait(lambda : len(storage._cache) > count)
>>> loads = storage.server_status()['loads']
Now if we reload the data, it will be satisfied from the cache:
>>> for oid in oids:
... _ = conn._storage.load(oid)
>>> storage.server_status()['loads'] == loads
True
>>> conn.close()
>>> stop()
"""
def test_suite(): def test_suite():
return unittest.TestSuite(( return unittest.TestSuite((
......
...@@ -195,6 +195,8 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -195,6 +195,8 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
factory, context, (client_cert, client_key, None), factory, context, (client_cert, client_key, None),
check_hostname=True) check_hostname=True)
context.load_default_certs.assert_called_with()
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_dir( def test_ssl_mockiavellian_client_auth_dir(
...@@ -210,6 +212,7 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -210,6 +212,7 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
capath=here, capath=here,
check_hostname=True, check_hostname=True,
) )
context.load_default_certs.assert_not_called()
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
...@@ -226,6 +229,7 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -226,6 +229,7 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
cafile=server_cert, cafile=server_cert,
check_hostname=True, check_hostname=True,
) )
context.load_default_certs.assert_not_called()
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
...@@ -345,7 +349,10 @@ server_config = """ ...@@ -345,7 +349,10 @@ server_config = """
</zeo> </zeo>
""".format(server_cert, server_key, client_cert) """.format(server_cert, server_key, client_cert)
def client_ssl(): def client_ssl(cafile=server_key,
client_cert=client_cert,
client_key=client_key,
):
context = ssl.create_default_context( context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH, cafile=server_cert) ssl.Purpose.CLIENT_AUTH, cafile=server_cert)
...@@ -353,3 +360,7 @@ def client_ssl(): ...@@ -353,3 +360,7 @@ def client_ssl():
context.verify_mode = ssl.CERT_REQUIRED context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = False context.check_hostname = False
return context return context
# Here's a command to create a cert/key pair:
# openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem \
# -days 999999 -nodes -batch
...@@ -11,12 +11,16 @@ def ssl_config(section, server): ...@@ -11,12 +11,16 @@ def ssl_config(section, server):
if auth: if auth:
if os.path.isdir(auth): if os.path.isdir(auth):
capath=auth capath=auth
else: elif auth != 'DYNAMIC':
cafile=auth cafile=auth
context = ssl.create_default_context( context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH, cafile=cafile, capath=capath) ssl.Purpose.CLIENT_AUTH, cafile=cafile, capath=capath)
if not auth:
assert not server
context.load_default_certs()
if section.certificate: if section.certificate:
password = section.password_function password = section.password_function
if password: if password:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment