Commit b6ec0eca authored by Jim Fulton's avatar Jim Fulton

Added protocol negotiation and test fixes

parent 7e5b78fc
......@@ -32,6 +32,8 @@ class Protocol(asyncio.Protocol):
transport = protocol_version = None
protocols = b"Z309", b"Z310", b"Z3101"
def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1):
"""Create a client interface
......@@ -179,8 +181,12 @@ class Protocol(asyncio.Protocol):
# lastTid before processing (and possibly missing) subsequent
# invalidations.
self.protocol_version = protocol_version
self._write(protocol_version) # XXX protocol negotiation
self.protocol_version = min(protocol_version, self.protocols[-1])
if self.protocol_version not in self.protocols:
self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version))
return
self._write(self.protocol_version)
register = self.promise(
'register', self.storage_key,
......@@ -351,7 +357,7 @@ class Client:
def disconnected(self, protocol=None):
if protocol is None or protocol is self.protocol:
if protocol is self.protocol:
if protocol is self.protocol and protocol is not None:
self.client.notify_disconnected()
self.ready = False
self.connected = concurrent.futures.Future()
......@@ -544,14 +550,6 @@ class Client:
self.close()
future.set_result(None)
# Methods called by the server:
client_methods = (
'invalidateTransaction', 'serialnos', 'info',
'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop',
)
client_delegated = client_methods[1:]
def invalidateTransaction(self, tid, oids):
if self.ready:
for oid in oids:
......
......@@ -10,6 +10,7 @@ import pdb
import pickle
import struct
import unittest
import ZEO.Exceptions
from .testing import Loop
from .client import ClientRunner, Fallback
......@@ -33,6 +34,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# so we don't have to actually make any network connection.
loop = Loop(addrs if loop_addrs is None else loop_addrs)
self.setup_delegation(loop)
self.assertFalse(wrapper.notify_disconnected.called)
protocol = loop.protocol
transport = loop.transport
......@@ -45,8 +47,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
sized(pickle.dumps((message_id, False, '.reply', result), 3)))
if finish_start:
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
......@@ -70,16 +72,19 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
self.assertFalse(wrapper.notify_disconnected.called)
# The client isn't connected until the server sends it some data.
self.assertFalse(client.connected.done() or transport.data)
# The server sends the client some data:
protocol.data_received(sized(b'Z101'))
# The server sends the client it's protocol. In this case,
# it's a very high one. The client will send it's highest that
# it can use.
protocol.data_received(sized(b'Z99999'))
# The client sends back a handshake, and registers the
# storage, and requests the last transaction.
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
......@@ -98,7 +103,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assert_(isinstance(f2.exception(), ClientDisconnected))
# The wrapper object (ClientStorage) hasn't been notified:
wrapper.notify_connected.assert_not_called()
self.assertFalse(wrapper.notify_connected.called)
# Let's respond to those first 2 calls:
......@@ -139,6 +144,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Let's send an invalidation:
send('invalidateTransaction', b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
# Now, if we try to load current again, we'll make a server request.
loaded = self.load(b'1'*8)
......@@ -200,7 +206,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
)
exc = TypeError(43)
wrapper.notify_disconnected.assert_not_called()
self.assertFalse(wrapper.notify_disconnected.called)
wrapper.notify_connected.reset_mock()
protocol.connection_lost(exc)
wrapper.notify_disconnected.assert_called_with()
......@@ -216,13 +222,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# and we have a new incomplete connect future:
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
# This time we'll send a lower protocol version. The client
# will send it back, because it's lower than the client's
# protocol:
protocol.data_received(sized(b'Z310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310')
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
wrapper.notify_connected.assert_not_called()
self.assertFalse(wrapper.notify_connected.called)
respond(1, None)
respond(2, b'e'*8)
wrapper.notify_connected.assert_called_with(client)
......@@ -233,7 +243,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Because we were able to update the cache, we didn't have to
# invalidate the database cache:
wrapper.invalidateTransaction.assert_not_called()
self.assertFalse(wrapper.invalidateTransaction.called)
# The close method closes the connection and cache:
client.close()
......@@ -252,8 +262,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
cache.store(b'2'*8, b'a'*8, None, '2 data')
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
......@@ -278,7 +288,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Because we were able to update the cache, we didn't have to
# invalidate the database cache:
wrapper.invalidateCache.assert_not_called()
self.assertFalse(wrapper.invalidateCache.called)
def test_cache_way_behind(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
......@@ -289,8 +299,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertTrue(cache)
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
......@@ -341,8 +351,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(sorted(loop.connecting), addrs[:1])
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
respond(1, None)
# Now, when the first connection fails, it won't be retried,
......@@ -359,8 +369,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.start())
cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
......@@ -377,8 +387,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(transport is loop.transport)
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
......@@ -398,8 +408,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
# We see that the client tried a writable connection:
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
......@@ -429,8 +439,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We connect the second address:
loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z101')
loop.protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(loop.transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
......@@ -461,8 +471,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
......@@ -484,8 +494,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Similarly, invalidations aren't processed while reconnecting:
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
......@@ -533,6 +543,16 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(self.parse(transport.pop()),
[])
def test_bad_protocol(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
with mock.patch("ZEO.asyncio.client.logger.error") as error:
self.assertFalse(error.called)
protocol.data_received(sized(b'Z200'))
self.assert_(isinstance(error.call_args[0][1],
ZEO.Exceptions.ProtocolError))
def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment