Commit b6ec0eca authored by Jim Fulton's avatar Jim Fulton

Added protocol negotiation and test fixes

parent 7e5b78fc
...@@ -32,6 +32,8 @@ class Protocol(asyncio.Protocol): ...@@ -32,6 +32,8 @@ class Protocol(asyncio.Protocol):
transport = protocol_version = None transport = protocol_version = None
protocols = b"Z309", b"Z310", b"Z3101"
def __init__(self, loop, def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1): addr, client, storage_key, read_only, connect_poll=1):
"""Create a client interface """Create a client interface
...@@ -179,8 +181,12 @@ class Protocol(asyncio.Protocol): ...@@ -179,8 +181,12 @@ class Protocol(asyncio.Protocol):
# lastTid before processing (and possibly missing) subsequent # lastTid before processing (and possibly missing) subsequent
# invalidations. # invalidations.
self.protocol_version = protocol_version self.protocol_version = min(protocol_version, self.protocols[-1])
self._write(protocol_version) # XXX protocol negotiation if self.protocol_version not in self.protocols:
self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version))
return
self._write(self.protocol_version)
register = self.promise( register = self.promise(
'register', self.storage_key, 'register', self.storage_key,
...@@ -351,7 +357,7 @@ class Client: ...@@ -351,7 +357,7 @@ class Client:
def disconnected(self, protocol=None): def disconnected(self, protocol=None):
if protocol is None or protocol is self.protocol: if protocol is None or protocol is self.protocol:
if protocol is self.protocol: if protocol is self.protocol and protocol is not None:
self.client.notify_disconnected() self.client.notify_disconnected()
self.ready = False self.ready = False
self.connected = concurrent.futures.Future() self.connected = concurrent.futures.Future()
...@@ -544,14 +550,6 @@ class Client: ...@@ -544,14 +550,6 @@ class Client:
self.close() self.close()
future.set_result(None) future.set_result(None)
# Methods called by the server:
client_methods = (
'invalidateTransaction', 'serialnos', 'info',
'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop',
)
client_delegated = client_methods[1:]
def invalidateTransaction(self, tid, oids): def invalidateTransaction(self, tid, oids):
if self.ready: if self.ready:
for oid in oids: for oid in oids:
......
...@@ -10,6 +10,7 @@ import pdb ...@@ -10,6 +10,7 @@ import pdb
import pickle import pickle
import struct import struct
import unittest import unittest
import ZEO.Exceptions
from .testing import Loop from .testing import Loop
from .client import ClientRunner, Fallback from .client import ClientRunner, Fallback
...@@ -33,6 +34,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -33,6 +34,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# so we don't have to actually make any network connection. # so we don't have to actually make any network connection.
loop = Loop(addrs if loop_addrs is None else loop_addrs) loop = Loop(addrs if loop_addrs is None else loop_addrs)
self.setup_delegation(loop) self.setup_delegation(loop)
self.assertFalse(wrapper.notify_disconnected.called)
protocol = loop.protocol protocol = loop.protocol
transport = loop.transport transport = loop.transport
...@@ -45,8 +47,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -45,8 +47,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
sized(pickle.dumps((message_id, False, '.reply', result), 3))) sized(pickle.dumps((message_id, False, '.reply', result), 3)))
if finish_start: if finish_start:
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse parse = self.parse
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
...@@ -70,16 +72,19 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -70,16 +72,19 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start()) self.start())
self.assertFalse(wrapper.notify_disconnected.called)
# The client isn't connected until the server sends it some data. # The client isn't connected until the server sends it some data.
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
# The server sends the client some data: # The server sends the client it's protocol. In this case,
protocol.data_received(sized(b'Z101')) # it's a very high one. The client will send it's highest that
# it can use.
protocol.data_received(sized(b'Z99999'))
# The client sends back a handshake, and registers the # The client sends back a handshake, and registers the
# storage, and requests the last transaction. # storage, and requests the last transaction.
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse parse = self.parse
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
...@@ -98,7 +103,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -98,7 +103,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assert_(isinstance(f2.exception(), ClientDisconnected)) self.assert_(isinstance(f2.exception(), ClientDisconnected))
# The wrapper object (ClientStorage) hasn't been notified: # The wrapper object (ClientStorage) hasn't been notified:
wrapper.notify_connected.assert_not_called() self.assertFalse(wrapper.notify_connected.called)
# Let's respond to those first 2 calls: # Let's respond to those first 2 calls:
...@@ -139,6 +144,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -139,6 +144,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Let's send an invalidation: # Let's send an invalidation:
send('invalidateTransaction', b'b'*8, [b'1'*8]) send('invalidateTransaction', b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'b'*8, [b'1'*8]) wrapper.invalidateTransaction.assert_called_with(b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
# Now, if we try to load current again, we'll make a server request. # Now, if we try to load current again, we'll make a server request.
loaded = self.load(b'1'*8) loaded = self.load(b'1'*8)
...@@ -200,7 +206,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -200,7 +206,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
) )
exc = TypeError(43) exc = TypeError(43)
wrapper.notify_disconnected.assert_not_called() self.assertFalse(wrapper.notify_disconnected.called)
wrapper.notify_connected.reset_mock() wrapper.notify_connected.reset_mock()
protocol.connection_lost(exc) protocol.connection_lost(exc)
wrapper.notify_disconnected.assert_called_with() wrapper.notify_disconnected.assert_called_with()
...@@ -216,13 +222,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -216,13 +222,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# and we have a new incomplete connect future: # and we have a new incomplete connect future:
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') # This time we'll send a lower protocol version. The client
# will send it back, because it's lower than the client's
# protocol:
protocol.data_received(sized(b'Z310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310')
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
wrapper.notify_connected.assert_not_called() self.assertFalse(wrapper.notify_connected.called)
respond(1, None) respond(1, None)
respond(2, b'e'*8) respond(2, b'e'*8)
wrapper.notify_connected.assert_called_with(client) wrapper.notify_connected.assert_called_with(client)
...@@ -233,7 +243,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -233,7 +243,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Because we were able to update the cache, we didn't have to # Because we were able to update the cache, we didn't have to
# invalidate the database cache: # invalidate the database cache:
wrapper.invalidateTransaction.assert_not_called() self.assertFalse(wrapper.invalidateTransaction.called)
# The close method closes the connection and cache: # The close method closes the connection and cache:
client.close() client.close()
...@@ -252,8 +262,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -252,8 +262,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
cache.store(b'2'*8, b'a'*8, None, '2 data') cache.store(b'2'*8, b'a'*8, None, '2 data')
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
...@@ -278,7 +288,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -278,7 +288,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Because we were able to update the cache, we didn't have to # Because we were able to update the cache, we didn't have to
# invalidate the database cache: # invalidate the database cache:
wrapper.invalidateCache.assert_not_called() self.assertFalse(wrapper.invalidateCache.called)
def test_cache_way_behind(self): def test_cache_way_behind(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport, send, respond = (
...@@ -289,8 +299,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -289,8 +299,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertTrue(cache) self.assertTrue(cache)
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
...@@ -341,8 +351,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -341,8 +351,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(sorted(loop.connecting), addrs[:1]) self.assertEqual(sorted(loop.connecting), addrs[:1])
protocol = loop.protocol protocol = loop.protocol
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
respond(1, None) respond(1, None)
# Now, when the first connection fails, it won't be retried, # Now, when the first connection fails, it won't be retried,
...@@ -359,8 +369,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -359,8 +369,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.start()) self.start())
cache.store(b'4'*8, b'a'*8, None, '4 data') cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8) cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse parse = self.parse
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
...@@ -377,8 +387,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -377,8 +387,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(transport is loop.transport) self.assertFalse(transport is loop.transport)
protocol = loop.protocol protocol = loop.protocol
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
...@@ -398,8 +408,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -398,8 +408,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We'll treat the first address as read-only and we'll let it connect: # We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0]) loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport protocol, transport = loop.protocol, loop.transport
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
# We see that the client tried a writable connection: # We see that the client tried a writable connection:
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
...@@ -429,8 +439,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -429,8 +439,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We connect the second address: # We connect the second address:
loop.connect_connecting(addrs[1]) loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(b'Z101')) loop.protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z101') self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(loop.transport.pop()), self.assertEqual(self.parse(loop.transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
...@@ -461,8 +471,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -461,8 +471,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# While we're verifying, invalidations are ignored # While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start()) self.start())
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
...@@ -484,8 +494,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -484,8 +494,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Similarly, invalidations aren't processed while reconnecting: # Similarly, invalidations aren't processed while reconnecting:
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
...@@ -533,6 +543,16 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -533,6 +543,16 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
[]) [])
def test_bad_protocol(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
with mock.patch("ZEO.asyncio.client.logger.error") as error:
self.assertFalse(error.called)
protocol.data_received(sized(b'Z200'))
self.assert_(isinstance(error.call_args[0][1],
ZEO.Exceptions.ProtocolError))
def test_get_peername(self): def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True)) self.start(finish_start=True))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment