Commit 3f31236b authored by Jim Fulton's avatar Jim Fulton

Checkpointing....

Many tests passing. Quite a few still failing.
parent bdbc36dd
...@@ -118,7 +118,7 @@ setup(name="ZEO", ...@@ -118,7 +118,7 @@ setup(name="ZEO",
install_requires = [ install_requires = [
'ZODB >= 4.2.0b1', 'ZODB >= 4.2.0b1',
'six', 'six',
'transaction', 'transaction >= 1.6.0',
'persistent >= 4.1.0', 'persistent >= 4.1.0',
'zc.lockfile', 'zc.lockfile',
'ZConfig', 'ZConfig',
......
This diff is collapsed.
...@@ -26,3 +26,7 @@ class ClientDisconnected(ClientStorageError): ...@@ -26,3 +26,7 @@ class ClientDisconnected(ClientStorageError):
class AuthError(StorageError): class AuthError(StorageError):
"""The client provided invalid authentication credentials.""" """The client provided invalid authentication credentials."""
class ProtocolError(ClientStorageError):
"""A client contacted a server with an incomparible protocol
"""
This diff is collapsed.
...@@ -1574,10 +1574,6 @@ class ZEOStorage308Adapter: ...@@ -1574,10 +1574,6 @@ class ZEOStorage308Adapter:
abortVersion = commitVersion abortVersion = commitVersion
def zeoLoad(self, oid): # Z200
p, s = self.storage.loadEx(oid)
return p, s, '', None, None
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.storage, name) return getattr(self.storage, name)
......
...@@ -21,44 +21,24 @@ is used to store the data until a commit or abort. ...@@ -21,44 +21,24 @@ is used to store the data until a commit or abort.
# A faster implementation might store trans data in memory until it # A faster implementation might store trans data in memory until it
# reaches a certain size. # reaches a certain size.
from threading import Lock
import os import os
import tempfile import tempfile
import ZODB.blob import ZODB.blob
from ZODB.ConflictResolution import ResolvedSerial
from ZEO._compat import Pickler, Unpickler from ZEO._compat import Pickler, Unpickler
class TransactionBuffer: class TransactionBuffer:
# Valid call sequences:
#
# ((store | invalidate)* begin_iterate next* clear)* close
#
# get_size can be called any time
# The TransactionBuffer is used by client storage to hold update # The TransactionBuffer is used by client storage to hold update
# data until the tpc_finish(). It is normally used by a single # data until the tpc_finish(). It is only used by a single
# thread, because only one thread can be in the two-phase commit # thread, because only one thread can be in the two-phase commit
# at one time. # at one time.
# It is possible, however, for one thread to close the storage def __init__(self, connection_generation):
# while another thread is in the two-phase commit. We must use self.connection_generation = connection_generation
# a lock to guard against this race, because unpredictable things
# can happen in Python if one thread closes a file that another
# thread is reading. In a debug build, an assert() can fail.
# Caution: If an operation is performed on a closed TransactionBuffer,
# it has no effect and does not raise an exception. The only time
# this should occur is when a ClientStorage is closed in one
# thread while another thread is in its tpc_finish(). It's not
# clear what should happen in this case. If the tpc_finish()
# completes without error, the Connection using it could have
# inconsistent data. This should have minimal effect, though,
# because the Connection is connected to a closed storage.
def __init__(self):
self.file = tempfile.TemporaryFile(suffix=".tbuf") self.file = tempfile.TemporaryFile(suffix=".tbuf")
self.lock = Lock()
self.closed = 0
self.count = 0 self.count = 0
self.size = 0 self.size = 0
self.blobs = [] self.blobs = []
...@@ -66,89 +46,45 @@ class TransactionBuffer: ...@@ -66,89 +46,45 @@ class TransactionBuffer:
# stored are builtin types -- strings or None. # stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1) self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1 self.pickler.fast = 1
self.serials = {} # processed { oid -> serial }
self.exception = None
def close(self): def close(self):
self.clear() self.file.close()
self.lock.acquire()
try:
self.closed = 1
try:
self.file.close()
except OSError:
pass
finally:
self.lock.release()
def store(self, oid, data): def store(self, oid, data):
"""Store oid, version, data for later retrieval""" """Store oid, version, data for later retrieval"""
self.lock.acquire() self.pickler.dump((oid, data))
try: self.count += 1
if self.closed: # Estimate per-record cache size
return self.size = self.size + (data and len(data) or 0) + 31
self.pickler.dump((oid, data))
self.count += 1 def serial(self, oid, serial):
# Estimate per-record cache size if isinstance(serial, Exception):
self.size = self.size + (data and len(data) or 0) + 31 self.exception = serial
finally: else:
self.lock.release() self.serials[oid] = serial
def storeBlob(self, oid, blobfilename): def storeBlob(self, oid, blobfilename):
self.blobs.append((oid, blobfilename)) self.blobs.append((oid, blobfilename))
def invalidate(self, oid):
self.lock.acquire()
try:
if self.closed:
return
self.pickler.dump((oid, None))
self.count += 1
finally:
self.lock.release()
def clear(self):
"""Mark the buffer as empty"""
self.lock.acquire()
try:
if self.closed:
return
self.file.seek(0)
self.count = 0
self.size = 0
while self.blobs:
oid, blobfilename = self.blobs.pop()
if os.path.exists(blobfilename):
ZODB.blob.remove_committed(blobfilename)
finally:
self.lock.release()
def __iter__(self):
self.lock.acquire()
try:
if self.closed:
return
self.file.flush()
self.file.seek(0)
return TBIterator(self.file, self.count)
finally:
self.lock.release()
class TBIterator(object):
def __init__(self, f, count):
self.file = f
self.count = count
self.unpickler = Unpickler(f)
def __iter__(self): def __iter__(self):
return self self.file.seek(0)
unpickler = Unpickler(self.file)
def __next__(self): serials = self.serials
"""Return next tuple of data or None if EOF"""
if self.count == 0: # Gaaaa, this is awkward. There can be entries in serials that
self.file.seek(0) # aren't in the buffer, because undo. Entries can be repeated
self.size = 0 # in the buffer, because ZODB. (Maybe this is a bug now, but
raise StopIteration # it may be a feature later.
oid_ver_data = self.unpickler.load()
self.count -= 1 seen = set()
return oid_ver_data for i in range(self.count):
next = __next__ oid, data = unpickler.load()
seen.add(oid)
yield oid, data, serials[oid] == ResolvedSerial
# We may have leftover serials because undo
for oid, serial in serials.items():
if oid not in seen:
yield oid, None, serial == ResolvedSerial
...@@ -6,6 +6,7 @@ import concurrent.futures ...@@ -6,6 +6,7 @@ import concurrent.futures
import logging import logging
import random import random
import threading import threading
import traceback
import ZEO.Exceptions import ZEO.Exceptions
import ZODB.POSException import ZODB.POSException
...@@ -73,7 +74,8 @@ class Protocol(asyncio.Protocol): ...@@ -73,7 +74,8 @@ class Protocol(asyncio.Protocol):
if self.transport is not None: if self.transport is not None:
self.transport.close() self.transport.close()
for future in self.futures.values(): for future in self.futures.values():
future.set_exception(Closed()) future.set_exception(
ZEO.Exceptions.ClientDisconnected("Closed"))
self.futures.clear() self.futures.clear()
def protocol_factory(self): def protocol_factory(self):
...@@ -156,8 +158,7 @@ class Protocol(asyncio.Protocol): ...@@ -156,8 +158,7 @@ class Protocol(asyncio.Protocol):
return self.transport.get_extra_info('peername') return self.transport.get_extra_info('peername')
def connection_lost(self, exc): def connection_lost(self, exc):
if exc is None: if self.closed:
# we were closed
for f in self.futures.values(): for f in self.futures.values():
f.cancel() f.cancel()
else: else:
...@@ -320,7 +321,8 @@ class Client: ...@@ -320,7 +321,8 @@ class Client:
# connect. # connect.
protocol = None protocol = None
ready = False ready = None # Tri-value: None=Never connected, True=connected,
# False=Disconnected
def __init__(self, loop, def __init__(self, loop,
addrs, client, cache, storage_key, read_only, connect_poll, addrs, client, cache, storage_key, read_only, connect_poll,
...@@ -350,7 +352,9 @@ class Client: ...@@ -350,7 +352,9 @@ class Client:
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
self.protocol.close() self.ready = False
if self.protocol is not None:
self.protocol.close()
self.cache.close() self.cache.close()
self._clear_protocols() self._clear_protocols()
...@@ -364,7 +368,8 @@ class Client: ...@@ -364,7 +368,8 @@ class Client:
if protocol is None or protocol is self.protocol: if protocol is None or protocol is self.protocol:
if protocol is self.protocol and protocol is not None: if protocol is self.protocol and protocol is not None:
self.client.notify_disconnected() self.client.notify_disconnected()
self.ready = False if self.ready:
self.ready = False
self.connected = concurrent.futures.Future() self.connected = concurrent.futures.Future()
self.protocol = None self.protocol = None
self._clear_protocols() self._clear_protocols()
...@@ -468,8 +473,8 @@ class Client: ...@@ -468,8 +473,8 @@ class Client:
@self.protocol.promise('get_info') @self.protocol.promise('get_info')
def got_info(info): def got_info(info):
self.connected.set_result(None)
self.client.notify_connected(self, info) self.client.notify_connected(self, info)
self.connected.set_result(None)
@got_info.catch @got_info.catch
def failed_info(exc): def failed_info(exc):
...@@ -497,16 +502,21 @@ class Client: ...@@ -497,16 +502,21 @@ class Client:
def _when_ready(self, func, result_future, *args): def _when_ready(self, func, result_future, *args):
@self.connected.add_done_callback if self.ready is None:
def done(future): # We started without waiting for a connection. (prob tests :( )
e = future.exception() result_future.set_exception(
if e is not None: ZEO.Exceptions.ClientDisconnected("never connected"))
result_future.set_exception(e) else:
else: @self.connected.add_done_callback
if self.ready: def done(future):
func(result_future, *args) e = future.exception()
if e is not None:
future.set_exception(e)
else: else:
self._when_ready(func, result_future, *args) if self.ready:
func(result_future, *args)
else:
self._when_ready(func, result_future, *args)
def call_threadsafe(self, future, method, args): def call_threadsafe(self, future, method, args):
if self.ready: if self.ready:
...@@ -541,7 +551,7 @@ class Client: ...@@ -541,7 +551,7 @@ class Client:
future.set_result(data) future.set_result(data)
if data: if data:
data, start, end = data data, start, end = data
self.cache.store(oid, start, end, data) self.cache.store(oid, start, end, data)
load_before.catch(future.set_exception) load_before.catch(future.set_exception)
else: else:
...@@ -558,7 +568,7 @@ class Client: ...@@ -558,7 +568,7 @@ class Client:
cache.store(oid, tid, None, data) cache.store(oid, tid, None, data)
cache.setLastTid(tid) cache.setLastTid(tid)
f(tid) f(tid)
future.set_result(None) future.set_result(tid)
committed.catch(future.set_exception) committed.catch(future.set_exception)
else: else:
...@@ -579,6 +589,17 @@ class Client: ...@@ -579,6 +589,17 @@ class Client:
def protocol_version(self): def protocol_version(self):
return self.protocol.protocol_version return self.protocol.protocol_version
def is_read_only(self):
try:
protocol = self.protocol
except AttributeError:
return self.read_only
else:
if protocol is None:
return self.read_only
else:
return protocol.read_only
class ClientRunner: class ClientRunner:
def set_options(self, addrs, wrapper, cache, storage_key, read_only, def set_options(self, addrs, wrapper, cache, storage_key, read_only,
...@@ -591,6 +612,8 @@ class ClientRunner: ...@@ -591,6 +612,8 @@ class ClientRunner:
def setup_delegation(self, loop): def setup_delegation(self, loop):
self.loop = loop self.loop = loop
self.client = Client(loop, *self.__args) self.client = Client(loop, *self.__args)
self.call_threadsafe = self.client.call_threadsafe
self.call_async_threadsafe = self.client.call_async_threadsafe
from concurrent.futures import Future from concurrent.futures import Future
call_soon_threadsafe = loop.call_soon_threadsafe call_soon_threadsafe = loop.call_soon_threadsafe
...@@ -614,10 +637,17 @@ class ClientRunner: ...@@ -614,10 +637,17 @@ class ClientRunner:
return future.result(self.timeout if timeout is False else timeout) return future.result(self.timeout if timeout is False else timeout)
def call(self, method, *args, timeout=None): def call(self, method, *args, timeout=None):
return self.__call(self.client.call_threadsafe, method, args) return self.__call(self.call_threadsafe, method, args)
def call_future(self, method, *args):
# for tests
result = concurrent.futures.Future()
self.loop.call_soon_threadsafe(
self.call_threadsafe, result, method, args)
return result
def async(self, method, *args): def async(self, method, *args):
return self.__call(self.client.call_async_threadsafe, method, args) return self.__call(self.call_async_threadsafe, method, args)
def async_iter(self, it): def async_iter(self, it):
return self.__call(self.client.call_async_iter_threadsafe, it) return self.__call(self.client.call_async_iter_threadsafe, it)
...@@ -648,6 +678,12 @@ class ClientRunner: ...@@ -648,6 +678,12 @@ class ClientRunner:
def close(self): def close(self):
self.__call(self.client.close_threadsafe) self.__call(self.client.close_threadsafe)
# Short circuit from now on. We're closed.
def call_closed(*a, **k):
raise ZEO.Exceptions.ClientDisconnected('closed')
self.__call = call_closed
def new_addr(self, addrs): def new_addr(self, addrs):
# This usually doesn't have an immediate effect, since the # This usually doesn't have an immediate effect, since the
# addrs aren't used until the client disconnects.xs # addrs aren't used until the client disconnects.xs
...@@ -663,21 +699,45 @@ class ClientThread(ClientRunner): ...@@ -663,21 +699,45 @@ class ClientThread(ClientRunner):
def __init__(self, addrs, client, cache, def __init__(self, addrs, client, cache,
storage_key='1', read_only=False, timeout=30, storage_key='1', read_only=False, timeout=30,
disconnect_poll=1): disconnect_poll=1, wait=True):
self.set_options(addrs, client, cache, storage_key, read_only, self.set_options(addrs, client, cache, storage_key, read_only,
timeout, disconnect_poll) timeout, disconnect_poll)
threading.Thread( self.thread = threading.Thread(
target=self.run, target=self.run,
name='zeo_client_'+storage_key, name='zeo_client_'+storage_key,
daemon=True, daemon=True,
).start() )
self.connected.result(timeout) self.started = threading.Event()
self.thread.start()
self.started.wait()
if wait:
self.connected.result(timeout)
exception = None
def run(self): def run(self):
loop = asyncio.new_event_loop() try:
asyncio.set_event_loop(loop) loop = asyncio.new_event_loop()
self.setup_delegation(loop) asyncio.set_event_loop(loop)
loop.run_forever() self.setup_delegation(loop)
self.started.set()
loop.run_forever()
except Exception as exc:
logger.exception("Client thread")
self.exception = exc
raise
else:
loop.close()
logger.debug('Stopping client thread')
closed = False
def close(self):
if not self.closed:
self.closed = True
super().close()
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join(9)
if self.exception:
raise self.exception
class Promise: class Promise:
"""Lightweight future with a partial promise API. """Lightweight future with a partial promise API.
......
...@@ -102,3 +102,12 @@ class Transport: ...@@ -102,3 +102,12 @@ class Transport:
def get_extra_info(self, name): def get_extra_info(self, name):
return self.extra[name] return self.extra[name]
class AsyncRPC:
"""Adapt an asyncio API to an RPC to help hysterical tests
"""
def __init__(self, api):
self.api = api
def __getattr__(self, name):
return lambda *a, **kw: self.api.call(name, *a, **kw)
...@@ -60,17 +60,11 @@ class WorkerThread(TestThread): ...@@ -60,17 +60,11 @@ class WorkerThread(TestThread):
# coordinate the action of multiple threads that all call # coordinate the action of multiple threads that all call
# vote(). This method sends the vote call, then sets the # vote(). This method sends the vote call, then sets the
# event saying vote was called, then waits for the vote # event saying vote was called, then waits for the vote
# response. It digs deep into the implementation of the client. # response.
# This method is a replacement for: future = self.storage._server.call_future('vote', id(self.trans))
# self.ready.set()
# self.storage.tpc_vote(self.trans)
rpc = self.storage._server.rpc
msgid = rpc._deferred_call('vote', id(self.trans))
self.ready.set() self.ready.set()
rpc._deferred_wait(msgid) future.result(9)
self.storage._check_serials()
class CommitLockTests: class CommitLockTests:
......
...@@ -19,7 +19,6 @@ import asyncore ...@@ -19,7 +19,6 @@ import asyncore
import threading import threading
import logging import logging
import ZEO.ServerStub
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZEO.zrpc.marshal import encode from ZEO.zrpc.marshal import encode
...@@ -40,20 +39,10 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests') ...@@ -40,20 +39,10 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests')
ZERO = '\0'*8 ZERO = '\0'*8
class TestServerStub(ZEO.ServerStub.StorageServer):
__super_getInvalidations = ZEO.ServerStub.StorageServer.getInvalidations
def getInvalidations(self, tid):
# squirrel the results away for inspection by test case
self._last_invals = self.__super_getInvalidations(tid)
return self._last_invals
class TestClientStorage(ClientStorage): class TestClientStorage(ClientStorage):
test_connection = False test_connection = False
StorageServerStubClass = TestServerStub
connection_count_for_tests = 0 connection_count_for_tests = 0
def notifyConnected(self, conn): def notifyConnected(self, conn):
...@@ -592,7 +581,6 @@ class InvqTests(CommonSetupTearDown): ...@@ -592,7 +581,6 @@ class InvqTests(CommonSetupTearDown):
def checkQuickVerificationWith2Clients(self): def checkQuickVerificationWith2Clients(self):
perstorage = self.openClientStorage(cache="test", cache_size=4000) perstorage = self.openClientStorage(cache="test", cache_size=4000)
self.assertEqual(perstorage.verify_result, "empty cache")
self._storage = self.openClientStorage() self._storage = self.openClientStorage()
oid = self._storage.new_oid() oid = self._storage.new_oid()
...@@ -624,8 +612,6 @@ class InvqTests(CommonSetupTearDown): ...@@ -624,8 +612,6 @@ class InvqTests(CommonSetupTearDown):
label="perstorage.verify_result to be quick verification") label="perstorage.verify_result to be quick verification")
self.assertEqual(perstorage.verify_result, "quick verification") self.assertEqual(perstorage.verify_result, "quick verification")
self.assertEqual(perstorage._server._last_invals,
(revid, [oid]))
self.assertEqual(perstorage.load(oid, ''), self.assertEqual(perstorage.load(oid, ''),
self._storage.load(oid, '')) self._storage.load(oid, ''))
......
...@@ -17,6 +17,8 @@ import transaction ...@@ -17,6 +17,8 @@ import transaction
import six import six
import gc import gc
from ..asyncio.testing import AsyncRPC
class IterationTests: class IterationTests:
def _assertIteratorIdsEmpty(self): def _assertIteratorIdsEmpty(self):
...@@ -52,7 +54,7 @@ class IterationTests: ...@@ -52,7 +54,7 @@ class IterationTests:
def checkIteratorGCProtocol(self): def checkIteratorGCProtocol(self):
# Test garbage collection on protocol level. # Test garbage collection on protocol level.
server = self._storage._server server = AsyncRPC(self._storage._server)
iid = server.iterator_start(None, None) iid = server.iterator_start(None, None)
# None signals the end of iteration. # None signals the end of iteration.
...@@ -79,7 +81,7 @@ class IterationTests: ...@@ -79,7 +81,7 @@ class IterationTests:
self.assertEquals(0, len(self._storage._iterator_ids)) self.assertEquals(0, len(self._storage._iterator_ids))
# The iterator has run through, so the server has already disposed it. # The iterator has run through, so the server has already disposed it.
self.assertRaises(KeyError, self._storage._server.iterator_next, iid) self.assertRaises(KeyError, self._storage._call, 'iterator_next', iid)
def checkIteratorGCSpanTransactions(self): def checkIteratorGCSpanTransactions(self):
# Keep a hard reference to the iterator so it won't be automatically # Keep a hard reference to the iterator so it won't be automatically
...@@ -112,7 +114,7 @@ class IterationTests: ...@@ -112,7 +114,7 @@ class IterationTests:
self._storage._iterators._last_gc = -1 self._storage._iterators._last_gc = -1
self._dostore() self._dostore()
self._assertIteratorIdsEmpty() self._assertIteratorIdsEmpty()
self.assertRaises(KeyError, self._storage._server.iterator_next, iid) self.assertRaises(KeyError, self._storage._call, 'iterator_next', iid)
def checkIteratorGCStorageTPCAborting(self): def checkIteratorGCStorageTPCAborting(self):
# The odd little jig we do below arises from the fact that the # The odd little jig we do below arises from the fact that the
...@@ -129,7 +131,7 @@ class IterationTests: ...@@ -129,7 +131,7 @@ class IterationTests:
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
self._assertIteratorIdsEmpty() self._assertIteratorIdsEmpty()
self.assertRaises(KeyError, self._storage._server.iterator_next, iid) self.assertRaises(KeyError, self._storage._call, 'iterator_next', iid)
def checkIteratorGCStorageDisconnect(self): def checkIteratorGCStorageDisconnect(self):
...@@ -146,7 +148,7 @@ class IterationTests: ...@@ -146,7 +148,7 @@ class IterationTests:
# Show that after disconnecting, the client side GCs the iterators # Show that after disconnecting, the client side GCs the iterators
# as well. I'm calling this directly to avoid accidentally # as well. I'm calling this directly to avoid accidentally
# calling tpc_abort implicitly. # calling tpc_abort implicitly.
self._storage.notifyDisconnected() self._storage.notify_disconnected()
self.assertEquals(0, len(self._storage._iterator_ids)) self.assertEquals(0, len(self._storage._iterator_ids))
def checkIteratorParallel(self): def checkIteratorParallel(self):
......
...@@ -17,7 +17,7 @@ import threading ...@@ -17,7 +17,7 @@ import threading
import transaction import transaction
from ZODB.tests.StorageTestBase import zodb_pickle, MinPO from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
import ZEO.ClientStorage import ZEO.Exceptions
ZERO = '\0'*8 ZERO = '\0'*8
...@@ -54,7 +54,7 @@ class GetsThroughVoteThread(BasicThread): ...@@ -54,7 +54,7 @@ class GetsThroughVoteThread(BasicThread):
self.doNextEvent.wait(10) self.doNextEvent.wait(10)
try: try:
self.storage.tpc_finish(self.trans) self.storage.tpc_finish(self.trans)
except ZEO.ClientStorage.ClientStorageError: except ZEO.Exceptions.ClientStorageError:
self.gotValueError = 1 self.gotValueError = 1
self.storage.tpc_abort(self.trans) self.storage.tpc_abort(self.trans)
...@@ -67,7 +67,7 @@ class GetsThroughBeginThread(BasicThread): ...@@ -67,7 +67,7 @@ class GetsThroughBeginThread(BasicThread):
def run(self): def run(self):
try: try:
self.storage.tpc_begin(self.trans) self.storage.tpc_begin(self.trans)
except ZEO.ClientStorage.ClientStorageError: except ZEO.Exceptions.ClientStorageError:
self.gotValueError = 1 self.gotValueError = 1
......
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Implements plaintext password authentication. The password is stored in
an SHA hash in the Database. The client sends over the plaintext
password, and the SHA hashing is done on the server side.
This mechanism offers *no network security at all*; the only security
is provided by not storing plaintext passwords on disk.
"""
from ZEO.hash import sha1
from ZEO.StorageServer import ZEOStorage
from ZEO.auth import register_module
from ZEO.auth.base import Client, Database
def session_key(username, realm, password):
key = "%s:%s:%s" % (username, realm, password)
return sha1(key.encode('utf-8')).hexdigest().encode('ascii')
class StorageClass(ZEOStorage):
def auth(self, username, password):
try:
dbpw = self.database.get_password(username)
except LookupError:
return 0
password_dig = sha1(password.encode('utf-8')).hexdigest()
if dbpw == password_dig:
self.connection.setSessionKey(session_key(username,
self.database.realm,
password))
return self._finish_auth(dbpw == password_dig)
class PlaintextClient(Client):
extensions = ["auth"]
def start(self, username, realm, password):
if self.stub.auth(username, password):
return session_key(username, realm, password)
else:
return None
register_module("plaintext", StorageClass, PlaintextClient, Database)
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Test suite for AuthZEO."""
import os
import tempfile
import time
import unittest
from ZEO import zeopasswd
from ZEO.Exceptions import ClientDisconnected
from ZEO.tests.ConnectionTests import CommonSetupTearDown
class _AuthTest(CommonSetupTearDown):
__super_getServerConfig = CommonSetupTearDown.getServerConfig
__super_setUp = CommonSetupTearDown.setUp
__super_tearDown = CommonSetupTearDown.tearDown
realm = None
def setUp(self):
fd, self.pwfile = tempfile.mkstemp('pwfile')
os.close(fd)
if self.realm:
self.pwdb = self.dbclass(self.pwfile, self.realm)
else:
self.pwdb = self.dbclass(self.pwfile)
self.pwdb.add_user("foo", "bar")
self.pwdb.save()
self._checkZEOpasswd()
self.__super_setUp()
def _checkZEOpasswd(self):
args = ["-f", self.pwfile, "-p", self.protocol]
if self.protocol == "plaintext":
from ZEO.auth.base import Database
zeopasswd.main(args + ["-d", "foo"], Database)
zeopasswd.main(args + ["foo", "bar"], Database)
else:
zeopasswd.main(args + ["-d", "foo"])
zeopasswd.main(args + ["foo", "bar"])
def tearDown(self):
os.remove(self.pwfile)
self.__super_tearDown()
def getConfig(self, path, create, read_only):
return "<mappingstorage 1/>"
def getServerConfig(self, addr, ro_svr):
zconf = self.__super_getServerConfig(addr, ro_svr)
zconf.authentication_protocol = self.protocol
zconf.authentication_database = self.pwfile
zconf.authentication_realm = self.realm
return zconf
def wait(self):
for i in range(25):
time.sleep(0.1)
if self._storage.test_connection:
return
self.fail("Timed out waiting for client to authenticate")
def testOK(self):
# Sleep for 0.2 seconds to give the server some time to start up
# seems to be needed before and after creating the storage
self._storage = self.openClientStorage(wait=0, username="foo",
password="bar", realm=self.realm)
self.wait()
self.assert_(self._storage._connection)
self._storage._connection.poll()
self.assert_(self._storage.is_connected())
# Make a call to make sure the mechanism is working
self._storage.undoInfo()
def testNOK(self):
self._storage = self.openClientStorage(wait=0, username="foo",
password="noogie",
realm=self.realm)
self.wait()
# If the test established a connection, then it failed.
self.failIf(self._storage._connection)
def testUnauthenticatedMessage(self):
# Test that an unauthenticated message is rejected by the server
# if it was sent after the connection was authenticated.
self._storage = self.openClientStorage(wait=0, username="foo",
password="bar", realm=self.realm)
# Sleep for 0.2 seconds to give the server some time to start up
# seems to be needed before and after creating the storage
self.wait()
self._storage.undoInfo()
# Manually clear the state of the hmac connection
self._storage._connection._SizedMessageAsyncConnection__hmac_send = None
# Once the client stops using the hmac, it should be disconnected.
self.assertRaises(ClientDisconnected, self._storage.undoInfo)
class PlainTextAuth(_AuthTest):
import ZEO.tests.auth_plaintext
protocol = "plaintext"
database = "authdb.sha"
dbclass = ZEO.tests.auth_plaintext.Database
realm = "Plaintext Realm"
class DigestAuth(_AuthTest):
import ZEO.auth.auth_digest
protocol = "digest"
database = "authdb.digest"
dbclass = ZEO.auth.auth_digest.DigestDatabase
realm = "Digest Realm"
test_classes = [PlainTextAuth, DigestAuth]
def test_suite():
suite = unittest.TestSuite()
for klass in test_classes:
sub = unittest.makeSuite(klass)
suite.addTest(sub)
return suite
if __name__ == "__main__":
unittest.main(defaultTest='test_suite')
...@@ -46,7 +46,6 @@ import threading ...@@ -46,7 +46,6 @@ import threading
import time import time
import transaction import transaction
import unittest import unittest
import ZEO.ServerStub
import ZEO.StorageServer import ZEO.StorageServer
import ZEO.tests.ConnectionTests import ZEO.tests.ConnectionTests
import ZEO.zrpc.connection import ZEO.zrpc.connection
...@@ -1721,7 +1720,7 @@ def can_use_empty_string_for_local_host_on_client(): ...@@ -1721,7 +1720,7 @@ def can_use_empty_string_for_local_host_on_client():
""" """
slow_test_classes = [ slow_test_classes = [
BlobAdaptedFileStorageTests, BlobWritableCacheTests, #BlobAdaptedFileStorageTests, BlobWritableCacheTests,
MappingStorageTests, DemoStorageTests, MappingStorageTests, DemoStorageTests,
FileStorageTests, FileStorageHexTests, FileStorageClientHexTests, FileStorageTests, FileStorageHexTests, FileStorageClientHexTests,
] ]
...@@ -1733,12 +1732,6 @@ quick_test_classes = [ ...@@ -1733,12 +1732,6 @@ quick_test_classes = [
class ServerManagingClientStorage(ClientStorage): class ServerManagingClientStorage(ClientStorage):
class StorageServerStubClass(ZEO.ServerStub.StorageServer):
# Wait for abort for the benefit of blob_transaction.txt
def tpc_abort(self, id):
self.rpc.call('tpc_abort', id)
def __init__(self, name, blob_dir, shared=False, extrafsoptions=''): def __init__(self, name, blob_dir, shared=False, extrafsoptions=''):
if shared: if shared:
server_blob_dir = blob_dir server_blob_dir = blob_dir
......
...@@ -174,9 +174,6 @@ def main(): ...@@ -174,9 +174,6 @@ def main():
zo.realize(["-C", configfile]) zo.realize(["-C", configfile])
addr = zo.address addr = zo.address
if zo.auth_protocol == "plaintext":
__import__('ZEO.tests.auth_plaintext')
if isinstance(addr, tuple): if isinstance(addr, tuple):
test_addr = addr[0], addr[1]+1 test_addr = addr[0], addr[1]+1
else: else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment