Commit e302ff45 authored by Jim Fulton's avatar Jim Fulton

Some refactoring

Renamed connection_timeout to connect_poll and use it when reconnecting.

Optimized hanfline of first messagem, as we did in zrpc, because it
only occurs once. :)

More/better comments.

Move loop argument to front of constructor arguments.

Added close/close_threadsafe to wait for connection on close.

Added is_connected.

Added new_addr.
parent 107f1077
...@@ -29,10 +29,11 @@ class Protocol(asyncio.Protocol): ...@@ -29,10 +29,11 @@ class Protocol(asyncio.Protocol):
# One place where special care was required was in cache setup on # One place where special care was required was in cache setup on
# connect. See finish connect below. # connect. See finish connect below.
transport = protocol_version = None transport = protocol_version = None
def __init__(self, addr, client, storage_key, read_only, loop, def __init__(self, loop,
connect_timeout=1): addr, client, storage_key, read_only, connect_poll=1):
"""Create a client interface """Create a client interface
addr is either a host,port tuple or a string file name. addr is either a host,port tuple or a string file name.
...@@ -48,9 +49,13 @@ class Protocol(asyncio.Protocol): ...@@ -48,9 +49,13 @@ class Protocol(asyncio.Protocol):
self.name = "%s(%r, %r, %r)" % ( self.name = "%s(%r, %r, %r)" % (
self.__class__.__name__, addr, storage_key, read_only) self.__class__.__name__, addr, storage_key, read_only)
self.client = client self.client = client
self.connect_timeout = connect_timeout self.connect_poll = connect_poll
self.futures = {} # { message_id -> future } self.futures = {} # { message_id -> future }
self.input = [] self.input = []
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
self.connect() self.connect()
def __repr__(self): def __repr__(self):
...@@ -85,8 +90,10 @@ class Protocol(asyncio.Protocol): ...@@ -85,8 +90,10 @@ class Protocol(asyncio.Protocol):
if future.exception() is not None: if future.exception() is not None:
# keep trying # keep trying
if not self.closed: if not self.closed:
self.loop.call_later(1 + local_random.random(), self.loop.call_later(
self.connect) self.connect_poll + local_random.random(),
self.connect,
)
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) logger.info("Connected %s", self)
...@@ -112,22 +119,24 @@ class Protocol(asyncio.Protocol): ...@@ -112,22 +119,24 @@ class Protocol(asyncio.Protocol):
self.client.disconnected(self) self.client.disconnected(self)
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
# We use a promise model rather than coroutines here because # We use a promise model rather than coroutines here because
# for the most part, this class is reactive a coroutines # for the most part, this class is reactive and coroutines
# aren't a good model of it's activities. During # aren't a good model of it's activities. During
# initialization, however, we use promises to provide an # initialization, however, we use promises to provide an
# impertive flow. # imperative flow.
# The promise(/future) implementation we use differs from # The promise(/future) implementation we use differs from
# asyncio.Future in that callbacks are called immediately, # asyncio.Future in that callbacks are called immediately,
# rather than using the loops call_soon. We want to avoid a # rather than using the loops call_soon. We want to avoid a
# race between invalidations and cache initialization. In # race between invalidations and cache initialization. In
# particular, after calling lastTransaction or # particular, after getting a response from lastTransaction or
# getInvalidations, we want to make sure we set the cache's # getInvalidations, we want to make sure we set the cache's
# lastTid before processing subsequent invalidations. # lastTid before processing (and possibly missing) subsequent
# invalidations.
self.protocol_version = protocol_version self.protocol_version = protocol_version
self._write(protocol_version) self._write(protocol_version) # XXX protocol negotiation
register = self.promise( register = self.promise(
'register', self.storage_key, 'register', self.storage_key,
...@@ -168,6 +177,9 @@ class Protocol(asyncio.Protocol): ...@@ -168,6 +177,9 @@ class Protocol(asyncio.Protocol):
want = 4 want = 4
getting_size = True getting_size = True
def data_received(self, data): def data_received(self, data):
# Low-level input handler collects data into sized messages.
self.got += len(data) self.got += len(data)
self.input.append(data) self.input.append(data)
while self.got >= self.want: while self.got >= self.want:
...@@ -193,27 +205,29 @@ class Protocol(asyncio.Protocol): ...@@ -193,27 +205,29 @@ class Protocol(asyncio.Protocol):
self.getting_size = True self.getting_size = True
self.message_received(collected) self.message_received(collected)
def first_message_received(self, data):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.finish_connect(data)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
if self.protocol_version is None: msgid, async, name, args = loads(data)
self.finish_connect(data) if name == '.reply':
future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and
issubclass(args[0], Exception)
):
future.set_exception(args[1])
else:
future.set_result(args)
else: else:
msgid, async, name, args = loads(data) assert async # clients only get async calls
if name == '.reply': if name in self.client_methods:
future = self.futures.pop(msgid) getattr(self.client, name)(*args)
if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and
issubclass(args[0], Exception)
):
future.set_exception(args[1])
else:
future.set_result(args)
else: else:
assert async # clients only get async calls raise AttributeError(name)
if name in self.client_methods:
getattr(self.client, name)(*args)
else:
raise AttributeError(name)
def call_async(self, method, args): def call_async(self, method, args):
self._write(dumps((0, True, method, args), 3)) self._write(dumps((0, True, method, args), 3))
...@@ -250,7 +264,9 @@ class Client: ...@@ -250,7 +264,9 @@ class Client:
protocol = None protocol = None
ready = False ready = False
def __init__(self, addrs, client, cache, storage_key, read_only, loop): def __init__(self, loop,
addrs, client, cache, storage_key, read_only, connect_poll,
register_failed_poll=9):
"""Create a client interface """Create a client interface
addr is either a host,port tuple or a string file name. addr is either a host,port tuple or a string file name.
...@@ -263,6 +279,8 @@ class Client: ...@@ -263,6 +279,8 @@ class Client:
self.addrs = addrs self.addrs = addrs
self.storage_key = storage_key self.storage_key = storage_key
self.read_only = read_only self.read_only = read_only
self.connect_poll = connect_poll
self.register_failed_poll = register_failed_poll
self.client = client self.client = client
for name in Protocol.client_delegated: for name in Protocol.client_delegated:
setattr(self, name, getattr(client, name)) setattr(self, name, getattr(client, name))
...@@ -302,8 +320,9 @@ class Client: ...@@ -302,8 +320,9 @@ class Client:
def try_connecting(self): def try_connecting(self):
if not self.closed: if not self.closed:
self.protocols = [ self.protocols = [
Protocol(addr, self, self.storage_key, self.read_only, Protocol(self.loop, addr, self,
self.loop) self.storage_key, self.read_only, self.connect_poll,
)
for addr in self.addrs for addr in self.addrs
] ]
...@@ -330,7 +349,9 @@ class Client: ...@@ -330,7 +349,9 @@ class Client:
if (self.protocol is None and not if (self.protocol is None and not
any(not p.closed for p in self.protocols) any(not p.closed for p in self.protocols)
): ):
self.loop.call_later(9 + local_random.random(), self.try_connecting) self.loop.call_later(
self.register_failed_poll + local_random.random(),
self.try_connecting)
def verify(self, last_transaction_promise): def verify(self, last_transaction_promise):
protocol = self.protocol protocol = self.protocol
...@@ -357,6 +378,7 @@ class Client: ...@@ -357,6 +378,7 @@ class Client:
tid, oids = vdata tid, oids = vdata
for oid in oids: for oid in oids:
cache.invalidate(oid, None) cache.invalidate(oid, None)
self.client.invalidateTransaction(tid, oids)
return tid return tid
else: else:
# cache is too old # cache is too old
...@@ -365,8 +387,10 @@ class Client: ...@@ -365,8 +387,10 @@ class Client:
self.client.invalidateCache() self.client.invalidateCache()
return server_tid return server_tid
verify_invalidations(self.finished_verify, verify_invalidations(
self.connected.set_exception) self.finished_verify,
self.connected.set_exception,
)
else: else:
self.finished_verify(server_tid) self.finished_verify(server_tid)
...@@ -455,6 +479,10 @@ class Client: ...@@ -455,6 +479,10 @@ class Client:
else: else:
future.set_exception(ClientDisconnected()) future.set_exception(ClientDisconnected())
def close_threadsafe(self, future):
self.close()
future.set_result(None)
# Methods called by the server: # Methods called by the server:
client_methods = ( client_methods = (
...@@ -474,14 +502,15 @@ class Client: ...@@ -474,14 +502,15 @@ class Client:
class ClientRunner: class ClientRunner:
def set_options(self, addrs, wrapper, cache, storage_key, read_only, def set_options(self, addrs, wrapper, cache, storage_key, read_only,
timeout=30): timeout=30, disconnect_poll=1):
self.__args = addrs, wrapper, cache, storage_key, read_only self.__args = (addrs, wrapper, cache, storage_key, read_only,
disconnect_poll)
self.timeout = timeout self.timeout = timeout
self.connected = concurrent.futures.Future() self.connected = concurrent.futures.Future()
def setup_delegation(self, loop): def setup_delegation(self, loop):
self.loop = loop self.loop = loop
self.client = Client(*self.__args, loop=loop) self.client = Client(loop, *self.__args)
from concurrent.futures import Future from concurrent.futures import Future
call_soon_threadsafe = loop.call_soon_threadsafe call_soon_threadsafe = loop.call_soon_threadsafe
...@@ -519,6 +548,17 @@ class ClientRunner: ...@@ -519,6 +548,17 @@ class ClientRunner:
def tpc_finish(self, tid, updates): def tpc_finish(self, tid, updates):
return self.__call(self.client.tpc_finish_threadsafe, tid, updates) return self.__call(self.client.tpc_finish_threadsafe, tid, updates)
def is_connected(self):
return self.client.ready
def close(self):
self.__call(self.client.close_threadsafe)
def new_addr(self, addrs):
# This usually doesn't have an immediate effect, since the
# addrs aren't used until the client disconnects.xs
self.client.addrs = addrs
class ClientThread(ClientRunner): class ClientThread(ClientRunner):
"""Thread wrapper for client interface """Thread wrapper for client interface
...@@ -529,7 +569,8 @@ class ClientThread(ClientRunner): ...@@ -529,7 +569,8 @@ class ClientThread(ClientRunner):
def __init__(self, addrs, client, cache, def __init__(self, addrs, client, cache,
storage_key='1', read_only=False, timeout=30): storage_key='1', read_only=False, timeout=30):
self.set_options(addrs, client, cache, storage_key, read_only, timeout) self.set_options(addrs, client, cache, storage_key, read_only,
timeout, disconnect_poll)
threading.Thread( threading.Thread(
target=self.run, target=self.run,
args=(addr, client, cache, storage_key, read_only), args=(addr, client, cache, storage_key, read_only),
...@@ -562,6 +603,32 @@ class Promise: ...@@ -562,6 +603,32 @@ class Promise:
next = success_callback = error_callback = cancelled = None next = success_callback = error_callback = cancelled = None
def __call__(self, success_callback = None, error_callback = None): def __call__(self, success_callback = None, error_callback = None):
"""Set the promises success and error handlers and beget a new promise
The promise returned provides for promise chaining, providing
a sane imperative flow. Let's call this the "next" promise.
Any results or exceptions generated by the promise or it's
callbacks are passed on to the next promise.
When the promise completes successfully, if a success callback
isn't set, then the next promise is completed with the
successfull result. If a success callback is provided, it's
called. If the call succeeds, and the result is a promise,
them the result is called with the next promise's set_result
and set_exception methods, chaining the result and next
promise. If the result isn't a promise, then the next promise
is completed with it by calling set_result. If the success
callback fails, then it's exception is passed to
next.set_exception.
If the promise completes with an error and the error callback
isn't set, then the exception is passed to the next promises
set_exception. If an error handler is provided, it's called
and if it doesn't error, then the original exception is passed
to the next promise's set_exception. If there error handler
errors, then that exception is passed to the next promise's
set_exception.
"""
self.next = self.__class__() self.next = self.__class__()
self.success_callback = success_callback self.success_callback = success_callback
self.error_callback = error_callback self.error_callback = error_callback
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment