Commit 7cdcc013 authored by Jim Fulton's avatar Jim Fulton

Updated connect/verify to use generators

parent f96776b3
...@@ -159,6 +159,7 @@ class Protocol(base.Protocol): ...@@ -159,6 +159,7 @@ class Protocol(base.Protocol):
self.closed = True self.closed = True
self.client.disconnected(self) self.client.disconnected(self)
@future_generator
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
# We use a promise model rather than coroutines here because # We use a promise model rather than coroutines here because
...@@ -182,56 +183,29 @@ class Protocol(base.Protocol): ...@@ -182,56 +183,29 @@ class Protocol(base.Protocol):
self.client.register_failed( self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version)) self, ZEO.Exceptions.ProtocolError(protocol_version))
return return
self._write(self.protocol_version) self._write(self.protocol_version)
register = self.promise( try:
try:
server_tid = yield self.fut(
'register', self.storage_key, 'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False, self.read_only if self.read_only is not Fallback else False,
) )
if self.read_only is not Fallback: except ZODB.POSException.ReadOnlyError:
# Get lastTransaction in flight right away to make
# successful connection quicker, but only if we're not
# doing read-only fallback. If we might need to retry, we
# can't send lastTransaction because if the registration
# fails, it will be seen as an invalid message and the
# connection will close. :( It would be a lot better of
# registere returned the last transaction (and info while
# it's at it).
lastTransaction = self.promise('lastTransaction')
else:
lastTransaction = None # to make python happy
@register
def registered(_):
if self.read_only is Fallback: if self.read_only is Fallback:
self.read_only = False
r_lastTransaction = self.promise('lastTransaction')
else:
r_lastTransaction = lastTransaction
self.client.registered(self, r_lastTransaction)
@register.catch
def register_failed(exc):
if (isinstance(exc, ZODB.POSException.ReadOnlyError) and
self.read_only is Fallback):
# We tried a write connection, degrade to a read-only one
self.read_only = True self.read_only = True
logger.info("%s write connection failed. Trying read-only", server_tid = yield self.fut(
self) 'register', self.storage_key, True)
register = self.promise('register', self.storage_key, True) else:
# get lastTransaction in flight. raise
lastTransaction = self.promise('lastTransaction')
@register
def registered(_):
self.client.registered(self, lastTransaction)
@register.catch
def register_failed(exc):
self.client.register_failed(self, exc)
else: else:
if self.read_only is Fallback:
self.read_only = False
except Exception as exc:
self.client.register_failed(self, exc) self.client.register_failed(self, exc)
else:
self.client.registered(self, server_tid)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
...@@ -272,6 +246,9 @@ class Protocol(base.Protocol): ...@@ -272,6 +246,9 @@ class Protocol(base.Protocol):
def promise(self, method, *args): def promise(self, method, *args):
return self.call(Promise(), method, args) return self.call(Promise(), method, args)
def fut(self, method, *args):
return self.call(Fut(), method, args)
def load_before(self, oid, tid): def load_before(self, oid, tid):
# Special-case loadBefore, so we collapse outstanding requests # Special-case loadBefore, so we collapse outstanding requests
message_id = (oid, tid) message_id = (oid, tid)
...@@ -405,18 +382,18 @@ class Client(object): ...@@ -405,18 +382,18 @@ class Client(object):
for addr in self.addrs for addr in self.addrs
] ]
def registered(self, protocol, last_transaction_promise): def registered(self, protocol, server_tid):
if self.protocol is None: if self.protocol is None:
self.protocol = protocol self.protocol = protocol
if not (self.read_only is Fallback and protocol.read_only): if not (self.read_only is Fallback and protocol.read_only):
# We're happy with this protocol. Tell the others to # We're happy with this protocol. Tell the others to
# stop trying. # stop trying.
self._clear_protocols(protocol) self._clear_protocols(protocol)
self.verify(last_transaction_promise) self.verify(server_tid)
elif (self.read_only is Fallback and not protocol.read_only and elif (self.read_only is Fallback and not protocol.read_only and
self.protocol.read_only): self.protocol.read_only):
self.upgrade(protocol) self.upgrade(protocol)
self.verify(last_transaction_promise) self.verify(server_tid)
else: else:
protocol.close() # too late, we went home with another protocol.close() # too late, we went home with another
...@@ -434,11 +411,14 @@ class Client(object): ...@@ -434,11 +411,14 @@ class Client(object):
self.try_connecting) self.try_connecting)
verify_result = None # for tests verify_result = None # for tests
def verify(self, last_transaction_promise):
@future_generator
def verify(self, server_tid):
protocol = self.protocol protocol = self.protocol
if server_tid is None:
server_tid = yield protocol.fut('lastTransaction')
@last_transaction_promise try:
def finish_verify(server_tid):
cache = self.cache cache = self.cache
if cache: if cache:
cache_tid = cache.getLastTid() cache_tid = cache.getLastTid()
...@@ -447,7 +427,6 @@ class Client(object): ...@@ -447,7 +427,6 @@ class Client(object):
logger.error("Non-empty cache w/o tid -- clearing") logger.error("Non-empty cache w/o tid -- clearing")
cache.clear() cache.clear()
self.client.invalidateCache() self.client.invalidateCache()
self.finished_verify(server_tid)
elif cache_tid > server_tid: elif cache_tid > server_tid:
self.verify_result = "Cache newer than server" self.verify_result = "Cache newer than server"
logger.critical( logger.critical(
...@@ -456,17 +435,14 @@ class Client(object): ...@@ -456,17 +435,14 @@ class Client(object):
server_tid, cache_tid, protocol) server_tid, cache_tid, protocol)
elif cache_tid == server_tid: elif cache_tid == server_tid:
self.verify_result = "Cache up to date" self.verify_result = "Cache up to date"
self.finished_verify(server_tid)
else: else:
@protocol.promise('getInvalidations', cache_tid) vdata = yield protocol.fut('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata: if vdata:
self.verify_result = "quick verification" self.verify_result = "quick verification"
tid, oids = vdata server_tid, oids = vdata
for oid in oids: for oid in oids:
cache.invalidate(oid, None) cache.invalidate(oid, None)
self.client.invalidateTransaction(tid, oids) self.client.invalidateTransaction(server_tid, oids)
return tid
else: else:
# cache is too old # cache is too old
self.verify_result = "cache too old, clearing" self.verify_result = "cache too old, clearing"
...@@ -481,37 +457,33 @@ class Client(object): ...@@ -481,37 +457,33 @@ class Client(object):
) )
self.cache.clear() self.cache.clear()
self.client.invalidateCache() self.client.invalidateCache()
return server_tid
verify_invalidations(
self.finished_verify,
self.connected.set_exception,
)
else: else:
self.verify_result = "empty cache" self.verify_result = "empty cache"
self.finished_verify(server_tid)
@finish_verify.catch except Exception as exc:
def verify_failed(exc):
del self.protocol del self.protocol
self.register_failed(protocol, exc) self.register_failed(protocol, exc)
else:
def finished_verify(self, server_tid):
# The cache is validated and the last tid we got from the server. # The cache is validated and the last tid we got from the server.
# Set ready so we apply any invalidations that follow. # Set ready so we apply any invalidations that follow.
# We've been ignoring them up to this point. # We've been ignoring them up to this point.
self.cache.setLastTid(server_tid) self.cache.setLastTid(server_tid)
self.ready = True self.ready = True
@self.protocol.promise('get_info') try:
def got_info(info): info = yield protocol.fut('get_info')
self.client.notify_connected(self, info) except Exception as exc:
self.connected.set_result(None) # This is weird. We were connected and verified our cache, but
# Now we errored getting info.
@got_info.catch # XXX Need a test fpr this. The lone before is what we
def failed_info(exc): # had, but it's wrong.
self.register_failed(self, exc) self.register_failed(self, exc)
else:
self.client.notify_connected(self, info)
self.connected.set_result(None)
def get_peername(self): def get_peername(self):
return self.protocol.get_peername() return self.protocol.get_peername()
...@@ -822,6 +794,28 @@ class ClientThread(ClientRunner): ...@@ -822,6 +794,28 @@ class ClientThread(ClientRunner):
if self.exception: if self.exception:
raise self.exception raise self.exception
class Fut(object):
"""Lightweight future that calls it's callback immediately rather than soon
"""
def add_done_callback(self, cb):
self.cb = cb
exc = None
def set_exception(self, exc):
self.exc = exc
self.cb(self)
def set_result(self, result):
self._result = result
self.cb(self)
def result(self):
if self.exc:
raise self.exc
else:
return self._result
class Promise(object): class Promise(object):
"""Lightweight future with a partial promise API. """Lightweight future with a partial promise API.
......
...@@ -72,6 +72,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -72,6 +72,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
maxDiff = None maxDiff = None
def tearDown(self):
self.client.close()
super(ClientTests, self)
def start(self, def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None, addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
read_only=False, read_only=False,
...@@ -96,12 +100,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -96,12 +100,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
if finish_start: if finish_start:
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.pop(2, False), b'Z3101') self.assertEqual(self.pop(2, False), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop(4)
self.assertEqual(self.pop(), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
self.respond(3, dict(length=42)) self.respond(3, dict(length=42))
...@@ -135,12 +136,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -135,12 +136,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# The client sends back a handshake, and registers the # The client sends back a handshake, and registers the
# storage, and requests the last transaction. # storage, and requests the last transaction.
self.assertEqual(self.pop(2, False), b'Z5') self.assertEqual(self.pop(2, False), b'Z5')
self.assertEqual(self.pop(), self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# Actually, the client isn't connected until it initializes it's cache: # The client isn't connected until it initializes it's cache:
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
# If we try to make calls while the client is *initially* # If we try to make calls while the client is *initially*
...@@ -163,9 +161,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -163,9 +161,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# The wrapper object (ClientStorage) hasn't been notified: # The wrapper object (ClientStorage) hasn't been notified:
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
# Let's respond to those first 2 calls: # Let's respond to the register call:
self.respond(1, None) self.respond(1, None)
# The client requests the last transaction:
self.assertEqual(self.pop(), (2, False, 'lastTransaction', ()))
# We respond
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
# After verification, the client requests info: # After verification, the client requests info:
...@@ -298,15 +300,14 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -298,15 +300,14 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# protocol: # protocol:
protocol.data_received(sized(b'Z310')) protocol.data_received(sized(b'Z310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310') self.assertEqual(self.unsized(transport.pop(2)), b'Z310')
self.assertEqual(self.pop(), self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
self.respond(1, None)
self.respond(2, b'e'*8) # If the register response is a tid, then the client won't
self.assertEqual(self.pop(), (3, False, 'get_info', ())) # request lastTransaction
self.respond(3, dict(length=42)) self.respond(1, b'e'*8)
self.assertEqual(self.pop(), (2, False, 'get_info', ()))
self.respond(2, dict(length=42))
# Because the server tid matches the cache tid, we're done connecting # Because the server tid matches the cache tid, we're done connecting
wrapper.notify_connected.assert_called_with(client, {'length': 42}) wrapper.notify_connected.assert_called_with(client, {'length': 42})
...@@ -335,12 +336,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -335,12 +336,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, b'e'*8) self.respond(2, b'e'*8)
self.pop(4)
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
...@@ -373,12 +371,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -373,12 +371,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, b'e'*8) self.respond(2, b'e'*8)
self.pop(4)
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
...@@ -445,12 +440,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -445,12 +440,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.setLastTid('b'*8) cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop()
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
self.assert_(8 < delay < 10) self.assert_(8 < delay < 10)
...@@ -462,12 +454,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -462,12 +454,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.respond(2, 'b'*8) self.respond(2, 'b'*8)
self.pop(4)
self.assertEqual(self.pop(), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
self.respond(3, dict(length=42)) self.respond(3, dict(length=42))
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
...@@ -493,12 +482,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -493,12 +482,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
# The client tries for a read-only connection: # The client tries for a read-only connection:
self.assertEqual(self.pop(), self.assertEqual(self.pop(), (2, False, 'register', ('TEST', True)))
[(2, False, 'register', ('TEST', True)),
(3, False, 'lastTransaction', ()),
])
# We respond with successfully: # We respond with successfully:
self.respond(2, None) self.respond(2, None)
self.pop(2)
self.respond(3, 'b'*8) self.respond(3, 'b'*8)
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
...@@ -525,12 +512,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -525,12 +512,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We respond and the writable connection succeeds: # We respond and the writable connection succeeds:
self.respond(1, None) self.respond(1, None)
self.assertFalse(self.is_read_only())
# at this point, a lastTransaction request is emitted: # at this point, a lastTransaction request is emitted:
self.assertEqual(self.parse(loop.transport.pop()), self.assertEqual(self.parse(loop.transport.pop()),
(2, False, 'lastTransaction', ())) (2, False, 'lastTransaction', ()))
self.assertFalse(self.is_read_only())
# Now, the original protocol is closed, and the client is # Now, the original protocol is closed, and the client is
# no-longer ready: # no-longer ready:
...@@ -554,11 +541,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -554,11 +541,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport = self.start() wrapper, cache, loop, client, protocol, transport = self.start()
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.pop(4)
self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
self.respond(2, b'a'*8) self.respond(2, b'a'*8)
self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False) self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False)
...@@ -575,11 +559,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -575,11 +559,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.respond(1, None) self.respond(1, None)
self.pop(4)
self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
self.respond(2, b'c'*8) self.respond(2, b'c'*8)
self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False) self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment