Commit 7d20e5bd authored by Grégory Wisniewski's avatar Grégory Wisniewski

undoLog is broken, make the iterator use a workaround.

undoLog doesn't work when first is non-zero, this breaks iterator and
cannot be fixed for undoLog at the moment.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2550 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f6b30dec
......@@ -1127,7 +1127,32 @@ class Application(object):
for k, v in loads(extension).items():
txn_info[k] = v
def __undoLog(self, first, last, filter=None, block=0, with_oids=False):
def _getTransactionInformation(self, tid):
cell_list = self._getCellListForTID(tid, readable=True)
shuffle(cell_list)
cell_list.sort(key=self.cp.getCellSortKey)
for cell in cell_list:
conn = self.cp.getConnForCell(cell)
if conn is not None:
self.local_var.txn_info = 0
self.local_var.txn_ext = 0
try:
self._askStorage(conn,
Packets.AskTransactionInformation(tid))
except ConnectionClosed:
continue
if isinstance(self.local_var.txn_info, dict):
break
if self.local_var.txn_info in (-1, 0):
# TID not found at all
raise NeoException, 'Data inconsistency detected: ' \
'transaction info for TID %r could not ' \
'be found' % (tid, )
return (self.local_var.txn_info, self.local_var.txn_ext)
def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken
if last < 0:
# See FileStorage.py for explanation
last = first - last
......@@ -1161,51 +1186,51 @@ class Application(object):
undo_info = []
append = undo_info.append
for tid in ordered_tids:
cell_list = self._getCellListForTID(tid, readable=True)
shuffle(cell_list)
cell_list.sort(key=self.cp.getCellSortKey)
for cell in cell_list:
conn = self.cp.getConnForCell(cell)
if conn is not None:
self.local_var.txn_info = 0
self.local_var.txn_ext = 0
try:
self._askStorage(conn,
Packets.AskTransactionInformation(tid))
except ConnectionClosed:
continue
if isinstance(self.local_var.txn_info, dict):
break
if self.local_var.txn_info in (-1, 0):
# TID not found at all
raise NeoException, 'Data inconsistency detected: ' \
'transaction info for TID %r could not ' \
'be found' % (tid, )
(txn_info, txn_ext) = self._getTransactionInformation(tid)
if filter is None or filter(self.local_var.txn_info):
txn_info = self.local_var.txn_info
txn_info.pop('packed')
if not with_oids:
txn_info.pop("oids")
self._insertMetadata(txn_info, self.local_var.txn_ext)
else:
txn_info['ext'] = loads(self.local_var.txn_ext)
append(txn_info)
if len(undo_info) >= last - first:
break
# Check we return at least one element, otherwise call
# again but extend offset
if len(undo_info) == 0 and not block:
undo_info = self.__undoLog(first=first, last=last*5, filter=filter,
block=1, with_oids=with_oids)
undo_info = self.undoLog(first=first, last=last*5, filter=filter,
block=1)
return undo_info
def undoLog(self, first, last, filter=None, block=0):
return self.__undoLog(first, last, filter, block)
def transactionLog(self, first, last):
return self.__undoLog(first, last, with_oids=True)
def transactionLog(self, start, stop, limit):
node_map = self.pt.getNodeMap()
node_list = node_map.keys()
node_list.sort(key=self.cp.getCellSortKey)
partition_set = set(range(self.pt.getPartitions()))
queue = self.local_var.queue
# request a tid list for each partition
self.local_var.tids_from = set()
for node in node_list:
conn = self.cp.getConnForNode(node)
request_set = set(node_map[node]) & partition_set
if conn is None or not request_set:
continue
partition_set -= set(request_set)
packet = Packets.AskTIDsFrom(start, stop, limit, request_set)
conn.ask(packet, queue=queue)
if not partition_set:
break
assert not partition_set
self.waitResponses()
# request transactions informations
txn_list = []
append = txn_list.append
tid = None
for tid in sorted(self.local_var.tids_from):
(txn_info, txn_ext) = self._getTransactionInformation(tid)
txn_info['ext'] = loads(self.local_var.txn_ext)
append(txn_info)
return (tid, txn_list)
def history(self, oid, version=None, size=1, filter=None):
# Get history informations for object first
......@@ -1297,7 +1322,9 @@ class Application(object):
assert real_tid == tid, (real_tid, tid)
transaction_iter.close()
def iterator(self, start=None, stop=None):
def iterator(self, start, stop):
if start is None:
start = ZERO_TID
return Iterator(self, start, stop)
def lastTransaction(self):
......
......@@ -95,6 +95,11 @@ class StorageAnswersHandler(AnswerBaseHandler):
if tid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started')
def answerTIDsFrom(self, conn, tid_list):
neo.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
assert not self.app.local_var.tids_from.intersection(set(tid_list))
self.app.local_var.tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
# transaction information are returned as a dict
......
......@@ -18,10 +18,12 @@
from ZODB import BaseStorage
from zope.interface import implements
import ZODB.interfaces
from neo import util
from neo.util import u64, add64
from neo.client.exception import NEOStorageCreationUndoneError
from neo.client.exception import NEOStorageNotFoundError
CHUNK_LENGTH = 100
class Record(BaseStorage.DataRecord):
""" TBaseStorageransaction record yielded by the Transaction object """
......@@ -29,8 +31,8 @@ class Record(BaseStorage.DataRecord):
BaseStorage.DataRecord.__init__(self, oid, tid, data, prev)
def __str__(self):
oid = util.u64(self.oid)
tid = util.u64(self.tid)
oid = u64(self.oid)
tid = u64(self.tid)
args = (oid, tid, len(self.data), self.data_txn)
return 'Record %s:%s: %s (%s)' % args
......@@ -86,7 +88,7 @@ class Transaction(BaseStorage.TransactionRecord):
return record
def __str__(self):
tid = util.u64(self.tid)
tid = u64(self.tid)
args = (tid, self.user, self.status)
return 'Transaction #%s: %s %s' % args
......@@ -97,17 +99,15 @@ class Iterator(object):
def __init__(self, app, start, stop):
self.app = app
self.txn_list = []
assert None not in (start, stop)
self._start = start
self._stop = stop
# next index to load from storage nodes
self._next = 0
# index of current iteration
self._index = 0
self._closed = False
# OID -> previous TID mapping
# TODO: prune old entries while walking ?
self._prev_serial_dict = {}
if start is not None:
self.txn_list = self._skip(start)
def __iter__(self):
return self
......@@ -118,41 +118,21 @@ class Iterator(object):
raise IndexError, index
return self.next()
def _read(self):
""" Request more transactions """
chunk = self.app.transactionLog(self._next, self._next + 100)
if not chunk:
# nothing more
raise StopIteration
self._next += len(chunk)
return chunk
def _skip(self, start):
""" Skip transactions until 'start' is reached """
chunk = self._read()
while chunk[0]['id'] < start:
chunk = self._read()
if chunk[-1]['id'] < start:
for index, txn in enumerate(reversed(chunk)):
if txn['id'] >= start:
break
# keep only greater transactions
chunk = chunk[:-index]
return chunk
def next(self):
""" Return an iterator for the next transaction"""
if self._closed:
raise IOError, 'iterator closed'
if not self.txn_list:
self.txn_list = self._read()
txn = self.txn_list.pop()
(max_tid, chunk) = self.app.transactionLog(self._start, self._stop,
CHUNK_LENGTH)
if not chunk:
# nothing more
raise StopIteration
self._start = add64(max_tid, 1)
self.txn_list = chunk
txn = self.txn_list.pop(0)
self._index += 1
tid = txn['id']
stop = self._stop
if stop is not None and stop < tid:
# stop reached
raise StopIteration
user = txn['user_name']
desc = txn['description']
oid_list = txn['oids']
......
......@@ -1098,12 +1098,29 @@ class AskTIDsFrom(Packet):
S -> S.
"""
_header_format = '!8s8sLL'
_list_entry_format = 'L'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, min_tid, max_tid, length, partition):
return pack(self._header_format, min_tid, max_tid, length, partition)
def _encode(self, min_tid, max_tid, length, partition_list):
body = [pack(self._header_format, min_tid, max_tid, length,
len(partition_list))]
list_entry_format = self._list_entry_format
for partition in partition_list:
body.append(pack(list_entry_format, partition))
return ''.join(body)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
body = StringIO(body)
read = body.read
header = unpack(self._header_format, read(self._header_len))
min_tid, max_tid, length, list_length = header
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
partition_list = []
for _ in xrange(list_length):
partition = unpack(list_entry_format, read(list_entry_len))[0]
partition_list.append(partition)
return (min_tid, max_tid, length, partition_list)
class AnswerTIDsFrom(AnswerTIDs):
"""
......
......@@ -86,6 +86,17 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
self._askStoreObject(conn, oid, serial, compression, checksum, data,
data_serial, tid, time.time())
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
app = self.app
getReplicationTIDList = app.dm.getReplicationTIDList
partitions = app.pt.getPartitions()
tid_list = []
extend = tid_list.extend
for partition in partition_list:
extend(getReplicationTIDList(min_tid, max_tid, length,
partitions, partition))
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askTIDs(self, conn, first, last, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
......
......@@ -190,7 +190,7 @@ class ReplicationHandler(EventHandler):
partition_id = replicator.getCurrentRID()
max_tid = replicator.getCurrentCriticalTID()
replicator.getTIDsFrom(min_tid, max_tid, length, partition_id)
return Packets.AskTIDsFrom(min_tid, max_tid, length, partition_id)
return Packets.AskTIDsFrom(min_tid, max_tid, length, [partition_id])
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator
......
......@@ -30,7 +30,9 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
assert len(partition_list) == 1, partition_list
partition = partition_list[0]
app = self.app
tid_list = app.dm.getReplicationTIDList(min_tid, max_tid, length,
app.pt.getPartitions(), partition)
......
......@@ -426,10 +426,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID()
......@@ -453,10 +453,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
# CheckSerialRange
def test_answerCheckSerialFullRangeIdenticalChunkWithNext(self):
......
......@@ -119,7 +119,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.app.pt = Mock({'getPartitions': 1})
tid = self.getNextTID()
tid2 = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, tid2, 2, 1)
self.operation.askTIDsFrom(conn, tid, tid2, 2, [1])
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, tid2, 2, 1, 1)
......
......@@ -591,12 +591,12 @@ class ProtocolTests(NeoUnitTestBase):
def test_AskTIDsFrom(self):
tid = self.getNextTID()
tid2 = self.getNextTID()
p = Packets.AskTIDsFrom(tid, tid2, 1000, 5)
p = Packets.AskTIDsFrom(tid, tid2, 1000, [5])
min_tid, max_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid)
self.assertEqual(max_tid, tid2)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
self.assertEqual(partition, [5])
def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment