Commit 55be3ee8 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement distributed packing.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2287 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f2720778
...@@ -74,6 +74,11 @@ class Storage(BaseStorage.BaseStorage, ...@@ -74,6 +74,11 @@ class Storage(BaseStorage.BaseStorage,
return self.app.store(oid=oid, serial=serial, return self.app.store(oid=oid, serial=serial,
data=data, version=version, transaction=transaction) data=data, version=version, transaction=transaction)
@check_read_only
def deleteObject(oid, serial, transaction):
self.app.store(oid=oid, serial=serial, data='', version=None,
transaction=transaction)
def getSerial(self, oid): def getSerial(self, oid):
try: try:
return self.app.getSerial(oid = oid) return self.app.getSerial(oid = oid)
...@@ -154,8 +159,11 @@ class Storage(BaseStorage.BaseStorage, ...@@ -154,8 +159,11 @@ class Storage(BaseStorage.BaseStorage,
def restore(self, oid, serial, data, version, prev_txn, transaction): def restore(self, oid, serial, data, version, prev_txn, transaction):
raise NotImplementedError raise NotImplementedError
def pack(self, t, referencesf): def pack(self, t, referencesf, gc=False):
raise NotImplementedError if gc:
logging.warning('Garbage Collection is not available in NEO, '
'please use an external tool. Packing without GC.')
self.app.pack(t)
def lastSerial(self): def lastSerial(self):
# seems unused # seems unused
......
...@@ -24,12 +24,13 @@ import time ...@@ -24,12 +24,13 @@ import time
from ZODB.POSException import UndoError, StorageTransactionError, ConflictError from ZODB.POSException import UndoError, StorageTransactionError, ConflictError
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
from persistent.TimeStamp import TimeStamp
from neo import setupLog from neo import setupLog
setupLog('CLIENT', verbose=True) setupLog('CLIENT', verbose=True)
from neo import logging from neo import logging
from neo.protocol import NodeTypes, Packets, INVALID_PARTITION from neo.protocol import NodeTypes, Packets, INVALID_PARTITION, ZERO_TID
from neo.event import EventManager from neo.event import EventManager
from neo.util import makeChecksum as real_makeChecksum, dump from neo.util import makeChecksum as real_makeChecksum, dump
from neo.locking import Lock from neo.locking import Lock
...@@ -1229,3 +1230,9 @@ class Application(object): ...@@ -1229,3 +1230,9 @@ class Application(object):
def isTransactionVoted(self): def isTransactionVoted(self):
return self.local_var.txn_voted return self.local_var.txn_voted
def pack(self, t):
tid = repr(TimeStamp(*time.gmtime(t)[:5] + (t % 60, )))
if tid == ZERO_TID:
raise NEOStorageError('Invalid pack time')
self._askPrimary(Packets.AskPack(tid))
...@@ -21,6 +21,7 @@ from neo.client.handlers import BaseHandler, AnswerBaseHandler ...@@ -21,6 +21,7 @@ from neo.client.handlers import BaseHandler, AnswerBaseHandler
from neo.pt import MTPartitionTable as PartitionTable from neo.pt import MTPartitionTable as PartitionTable
from neo.protocol import NodeTypes, NodeStates, ProtocolError from neo.protocol import NodeTypes, NodeStates, ProtocolError
from neo.util import dump from neo.util import dump
from neo.client.exception import NEOStorageError
class PrimaryBootstrapHandler(AnswerBaseHandler): class PrimaryBootstrapHandler(AnswerBaseHandler):
""" Bootstrap handler used when looking for the primary master """ """ Bootstrap handler used when looking for the primary master """
...@@ -170,3 +171,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -170,3 +171,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
raise ProtocolError('Wrong TID, transaction not started') raise ProtocolError('Wrong TID, transaction not started')
self.app.setTransactionFinished() self.app.setTransactionFinished()
def answerPack(self, conn, status):
if not status:
raise NEOStorageError('Already packing')
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from ZODB import BaseStorage from ZODB import BaseStorage
from neo import util from neo import util
from neo.client.exception import NEOStorageCreationUndoneError from neo.client.exception import NEOStorageCreationUndoneError
from neo.client.exception import NEOStorageNotFoundError
class Record(BaseStorage.DataRecord): class Record(BaseStorage.DataRecord):
""" TBaseStorageransaction record yielded by the Transaction object """ """ TBaseStorageransaction record yielded by the Transaction object """
...@@ -60,17 +61,27 @@ class Transaction(BaseStorage.TransactionRecord): ...@@ -60,17 +61,27 @@ class Transaction(BaseStorage.TransactionRecord):
app = self.app app = self.app
oid_list = self.oid_list oid_list = self.oid_list
oid_index = self.oid_index oid_index = self.oid_index
if self.oid_index >= len(oid_list): oid_len = len(oid_list)
# no more records for this transaction
self.oid_index = 0
raise StopIteration
oid = oid_list[oid_index]
self.oid_index = oid_index + 1
# load an object # load an object
while oid_index < oid_len:
oid = oid_list[oid_index]
try: try:
data, _, next_tid = app._load(oid, serial=self.tid) data, _, next_tid = app._load(oid, serial=self.tid)
except NEOStorageCreationUndoneError: except NEOStorageCreationUndoneError:
data = next_tid = None data = next_tid = None
except NEOStorageNotFoundError:
# Transactions are not updated after a pack, so their object
# will not be found in the database. Skip them.
oid_list.pop(oid_index)
oid_len -= 1
continue
oid_index += 1
break
else:
# no more records for this transaction
self.oid_index = 0
raise StopIteration
self.oid_index = oid_index
record = Record(oid, self.tid, '', data, record = Record(oid, self.tid, '', data,
self.prev_serial_dict.get(oid)) self.prev_serial_dict.get(oid))
if next_tid is None: if next_tid is None:
......
...@@ -353,6 +353,12 @@ class EventHandler(object): ...@@ -353,6 +353,12 @@ class EventHandler(object):
def answerBarrier(self, conn): def answerBarrier(self, conn):
pass pass
def askPack(self, conn, tid):
raise UnexpectedPacketError
def answerPack(self, conn, status):
raise UnexpectedPacketError
# Error packet handlers. # Error packet handlers.
def error(self, conn, code, message): def error(self, conn, code, message):
...@@ -468,6 +474,8 @@ class EventHandler(object): ...@@ -468,6 +474,8 @@ class EventHandler(object):
d[Packets.AnswerHasLock] = self.answerHasLock d[Packets.AnswerHasLock] = self.answerHasLock
d[Packets.AskBarrier] = self.askBarrier d[Packets.AskBarrier] = self.askBarrier
d[Packets.AnswerBarrier] = self.answerBarrier d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack
return d return d
......
...@@ -40,6 +40,7 @@ from neo.live_debug import register as registerLiveDebugger ...@@ -40,6 +40,7 @@ from neo.live_debug import register as registerLiveDebugger
class Application(object): class Application(object):
"""The master node application.""" """The master node application."""
packing = None
def __init__(self, config): def __init__(self, config):
......
...@@ -88,3 +88,15 @@ class ClientServiceHandler(MasterHandler): ...@@ -88,3 +88,15 @@ class ClientServiceHandler(MasterHandler):
node = self.app.nm.getByUUID(conn.getUUID()) node = self.app.nm.getByUUID(conn.getUUID())
app.tm.prepare(node, tid, oid_list, used_uuid_set, conn.getPeerId()) app.tm.prepare(node, tid, oid_list, used_uuid_set, conn.getPeerId())
def askPack(self, conn, tid):
app = self.app
if app.packing is None:
storage_list = self.app.nm.getStorageList(only_identified=True)
app.packing = (conn, conn.getPeerId(),
set(x.getUUID() for x in storage_list))
p = Packets.AskPack(tid)
for storage in storage_list:
storage.getConnection().ask(p)
else:
conn.answer(Packets.AnswerPack(False))
...@@ -22,6 +22,7 @@ from neo.protocol import CellStates, Packets ...@@ -22,6 +22,7 @@ from neo.protocol import CellStates, Packets
from neo.master.handlers import BaseServiceHandler from neo.master.handlers import BaseServiceHandler
from neo.exception import OperationFailure from neo.exception import OperationFailure
from neo.util import dump from neo.util import dump
from neo.connector import ConnectorConnectionClosedException
class StorageServiceHandler(BaseServiceHandler): class StorageServiceHandler(BaseServiceHandler):
...@@ -46,6 +47,9 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -46,6 +47,9 @@ class StorageServiceHandler(BaseServiceHandler):
# if a transaction is known, this means that it's being committed # if a transaction is known, this means that it's being committed
if transaction.forget(uuid): if transaction.forget(uuid):
self._afterLock(tid) self._afterLock(tid)
packing = self.app.packing
if packing is not None:
self.answerPack(conn, False)
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
...@@ -124,4 +128,15 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -124,4 +128,15 @@ class StorageServiceHandler(BaseServiceHandler):
break break
self.app.broadcastPartitionChanges(cell_list) self.app.broadcastPartitionChanges(cell_list)
def answerPack(self, conn, status):
app = self.app
if app.packing is not None:
client, msg_id, uid_set = app.packing
uid_set.remove(conn.getUUID())
if not uid_set:
app.packing = None
try:
client.answer(Packets.AnswerPack(True), msg_id=msg_id)
except ConnectorConnectionClosedException:
pass
...@@ -1634,6 +1634,32 @@ class AskBarrier(Packet): ...@@ -1634,6 +1634,32 @@ class AskBarrier(Packet):
class AnswerBarrier(Packet): class AnswerBarrier(Packet):
pass pass
class AskPack(Packet):
"""
Request a pack at given TID.
C -> M
M -> S
"""
def _encode(self, tid):
return _encodeTID(tid)
def _decode(self, body):
return (_decodeTID(body), )
class AnswerPack(Packet):
"""
Inform that packing it over.
S -> M
M -> C
"""
_header_format = '!H'
def _encode(self, status):
return pack(self._header_format, int(status))
def _decode(self, body):
return (bool(unpack(self._header_format, body)[0]), )
class Error(Packet): class Error(Packet):
""" """
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
...@@ -1873,6 +1899,10 @@ class PacketRegistry(dict): ...@@ -1873,6 +1899,10 @@ class PacketRegistry(dict):
0x037, 0x037,
AskBarrier, AskBarrier,
AnswerBarrier) AnswerBarrier)
AskPack, AnswerPack = register(
0x0038,
AskPack,
AnswerPack)
# build a "singleton" # build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
...@@ -323,3 +323,19 @@ class DatabaseManager(object): ...@@ -323,3 +323,19 @@ class DatabaseManager(object):
the given list.""" the given list."""
raise NotImplementedError raise NotImplementedError
def pack(self, tid, updateObjectDataForPack):
"""Prune all non-current object revisions at given tid.
updateObjectDataForPack is a function called for each deleted object
and revision with:
- OID
- packed TID
- new value_serial
If object data was moved to an after-pack-tid revision, this
parameter contains the TID of that revision, allowing to backlink
to it.
- getObjectData function
To call if value_serial is None and an object needs to be updated.
Takes no parameter, returns a 3-tuple: compression, checksum,
value
"""
...@@ -206,6 +206,16 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -206,6 +206,16 @@ class MySQLDatabaseManager(DatabaseManager):
value = "'%s'" % (e(str(value)), ) value = "'%s'" % (e(str(value)), )
q("""REPLACE INTO config VALUES ('%s', %s)""" % (key, value)) q("""REPLACE INTO config VALUES ('%s', %s)""" % (key, value))
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
result = int(self.getConfiguration('_pack_tid'))
except KeyError:
result = -1
return result
def getPartitionTable(self): def getPartitionTable(self):
q = self.query q = self.query
cell_list = q("""SELECT rid, uuid, state FROM pt""") cell_list = q("""SELECT rid, uuid, state FROM pt""")
...@@ -618,9 +628,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -618,9 +628,11 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
oid = util.u64(oid) oid = util.u64(oid)
p64 = util.p64 p64 = util.p64
pack_tid = self._getPackTID()
r = q("""SELECT serial, LENGTH(value), value_serial FROM obj r = q("""SELECT serial, LENGTH(value), value_serial FROM obj
WHERE oid = %d ORDER BY serial DESC LIMIT %d, %d""" \ WHERE oid = %d AND serial >= %d
% (oid, offset, length)) ORDER BY serial DESC LIMIT %d, %d""" \
% (oid, pack_tid, offset, length))
if r: if r:
result = [] result = []
append = result.append append = result.append
...@@ -683,3 +695,92 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -683,3 +695,92 @@ class MySQLDatabaseManager(DatabaseManager):
% (oid, ','.join([str(util.u64(serial)) for serial in serial_list]))) % (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack):
q = self.query
p64 = util.p64
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
value_serial = None
for table in ('obj', 'tobj'):
for (serial, ) in q('SELECT serial FROM %(table)s WHERE '
'oid = %(oid)d AND serial >= %(max_serial)d AND '
'value_serial = %(orig_serial)d ORDER BY serial ASC' % {
'table': table,
'oid': oid,
'orig_serial': orig_serial,
'max_serial': max_serial,
}):
if value_serial is None:
# First found, copy data to it and mark its serial for
# future reference.
value_serial = serial
q('REPLACE INTO %(table)s (oid, serial, compression, '
'checksum, value, value_serial) SELECT oid, '
'%(serial)d, compression, checksum, value, NULL FROM '
'obj WHERE oid = %(oid)d AND serial = %(orig_serial)d' \
% {
'table': table,
'oid': oid,
'serial': serial,
'orig_serial': orig_serial,
})
else:
q('REPLACE INTO %(table)s (oid, serial, value_serial) '
'VALUES (%(oid)d, %(serial)d, %(value_serial)d)' % {
'table': table,
'oid': oid,
'serial': serial,
'value_serial': value_serial,
})
def getObjectData():
assert value_serial is None
return q('SELECT compression, checksum, value FROM obj WHERE '
'oid = %(oid)d AND serial = %(orig_serial)d' % {
'oid': oid,
'orig_serial': orig_serial,
})[0]
if value_serial:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(orig_serial), value_serial,
getObjectData)
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
q = self.query
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
self.begin()
try:
self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
'MAX(serial) FROM obj WHERE serial <= %(tid)d '
'GROUP BY oid' % {'tid': tid}):
if q('SELECT LENGTH(value) FROM obj WHERE oid = %(oid)d AND '
'serial = %(max_serial)d' % {
'oid': oid,
'max_serial': max_serial,
})[0][0] == 0:
count += 1
max_serial += 1
if count:
# There are things to delete for this object
for (serial, ) in q('SELECT serial FROM obj WHERE '
'oid=%(oid)d AND serial < %(max_serial)d' % {
'oid': oid,
'max_serial': max_serial,
}):
updatePackFuture(oid, serial, max_serial,
updateObjectDataForPack)
q('DELETE FROM obj WHERE oid=%(oid)d AND '
'serial=%(serial)d' % {
'oid': oid,
'serial': serial
})
except:
self.rollback()
raise
self.commit()
...@@ -64,3 +64,11 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -64,3 +64,11 @@ class MasterOperationHandler(BaseMasterHandler):
raise ProtocolError('Unknown transaction') raise ProtocolError('Unknown transaction')
# TODO: send an answer # TODO: send an answer
self.app.tm.unlock(tid) self.app.tm.unlock(tid)
def askPack(self, conn, tid):
app = self.app
logging.info('Pack started, up to %s...', dump(tid))
app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.')
conn.answer(Packets.AnswerPack(True))
...@@ -285,3 +285,19 @@ class TransactionManager(object): ...@@ -285,3 +285,19 @@ class TransactionManager(object):
for oid, tid in self._store_lock_dict.items(): for oid, tid in self._store_lock_dict.items():
logging.info(' %r by %r', dump(oid), dump(tid)) logging.info(' %r by %r', dump(oid), dump(tid))
def updateObjectDataForPack(self, oid, orig_serial, new_serial,
getObjectData):
lock_tid = self.getLockingTID(oid)
if lock_tid is not None:
transaction = self._transaction_dict[lock_tid]
oid, compression, checksum, data, value_serial = \
transaction.getObject(oid)
if value_serial == orig_serial:
if new_serial:
value_serial = new_serial
else:
compression, checksum, data = getObjectData()
value_serial = None
transaction.addObject(oid, compression, checksum, data,
value_serial)
...@@ -156,7 +156,7 @@ class NeoTestBase(unittest.TestCase): ...@@ -156,7 +156,7 @@ class NeoTestBase(unittest.TestCase):
}) })
def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000), def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000),
is_server=False, connector=None): is_server=False, connector=None, peer_id=None):
if connector is None: if connector is None:
connector = self.getFakeConnector() connector = self.getFakeConnector()
return Mock({ return Mock({
...@@ -166,6 +166,7 @@ class NeoTestBase(unittest.TestCase): ...@@ -166,6 +166,7 @@ class NeoTestBase(unittest.TestCase):
'__repr__': 'FakeConnection', '__repr__': 'FakeConnection',
'__nonzero__': 0, '__nonzero__': 0,
'getConnector': connector, 'getConnector': connector,
'getPeerId': peer_id,
}) })
def checkProtocolErrorRaised(self, method, *args, **kwargs): def checkProtocolErrorRaised(self, method, *args, **kwargs):
......
...@@ -25,6 +25,7 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError ...@@ -25,6 +25,7 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL
from neo.util import makeChecksum from neo.util import makeChecksum
import time
def _getMasterConnection(self): def _getMasterConnection(self):
if self.master_conn is None: if self.master_conn is None:
...@@ -1164,6 +1165,21 @@ class ClientApplicationTests(NeoTestBase): ...@@ -1164,6 +1165,21 @@ class ClientApplicationTests(NeoTestBase):
self.assertTrue(hasattr(app1_local, property_id)) self.assertTrue(hasattr(app1_local, property_id))
self.assertFalse(hasattr(app2_local, property_id)) self.assertFalse(hasattr(app2_local, property_id))
def test_pack(self):
app = self.getApp()
marker = []
def askPrimary(packet):
marker.append(packet)
app._askPrimary = askPrimary
# XXX: could not identify a value causing TimeStamp to return ZERO_TID
#self.assertRaises(NEOStorageError, app.pack, )
self.assertEqual(len(marker), 0)
now = time.time()
app.pack(now)
self.assertEqual(len(marker), 1)
self.assertEqual(marker[0].getType(), Packets.AskPack)
# XXX: how to validate packet content ?
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -23,6 +23,7 @@ from neo.protocol import NodeTypes, NodeStates ...@@ -23,6 +23,7 @@ from neo.protocol import NodeTypes, NodeStates
from neo.client.handlers.master import PrimaryBootstrapHandler from neo.client.handlers.master import PrimaryBootstrapHandler
from neo.client.handlers.master import PrimaryNotificationsHandler, \ from neo.client.handlers.master import PrimaryNotificationsHandler, \
PrimaryAnswersHandler PrimaryAnswersHandler
from neo.client.exception import NEOStorageError
MARKER = [] MARKER = []
...@@ -255,6 +256,10 @@ class MasterAnswersHandlerTests(MasterHandlerTests): ...@@ -255,6 +256,10 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
calls = app.mockGetNamedCalls('setTransactionFinished') calls = app.mockGetNamedCalls('setTransactionFinished')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
# Check it doesn't raise
self.handler.answerPack(None, True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.protocol import NodeTypes, NodeStates from neo.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -154,6 +154,34 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -154,6 +154,34 @@ class MasterClientHandlerTests(NeoTestBase):
self.__testWithMethod(self.service.connectionClosed, self.__testWithMethod(self.service.connectionClosed,
NodeStates.TEMPORARILY_DOWN) NodeStates.TEMPORARILY_DOWN)
def test_askPack(self):
self.assertEqual(self.app.packing, None)
self.app.nm.createClient()
tid = self.getNextTID()
peer_id = 42
conn = self.getFakeConnection(peer_id=peer_id)
storage_uuid = self.identifyToMasterNode()
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack,
decode=True)[0]
self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id)
self.assertEqual(self.app.packing[2], set([storage_uuid, ]))
# Asking again to pack will cause an immediate error
storage_uuid = self.identifyToMasterNode()
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack,
decode=True)[0]
self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from mock import Mock from mock import Mock
from struct import pack from struct import pack
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.protocol import NodeTypes, NodeStates from neo.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.storage import StorageServiceHandler from neo.master.handlers.storage import StorageServiceHandler
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -247,6 +247,30 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -247,6 +247,30 @@ class MasterStorageHandlerTests(NeoTestBase):
# T3: action not significant to this transacion, so no response # T3: action not significant to this transacion, so no response
self.checkNoPacketSent(cconn3, check_notify=False) self.checkNoPacketSent(cconn3, check_notify=False)
def test_answerPack(self):
# Note: incomming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage()
node2, conn2 = self._getStorage()
self.app.packing = None
# Does nothing
self.service.answerPack(None, False)
client_conn = Mock({
'getPeerId': 512,
})
client_peer_id = 42
self.app.packing = (client_conn, client_peer_id, set([conn1.getUUID(),
conn2.getUUID()]))
self.service.answerPack(conn1, False)
self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], set([conn2.getUUID(), ]))
self.service.answerPack(conn2, False)
status = self.checkAnswerPacket(client_conn, Packets.AnswerPack,
decode=True)[0]
# TODO: verify packet peer id
self.assertTrue(status)
self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -23,7 +23,7 @@ from neo.storage.app import Application ...@@ -23,7 +23,7 @@ from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.protocol import CellStates, ProtocolError from neo.protocol import CellStates, ProtocolError, Packets
from neo.protocol import INVALID_TID, INVALID_OID from neo.protocol import INVALID_TID, INVALID_OID
class StorageMasterHandlerTests(NeoTestBase): class StorageMasterHandlerTests(NeoTestBase):
...@@ -196,5 +196,16 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -196,5 +196,16 @@ class StorageMasterHandlerTests(NeoTestBase):
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs((INVALID_TID, )) calls[0].checkArgs((INVALID_TID, ))
def test_askPack(self):
self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection()
tid = self.getNextTID()
self.operation.askPack(conn, tid)
calls = self.app.dm.mockGetNamedCalls('pack')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, self.app.tm.updateObjectDataForPack)
# Content has no meaning here, don't check.
self.checkAnswerPacket(conn, Packets.AnswerPack)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -334,5 +334,57 @@ class TransactionManagerTests(NeoTestBase): ...@@ -334,5 +334,57 @@ class TransactionManagerTests(NeoTestBase):
self.manager.storeObject(tid1, serial1, *obj1) self.manager.storeObject(tid1, serial1, *obj1)
self.assertEqual(self.manager.getLockingTID(oid1), tid1) self.assertEqual(self.manager.getLockingTID(oid1), tid1)
def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID()
oid = self.getOID(1)
orig_serial = self.getNextTID()
uuid = self.getNewUUID()
locking_serial = self.getNextTID()
other_serial = self.getNextTID()
new_serial = self.getNextTID()
compression = 1
checksum = 42
value = 'foo'
self.manager.register(uuid, locking_serial)
def getObjectData():
return (compression, checksum, value)
# Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
# Object known, but doesn't point at orig_serial, it is not updated
self.manager.storeObject(locking_serial, ram_serial, oid, 0, 512,
'bar', None)
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
# Object known and points at undone data it gets updated
# ...with data_serial: getObjectData must not be called
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
None)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, None, None, None, new_serial))
# with data
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None,
getObjectData)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, compression, checksum, value, None))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -635,6 +635,18 @@ class ProtocolTests(NeoTestBase): ...@@ -635,6 +635,18 @@ class ProtocolTests(NeoTestBase):
def test_AnswerObjectHistoryFrom(self): def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom) self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom)
def test_AskPack(self):
tid = self.getNextTID()
p = Packets.AskPack(tid)
ptid = p.decode()[0]
self.assertEqual(ptid, tid)
def test_AnswerPack(self):
status = True
p = Packets.AnswerPack(status)
pstatus = p.decode()[0]
self.assertEqual(pstatus, status)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,14 +17,18 @@ ...@@ -17,14 +17,18 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest import unittest
from ZODB.tests.PackableStorage import PackableStorage try:
from ZODB.tests.PackableStorage import PackableStorageWithOptionalGC
except ImportError:
from ZODB.tests.PackableStorage import PackableStorage as \
PackableStorageWithOptionalGC
from ZODB.tests.PackableStorage import PackableUndoStorage from ZODB.tests.PackableStorage import PackableUndoStorage
from ZODB.tests.StorageTestBase import StorageTestBase from ZODB.tests.StorageTestBase import StorageTestBase
from neo.tests.zodb import ZODBTestCase from neo.tests.zodb import ZODBTestCase
class PackableTests(ZODBTestCase, StorageTestBase, PackableStorage, class PackableTests(ZODBTestCase, StorageTestBase,
PackableUndoStorage): PackableStorageWithOptionalGC, PackableUndoStorage):
pass pass
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -76,7 +76,7 @@ ZODB_TEST_MODULES = [ ...@@ -76,7 +76,7 @@ ZODB_TEST_MODULES = [
('neo.tests.zodb.testHistory', 'check'), ('neo.tests.zodb.testHistory', 'check'),
('neo.tests.zodb.testIterator', 'check'), ('neo.tests.zodb.testIterator', 'check'),
('neo.tests.zodb.testMT', 'check'), ('neo.tests.zodb.testMT', 'check'),
# ('neo.tests.zodb.testPack', 'check'), ('neo.tests.zodb.testPack', 'check'),
('neo.tests.zodb.testPersistent', 'check'), ('neo.tests.zodb.testPersistent', 'check'),
('neo.tests.zodb.testReadOnly', 'check'), ('neo.tests.zodb.testReadOnly', 'check'),
('neo.tests.zodb.testRevision', 'check'), ('neo.tests.zodb.testRevision', 'check'),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment