Commit 7e5b78fc authored by Jim Fulton's avatar Jim Fulton

Async changes driven by ClientStorage integration

- Fixed tpc_finish:

  - Use tid from server to update cache.

  - Accept and call callback function.

- Implemented flow control

- Added connection/disconnection notification (to client storage).

- implemented get_peername.

- implemented is_read_only

- renamed callAsync to async (death to Camels!)
parent 4cfe3662
......@@ -51,7 +51,9 @@ class Protocol(asyncio.Protocol):
self.client = client
self.connect_poll = connect_poll
self.futures = {} # { message_id -> future }
self.input = []
self.input = [] # Buffer when assembling messages
self.output = [] # Buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
......@@ -98,15 +100,57 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport):
logger.info("Connected %s", self)
self.transport = transport
paused = self.paused
output = self.output
append = output.append
writelines = transport.writelines
from struct import pack
def write(message):
if paused:
append(message)
else:
writelines((pack(">I", len(message)), message))
self._write = write
def writeit(data):
# Note, don't worry about combining messages. Iters
# will be used with blobs, in which case, the individual
# messages will be big to begin with.
data = iter(data)
for message in data:
writelines((pack(">I", len(message)), message))
if paused:
append(data)
break
self._writeit = writeit
def pause_writing(self):
self.paused.append(1)
def resume_writing(self):
paused = self.paused
del paused[:]
output = self.output
writelines = self.transport.writelines
from struct import pack
while output and not paused:
message = output.pop(0)
if isinstance(message, bytes):
writelines((pack(">I", len(message)), message))
else:
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
output.insert(0, data)
break
def get_peername(self):
return self.transport.get_extra_info('peername')
def connection_lost(self, exc):
if exc is None:
# we were closed
......@@ -232,6 +276,9 @@ class Protocol(asyncio.Protocol):
def call_async(self, method, args):
self._write(dumps((0, True, method, args), 3))
def call_async_iter(self, it):
self._writeit(dumps((0, True, method, args), 3) for method, args in it)
message_id = 0
def call(self, future, method, args):
self.message_id += 1
......@@ -304,6 +351,8 @@ class Client:
def disconnected(self, protocol=None):
if protocol is None or protocol is self.protocol:
if protocol is self.protocol:
self.client.notify_disconnected()
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol = None
......@@ -403,6 +452,10 @@ class Client:
self.cache.setLastTid(server_tid)
self.ready = True
self.connected.set_result(None)
self.client.notify_connected(self)
def get_peername(self):
return self.protocol.get_peername()
def call_async_threadsafe(self, future, method, args):
if self.ready:
......@@ -411,6 +464,13 @@ class Client:
else:
future.set_exception(ZEO.Exceptions.ClientDisconnected())
def call_async_iter_threadsafe(self, future, it):
if self.ready:
self.protocol.call_async_iter(it)
future.set_result(None)
else:
future.set_exception(ZEO.Exceptions.ClientDisconnected())
def _when_ready(self, func, result_future, *args):
@self.connected.add_done_callback
......@@ -463,16 +523,17 @@ class Client:
else:
self._when_ready(self.load_before_threadsafe, future, oid, tid)
def tpc_finish_threadsafe(self, future, tid, updates):
def tpc_finish_threadsafe(self, future, tid, updates, f):
if self.ready:
@self.protocol.promise('tpc_finish', tid)
def committed(_):
def committed(tid):
cache = self.cache
for oid, s, data in updates:
for oid, data, resolved in updates:
cache.invalidate(oid, tid)
if data and s != ResolvedSerial:
if data and not resolved:
cache.store(oid, tid, None, data)
cache.setLastTid(tid)
f(tid)
future.set_result(None)
committed.catch(future.set_exception)
......@@ -536,21 +597,35 @@ class ClientRunner:
def call(self, method, *args, timeout=None):
return self.__call(self.client.call_threadsafe, method, args)
def callAsync(self, method, *args):
def async(self, method, *args):
return self.__call(self.client.call_async_threadsafe, method, args)
def async_iter(self, it):
return self.__call(self.client.call_async_iter_threadsafe, it)
def load(self, oid):
return self.__call(self.client.load_threadsafe, oid)
def load_before(self, oid, tid):
return self.__call(self.client.load_before_threadsafe, oid, tid)
def tpc_finish(self, tid, updates):
return self.__call(self.client.tpc_finish_threadsafe, tid, updates)
def tpc_finish(self, tid, updates, f):
return self.__call(self.client.tpc_finish_threadsafe, tid, updates, f)
def is_connected(self):
return self.client.ready
def is_read_only(self):
try:
protocol = self.client.protocol
except AttributeError:
return True
else:
if protocol is None:
return True
else:
return protocol.read_only
def close(self):
self.__call(self.client.close_threadsafe)
......
......@@ -17,7 +17,7 @@ class Loop:
def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory()
self.transport = transport = Transport()
self.transport = transport = Transport(protocol)
protocol.connection_made(transport)
future.set_result((transport, protocol))
......@@ -60,14 +60,26 @@ class Loop:
class Transport:
def __init__(self):
capacity = 1 << 64
paused = False
extra = dict(peername='1.2.3.4')
def __init__(self, protocol):
self.data = []
self.protocol = protocol
def write(self, data):
self.data.append(data)
self.check_pause()
def writelines(self, lines):
self.data.extend(lines)
self.check_pause()
def check_pause(self):
if len(self.data) > self.capacity and not self.paused:
self.paused = True
self.protocol.pause_writing()
def pop(self, count=None):
if count:
......@@ -76,8 +88,17 @@ class Transport:
else:
r = self.data[:]
del self.data[:]
self.check_resume()
return r
def check_resume(self):
if len(self.data) < self.capacity and self.paused:
self.paused = False
self.protocol.resume_writing()
closed = False
def close(self):
self.closed = True
def get_extra_info(self, name):
return self.extra[name]
......@@ -20,6 +20,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
read_only=False,
finish_start=False,
):
# To create a client, we need to specify an address, a client
# object and a cache.
......@@ -43,6 +44,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
loop.protocol.data_received(
sized(pickle.dumps((message_id, False, '.reply', result), 3)))
if finish_start:
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, 'a'*8)
return (wrapper, cache, self.loop, self.client, protocol, transport,
send, respond)
......@@ -82,9 +94,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(f1.done())
# If we try to make an async call, we get an immediate error:
f2 = self.callAsync('bar', 3, 4)
f2 = self.async('bar', 3, 4)
self.assert_(isinstance(f2.exception(), ClientDisconnected))
# The wrapper object (ClientStorage) hasn't been notified:
wrapper.notify_connected.assert_not_called()
# Let's respond to those first 2 calls:
respond(1, None)
......@@ -96,11 +111,14 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(cache.getLastTid(), 'a'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'foo', (1, 2)))
# The wrapper object (ClientStorage) has been notified:
wrapper.notify_connected.assert_called_with(client)
respond(3, 42)
self.assertEqual(f1.result(), 42)
# Now we can make async calls:
f2 = self.callAsync('bar', 3, 4)
f2 = self.async('bar', 3, 4)
self.assert_(f2.done() and f2.exception() is None)
self.assertEqual(parse(transport.pop()), (0, True, 'bar', (3, 4)))
......@@ -144,26 +162,32 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, tid, data)
# with committed data. To do this, we pass a (oid, data, resolved)
# iteratable to tpc_finish_threadsafe.
from ZODB.ConflictResolution import ResolvedSerial
tids = []
def finished_cb(tid):
tids.append(tid)
committed = self.tpc_finish(
b'd'*8,
[(b'2'*8, b'd'*8, 'committed 2'),
(b'1'*8, ResolvedSerial, 'committed 3'),
(b'4'*8, b'd'*8, 'committed 4'),
])
[(b'2'*8, 'committed 2', False),
(b'1'*8, 'committed 3', True),
(b'4'*8, 'committed 4', False),
],
finished_cb)
self.assertFalse(committed.done() or
cache.load(b'2'*8) or
cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(parse(transport.pop()),
(7, False, 'tpc_finish', (b'd'*8,)))
respond(7, None)
respond(7, b'e'*8)
self.assertEqual(committed.result(), None)
self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'd'*8))
self.assertEqual(cache.load(b'4'*8), ('committed 4', b'd'*8))
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
self.assertEqual(cache.load(b'4'*8), ('committed 4', b'e'*8))
self.assertEqual(tids.pop(), b'e'*8)
# If the protocol is disconnected, it will reconnect and will
# resolve outstanding requests with exceptions:
......@@ -176,7 +200,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
)
exc = TypeError(43)
wrapper.notify_disconnected.assert_not_called()
wrapper.notify_connected.reset_mock()
protocol.connection_lost(exc)
wrapper.notify_disconnected.assert_called_with()
self.assertEqual(loaded.exception(), exc)
self.assertEqual(f1.exception(), exc)
......@@ -195,12 +222,14 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
wrapper.notify_connected.assert_not_called()
respond(1, None)
respond(2, b'd'*8)
respond(2, b'e'*8)
wrapper.notify_connected.assert_called_with(client)
# Because the server tid matches the cache tid, we're done connecting
self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'd'*8)
self.assertEqual(cache.getLastTid(), b'e'*8)
# Because we were able to update the cache, we didn't have to
# invalidate the database cache:
......@@ -364,6 +393,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(addrs, (), read_only=Fallback))
self.assertTrue(self.is_read_only())
# We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport
......@@ -376,6 +407,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
])
# We respond with a read-only exception:
respond(1, (ReadOnlyError, ReadOnlyError()))
self.assertTrue(self.is_read_only())
# The client tries for a read-only connection:
self.assertEqual(self.parse(transport.pop()),
......@@ -385,6 +417,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We respond with successfully:
respond(3, None)
respond(4, 'b'*8)
self.assertTrue(self.is_read_only())
# At this point, the client is ready and using the protocol,
# and the protocol is read-only:
......@@ -402,9 +435,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
self.assertTrue(self.is_read_only())
# We respond and the writable connection succeeds:
respond(1, None)
self.assertFalse(self.is_read_only())
# Now, the original protocol is closed, and the client is
# no-longer ready:
......@@ -463,6 +498,46 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
wrapper.invalidateTransaction.assert_called_with(b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
def test_flow_control(self):
# When sending a lot of data (blobs), we don't want to fill up
# memory behind a slow socket. Asycio's flow control helper
# seems a bit complicated. We'd rather pass an iterator that's
# consumed as we can.
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True))
# Give the transport a small capacity:
transport.capacity = 2
self.async('foo')
self.async('bar')
self.async('baz')
self.async('splat')
# The first 2 were sent, but the remaining were queued.
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'foo', ()), (0, True, 'bar', ())])
# But popping them allowed sending to resume:
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'baz', ()), (0, True, 'splat', ())])
# This is especially handy with iterators:
self.async_iter((name, ()) for name in 'abcde')
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'a', ()), (0, True, 'b', ())])
self.assertEqual(self.parse(transport.pop()),
[(0, True, 'c', ()), (0, True, 'd', ())])
self.assertEqual(self.parse(transport.pop()),
(0, True, 'e', ()))
self.assertEqual(self.parse(transport.pop()),
[])
def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True))
self.assertEqual(client.get_peername(), '1.2.3.4')
def unsized(self, data, unpickle=False):
result = []
while data:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment