Commit 71a0de50 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement revision-aware caching.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2531 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 6b799777
...@@ -43,7 +43,7 @@ from neo.client.handlers import storage, master ...@@ -43,7 +43,7 @@ from neo.client.handlers import storage, master
from neo.dispatcher import Dispatcher, ForgottenPacket from neo.dispatcher import Dispatcher, ForgottenPacket
from neo.client.poll import ThreadedPoll, psThreadedPoll from neo.client.poll import ThreadedPoll, psThreadedPoll
from neo.client.iterator import Iterator from neo.client.iterator import Iterator
from neo.client.mq import MQ from neo.client.mq import MQ, MQIndex
from neo.client.pool import ConnectionPool from neo.client.pool import ConnectionPool
from neo.util import u64, parseMasterList from neo.util import u64, parseMasterList
from neo.profiling import profiler_decorator, PROFILING_ENABLED from neo.profiling import profiler_decorator, PROFILING_ENABLED
...@@ -119,6 +119,139 @@ class ThreadContext(object): ...@@ -119,6 +119,139 @@ class ThreadContext(object):
'last_transaction': None, 'last_transaction': None,
} }
class RevisionIndex(MQIndex):
"""
This cache index allows accessing a specifig revision of a cached object.
It requires cache key to be a 2-tuple, composed of oid and revision.
Note: it is expected that rather few revisions are held in cache, with few
lookups for old revisions, so they are held in a simple sorted list
Note2: all methods here must be called with cache lock acquired.
"""
def __init__(self):
# key: oid
# value: tid list, from highest to lowest
self._oid_dict = {}
# key: oid
# value: tid list, from lowest to highest
self._invalidated = {}
def clear(self):
self._oid_dict.clear()
self._invalidated.clear()
def remove(self, key):
oid_dict = self._oid_dict
oid, tid = key
tid_list = oid_dict[oid]
tid_list.remove(tid)
if not tid_list:
# No more serial known for this object, drop entirely
del oid_dict[oid]
self._invalidated.pop(oid, None)
def add(self, key):
oid_dict = self._oid_dict
oid, tid = key
try:
serial_list = oid_dict[oid]
except KeyError:
serial_list = oid_dict[oid] = []
else:
assert tid not in serial_list
if not(serial_list) or tid > serial_list[0]:
serial_list.insert(0, tid)
else:
serial_list.insert(0, tid)
serial_list.sort(reverse=True)
invalidated = self._invalidated
try:
tid_list = invalidated[oid]
except KeyError:
pass
else:
try:
tid_list.remove(tid)
except ValueError:
pass
else:
if not tid_list:
del invalidated[oid]
def invalidate(self, oid_list, tid):
"""
Mark object invalidated by given transaction.
Must be called with increasing TID values (which is standard for
ZODB).
"""
invalidated = self._invalidated
oid_dict = self._oid_dict
for oid in (x for x in oid_list if x in oid_dict):
try:
tid_list = invalidated[oid]
except KeyError:
tid_list = invalidated[oid] = []
assert not tid_list or tid > tid_list[-1], (dump(oid), dump(tid),
dump(tid_list[-1]))
tid_list.append(tid)
def getSerialBefore(self, oid, tid):
"""
Get the first tid in cache which value is lower that given tid.
"""
# WARNING: return-intensive to save on indentation
oid_list = self._oid_dict.get(oid)
if oid_list is None:
# Unknown oid
return None
for result in oid_list:
if result < tid:
# Candidate found
break
else:
# No candidate in cache.
return None
# Check if there is a chance that an intermediate revision would
# exist, while missing from cache.
try:
inv_tid_list = self._invalidated[oid]
except KeyError:
return result
# Remember: inv_tid_list is sorted in ascending order.
for inv_tid in inv_tid_list:
if tid < inv_tid:
# We don't care about invalidations past requested TID.
break
elif result < inv_tid < tid:
# An invalidation was received between candidate revision,
# and before requested TID: there is a matching revision we
# don't know of, so we cannot answer.
return None
return result
def getLatestSerial(self, oid):
"""
Get the latest tid for given object.
"""
result = self._oid_dict.get(oid)
if result is not None:
result = result[0]
try:
tid_list = self._invalidated[oid]
except KeyError:
pass
else:
if result < tid_list[-1]:
# An invalidation happened from a transaction later than our
# most recent view of this object, so we cannot answer.
result = None
return result
def getSerialList(self, oid):
"""
Get the list of all serials cache knows about for given object.
"""
return self._oid_dict.get(oid, [])[:]
class Application(object): class Application(object):
"""The client node application.""" """The client node application."""
...@@ -147,6 +280,8 @@ class Application(object): ...@@ -147,6 +280,8 @@ class Application(object):
# no self-assigned UUID, primary master will supply us one # no self-assigned UUID, primary master will supply us one
self.uuid = None self.uuid = None
self.mq_cache = MQ() self.mq_cache = MQ()
self.cache_revision_index = RevisionIndex()
self.mq_cache.addIndex(self.cache_revision_index)
self.new_oid_list = [] self.new_oid_list = []
self.last_oid = '\0' * 8 self.last_oid = '\0' * 8
self.storage_event_handler = storage.StorageEventHandler(self) self.storage_event_handler = storage.StorageEventHandler(self)
...@@ -429,7 +564,7 @@ class Application(object): ...@@ -429,7 +564,7 @@ class Application(object):
return int(u64(self.last_oid)) return int(u64(self.last_oid))
@profiler_decorator @profiler_decorator
def _load(self, oid, serial=None, tid=None, cache=0): def _load(self, oid, serial=None, tid=None):
""" """
Internal method which manage load, loadSerial and loadBefore. Internal method which manage load, loadSerial and loadBefore.
OID and TID (serial) parameters are expected packed. OID and TID (serial) parameters are expected packed.
...@@ -441,8 +576,6 @@ class Application(object): ...@@ -441,8 +576,6 @@ class Application(object):
tid tid
If given, the excluded upper bound serial at which OID is desired. If given, the excluded upper bound serial at which OID is desired.
serial should be None. serial should be None.
cache
Store data in cache for future lookups.
Return value: (3-tuple) Return value: (3-tuple)
- Object data (None if object creation was undone). - Object data (None if object creation was undone).
...@@ -471,7 +604,6 @@ class Application(object): ...@@ -471,7 +604,6 @@ class Application(object):
if not self.local_var.barrier_done: if not self.local_var.barrier_done:
self.invalidationBarrier() self.invalidationBarrier()
self.local_var.barrier_done = True self.local_var.barrier_done = True
if cache:
try: try:
result = self._loadFromCache(oid, serial, tid) result = self._loadFromCache(oid, serial, tid)
except KeyError: except KeyError:
...@@ -480,10 +612,9 @@ class Application(object): ...@@ -480,10 +612,9 @@ class Application(object):
return result return result
data, start_serial, end_serial = self._loadFromStorage(oid, serial, data, start_serial, end_serial = self._loadFromStorage(oid, serial,
tid) tid)
if cache:
self._cache_lock_acquire() self._cache_lock_acquire()
try: try:
self.mq_cache[oid] = start_serial, data self.mq_cache[(oid, start_serial)] = data, end_serial
finally: finally:
self._cache_lock_release() self._cache_lock_release()
if data == '': if data == '':
...@@ -555,16 +686,25 @@ class Application(object): ...@@ -555,16 +686,25 @@ class Application(object):
""" """
self._cache_lock_acquire() self._cache_lock_acquire()
try: try:
tid, data = self.mq_cache[oid] if at_tid is not None:
neo.logging.debug('load oid %s is cached', dump(oid)) tid = at_tid
return (data, tid, None) elif before_tid is not None:
tid = self.cache_revision_index.getSerialBefore(oid,
before_tid)
else:
tid = self.cache_revision_index.getLatestSerial(oid)
if tid is None:
raise KeyError
# Raises KeyError on miss
data, next_tid = self.mq_cache[(oid, tid)]
return (data, tid, next_tid)
finally: finally:
self._cache_lock_release() self._cache_lock_release()
@profiler_decorator @profiler_decorator
def load(self, oid, version=None): def load(self, oid, version=None):
"""Load an object for a given oid.""" """Load an object for a given oid."""
result = self._load(oid, cache=1)[:2] result = self._load(oid)[:2]
# Start a network barrier, so we get all invalidations *after* we # Start a network barrier, so we get all invalidations *after* we
# received data. This ensures we get any invalidation message that # received data. This ensures we get any invalidation message that
# would have been about the version we loaded. # would have been about the version we loaded.
...@@ -578,7 +718,6 @@ class Application(object): ...@@ -578,7 +718,6 @@ class Application(object):
@profiler_decorator @profiler_decorator
def loadSerial(self, oid, serial): def loadSerial(self, oid, serial):
"""Load an object for a given oid and serial.""" """Load an object for a given oid and serial."""
# Do not try in cache as it manages only up-to-date object
neo.logging.debug('loading %s at %s', dump(oid), dump(serial)) neo.logging.debug('loading %s at %s', dump(oid), dump(serial))
return self._load(oid, serial=serial)[0] return self._load(oid, serial=serial)[0]
...@@ -586,7 +725,6 @@ class Application(object): ...@@ -586,7 +725,6 @@ class Application(object):
@profiler_decorator @profiler_decorator
def loadBefore(self, oid, tid): def loadBefore(self, oid, tid):
"""Load an object for a given oid before tid committed.""" """Load an object for a given oid before tid committed."""
# Do not try in cache as it manages only up-to-date object
neo.logging.debug('loading %s before %s', dump(oid), dump(tid)) neo.logging.debug('loading %s before %s', dump(oid), dump(tid))
return self._load(oid, tid=tid) return self._load(oid, tid=tid)
...@@ -878,12 +1016,30 @@ class Application(object): ...@@ -878,12 +1016,30 @@ class Application(object):
self._cache_lock_acquire() self._cache_lock_acquire()
try: try:
mq_cache = self.mq_cache mq_cache = self.mq_cache
update = mq_cache.update
def updateNextSerial(value):
data, next_tid = value
assert next_tid is None, (dump(oid), dump(base_tid),
dump(next_tid))
return (data, tid)
get_baseTID = local_var.object_serial_dict.get
for oid, data in local_var.data_dict.iteritems(): for oid, data in local_var.data_dict.iteritems():
if data is None:
# this is just a remain of
# checkCurrentSerialInTransaction call, ignore (no data
# was modified).
continue
# Update ex-latest value in cache
base_tid = get_baseTID(oid)
try:
update((oid, base_tid), updateNextSerial)
except KeyError:
pass
if data == '': if data == '':
if oid in mq_cache: self.cache_revision_index.invalidate([oid], tid)
del mq_cache[oid]
else: else:
mq_cache[oid] = tid, data # Store in cache with no next_tid
mq_cache[(oid, tid)] = (data, None)
finally: finally:
self._cache_lock_release() self._cache_lock_release()
local_var.clear() local_var.clear()
...@@ -1234,6 +1390,15 @@ class Application(object): ...@@ -1234,6 +1390,15 @@ class Application(object):
if tid == ZERO_TID: if tid == ZERO_TID:
raise NEOStorageError('Invalid pack time') raise NEOStorageError('Invalid pack time')
self._askPrimary(Packets.AskPack(tid)) self._askPrimary(Packets.AskPack(tid))
# XXX: this is only needed to make ZODB unit tests pass.
# It should not be otherwise required (clients should be free to load
# old data as long as it is available in cache, event if it was pruned
# by a pack), so don't bother invalidating on other clients.
self._cache_lock_acquire()
try:
self.mq_cache.clear()
finally:
self._cache_lock_release()
def getLastTID(self, oid): def getLastTID(self, oid):
return self._load(oid)[1] return self._load(oid)[1]
......
...@@ -123,10 +123,7 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -123,10 +123,7 @@ class PrimaryNotificationsHandler(BaseHandler):
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
# ZODB required a dict with oid as key, so create it # ZODB required a dict with oid as key, so create it
mq_cache = app.mq_cache app.cache_revision_index.invalidate(oid_list, tid)
for oid in oid_list:
if oid in mq_cache:
del mq_cache[oid]
db = app.getDB() db = app.getDB()
if db is not None: if db is not None:
db.invalidate(tid, dict.fromkeys(oid_list, tid)) db.invalidate(tid, dict.fromkeys(oid_list, tid))
......
...@@ -21,7 +21,7 @@ from cPickle import dumps ...@@ -21,7 +21,7 @@ from cPickle import dumps
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from neo.tests import NeoUnitTestBase from neo.tests import NeoUnitTestBase
from neo.client.app import Application from neo.client.app import Application, RevisionIndex
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL
...@@ -208,7 +208,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -208,7 +208,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None) an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None)
# connection to SN close # connection to SN close
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -224,7 +225,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -224,7 +225,8 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAskObject(conn) self.checkAskObject(conn)
Application._waitMessage = _waitMessage Application._waitMessage = _waitMessage
# object not found in NEO -> NEOStorageNotFoundError # object not found in NEO -> NEOStorageNotFoundError
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -254,7 +256,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -254,7 +256,7 @@ class ClientApplicationTests(NeoUnitTestBase):
result = app.load(oid) result = app.load(oid)
self.assertEquals(result, ('OBJ', tid1)) self.assertEquals(result, ('OBJ', tid1))
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue((oid, tid1) in mq)
# object is now cached, try to reload it # object is now cached, try to reload it
conn = Mock({ conn = Mock({
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
...@@ -272,7 +274,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -272,7 +274,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
# object not found in NEO -> NEOStorageNotFoundError # object not found in NEO -> NEOStorageNotFoundError
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -285,10 +288,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -285,10 +288,10 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2) self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2)
self.checkAskObject(conn) self.checkAskObject(conn)
# object should not have been cached # object should not have been cached
self.assertFalse(oid in mq) self.assertFalse((oid, tid2) in mq)
# now a cached version ewxists but should not be hit # now a cached version ewxists but should not be hit
mq.store(oid, (tid2, 'WRONG')) mq.store((oid, tid2), ('WRONG', None))
self.assertTrue(oid in mq) self.assertTrue((oid, tid2) in mq)
another_object = (1, oid, tid2, INVALID_SERIAL, 0, another_object = (1, oid, tid2, INVALID_SERIAL, 0,
makeChecksum('RIGHT'), 'RIGHT', None) makeChecksum('RIGHT'), 'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
...@@ -302,7 +305,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -302,7 +305,7 @@ class ClientApplicationTests(NeoUnitTestBase):
result = app.loadSerial(oid, tid1) result = app.loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT') self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue((oid, tid2) in mq)
def test_loadBefore(self): def test_loadBefore(self):
app = self.getApp() app = self.getApp()
...@@ -313,7 +316,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -313,7 +316,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
tid3 = self.makeTID(3) tid3 = self.makeTID(3)
# object not found in NEO -> NEOStorageDoesNotExistError # object not found in NEO -> NEOStorageDoesNotExistError
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidDoesNotExist('') packet = Errors.OidDoesNotExist('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -337,11 +341,12 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -337,11 +341,12 @@ class ClientApplicationTests(NeoUnitTestBase):
app.local_var.asked_object = an_object[:-1] app.local_var.asked_object = an_object[:-1]
self.assertRaises(NEOStorageError, app.loadBefore, oid, tid1) self.assertRaises(NEOStorageError, app.loadBefore, oid, tid1)
# object should not have been cached # object should not have been cached
self.assertFalse(oid in mq) self.assertFalse((oid, tid1) in mq)
# as for loadSerial, the object is cached but should be loaded from db # as for loadSerial, the object is cached but should be loaded from db
mq.store(oid, (tid1, 'WRONG')) mq.store((oid, tid1), ('WRONG', tid2))
self.assertTrue(oid in mq) self.assertTrue((oid, tid1) in mq)
another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'), app.cache_revision_index.invalidate([oid], tid2)
another_object = (1, oid, tid2, tid3, 0, makeChecksum('RIGHT'),
'RIGHT', None) 'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
packet.setId(0) packet.setId(0)
...@@ -352,9 +357,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -352,9 +357,9 @@ class ClientApplicationTests(NeoUnitTestBase):
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = another_object app.local_var.asked_object = another_object
result = app.loadBefore(oid, tid3) result = app.loadBefore(oid, tid3)
self.assertEquals(result, ('RIGHT', tid1, tid2)) self.assertEquals(result, ('RIGHT', tid2, tid3))
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue((oid, tid1) in mq)
def test_tpc_begin(self): def test_tpc_begin(self):
app = self.getApp() app = self.getApp()
...@@ -1156,6 +1161,90 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -1156,6 +1161,90 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertEqual(marker[0].getType(), Packets.AskPack) self.assertEqual(marker[0].getType(), Packets.AskPack)
# XXX: how to validate packet content ? # XXX: how to validate packet content ?
def test_RevisionIndex_1(self):
# Test add, getLatestSerial, getSerialList and clear
# without invalidations
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid1 = self.getOID(1)
tid2 = self.getOID(2)
tid3 = self.getOID(3)
ri = RevisionIndex()
# index is empty
self.assertEqual(ri.getSerialList(oid1), [])
ri.add((oid1, tid1))
# now, it knows oid1 at tid1
self.assertEqual(ri.getLatestSerial(oid1), tid1)
self.assertEqual(ri.getSerialList(oid1), [tid1])
self.assertEqual(ri.getSerialList(oid2), [])
ri.add((oid1, tid2))
# and at tid2
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
ri.remove((oid1, tid1))
# oid1 at tid1 was pruned from cache
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2])
ri.remove((oid1, tid2))
# oid1 is completely priuned from cache
self.assertEqual(ri.getLatestSerial(oid1), None)
self.assertEqual(ri.getSerialList(oid1), [])
ri.add((oid1, tid2))
ri.add((oid1, tid1))
# oid1 is populated, but in non-chronological order, check index
# still answers consistent result.
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
ri.add((oid2, tid3))
# which is not affected by the addition of oid2 at tid3
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
ri.clear()
# index is empty again
self.assertEqual(ri.getSerialList(oid1), [])
self.assertEqual(ri.getSerialList(oid2), [])
def test_RevisionIndex_2(self):
# Test getLatestSerial & getSerialBefore with invalidations
oid1 = self.getOID(1)
tid1 = self.getOID(1)
tid2 = self.getOID(2)
tid3 = self.getOID(3)
tid4 = self.getOID(4)
tid5 = self.getOID(5)
tid6 = self.getOID(6)
ri = RevisionIndex()
ri.add((oid1, tid1))
ri.add((oid1, tid2))
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialBefore(oid1, tid2), tid1)
self.assertEqual(ri.getSerialBefore(oid1, tid3), tid2)
self.assertEqual(ri.getSerialBefore(oid1, tid4), tid2)
ri.invalidate([oid1], tid3)
# We don't have the latest data in cache, return None
self.assertEqual(ri.getLatestSerial(oid1), None)
self.assertEqual(ri.getSerialBefore(oid1, tid2), tid1)
self.assertEqual(ri.getSerialBefore(oid1, tid3), tid2)
# There is a gap between the last version we have and requested one,
# return None
self.assertEqual(ri.getSerialBefore(oid1, tid4), None)
ri.add((oid1, tid3))
# No gap anymore, tid3 found.
self.assertEqual(ri.getLatestSerial(oid1), tid3)
self.assertEqual(ri.getSerialBefore(oid1, tid4), tid3)
ri.invalidate([oid1], tid4)
ri.invalidate([oid1], tid5)
# A bigger gap...
self.assertEqual(ri.getLatestSerial(oid1), None)
self.assertEqual(ri.getSerialBefore(oid1, tid5), None)
self.assertEqual(ri.getSerialBefore(oid1, tid6), None)
# not entirely filled.
ri.add((oid1, tid5))
# Still, we know the latest and what is before tid6
self.assertEqual(ri.getLatestSerial(oid1), tid5)
self.assertEqual(ri.getSerialBefore(oid1, tid5), None)
self.assertEqual(ri.getSerialBefore(oid1, tid6), tid5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -156,17 +156,22 @@ class MasterNotificationsHandlerTests(MasterHandlerTests): ...@@ -156,17 +156,22 @@ class MasterNotificationsHandlerTests(MasterHandlerTests):
def test_invalidateObjects(self): def test_invalidateObjects(self):
conn = self.getConnection() conn = self.getConnection()
tid = self.getNextTID() tid = self.getNextTID()
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2, oid3 = self.getOID(1), self.getOID(2), self.getOID(3)
self.app.mq_cache = { self.app.mq_cache = {
oid1: tid, (oid1, tid): ('bla', None),
oid2: tid, (oid2, tid): ('bla', None),
} }
self.handler.invalidateObjects(conn, tid, [oid1]) self.app.cache_revision_index = Mock({
self.assertFalse(oid1 in self.app.mq_cache) 'invalidate': None,
self.assertTrue(oid2 in self.app.mq_cache) })
self.handler.invalidateObjects(conn, tid, [oid1, oid3])
cache_calls = self.app.cache_revision_index.mockGetNamedCalls(
'invalidate')
self.assertEqual(len(cache_calls), 1)
cache_calls[0].checkArgs([oid1, oid3], tid)
invalidation_calls = self.db.mockGetNamedCalls('invalidate') invalidation_calls = self.db.mockGetNamedCalls('invalidate')
self.assertEqual(len(invalidation_calls), 1) self.assertEqual(len(invalidation_calls), 1)
invalidation_calls[0].checkArgs(tid, {oid1:tid}) invalidation_calls[0].checkArgs(tid, {oid1:tid, oid3:tid})
def test_notifyPartitionChanges(self): def test_notifyPartitionChanges(self):
conn = self.getConnection() conn = self.getConnection()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment