Commit 7e5b78fc authored by Jim Fulton's avatar Jim Fulton

Async changes driven by ClientStorage integration

- Fixed tpc_finish:

  - Use tid from server to update cache.

  - Accept and call callback function.

- Implemented flow control

- Added connection/disconnection notification (to client storage).

- implemented get_peername.

- implemented is_read_only

- renamed callAsync to async (death to Camels!)
parent 4cfe3662
...@@ -51,7 +51,9 @@ class Protocol(asyncio.Protocol): ...@@ -51,7 +51,9 @@ class Protocol(asyncio.Protocol):
self.client = client self.client = client
self.connect_poll = connect_poll self.connect_poll = connect_poll
self.futures = {} # { message_id -> future } self.futures = {} # { message_id -> future }
self.input = [] self.input = [] # Buffer when assembling messages
self.output = [] # Buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently # Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received self.message_received = self.first_message_received
...@@ -98,15 +100,57 @@ class Protocol(asyncio.Protocol): ...@@ -98,15 +100,57 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) logger.info("Connected %s", self)
self.transport = transport self.transport = transport
paused = self.paused
output = self.output
append = output.append
writelines = transport.writelines writelines = transport.writelines
from struct import pack from struct import pack
def write(message): def write(message):
writelines((pack(">I", len(message)), message)) if paused:
append(message)
else:
writelines((pack(">I", len(message)), message))
self._write = write self._write = write
def writeit(data):
# Note, don't worry about combining messages. Iters
# will be used with blobs, in which case, the individual
# messages will be big to begin with.
data = iter(data)
for message in data:
writelines((pack(">I", len(message)), message))
if paused:
append(data)
break
self._writeit = writeit
def pause_writing(self):
self.paused.append(1)
def resume_writing(self):
paused = self.paused
del paused[:]
output = self.output
writelines = self.transport.writelines
from struct import pack
while output and not paused:
message = output.pop(0)
if isinstance(message, bytes):
writelines((pack(">I", len(message)), message))
else:
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
output.insert(0, data)
break
def get_peername(self):
return self.transport.get_extra_info('peername')
def connection_lost(self, exc): def connection_lost(self, exc):
if exc is None: if exc is None:
# we were closed # we were closed
...@@ -232,6 +276,9 @@ class Protocol(asyncio.Protocol): ...@@ -232,6 +276,9 @@ class Protocol(asyncio.Protocol):
def call_async(self, method, args): def call_async(self, method, args):
self._write(dumps((0, True, method, args), 3)) self._write(dumps((0, True, method, args), 3))
def call_async_iter(self, it):
self._writeit(dumps((0, True, method, args), 3) for method, args in it)
message_id = 0 message_id = 0
def call(self, future, method, args): def call(self, future, method, args):
self.message_id += 1 self.message_id += 1
...@@ -304,6 +351,8 @@ class Client: ...@@ -304,6 +351,8 @@ class Client:
def disconnected(self, protocol=None): def disconnected(self, protocol=None):
if protocol is None or protocol is self.protocol: if protocol is None or protocol is self.protocol:
if protocol is self.protocol:
self.client.notify_disconnected()
self.ready = False self.ready = False
self.connected = concurrent.futures.Future() self.connected = concurrent.futures.Future()
self.protocol = None self.protocol = None
...@@ -403,6 +452,10 @@ class Client: ...@@ -403,6 +452,10 @@ class Client:
self.cache.setLastTid(server_tid) self.cache.setLastTid(server_tid)
self.ready = True self.ready = True
self.connected.set_result(None) self.connected.set_result(None)
self.client.notify_connected(self)
def get_peername(self):
return self.protocol.get_peername()
def call_async_threadsafe(self, future, method, args): def call_async_threadsafe(self, future, method, args):
if self.ready: if self.ready:
...@@ -411,6 +464,13 @@ class Client: ...@@ -411,6 +464,13 @@ class Client:
else: else:
future.set_exception(ZEO.Exceptions.ClientDisconnected()) future.set_exception(ZEO.Exceptions.ClientDisconnected())
def call_async_iter_threadsafe(self, future, it):
if self.ready:
self.protocol.call_async_iter(it)
future.set_result(None)
else:
future.set_exception(ZEO.Exceptions.ClientDisconnected())
def _when_ready(self, func, result_future, *args): def _when_ready(self, func, result_future, *args):
@self.connected.add_done_callback @self.connected.add_done_callback
...@@ -463,16 +523,17 @@ class Client: ...@@ -463,16 +523,17 @@ class Client:
else: else:
self._when_ready(self.load_before_threadsafe, future, oid, tid) self._when_ready(self.load_before_threadsafe, future, oid, tid)
def tpc_finish_threadsafe(self, future, tid, updates): def tpc_finish_threadsafe(self, future, tid, updates, f):
if self.ready: if self.ready:
@self.protocol.promise('tpc_finish', tid) @self.protocol.promise('tpc_finish', tid)
def committed(_): def committed(tid):
cache = self.cache cache = self.cache
for oid, s, data in updates: for oid, data, resolved in updates:
cache.invalidate(oid, tid) cache.invalidate(oid, tid)
if data and s != ResolvedSerial: if data and not resolved:
cache.store(oid, tid, None, data) cache.store(oid, tid, None, data)
cache.setLastTid(tid) cache.setLastTid(tid)
f(tid)
future.set_result(None) future.set_result(None)
committed.catch(future.set_exception) committed.catch(future.set_exception)
...@@ -536,21 +597,35 @@ class ClientRunner: ...@@ -536,21 +597,35 @@ class ClientRunner:
def call(self, method, *args, timeout=None): def call(self, method, *args, timeout=None):
return self.__call(self.client.call_threadsafe, method, args) return self.__call(self.client.call_threadsafe, method, args)
def callAsync(self, method, *args): def async(self, method, *args):
return self.__call(self.client.call_async_threadsafe, method, args) return self.__call(self.client.call_async_threadsafe, method, args)
def async_iter(self, it):
return self.__call(self.client.call_async_iter_threadsafe, it)
def load(self, oid): def load(self, oid):
return self.__call(self.client.load_threadsafe, oid) return self.__call(self.client.load_threadsafe, oid)
def load_before(self, oid, tid): def load_before(self, oid, tid):
return self.__call(self.client.load_before_threadsafe, oid, tid) return self.__call(self.client.load_before_threadsafe, oid, tid)
def tpc_finish(self, tid, updates): def tpc_finish(self, tid, updates, f):
return self.__call(self.client.tpc_finish_threadsafe, tid, updates) return self.__call(self.client.tpc_finish_threadsafe, tid, updates, f)
def is_connected(self): def is_connected(self):
return self.client.ready return self.client.ready
def is_read_only(self):
try:
protocol = self.client.protocol
except AttributeError:
return True
else:
if protocol is None:
return True
else:
return protocol.read_only
def close(self): def close(self):
self.__call(self.client.close_threadsafe) self.__call(self.client.close_threadsafe)
......
...@@ -17,7 +17,7 @@ class Loop: ...@@ -17,7 +17,7 @@ class Loop:
def _connect(self, future, protocol_factory): def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory() self.protocol = protocol = protocol_factory()
self.transport = transport = Transport() self.transport = transport = Transport(protocol)
protocol.connection_made(transport) protocol.connection_made(transport)
future.set_result((transport, protocol)) future.set_result((transport, protocol))
...@@ -60,14 +60,26 @@ class Loop: ...@@ -60,14 +60,26 @@ class Loop:
class Transport: class Transport:
def __init__(self): capacity = 1 << 64
paused = False
extra = dict(peername='1.2.3.4')
def __init__(self, protocol):
self.data = [] self.data = []
self.protocol = protocol
def write(self, data): def write(self, data):
self.data.append(data) self.data.append(data)
self.check_pause()
def writelines(self, lines): def writelines(self, lines):
self.data.extend(lines) self.data.extend(lines)
self.check_pause()
def check_pause(self):
if len(self.data) > self.capacity and not self.paused:
self.paused = True
self.protocol.pause_writing()
def pop(self, count=None): def pop(self, count=None):
if count: if count:
...@@ -76,8 +88,17 @@ class Transport: ...@@ -76,8 +88,17 @@ class Transport:
else: else:
r = self.data[:] r = self.data[:]
del self.data[:] del self.data[:]
self.check_resume()
return r return r
def check_resume(self):
if len(self.data) < self.capacity and self.paused:
self.paused = False
self.protocol.resume_writing()
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
def get_extra_info(self, name):
return self.extra[name]
...@@ -20,6 +20,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -20,6 +20,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
def start(self, def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None, addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
read_only=False, read_only=False,
finish_start=False,
): ):
# To create a client, we need to specify an address, a client # To create a client, we need to specify an address, a client
# object and a cache. # object and a cache.
...@@ -43,6 +44,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -43,6 +44,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
loop.protocol.data_received( loop.protocol.data_received(
sized(pickle.dumps((message_id, False, '.reply', result), 3))) sized(pickle.dumps((message_id, False, '.reply', result), 3)))
if finish_start:
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, 'a'*8)
return (wrapper, cache, self.loop, self.client, protocol, transport, return (wrapper, cache, self.loop, self.client, protocol, transport,
send, respond) send, respond)
...@@ -82,9 +94,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -82,9 +94,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(f1.done()) self.assertFalse(f1.done())
# If we try to make an async call, we get an immediate error: # If we try to make an async call, we get an immediate error:
f2 = self.callAsync('bar', 3, 4) f2 = self.async('bar', 3, 4)
self.assert_(isinstance(f2.exception(), ClientDisconnected)) self.assert_(isinstance(f2.exception(), ClientDisconnected))
# The wrapper object (ClientStorage) hasn't been notified:
wrapper.notify_connected.assert_not_called()
# Let's respond to those first 2 calls: # Let's respond to those first 2 calls:
respond(1, None) respond(1, None)
...@@ -96,11 +111,14 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -96,11 +111,14 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(cache.getLastTid(), 'a'*8) self.assertEqual(cache.getLastTid(), 'a'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'foo', (1, 2))) self.assertEqual(parse(transport.pop()), (3, False, 'foo', (1, 2)))
# The wrapper object (ClientStorage) has been notified:
wrapper.notify_connected.assert_called_with(client)
respond(3, 42) respond(3, 42)
self.assertEqual(f1.result(), 42) self.assertEqual(f1.result(), 42)
# Now we can make async calls: # Now we can make async calls:
f2 = self.callAsync('bar', 3, 4) f2 = self.async('bar', 3, 4)
self.assert_(f2.done() and f2.exception() is None) self.assert_(f2.done() and f2.exception() is None)
self.assertEqual(parse(transport.pop()), (0, True, 'bar', (3, 4))) self.assertEqual(parse(transport.pop()), (0, True, 'bar', (3, 4)))
...@@ -144,26 +162,32 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -144,26 +162,32 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache # When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, tid, data) # with committed data. To do this, we pass a (oid, data, resolved)
# iteratable to tpc_finish_threadsafe. # iteratable to tpc_finish_threadsafe.
from ZODB.ConflictResolution import ResolvedSerial
tids = []
def finished_cb(tid):
tids.append(tid)
committed = self.tpc_finish( committed = self.tpc_finish(
b'd'*8, b'd'*8,
[(b'2'*8, b'd'*8, 'committed 2'), [(b'2'*8, 'committed 2', False),
(b'1'*8, ResolvedSerial, 'committed 3'), (b'1'*8, 'committed 3', True),
(b'4'*8, b'd'*8, 'committed 4'), (b'4'*8, 'committed 4', False),
]) ],
finished_cb)
self.assertFalse(committed.done() or self.assertFalse(committed.done() or
cache.load(b'2'*8) or cache.load(b'2'*8) or
cache.load(b'4'*8)) cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8)) self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
(7, False, 'tpc_finish', (b'd'*8,))) (7, False, 'tpc_finish', (b'd'*8,)))
respond(7, None) respond(7, b'e'*8)
self.assertEqual(committed.result(), None) self.assertEqual(committed.result(), None)
self.assertEqual(cache.load(b'1'*8), None) self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'd'*8)) self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
self.assertEqual(cache.load(b'4'*8), ('committed 4', b'd'*8)) self.assertEqual(cache.load(b'4'*8), ('committed 4', b'e'*8))
self.assertEqual(tids.pop(), b'e'*8)
# If the protocol is disconnected, it will reconnect and will # If the protocol is disconnected, it will reconnect and will
# resolve outstanding requests with exceptions: # resolve outstanding requests with exceptions:
...@@ -176,7 +200,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -176,7 +200,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
) )
exc = TypeError(43) exc = TypeError(43)
wrapper.notify_disconnected.assert_not_called()
wrapper.notify_connected.reset_mock()
protocol.connection_lost(exc) protocol.connection_lost(exc)
wrapper.notify_disconnected.assert_called_with()
self.assertEqual(loaded.exception(), exc) self.assertEqual(loaded.exception(), exc)
self.assertEqual(f1.exception(), exc) self.assertEqual(f1.exception(), exc)
...@@ -195,12 +222,14 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -195,12 +222,14 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
wrapper.notify_connected.assert_not_called()
respond(1, None) respond(1, None)
respond(2, b'd'*8) respond(2, b'e'*8)
wrapper.notify_connected.assert_called_with(client)
# Because the server tid matches the cache tid, we're done connecting # Because the server tid matches the cache tid, we're done connecting
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'd'*8) self.assertEqual(cache.getLastTid(), b'e'*8)
# Because we were able to update the cache, we didn't have to # Because we were able to update the cache, we didn't have to
# invalidate the database cache: # invalidate the database cache:
...@@ -364,6 +393,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -364,6 +393,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(addrs, (), read_only=Fallback)) self.start(addrs, (), read_only=Fallback))
self.assertTrue(self.is_read_only())
# We'll treat the first address as read-only and we'll let it connect: # We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0]) loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport protocol, transport = loop.protocol, loop.transport
...@@ -376,6 +407,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -376,6 +407,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
]) ])
# We respond with a read-only exception: # We respond with a read-only exception:
respond(1, (ReadOnlyError, ReadOnlyError())) respond(1, (ReadOnlyError, ReadOnlyError()))
self.assertTrue(self.is_read_only())
# The client tries for a read-only connection: # The client tries for a read-only connection:
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
...@@ -385,6 +417,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -385,6 +417,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We respond with successfully: # We respond with successfully:
respond(3, None) respond(3, None)
respond(4, 'b'*8) respond(4, 'b'*8)
self.assertTrue(self.is_read_only())
# At this point, the client is ready and using the protocol, # At this point, the client is ready and using the protocol,
# and the protocol is read-only: # and the protocol is read-only:
...@@ -402,9 +435,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -402,9 +435,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
self.assertTrue(self.is_read_only())
# We respond and the writable connection succeeds: # We respond and the writable connection succeeds:
respond(1, None) respond(1, None)
self.assertFalse(self.is_read_only())
# Now, the original protocol is closed, and the client is # Now, the original protocol is closed, and the client is
# no-longer ready: # no-longer ready:
...@@ -463,6 +498,46 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -463,6 +498,46 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
wrapper.invalidateTransaction.assert_called_with(b'e'*8, [b'1'*8]) wrapper.invalidateTransaction.assert_called_with(b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock() wrapper.invalidateTransaction.reset_mock()
def test_flow_control(self):
# When sending a lot of data (blobs), we don't want to fill up
# memory behind a slow socket. Asycio's flow control helper
# seems a bit complicated. We'd rather pass an iterator that's
# consumed as we can.
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True))
# Give the transport a small capacity:
transport.capacity = 2
self.async('foo')
self.async('bar')
self.async('baz')
self.async('splat')
# The first 2 were sent, but the remaining were queued.
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'foo', ()), (0, True, 'bar', ())])
# But popping them allowed sending to resume:
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'baz', ()), (0, True, 'splat', ())])
# This is especially handy with iterators:
self.async_iter((name, ()) for name in 'abcde')
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'a', ()), (0, True, 'b', ())])
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'c', ()), (0, True, 'd', ())])
self.assertEqual(self.parse(transport.pop()),
(0, True, 'e', ()))
self.assertEqual(self.parse(transport.pop()),
[])
def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True))
self.assertEqual(client.get_peername(), '1.2.3.4')
def unsized(self, data, unpickle=False): def unsized(self, data, unpickle=False):
result = [] result = []
while data: while data:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment