Commit d61a34c0 authored by Kirill Smelkov's avatar Kirill Smelkov

.

parents f5cf9484 42fd89bc
...@@ -6,5 +6,5 @@ ...@@ -6,5 +6,5 @@
/build/ /build/
/dist/ /dist/
/htmlcov/ /htmlcov/
/mock.py /neo/tests/mock.py
/neoppod.egg-info/ /neoppod.egg-info/
Change History Change History
============== ==============
1.7.1 (2017-01-18)
------------------
- Replication:
- Fixed possibly wrong knowledge of cells' backup_tid when resuming backup.
In such case, 'neoctl print ids' gave false impression that the backup
cluster was up-to-date. This also resulted in an inconsistent database
when leaving backup mode before that the issue resolved by itself.
- Storage nodes now select the partition which is furthest behind. Previous
criterion was such that in case of high upstream activity, the backup could
even be stuck looping on a subset of partitions.
- Fixed replication of unfinished imported transactions.
- Fixed abort before vote, to free the storage space used by the transaction.
A new 'prune_orphan' neoctl command was added to delete unreferenced raw data
in the database.
- Removed short storage option -R to reset the db.
Help is reworded to clarify that --reset exits once done.
- The application receiving buffer size has been increased.
This speeds up transfer of big packets.
- The master raised AttributeError at exit during recovery.
- At startup, the importer storage backend connected twice to the destination
database.
1.7.0 (2016-12-19) 1.7.0 (2016-12-19)
------------------ ------------------
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -71,6 +71,7 @@ class AdminEventHandler(EventHandler): ...@@ -71,6 +71,7 @@ class AdminEventHandler(EventHandler):
setNodeState = forward_ask(Packets.SetNodeState) setNodeState = forward_ask(Packets.SetNodeState)
checkReplicas = forward_ask(Packets.CheckReplicas) checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate) truncate = forward_ask(Packets.Truncate)
repair = forward_ask(Packets.Repair)
class MasterEventHandler(EventHandler): class MasterEventHandler(EventHandler):
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -191,11 +191,6 @@ class Storage(BaseStorage.BaseStorage, ...@@ -191,11 +191,6 @@ class Storage(BaseStorage.BaseStorage,
# seems used only by FileStorage # seems used only by FileStorage
raise NotImplementedError raise NotImplementedError
def cleanup(self):
# Used in unit tests to remove local database files.
# We have no such thing, so make this method a no-op.
pass
def close(self): def close(self):
# WARNING: This does not handle the case where an app is shared by # WARNING: This does not handle the case where an app is shared by
# several Storage instances, but this is something that only # several Storage instances, but this is something that only
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -544,6 +544,8 @@ class Application(ThreadedApplication): ...@@ -544,6 +544,8 @@ class Application(ThreadedApplication):
# A later serial has already been resolved, skip. # A later serial has already been resolved, skip.
resolved_serial_set.update(conflict_serial_set) resolved_serial_set.update(conflict_serial_set)
continue continue
if self.last_tid < conflict_serial:
self.sync() # possible late invalidation (very rare)
try: try:
new_data = tryToResolveConflict(oid, conflict_serial, new_data = tryToResolveConflict(oid, conflict_serial,
serial, data) serial, data)
......
# #
# Copyright (C) 2011-2016 Nexedi SA # Copyright (C) 2011-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2015-2016 Nexedi SA # Copyright (C) 2015-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -145,7 +145,7 @@ class SocketConnector(object): ...@@ -145,7 +145,7 @@ class SocketConnector(object):
def receive(self, read_buf): def receive(self, read_buf):
try: try:
data = self.socket.recv(4096) data = self.socket.recv(65536)
except socket.error, e: except socket.error, e:
self._error('recv', e) self._error('recv', e)
if data: if data:
...@@ -155,6 +155,7 @@ class SocketConnector(object): ...@@ -155,6 +155,7 @@ class SocketConnector(object):
raise ConnectorException raise ConnectorException
def send(self): def send(self):
# XXX: unefficient for big packets
msg = ''.join(self.queued) msg = ''.join(self.queued)
if msg: if msg:
try: try:
......
# #
# Copyright (C) 2010-2016 Nexedi SA # Copyright (C) 2010-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, thread import os
from time import time from time import time
from select import epoll, EPOLLIN, EPOLLOUT, EPOLLERR, EPOLLHUP from select import epoll, EPOLLIN, EPOLLOUT, EPOLLERR, EPOLLHUP
from errno import EAGAIN, EEXIST, EINTR, ENOENT from errno import EAGAIN, EEXIST, EINTR, ENOENT
...@@ -35,7 +35,6 @@ class EpollEventManager(object): ...@@ -35,7 +35,6 @@ class EpollEventManager(object):
"""This class manages connections and events based on epoll(5).""" """This class manages connections and events based on epoll(5)."""
_timeout = None _timeout = None
_trigger_exit = False
def __init__(self): def __init__(self):
self.connection_dict = {} self.connection_dict = {}
...@@ -43,6 +42,7 @@ class EpollEventManager(object): ...@@ -43,6 +42,7 @@ class EpollEventManager(object):
self.writer_set = set() self.writer_set = set()
self.epoll = epoll() self.epoll = epoll()
self._pending_processing = [] self._pending_processing = []
self._trigger_list = []
self._trigger_fd, w = os.pipe() self._trigger_fd, w = os.pipe()
os.close(w) os.close(w)
self._trigger_lock = Lock() self._trigger_lock = Lock()
...@@ -231,9 +231,12 @@ class EpollEventManager(object): ...@@ -231,9 +231,12 @@ class EpollEventManager(object):
if fd == self._trigger_fd: if fd == self._trigger_fd:
with self._trigger_lock: with self._trigger_lock:
self.epoll.unregister(fd) self.epoll.unregister(fd)
if self._trigger_exit: action_list = self._trigger_list
del self._trigger_exit try:
thread.exit() while action_list:
action_list.pop(0)()
finally:
del action_list[:]
continue continue
if conn.readable(): if conn.readable():
self._addPendingConnection(conn) self._addPendingConnection(conn)
...@@ -253,9 +256,9 @@ class EpollEventManager(object): ...@@ -253,9 +256,9 @@ class EpollEventManager(object):
def setTimeout(self, *args): def setTimeout(self, *args):
self._timeout, self._on_timeout = args self._timeout, self._on_timeout = args
def wakeup(self, exit=False): def wakeup(self, *actions):
with self._trigger_lock: with self._trigger_lock:
self._trigger_exit |= exit self._trigger_list += actions
try: try:
self.epoll.register(self._trigger_fd) self.epoll.register(self._trigger_fd)
except IOError, e: except IOError, e:
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2015-2016 Nexedi SA # Copyright (C) 2015-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2015-2016 Nexedi SA # Copyright (C) 2015-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
PROTOCOL_VERSION = 8 PROTOCOL_VERSION = 9
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -237,14 +237,10 @@ class Packet(object): ...@@ -237,14 +237,10 @@ class Packet(object):
_id = None _id = None
poll_thread = False poll_thread = False
def __init__(self, *args, **kw): def __init__(self, *args):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
if args or kw: if args:
args = list(args)
buf = StringIO() buf = StringIO()
# load named arguments
for item in self._fmt._items[len(args):]:
args.append(kw.get(item._name))
self._fmt.encode(buf.write, args) self._fmt.encode(buf.write, args)
self._body = buf.getvalue() self._body = buf.getvalue()
else: else:
...@@ -1176,6 +1172,25 @@ class SetClusterState(Packet): ...@@ -1176,6 +1172,25 @@ class SetClusterState(Packet):
_answer = Error _answer = Error
class Repair(Packet):
"""
Ask storage nodes to repair their databases. ctl -> A -> M
"""
_flags = map(PBoolean, ('dry_run',
# 'prune_orphan' (commented because it's the only option for the moment)
))
_fmt = PStruct('repair',
PFUUIDList,
*_flags)
_answer = Error
class RepairOne(Packet):
"""
See Repair. M -> S
"""
_fmt = PStruct('repair', *Repair._flags)
class ClusterInformation(Packet): class ClusterInformation(Packet):
""" """
Notify information about the cluster Notify information about the cluster
...@@ -1685,6 +1700,10 @@ class Packets(dict): ...@@ -1685,6 +1700,10 @@ class Packets(dict):
TweakPartitionTable, ignore_when_closed=False) TweakPartitionTable, ignore_when_closed=False)
SetClusterState = register( SetClusterState = register(
SetClusterState, ignore_when_closed=False) SetClusterState, ignore_when_closed=False)
Repair = register(
Repair)
NotifyRepair = register(
RepairOne)
NotifyClusterInformation = register( NotifyClusterInformation = register(
ClusterInformation) ClusterInformation)
AskClusterState, AnswerClusterState = register( AskClusterState, AnswerClusterState = register(
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -258,30 +258,34 @@ class PartitionTable(object): ...@@ -258,30 +258,34 @@ class PartitionTable(object):
partition on the line (here, line length is 11 to keep the docstring partition on the line (here, line length is 11 to keep the docstring
width under 80 column). width under 80 column).
""" """
node_list = sorted(self.count_dict)
result = ['pt: node %u: %s, %s' % (i, uuid_str(node.getUUID()), result = ['pt: node %u: %s, %s' % (i, uuid_str(node.getUUID()),
protocol.node_state_prefix_dict[node.getState()]) protocol.node_state_prefix_dict[node.getState()])
for i, node in enumerate(node_list)] for i, node in enumerate(sorted(self.count_dict))]
append = result.append append = result.append
line = [] line = []
max_line_len = 20 # XXX: hardcoded number of partitions per line max_line_len = 20 # XXX: hardcoded number of partitions per line
cell_state_dict = protocol.cell_state_prefix_dict
prefix = 0 prefix = 0
prefix_len = int(math.ceil(math.log10(self.np))) prefix_len = int(math.ceil(math.log10(self.np)))
for offset, row in enumerate(self.partition_list): for offset, row in enumerate(self.formatRows()):
if len(line) == max_line_len: if len(line) == max_line_len:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line))) append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
line = [] line = []
prefix = offset prefix = offset
line.append(row)
if line:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
return result
def formatRows(self):
node_list = sorted(self.count_dict)
cell_state_dict = protocol.cell_state_prefix_dict
for row in self.partition_list:
if row is None: if row is None:
line.append('X' * len(node_list)) yield 'X' * len(node_list)
else: else:
cell_dict = {x.getNode(): cell_state_dict[x.getState()] cell_dict = {x.getNode(): cell_state_dict[x.getState()]
for x in row} for x in row}
line.append(''.join(cell_dict.get(x, '.') for x in node_list)) yield ''.join(cell_dict.get(x, '.') for x in node_list)
if line:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
return result
def operational(self): def operational(self):
if not self.filled(): if not self.filled():
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading, weakref import thread, threading, weakref
from . import logging from . import logging
from .app import BaseApplication from .app import BaseApplication
from .connection import ConnectionClosed from .connection import ConnectionClosed
...@@ -69,7 +69,7 @@ class ThreadedApplication(BaseApplication): ...@@ -69,7 +69,7 @@ class ThreadedApplication(BaseApplication):
conn.close() conn.close()
# Stop polling thread # Stop polling thread
logging.debug('Stopping %s', self.poll_thread) logging.debug('Stopping %s', self.poll_thread)
self.em.wakeup(True) self.em.wakeup(thread.exit)
else: else:
super(ThreadedApplication, self).close() super(ThreadedApplication, self).close()
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (C) 2012-2016 Nexedi SA # Copyright (C) 2012-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -64,6 +64,9 @@ class AdministrationHandler(MasterHandler): ...@@ -64,6 +64,9 @@ class AdministrationHandler(MasterHandler):
for node in storage_list: for node in storage_list:
assert node.isPending(), node assert node.isPending(), node
if node.getConnection().isPending(): if node.getConnection().isPending():
# XXX: It's wrong to use ProtocolError here. We must reply
# less aggressively because the admin has no way to
# know that there's still pending activity.
raise ProtocolError('Cannot exit recovery now: node %r is ' raise ProtocolError('Cannot exit recovery now: node %r is '
'entering cluster' % (node, )) 'entering cluster' % (node, ))
app._startup_allowed = True app._startup_allowed = True
...@@ -147,6 +150,19 @@ class AdministrationHandler(MasterHandler): ...@@ -147,6 +150,19 @@ class AdministrationHandler(MasterHandler):
logging.warning('No node added') logging.warning('No node added')
conn.answer(Errors.Ack('No node added')) conn.answer(Errors.Ack('No node added'))
def repair(self, conn, uuid_list, *args):
getByUUID = self.app.nm.getByUUID
node_list = []
for uuid in uuid_list:
node = getByUUID(uuid)
if node is None or not (node.isStorage() and node.isIdentified()):
raise ProtocolError("invalid storage node %s" % uuid_str(uuid))
node_list.append(node)
repair = Packets.NotifyRepair(*args)
for node in node_list:
node.notify(repair)
conn.answer(Errors.Ack(''))
def tweakPartitionTable(self, conn, uuid_list): def tweakPartitionTable(self, conn, uuid_list):
app = self.app app = self.app
state = app.getClusterState() state = app.getClusterState()
......
# #
# Copyright (C) 2012-2016 Nexedi SA # Copyright (C) 2012-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -53,7 +53,10 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -53,7 +53,10 @@ class StorageServiceHandler(BaseServiceHandler):
last_tid = app.pt.getBackupTid(min) last_tid = app.pt.getBackupTid(min)
pending_list = () pending_list = ()
else: else:
last_tid = app.tm.getLastTID() # This can't be app.tm.getLastTID() for imported transactions,
# because outdated cells must at least wait that they're locked
# at source side. For normal transactions, it would not matter.
last_tid = app.getLastTransaction()
pending_list = app.tm.registerForNotification(conn.getUUID()) pending_list = app.tm.registerForNotification(conn.getUUID())
p = Packets.AnswerUnfinishedTransactions(last_tid, pending_list) p = Packets.AnswerUnfinishedTransactions(last_tid, pending_list)
conn.answer(p) conn.answer(p)
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -137,7 +137,7 @@ class RecoveryManager(MasterHandler): ...@@ -137,7 +137,7 @@ class RecoveryManager(MasterHandler):
logging.warning("Waiting for %r to come back." logging.warning("Waiting for %r to come back."
" No other node has version %s of the partition table.", " No other node has version %s of the partition table.",
node, self.target_ptid) node, self.target_ptid)
if node.getState() == new_state: if node is None or node.getState() == new_state:
return return
node.setState(new_state) node.setState(new_state)
# broadcast to all so that admin nodes gets informed # broadcast to all so that admin nodes gets informed
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -36,6 +36,7 @@ action_dict = { ...@@ -36,6 +36,7 @@ action_dict = {
'tweak': 'tweakPartitionTable', 'tweak': 'tweakPartitionTable',
'drop': 'dropNode', 'drop': 'dropNode',
'kill': 'killNode', 'kill': 'killNode',
'prune_orphan': 'pruneOrphan',
'truncate': 'truncate', 'truncate': 'truncate',
} }
...@@ -146,20 +147,20 @@ class TerminalNeoCTL(object): ...@@ -146,20 +147,20 @@ class TerminalNeoCTL(object):
assert len(params) == 0 assert len(params) == 0
return self.neoctl.startCluster() return self.neoctl.startCluster()
def _getStorageList(self, params):
if len(params) == 1 and params[0] == 'all':
node_list = self.neoctl.getNodeList(NodeTypes.STORAGE)
return [node[2] for node in node_list]
return map(self.asNode, params)
def enableStorageList(self, params): def enableStorageList(self, params):
""" """
Enable cluster to make use of pending storages. Enable cluster to make use of pending storages.
Parameters: all Parameters: node [node [...]]
node [node [...]] node: if "all", add all pending storage nodes,
node: if "all", add all pending storage nodes.
otherwise, the list of storage nodes to enable. otherwise, the list of storage nodes to enable.
""" """
if len(params) == 1 and params[0] == 'all': return self.neoctl.enableStorageList(self._getStorageList(params))
node_list = self.neoctl.getNodeList(NodeTypes.STORAGE)
uuid_list = [node[2] for node in node_list]
else:
uuid_list = map(self.asNode, params)
return self.neoctl.enableStorageList(uuid_list)
def tweakPartitionTable(self, params): def tweakPartitionTable(self, params):
""" """
...@@ -189,6 +190,20 @@ class TerminalNeoCTL(object): ...@@ -189,6 +190,20 @@ class TerminalNeoCTL(object):
""" """
return uuid_str(self.neoctl.getPrimary()) return uuid_str(self.neoctl.getPrimary())
def pruneOrphan(self, params):
"""
Fix database by deleting unreferenced raw data
This can take a long time.
Parameters: dry_run node [node [...]]
dry_run: 0 or 1
node: if "all", ask all connected storage nodes to repair,
otherwise, only the given list of storage nodes.
"""
dry_run = "01".index(params.pop(0))
return self.neoctl.repair(self._getStorageList(params), dry_run)
def truncate(self, params): def truncate(self, params):
""" """
Truncate the database at the given tid. Truncate the database at the given tid.
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -172,6 +172,12 @@ class NeoCTL(BaseApplication): ...@@ -172,6 +172,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response) raise RuntimeError(response)
return response[1] return response[1]
def repair(self, *args):
response = self.__ask(Packets.Repair(*args))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
def truncate(self, tid): def truncate(self, tid):
response = self.__ask(Packets.Truncate(tid)) response = self.__ask(Packets.Truncate(tid))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK: if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# neoadmin - run an administrator node of NEO # neoadmin - run an administrator node of NEO
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# neoadmin - run an administrator node of NEO # neoadmin - run an administrator node of NEO
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# neolog - read a NEO log # neolog - read a NEO log
# #
# Copyright (C) 2012-2016 Nexedi SA # Copyright (C) 2012-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# neomaster - run a master node of NEO # neomaster - run a master node of NEO
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# neomaster - run a master node of NEO # neomaster - run a master node of NEO
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# neostorage - run a storage node of NEO # neostorage - run a storage node of NEO
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -24,14 +24,14 @@ from neo.lib.config import getServerOptionParser, ConfigurationManager ...@@ -24,14 +24,14 @@ from neo.lib.config import getServerOptionParser, ConfigurationManager
parser = getServerOptionParser() parser = getServerOptionParser()
parser.add_option('-u', '--uuid', help='specify an UUID to use for this ' \ parser.add_option('-u', '--uuid', help='specify an UUID to use for this ' \
'process. Previously assigned UUID takes precedence (ie ' \ 'process. Previously assigned UUID takes precedence (ie ' \
'you should always use -R with this switch)') 'you should always use --reset with this switch)')
parser.add_option('-R', '--reset', action = 'store_true',
help = 'remove an existing database if any')
parser.add_option('-a', '--adapter', help = 'database adapter to use') parser.add_option('-a', '--adapter', help = 'database adapter to use')
parser.add_option('-d', '--database', help = 'database connections string') parser.add_option('-d', '--database', help = 'database connections string')
parser.add_option('-e', '--engine', help = 'database engine') parser.add_option('-e', '--engine', help = 'database engine')
parser.add_option('-w', '--wait', help='seconds to wait for backend to be ' parser.add_option('-w', '--wait', help='seconds to wait for backend to be '
'available, before erroring-out (-1 = infinite)', type='float', default=0) 'available, before erroring-out (-1 = infinite)', type='float', default=0)
parser.add_option('--reset', action='store_true',
help='remove an existing database if any, and exit')
defaults = dict( defaults = dict(
bind = '127.0.0.1', bind = '127.0.0.1',
......
#!/usr/bin/env python #!/usr/bin/env python
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import re import re
from collections import Counter, defaultdict from collections import Counter, defaultdict
from cStringIO import StringIO from cStringIO import StringIO
from fnmatch import fnmatchcase
from unittest.runner import _WritelnDecorator from unittest.runner import _WritelnDecorator
if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]): if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
...@@ -32,7 +33,8 @@ if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]): ...@@ -32,7 +33,8 @@ if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
coverage.neotestrunner = [] coverage.neotestrunner = []
coverage.start() coverage.start()
from neo.tests import getTempDirectory, __dict__ as neo_tests__dict__ from neo.tests import getTempDirectory, NeoTestBase, Patch, \
__dict__ as neo_tests__dict__
from neo.tests.benchmark import BenchmarkRunner from neo.tests.benchmark import BenchmarkRunner
# list of test modules # list of test modules
...@@ -64,7 +66,6 @@ UNIT_TEST_MODULES = [ ...@@ -64,7 +66,6 @@ UNIT_TEST_MODULES = [
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
'neo.tests.client.testStorageHandler',
'neo.tests.client.testConnectionPool', 'neo.tests.client.testConnectionPool',
# light functional tests # light functional tests
'neo.tests.threaded.test', 'neo.tests.threaded.test',
...@@ -114,17 +115,33 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -114,17 +115,33 @@ class NeoTestRunner(unittest.TextTestResult):
def wasSuccessful(self): def wasSuccessful(self):
return not (self.failures or self.errors or self.unexpectedSuccesses) return not (self.failures or self.errors or self.unexpectedSuccesses)
def run(self, name, modules): def run(self, name, modules, only):
print '\n', name
suite = unittest.TestSuite() suite = unittest.TestSuite()
loader = unittest.defaultTestLoader loader = unittest.TestLoader()
if only:
exclude = only[0] == '!'
test_only = only[exclude + 1:]
only = only[exclude]
if test_only:
def getTestCaseNames(testCaseClass):
tests = loader.__class__.getTestCaseNames(
loader, testCaseClass)
x = testCaseClass.__name__ + '.'
return [t for t in tests
if exclude != any(fnmatchcase(x + t, o)
for o in test_only)]
loader.getTestCaseNames = getTestCaseNames
if not only:
only = '*'
else:
print '\n', name
for test_module in modules: for test_module in modules:
# load prefix if supplied # load prefix if supplied
if isinstance(test_module, tuple): if isinstance(test_module, tuple):
test_module, prefix = test_module test_module, loader.testMethodPrefix = test_module
loader.testMethodPrefix = prefix if only and not (exclude and test_only or
else: exclude != fnmatchcase(test_module, only)):
loader.testMethodPrefix = 'test' continue
try: try:
test_module = __import__(test_module, globals(), locals(), ['*']) test_module = __import__(test_module, globals(), locals(), ['*'])
except ImportError, err: except ImportError, err:
...@@ -135,7 +152,11 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -135,7 +152,11 @@ class NeoTestRunner(unittest.TextTestResult):
# NOTE it is also possible to run individual tests via `python -m unittest ...` # NOTE it is also possible to run individual tests via `python -m unittest ...`
if 1 or test_module.__name__ == 'neo.tests.functional.testStorage': if 1 or test_module.__name__ == 'neo.tests.functional.testStorage':
suite.addTests(loader.loadTestsFromModule(test_module)) suite.addTests(loader.loadTestsFromModule(test_module))
try:
suite.run(self) suite.run(self)
finally:
# Workaround weird behaviour of Python.
self._previousTestClass = None
def startTest(self, test): def startTest(self, test):
super(NeoTestRunner, self).startTest(test) super(NeoTestRunner, self).startTest(test)
...@@ -203,7 +224,8 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -203,7 +224,8 @@ class NeoTestRunner(unittest.TextTestResult):
for test in self.unexpectedSuccesses: for test in self.unexpectedSuccesses:
body.write("UNEXPECTED SUCCESS: %s\n" % self.getDescription(test)) body.write("UNEXPECTED SUCCESS: %s\n" % self.getDescription(test))
self.stream = _WritelnDecorator(body) self.stream = _WritelnDecorator(body)
self.printErrors() self.printErrorList('ERROR', self.errors)
self.printErrorList('FAIL', self.failures)
return subject, body.getvalue() return subject, body.getvalue()
class TestRunner(BenchmarkRunner): class TestRunner(BenchmarkRunner):
...@@ -211,6 +233,11 @@ class TestRunner(BenchmarkRunner): ...@@ -211,6 +233,11 @@ class TestRunner(BenchmarkRunner):
def add_options(self, parser): def add_options(self, parser):
parser.add_option('-c', '--coverage', action='store_true', parser.add_option('-c', '--coverage', action='store_true',
help='Enable coverage') help='Enable coverage')
parser.add_option('-C', '--cov-unit', action='store_true',
help='Same as -c but output 1 file per test,'
' in the temporary test directory')
parser.add_option('-l', '--loop', type='int', default=1,
help='Repeat tests several times')
parser.add_option('-f', '--functional', action='store_true', parser.add_option('-f', '--functional', action='store_true',
help='Functional tests') help='Functional tests')
parser.add_option('-u', '--unit', action='store_true', parser.add_option('-u', '--unit', action='store_true',
...@@ -219,7 +246,12 @@ class TestRunner(BenchmarkRunner): ...@@ -219,7 +246,12 @@ class TestRunner(BenchmarkRunner):
help='ZODB test suite running on a NEO') help='ZODB test suite running on a NEO')
parser.add_option('-v', '--verbose', action='store_true', parser.add_option('-v', '--verbose', action='store_true',
help='Verbose output') help='Verbose output')
parser.usage += " [[!] module [test...]]"
parser.format_epilog = lambda _: """ parser.format_epilog = lambda _: """
Positional:
Filter by given module/test. These arguments are shell patterns.
This implies -ufz if none of this option is passed.
Environment Variables: Environment Variables:
NEO_TESTS_ADAPTER Default is SQLite for threaded clusters, NEO_TESTS_ADAPTER Default is SQLite for threaded clusters,
MySQL otherwise. MySQL otherwise.
...@@ -241,27 +273,51 @@ Environment Variables: ...@@ -241,27 +273,51 @@ Environment Variables:
""" % neo_tests__dict__ """ % neo_tests__dict__
def load_options(self, options, args): def load_options(self, options, args):
if not (options.unit or options.functional or options.zodb or args): if options.coverage and options.cov_unit:
sys.exit('-c conflicts with -C')
if not (options.unit or options.functional or options.zodb):
if not args:
sys.exit('Nothing to run, please give one of -f, -u, -z') sys.exit('Nothing to run, please give one of -f, -u, -z')
options.unit = options.functional = options.zodb = True
return dict( return dict(
loop = options.loop,
unit = options.unit, unit = options.unit,
functional = options.functional, functional = options.functional,
zodb = options.zodb, zodb = options.zodb,
verbosity = 2 if options.verbose else 1, verbosity = 2 if options.verbose else 1,
coverage = options.coverage, coverage = options.coverage,
cov_unit = options.cov_unit,
only = args,
) )
def start(self): def start(self):
config = self._config config = self._config
only = config.only
# run requested tests # run requested tests
runner = NeoTestRunner(config.title or 'Neo', config.verbosity) runner = NeoTestRunner(config.title or 'Neo', config.verbosity)
if config.cov_unit:
from coverage import Coverage
cov_dir = runner.temp_directory + '/coverage'
os.mkdir(cov_dir)
@Patch(NeoTestBase)
def setUp(orig, self):
orig(self)
self.__coverage = Coverage('%s/%s' % (cov_dir, self.id()))
self.__coverage.start()
@Patch(NeoTestBase)
def _tearDown(orig, self, success):
self.__coverage.stop()
self.__coverage.save()
del self.__coverage
orig(self, success)
try: try:
for _ in xrange(config.loop):
if config.unit: if config.unit:
runner.run('Unit tests', UNIT_TEST_MODULES) runner.run('Unit tests', UNIT_TEST_MODULES, only)
if config.functional: if config.functional:
runner.run('Functional tests', FUNC_TEST_MODULES) runner.run('Functional tests', FUNC_TEST_MODULES, only)
if config.zodb: if config.zodb:
runner.run('ZODB tests', ZODB_TEST_MODULES) runner.run('ZODB tests', ZODB_TEST_MODULES, only)
except KeyboardInterrupt: except KeyboardInterrupt:
config['mail_to'] = None config['mail_to'] = None
traceback.print_exc() traceback.print_exc()
...@@ -270,7 +326,13 @@ Environment Variables: ...@@ -270,7 +326,13 @@ Environment Variables:
if coverage.neotestrunner: if coverage.neotestrunner:
coverage.combine(coverage.neotestrunner) coverage.combine(coverage.neotestrunner)
coverage.save() coverage.save()
if runner.dots:
print
# build report # build report
if only and not config.mail_to:
runner._buildSummary = lambda *args: (
runner.__class__._buildSummary(runner, *args)[0], '')
self.build_report = str
self._successful = runner.wasSuccessful() self._successful = runner.wasSuccessful()
return runner.buildReport(self.add_status) return runner.buildReport(self.add_status)
......
#!/usr/bin/env python #!/usr/bin/env python
# #
# Copyright (C) 2011-2016 Nexedi SA # Copyright (C) 2011-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2012-2016 Nexedi SA # Copyright (C) 2012-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2014-2016 Nexedi SA # Copyright (C) 2014-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -281,7 +281,6 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -281,7 +281,6 @@ class ImporterDatabaseManager(DatabaseManager):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
super(ImporterDatabaseManager, self).__init__(*args, **kw) super(ImporterDatabaseManager, self).__init__(*args, **kw)
self.db._connect()
implements(self, """_getNextTID checkSerialRange checkTIDRange implements(self, """_getNextTID checkSerialRange checkTIDRange
deleteObject deleteTransaction dropPartitions getLastTID deleteObject deleteTransaction dropPartitions getLastTID
getReplicationObjectList getTIDList nonempty""".split()) getReplicationObjectList getTIDList nonempty""".split())
...@@ -305,10 +304,13 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -305,10 +304,13 @@ class ImporterDatabaseManager(DatabaseManager):
getPartitionTable changePartitionTable getPartitionTable changePartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction storeTransaction lockTransaction unlockTransaction
storeData _pruneData deferCommit storeData getOrphanList _pruneData deferCommit
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
def _connect(self):
pass
def commit(self): def commit(self):
self.db.commit() self.db.commit()
self._last_commit = time.time() self._last_commit = time.time()
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from functools import wraps from functools import wraps
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
...@@ -54,6 +56,9 @@ class DatabaseManager(object): ...@@ -54,6 +56,9 @@ class DatabaseManager(object):
ENGINES = () ENGINES = ()
_deferred = 0
_duplicating = _repairing = None
def __init__(self, database, engine=None, wait=0): def __init__(self, database, engine=None, wait=0):
""" """
Initialize the object. Initialize the object.
...@@ -64,22 +69,42 @@ class DatabaseManager(object): ...@@ -64,22 +69,42 @@ class DatabaseManager(object):
% (engine, self.ENGINES)) % (engine, self.ENGINES))
self._engine = engine self._engine = engine
self._wait = wait self._wait = wait
self._deferred = 0
self._parse(database) self._parse(database)
self._connect()
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "_getPartition": if attr == "_getPartition":
np = self.getNumPartitions() np = self.getNumPartitions()
value = lambda x: x % np value = lambda x: x % np
else: elif self._duplicating is None:
return self.__getattribute__(attr) return self.__getattribute__(attr)
else:
value = getattr(self._duplicating, attr)
setattr(self, attr, value) setattr(self, attr, value)
return value return value
@contextmanager
def _duplicate(self):
cls = self.__class__
db = cls.__new__(cls)
db._duplicating = self
try:
db._connect()
finally:
del db._duplicating
try:
yield db
finally:
db.close()
@abstract @abstract
def _parse(self, database): def _parse(self, database):
"""Called during instantiation, to process database parameter.""" """Called during instantiation, to process database parameter."""
@abstract
def _connect(self):
"""Connect to the database"""
def setup(self, reset=0): def setup(self, reset=0):
"""Set up a database, discarding existing data first if reset is True """Set up a database, discarding existing data first if reset is True
""" """
...@@ -415,6 +440,15 @@ class DatabaseManager(object): ...@@ -415,6 +440,15 @@ class DatabaseManager(object):
is always the case at tpc_vote. is always the case at tpc_vote.
""" """
@abstract
def getOrphanList(self):
"""Return the list of data id that is not referenced by the obj table
This is a repair method, and it's usually expensive.
There was a bug that did not free data of transactions that were
aborted before vote. This method is used to reclaim the wasted space.
"""
@abstract @abstract
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
"""To be overridden by the backend to delete any unreferenced data """To be overridden by the backend to delete any unreferenced data
...@@ -423,6 +457,8 @@ class DatabaseManager(object): ...@@ -423,6 +457,8 @@ class DatabaseManager(object):
- not in self._uncommitted_data - not in self._uncommitted_data
- and not referenced by a fully-committed object (storage should have - and not referenced by a fully-committed object (storage should have
an index or a refcount of all data ids of all objects) an index or a refcount of all data ids of all objects)
The returned value is the number of deleted rows from the data table.
""" """
@abstract @abstract
...@@ -588,6 +624,37 @@ class DatabaseManager(object): ...@@ -588,6 +624,37 @@ class DatabaseManager(object):
self._setTruncateTID(None) self._setTruncateTID(None)
self.commit() self.commit()
def repair(self, weak_app, dry_run):
t = self._repairing
if t and t.is_alive():
logging.error('already repairing')
return
def repair():
l = threading.Lock()
l.acquire()
def finalize():
try:
if data_id_list and not dry_run:
self.commit()
logging.info("repair: deleted %s orphan records",
self._pruneData(data_id_list))
self.commit()
finally:
l.release()
try:
with self._duplicate() as db:
data_id_list = db.getOrphanList()
logging.info("repair: found %s records that may be orphan",
len(data_id_list))
weak_app().em.wakeup(finalize)
l.acquire()
finally:
del self._repairing
logging.info("repair: done")
t = self._repairing = threading.Thread(target=repair)
t.daemon = 1
t.start()
@abstract @abstract
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -56,12 +56,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -56,12 +56,6 @@ class MySQLDatabaseManager(DatabaseManager):
_max_allowed_packet = 32769 * 1024 _max_allowed_packet = 32769 * 1024
def __init__(self, *args, **kw):
super(MySQLDatabaseManager, self).__init__(*args, **kw)
self.conn = None
self._config = {}
self._connect()
def _parse(self, database): def _parse(self, database):
""" Get the database credentials (username, password, database) """ """ Get the database credentials (username, password, database) """
# expected pattern : [user[:password]@]database[(~|.|/)unix_socket] # expected pattern : [user[:password]@]database[(~|.|/)unix_socket]
...@@ -93,6 +87,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -93,6 +87,7 @@ class MySQLDatabaseManager(DatabaseManager):
logging.exception('Connection to MySQL failed, retrying.') logging.exception('Connection to MySQL failed, retrying.')
time.sleep(1) time.sleep(1)
self._active = 0 self._active = 0
self._config = {}
conn = self.conn conn = self.conn
conn.autocommit(False) conn.autocommit(False)
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1)) conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
...@@ -475,6 +470,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -475,6 +470,11 @@ class MySQLDatabaseManager(DatabaseManager):
_structLL = struct.Struct(">LL") _structLL = struct.Struct(">LL")
_unpackLL = _structLL.unpack _unpackLL = _structLL.unpack
def getOrphanList(self):
return [x for x, in self.query(
"SELECT id FROM data LEFT JOIN obj ON (id=data_id)"
" WHERE data_id IS NULL")]
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list: if data_id_list:
...@@ -495,6 +495,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -495,6 +495,8 @@ class MySQLDatabaseManager(DatabaseManager):
if bigid_list: if bigid_list:
q("DELETE FROM bigdata WHERE id IN (%s)" q("DELETE FROM bigdata WHERE id IN (%s)"
% ",".join(map(str, bigid_list))) % ",".join(map(str, bigid_list)))
return len(id_list)
return 0
def _bigData(self, value): def _bigData(self, value):
bigdata_id, length = self._unpackLL(value) bigdata_id, length = self._unpackLL(value)
...@@ -582,11 +584,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -582,11 +584,8 @@ class MySQLDatabaseManager(DatabaseManager):
def abortTransaction(self, ttid): def abortTransaction(self, ttid):
ttid = util.u64(ttid) ttid = util.u64(ttid)
q = self.query q = self.query
sql = " FROM tobj WHERE tid=%s" % ttid q("DELETE FROM tobj WHERE tid=%s" % ttid)
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("DELETE" + sql)
q("DELETE FROM ttrans WHERE ttid=%s" % ttid) q("DELETE FROM ttrans WHERE ttid=%s" % ttid)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid): def deleteTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
......
# #
# Copyright (C) 2012-2016 Nexedi SA # Copyright (C) 2012-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -69,11 +69,6 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -69,11 +69,6 @@ class SQLiteDatabaseManager(DatabaseManager):
VERSION = 1 VERSION = 1
def __init__(self, *args, **kw):
super(SQLiteDatabaseManager, self).__init__(*args, **kw)
self._config = {}
self._connect()
def _parse(self, database): def _parse(self, database):
self.db = os.path.expanduser(database) self.db = os.path.expanduser(database)
...@@ -83,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -83,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
self._config = {}
def _commit(self): def _commit(self):
retry_if_locked(self.conn.commit) retry_if_locked(self.conn.commit)
...@@ -376,6 +372,11 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -376,6 +372,11 @@ class SQLiteDatabaseManager(DatabaseManager):
packed, buffer(''.join(oid_list)), packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid))) buffer(user), buffer(desc), buffer(ext), u64(ttid)))
def getOrphanList(self):
return [x for x, in self.query(
"SELECT id FROM data LEFT JOIN obj ON (id=data_id)"
" WHERE data_id IS NULL")]
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list: if data_id_list:
...@@ -385,6 +386,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -385,6 +386,8 @@ class SQLiteDatabaseManager(DatabaseManager):
% ",".join(map(str, data_id_list)))) % ",".join(map(str, data_id_list))))
q("DELETE FROM data WHERE id IN (%s)" q("DELETE FROM data WHERE id IN (%s)"
% ",".join(map(str, data_id_list))) % ",".join(map(str, data_id_list)))
return len(data_id_list)
return 0
def storeData(self, checksum, data, compression, def storeData(self, checksum, data, compression,
_dup=unique_constraint_message("data", "hash", "compression")): _dup=unique_constraint_message("data", "hash", "compression")):
...@@ -439,11 +442,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -439,11 +442,8 @@ class SQLiteDatabaseManager(DatabaseManager):
def abortTransaction(self, ttid): def abortTransaction(self, ttid):
args = util.u64(ttid), args = util.u64(ttid),
q = self.query q = self.query
sql = " FROM tobj WHERE tid=?" q("DELETE FROM tobj WHERE tid=?", args)
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("DELETE" + sql, args)
q("DELETE FROM ttrans WHERE ttid=?", args) q("DELETE FROM ttrans WHERE ttid=?", args)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid): def deleteTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
......
# #
# Copyright (C) 2010-2016 Nexedi SA # Copyright (C) 2010-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import weakref
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.exception import PrimaryFailure, StoppedOperation from neo.lib.exception import PrimaryFailure, StoppedOperation
...@@ -59,3 +60,7 @@ class BaseMasterHandler(EventHandler): ...@@ -59,3 +60,7 @@ class BaseMasterHandler(EventHandler):
def askFinalTID(self, conn, ttid): def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid))) conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
def notifyRepair(self, conn, *args):
app = self.app
app.dm.repair(weakref.ref(app), *args)
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -77,7 +77,7 @@ class ClientOperationHandler(EventHandler): ...@@ -77,7 +77,7 @@ class ClientOperationHandler(EventHandler):
checksum, data, data_serial, unlock) checksum, data, data_serial, unlock)
except ConflictError, err: except ConflictError, err:
# resolvable or not # resolvable or not
conn.answer(Packets.AnswerStoreObject(1, oid, err.getTID())) conn.answer(Packets.AnswerStoreObject(1, oid, err.tid))
except DelayedError: except DelayedError:
# locked by a previous transaction, retry later # locked by a previous transaction, retry later
# If we are unlocking, we want queueEvent to raise # If we are unlocking, we want queueEvent to raise
...@@ -194,8 +194,7 @@ class ClientOperationHandler(EventHandler): ...@@ -194,8 +194,7 @@ class ClientOperationHandler(EventHandler):
self.app.tm.checkCurrentSerial(ttid, serial, oid) self.app.tm.checkCurrentSerial(ttid, serial, oid)
except ConflictError, err: except ConflictError, err:
# resolvable or not # resolvable or not
conn.answer(Packets.AnswerCheckCurrentSerial(1, oid, conn.answer(Packets.AnswerCheckCurrentSerial(1, oid, err.tid))
err.getTID()))
except DelayedError: except DelayedError:
# locked by a previous transaction, retry later # locked by a previous transaction, retry later
try: try:
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -184,6 +184,10 @@ class StorageOperationHandler(EventHandler): ...@@ -184,6 +184,10 @@ class StorageOperationHandler(EventHandler):
if app.tm.isLockedTid(max_tid): if app.tm.isLockedTid(max_tid):
# Wow, backup cluster is fast. Requested transactions are still in # Wow, backup cluster is fast. Requested transactions are still in
# ttrans/ttobj so wait a little. # ttrans/ttobj so wait a little.
# This can also happen for internal replication, when
# NotifyTransactionFinished(M->S) + AskFetchTransactions(S->S)
# is faster than
# NotifyUnlockInformation(M->S)
app.queueEvent(self.askFetchTransactions, conn, app.queueEvent(self.askFetchTransactions, conn,
(partition, length, min_tid, max_tid, tid_list)) (partition, length, min_tid, max_tid, tid_list))
return return
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -92,6 +92,7 @@ class Replicator(object): ...@@ -92,6 +92,7 @@ class Replicator(object):
def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list): def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list):
"""This is a callback from MasterOperationHandler.""" """This is a callback from MasterOperationHandler."""
assert self.ttid_set.issubset(ttid_list), (self.ttid_set, ttid_list)
if ttid_list: if ttid_list:
self.ttid_set.update(ttid_list) self.ttid_set.update(ttid_list)
max_ttid = max(ttid_list) max_ttid = max(ttid_list)
......
# #
# Copyright (C) 2010-2016 Nexedi SA # Copyright (C) 2010-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -27,10 +27,7 @@ class ConflictError(Exception): ...@@ -27,10 +27,7 @@ class ConflictError(Exception):
def __init__(self, tid): def __init__(self, tid):
Exception.__init__(self) Exception.__init__(self)
self._tid = tid self.tid = tid
def getTID(self):
return self._tid
class DelayedError(Exception): class DelayedError(Exception):
...@@ -47,76 +44,41 @@ class Transaction(object): ...@@ -47,76 +44,41 @@ class Transaction(object):
""" """
Container for a pending transaction Container for a pending transaction
""" """
_tid = None tid = None
has_trans = False has_trans = False
def __init__(self, uuid, ttid): def __init__(self, uuid, ttid):
self._uuid = uuid
self._ttid = ttid
self._object_dict = {}
self._locked = False
self._birth = time() self._birth = time()
self._checked_set = set() self.uuid = uuid
# Consider using lists.
self.store_dict = {}
self.checked_set = set()
def __repr__(self): def __repr__(self):
return "<%s(ttid=%r, tid=%r, uuid=%r, locked=%r, age=%.2fs) at 0x%x>" \ return "<%s(tid=%r, uuid=%r, age=%.2fs) at 0x%x>" \
% (self.__class__.__name__, % (self.__class__.__name__,
dump(self._ttid), dump(self.tid),
dump(self._tid), uuid_str(self.uuid),
uuid_str(self._uuid),
self.isLocked(),
time() - self._birth, time() - self._birth,
id(self)) id(self))
def addCheckedObject(self, oid): def check(self, oid):
assert oid not in self._object_dict, dump(oid) assert oid not in self.store_dict, dump(oid)
self._checked_set.add(oid) assert oid not in self.checked_set, dump(oid)
self.checked_set.add(oid)
def getTTID(self):
return self._ttid
def setTID(self, tid):
assert self._tid is None, dump(self._tid)
assert tid is not None
self._tid = tid
def getTID(self):
return self._tid
def getUUID(self):
return self._uuid
def lock(self):
assert not self._locked
self._locked = True
def isLocked(self): def store(self, oid, data_id, value_serial):
return self._locked
def addObject(self, oid, data_id, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
""" """
assert oid not in self._checked_set, dump(oid) assert oid not in self.checked_set, dump(oid)
self._object_dict[oid] = oid, data_id, value_serial self.store_dict[oid] = oid, data_id, value_serial
def delObject(self, oid): def cancel(self, oid):
try: try:
return self._object_dict.pop(oid)[1] return self.store_dict.pop(oid)[1]
except KeyError: except KeyError:
self._checked_set.remove(oid) self.checked_set.remove(oid)
def getObject(self, oid):
return self._object_dict[oid]
def getObjectList(self):
return self._object_dict.values()
def getOIDList(self):
return self._object_dict.keys()
def getLockedOIDList(self):
return self._object_dict.keys() + list(self._checked_set)
class TransactionManager(object): class TransactionManager(object):
...@@ -145,7 +107,7 @@ class TransactionManager(object): ...@@ -145,7 +107,7 @@ class TransactionManager(object):
Return None if not found. Return None if not found.
""" """
try: try:
return self._transaction_dict[ttid].getObject(oid) return self._transaction_dict[ttid].store_dict[oid]
except KeyError: except KeyError:
return None return None
...@@ -166,7 +128,7 @@ class TransactionManager(object): ...@@ -166,7 +128,7 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
except KeyError: except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid)) raise ProtocolError("unknown ttid %s" % dump(ttid))
object_list = transaction.getObjectList() object_list = transaction.store_dict.itervalues()
if txn_info: if txn_info:
user, desc, ext, oid_list = txn_info user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid txn_info = oid_list, user, desc, ext, False, ttid
...@@ -185,21 +147,20 @@ class TransactionManager(object): ...@@ -185,21 +147,20 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
except KeyError: except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid)) raise ProtocolError("unknown ttid %s" % dump(ttid))
# remember that the transaction has been locked assert transaction.tid is None, dump(transaction.tid)
transaction.lock() assert ttid <= tid, (ttid, tid)
transaction.tid = tid
self._load_lock_dict.update( self._load_lock_dict.update(
dict.fromkeys(transaction.getOIDList(), ttid)) dict.fromkeys(transaction.store_dict, ttid))
# commit transaction and remember its definitive TID
if transaction.has_trans: if transaction.has_trans:
self._app.dm.lockTransaction(tid, ttid) self._app.dm.lockTransaction(tid, ttid)
transaction.setTID(tid)
def unlock(self, ttid): def unlock(self, ttid):
""" """
Unlock transaction Unlock transaction
""" """
try: try:
tid = self._transaction_dict[ttid].getTID() tid = self._transaction_dict[ttid].tid
except KeyError: except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid)) raise ProtocolError("unknown ttid %s" % dump(ttid))
logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid)) logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
...@@ -210,7 +171,7 @@ class TransactionManager(object): ...@@ -210,7 +171,7 @@ class TransactionManager(object):
def getFinalTID(self, ttid): def getFinalTID(self, ttid):
try: try:
return self._transaction_dict[ttid].getTID() return self._transaction_dict[ttid].tid
except KeyError: except KeyError:
return self._app.dm.getFinalTID(ttid) return self._app.dm.getFinalTID(ttid)
...@@ -233,7 +194,7 @@ class TransactionManager(object): ...@@ -233,7 +194,7 @@ class TransactionManager(object):
# drop the lock it held on this object, and drop object data for # drop the lock it held on this object, and drop object data for
# consistency. # consistency.
del self._store_lock_dict[oid] del self._store_lock_dict[oid]
data_id = self._transaction_dict[ttid].delObject(oid) data_id = self._transaction_dict[ttid].cancel(oid)
if data_id: if data_id:
self._app.dm.pruneData((data_id,)) self._app.dm.pruneData((data_id,))
# Give a chance to pending events to take that lock now. # Give a chance to pending events to take that lock now.
...@@ -245,7 +206,7 @@ class TransactionManager(object): ...@@ -245,7 +206,7 @@ class TransactionManager(object):
elif locking_tid == ttid: elif locking_tid == ttid:
# If previous store was an undo, next store must be based on # If previous store was an undo, next store must be based on
# undo target. # undo target.
previous_serial = self._transaction_dict[ttid].getObject(oid)[2] previous_serial = self._transaction_dict[ttid].store_dict[oid][2]
if previous_serial is None: if previous_serial is None:
# XXX: use some special serial when previous store was not # XXX: use some special serial when previous store was not
# an undo ? Maybe it should just not happen. # an undo ? Maybe it should just not happen.
...@@ -290,7 +251,7 @@ class TransactionManager(object): ...@@ -290,7 +251,7 @@ class TransactionManager(object):
except KeyError: except KeyError:
raise NotRegisteredError raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True) self.lockObject(ttid, serial, oid, unlock=True)
transaction.addCheckedObject(oid) transaction.check(oid)
def storeObject(self, ttid, serial, oid, compression, checksum, data, def storeObject(self, ttid, serial, oid, compression, checksum, data,
value_serial, unlock=False): value_serial, unlock=False):
...@@ -307,7 +268,7 @@ class TransactionManager(object): ...@@ -307,7 +268,7 @@ class TransactionManager(object):
data_id = None data_id = None
else: else:
data_id = self._app.dm.holdData(checksum, data, compression) data_id = self._app.dm.holdData(checksum, data, compression)
transaction.addObject(oid, data_id, value_serial) transaction.store(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
""" """
...@@ -323,24 +284,28 @@ class TransactionManager(object): ...@@ -323,24 +284,28 @@ class TransactionManager(object):
return return
logging.debug('Abort TXN %s', dump(ttid)) logging.debug('Abort TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
has_load_lock = transaction.isLocked() locked = transaction.tid
# if the transaction is locked, ensure we can drop it # if the transaction is locked, ensure we can drop it
if has_load_lock: if locked:
if not even_if_locked: if not even_if_locked:
return return
else: else:
self._app.dm.abortTransaction(ttid) dm = self._app.dm
dm.abortTransaction(ttid)
dm.releaseData([x[1] for x in transaction.store_dict.itervalues()],
True)
# unlock any object # unlock any object
for oid in transaction.getLockedOIDList(): for oid in transaction.store_dict, transaction.checked_set:
if has_load_lock: for oid in oid:
if locked:
lock_ttid = self._load_lock_dict.pop(oid, None) lock_ttid = self._load_lock_dict.pop(oid, None)
assert lock_ttid in (ttid, None), 'Transaction %s tried to ' \ assert lock_ttid in (ttid, None), ('Transaction %s tried'
'release the lock on oid %s, but it was held by %s' % ( ' to release the lock on oid %s, but it was held by %s'
dump(ttid), dump(oid), dump(lock_ttid)) % (dump(ttid), dump(oid), dump(lock_ttid)))
write_locking_tid = self._store_lock_dict.pop(oid) write_locking_tid = self._store_lock_dict.pop(oid)
assert write_locking_tid == ttid, 'Inconsistent locking state: ' \ assert write_locking_tid == ttid, ('Inconsistent locking'
'aborting %s:%s but %s has the lock.' % (dump(ttid), dump(oid), ' state: aborting %s:%s but %s has the lock.'
dump(write_locking_tid)) % (dump(ttid), dump(oid), dump(write_locking_tid)))
# remove the transaction # remove the transaction
del self._transaction_dict[ttid] del self._transaction_dict[ttid]
# some locks were released, some pending locks may now succeed # some locks were released, some pending locks may now succeed
...@@ -352,37 +317,35 @@ class TransactionManager(object): ...@@ -352,37 +317,35 @@ class TransactionManager(object):
""" """
logging.debug('Abort for %s', uuid_str(uuid)) logging.debug('Abort for %s', uuid_str(uuid))
# abort any non-locked transaction of this node # abort any non-locked transaction of this node
for transaction in self._transaction_dict.values(): for ttid, transaction in self._transaction_dict.items():
if transaction.getUUID() == uuid: if transaction.uuid == uuid:
self.abort(transaction.getTTID()) self.abort(ttid)
def isLockedTid(self, tid): def isLockedTid(self, tid):
for t in self._transaction_dict.itervalues(): return any(None is not t.tid <= tid
if t.isLocked() and t.getTID() <= tid: for t in self._transaction_dict.itervalues())
return True
return False
def loadLocked(self, oid): def loadLocked(self, oid):
return oid in self._load_lock_dict return oid in self._load_lock_dict
def log(self): def log(self):
logging.info("Transactions:") logging.info("Transactions:")
for txn in self._transaction_dict.values(): for ttid, txn in self._transaction_dict.iteritems():
logging.info(' %r', txn) logging.info(' %s %r', dump(ttid), txn)
logging.info(' Read locks:') logging.info(' Read locks:')
for oid, ttid in self._load_lock_dict.items(): for oid, ttid in self._load_lock_dict.iteritems():
logging.info(' %r by %r', dump(oid), dump(ttid)) logging.info(' %r by %r', dump(oid), dump(ttid))
logging.info(' Write locks:') logging.info(' Write locks:')
for oid, ttid in self._store_lock_dict.items(): for oid, ttid in self._store_lock_dict.iteritems():
logging.info(' %r by %r', dump(oid), dump(ttid)) logging.info(' %r by %r', dump(oid), dump(ttid))
def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id): def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id):
lock_tid = self.getLockingTID(oid) lock_tid = self.getLockingTID(oid)
if lock_tid is not None: if lock_tid is not None:
transaction = self._transaction_dict[lock_tid] transaction = self._transaction_dict[lock_tid]
if transaction.getObject(oid)[2] == orig_serial: if transaction.store_dict[oid][2] == orig_serial:
if new_serial: if new_serial:
data_id = None data_id = None
else: else:
self._app.dm.holdData(data_id) self._app.dm.holdData(data_id)
transaction.addObject(oid, data_id, new_serial) transaction.store(oid, data_id, new_serial)
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -29,7 +29,7 @@ import MySQLdb ...@@ -29,7 +29,7 @@ import MySQLdb
import transaction import transaction
from functools import wraps from functools import wraps
from mock import Mock from .mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property from neo.lib.util import cached_property
...@@ -281,18 +281,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -281,18 +281,6 @@ class NeoUnitTestBase(NeoTestBase):
def getNextTID(self, ltid=None): def getNextTID(self, ltid=None):
return newTid(ltid) return newTid(ltid)
def getPTID(self, i=None):
""" Return an integer PTID """
if i is None:
return random.randint(1, 2**64)
return i
def getOID(self, i=None):
""" Return a 8-bytes OID """
if i is None:
return os.urandom(8)
return pack('!Q', i)
def getFakeConnector(self, descriptor=None): def getFakeConnector(self, descriptor=None):
return Mock({ return Mock({
'__repr__': 'FakeConnector', '__repr__': 'FakeConnector',
...@@ -321,18 +309,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -321,18 +309,6 @@ class NeoUnitTestBase(NeoTestBase):
""" Check if the ProtocolError exception was raised """ """ Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs) self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception was raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception was raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs): def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception was raised """ """ Check if the NotReadyError exception was raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs) self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
...@@ -341,36 +317,19 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -341,36 +317,19 @@ class NeoUnitTestBase(NeoTestBase):
""" Ensure the connection was aborted """ """ Ensure the connection was aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1) self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
def checkNotAborted(self, conn):
""" Ensure the connection was not aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 0)
def checkClosed(self, conn): def checkClosed(self, conn):
""" Ensure the connection was closed """ """ Ensure the connection was closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 1) self.assertEqual(len(conn.mockGetNamedCalls('close')), 1)
def checkNotClosed(self, conn):
""" Ensure the connection was not closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 0)
def _checkNoPacketSend(self, conn, method_id): def _checkNoPacketSend(self, conn, method_id):
call_list = conn.mockGetNamedCalls(method_id) self.assertEqual([], conn.mockGetNamedCalls(method_id))
self.assertEqual(len(call_list), 0, call_list)
def checkNoPacketSent(self, conn, check_notify=True, check_answer=True, def checkNoPacketSent(self, conn):
check_ask=True):
""" check if no packet were sent """ """ check if no packet were sent """
if check_notify:
self._checkNoPacketSend(conn, 'notify') self._checkNoPacketSend(conn, 'notify')
if check_answer:
self._checkNoPacketSend(conn, 'answer') self._checkNoPacketSend(conn, 'answer')
if check_ask:
self._checkNoPacketSend(conn, 'ask') self._checkNoPacketSend(conn, 'ask')
def checkNoUUIDSet(self, conn):
""" ensure no UUID was set on the connection """
self.assertEqual(len(conn.mockGetNamedCalls('setUUID')), 0)
def checkUUIDSet(self, conn, uuid=None, check_intermediate=True): def checkUUIDSet(self, conn, uuid=None, check_intermediate=True):
""" ensure UUID was set on the connection """ """ ensure UUID was set on the connection """
calls = conn.mockGetNamedCalls('setUUID') calls = conn.mockGetNamedCalls('setUUID')
...@@ -384,151 +343,41 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -384,151 +343,41 @@ class NeoUnitTestBase(NeoTestBase):
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurate checks are required # in tests if more accurate checks are required
def checkErrorPacket(self, conn, decode=False): def checkErrorPacket(self, conn):
""" Check if an error packet was answered """ """ Check if an error packet was answered """
calls = conn.mockGetNamedCalls("answer") calls = conn.mockGetNamedCalls("answer")
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), Packets.Error) self.assertEqual(type(packet), Packets.Error)
if decode:
return packet.decode()
return packet return packet
def checkAskPacket(self, conn, packet_type, decode=False): def checkAskPacket(self, conn, packet_type):
""" Check if an ask-packet with the right type is sent """ """ Check if an ask-packet with the right type is sent """
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet return packet
def checkAnswerPacket(self, conn, packet_type, decode=False): def checkAnswerPacket(self, conn, packet_type):
""" Check if an answer-packet with the right type is sent """ """ Check if an answer-packet with the right type is sent """
calls = conn.mockGetNamedCalls('answer') calls = conn.mockGetNamedCalls('answer')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet return packet
def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False): def checkNotifyPacket(self, conn, packet_type, packet_number=0):
""" Check if a notify-packet with the right type is sent """ """ Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('notify') calls = conn.mockGetNamedCalls('notify')
packet = calls.pop(packet_number).getParam(0) packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet return packet
def checkNotify(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.Notify, **kw)
def checkNotifyNodeInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw)
def checkSendPartitionTable(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw)
def checkStartOperation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.StartOperation, **kw)
def checkInvalidateObjects(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.InvalidateObjects, **kw)
def checkAbortTransaction(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.AbortTransaction, **kw)
def checkNotifyLastOID(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyLastOID, **kw)
def checkAnswerTransactionFinished(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionFinished, **kw)
def checkAnswerInformationLocked(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerInformationLocked, **kw)
def checkAskLockInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLockInformation, **kw)
def checkNotifyUnlockInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw)
def checkNotifyTransactionFinished(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)
def checkRequestIdentification(self, conn, **kw):
return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
def checkAskPrimary(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskPrimary)
def checkAskUnfinishedTransactions(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions)
def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw)
def checkAskStoreObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreObject, **kw)
def checkAskStoreTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw)
def checkAskFinishTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskFinishTransaction, **kw)
def checkAskNewTid(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw)
def checkAskLastIDs(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLastIDs, **kw)
def checkAcceptIdentification(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw)
def checkAnswerPrimary(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw)
def checkAnswerLastIDs(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw)
def checkAnswerUnfinishedTransactions(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw)
def checkAnswerObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw)
def checkAnswerTransactionInformation(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw)
def checkAnswerBeginTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction, **kw)
def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerTidsFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
class Patch(object): class Patch(object):
""" """
...@@ -536,7 +385,7 @@ class Patch(object): ...@@ -536,7 +385,7 @@ class Patch(object):
Usage: Usage:
with Patch(someObject, attrToPatch=newValue, [otherAttr=...]) as patch: with Patch(someObject, attrToPatch=newValue) as patch:
[... code that runs with patches ...] [... code that runs with patches ...]
[... code that runs without patch ...] [... code that runs without patch ...]
...@@ -552,10 +401,29 @@ class Patch(object): ...@@ -552,10 +401,29 @@ class Patch(object):
For patched callables, the new one receives the original value as first For patched callables, the new one receives the original value as first
argument. argument.
Alternative usage:
@Patch(someObject)
def funcToPatch(orig, ...):
...
...
funcToPatch.revert()
The decorator applies the patch immediately.
""" """
applied = False applied = False
def __new__(cls, patched, **patch):
if patch:
return object.__new__(cls)
def patch(func):
self = cls(patched, **{func.__name__: func})
self.apply()
return self
return patch
def __init__(self, patched, **patch): def __init__(self, patched, **patch):
(name, patch), = patch.iteritems() (name, patch), = patch.iteritems()
self._patched = patched self._patched = patched
......
from __future__ import print_function
import sys import sys
import smtplib import smtplib
import optparse import optparse
...@@ -34,13 +34,13 @@ class BenchmarkRunner(object): ...@@ -34,13 +34,13 @@ class BenchmarkRunner(object):
parser.add_option('', '--repeat', type='int', default=1) parser.add_option('', '--repeat', type='int', default=1)
self.add_options(parser) self.add_options(parser)
# check common arguments # check common arguments
options, self._args = parser.parse_args() options, args = parser.parse_args()
if bool(options.mail_to) ^ bool(options.mail_from): if bool(options.mail_to) ^ bool(options.mail_from):
sys.exit('Need a sender and recipients to mail report') sys.exit('Need a sender and recipients to mail report')
mail_server = options.mail_server or MAIL_SERVER mail_server = options.mail_server or MAIL_SERVER
# check specifics arguments # check specifics arguments
self._config = AttributeDict() self._config = AttributeDict()
self._config.update(self.load_options(options, self._args)) self._config.update(self.load_options(options, args))
self._config.update( self._config.update(
title = options.title or self.__class__.__name__, title = options.title or self.__class__.__name__,
mail_from = options.mail_from, mail_from = options.mail_from,
...@@ -87,7 +87,7 @@ class BenchmarkRunner(object): ...@@ -87,7 +87,7 @@ class BenchmarkRunner(object):
try: try:
s.sendmail(self._config.mail_from, recipient, mail) s.sendmail(self._config.mail_from, recipient, mail)
except smtplib.SMTPRecipientsRefused: except smtplib.SMTPRecipientsRefused:
print "Mail for %s fails" % recipient print("Mail for %s fails" % recipient)
s.close() s.close()
def run(self): def run(self):
...@@ -95,9 +95,10 @@ class BenchmarkRunner(object): ...@@ -95,9 +95,10 @@ class BenchmarkRunner(object):
report = self.build_report(report) report = self.build_report(report)
if self._config.mail_to: if self._config.mail_to:
self.send_report(subject, report) self.send_report(subject, report)
print subject print(subject)
print if report:
print report print()
print(report, end='')
def was_successful(self): def was_successful(self):
return self._successful return self._successful
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
import unittest import unittest
from mock import Mock, ReturnValues from ..mock import Mock
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, ConflictError
from .. import NeoUnitTestBase, buildUrlFromString from .. import NeoUnitTestBase, buildUrlFromString
from neo.client.app import Application from neo.client.app import Application
from neo.client.cache import test as testCache from neo.client.cache import test as testCache
...@@ -25,14 +24,6 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError ...@@ -25,14 +24,6 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.lib.protocol import NodeTypes, Packets, Errors, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, Errors, UUID_NAMESPACES
from neo.lib.util import makeChecksum from neo.lib.util import makeChecksum
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue, flush_queue=True):
pass
def _getMasterConnection(self): def _getMasterConnection(self):
if self.master_conn is None: if self.master_conn is None:
self.last_tid = None self.last_tid = None
...@@ -43,11 +34,6 @@ def _getMasterConnection(self): ...@@ -43,11 +34,6 @@ def _getMasterConnection(self):
self.master_conn = Mock() self.master_conn = Mock()
return self.master_conn return self.master_conn
def getConnection(kw):
conn = Mock(kw)
conn.lock = threading.RLock()
return conn
def _ask(self, conn, packet, handler=None, **kw): def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None) self.setHandlerData(None)
conn.ask(packet, **kw) conn.ask(packet, **kw)
...@@ -82,6 +68,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -82,6 +68,9 @@ class ClientApplicationTests(NeoUnitTestBase):
# some helpers # some helpers
def checkAskObject(self, conn):
return self.checkAskPacket(conn, Packets.AskObject)
def _begin(self, app, txn, tid): def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn) txn_context = app._txn_container.new(txn)
txn_context['ttid'] = tid txn_context['ttid'] = tid
...@@ -95,11 +84,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -95,11 +84,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.dispatcher = Mock({ }) app.dispatcher = Mock({ })
return app return app
def getConnectionPool(self, conn_list):
return Mock({
'iterateForObject': conn_list,
})
def makeOID(self, value=None): def makeOID(self, value=None):
from random import randint from random import randint
if value is None: if value is None:
...@@ -107,24 +91,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -107,24 +91,6 @@ class ClientApplicationTests(NeoUnitTestBase):
return '\00' * 7 + chr(value) return '\00' * 7 + chr(value)
makeTID = makeOID makeTID = makeOID
def getNodeCellConn(self, index=1, address=('127.0.0.1', 10000), uuid=None):
conn = getConnection({
'getAddress': address,
'__repr__': 'connection mock',
'getUUID': uuid,
})
node = Mock({
'__repr__': 'node%s' % index,
'__hash__': index,
'getConnection': conn,
})
cell = Mock({
'getAddress': 'FakeServer',
'getState': 'FakeState',
'getNode': node,
})
return (node, cell, conn)
def makeTransactionObject(self, user='u', description='d', _extension='e'): def makeTransactionObject(self, user='u', description='d', _extension='e'):
class Transaction(object): class Transaction(object):
pass pass
...@@ -134,22 +100,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -134,22 +100,8 @@ class ClientApplicationTests(NeoUnitTestBase):
txn._extension = _extension txn._extension = _extension
return txn return txn
def beginTransaction(self, app, tid):
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({ 'fakeReceived': packet, })
txn = self.makeTransactionObject()
app.tpc_begin(txn, tid=tid)
return txn
# common checks # common checks
def checkDispatcherRegisterCalled(self, app, conn):
calls = app.dispatcher.mockGetNamedCalls('register')
#self.assertEqual(len(calls), 1)
#self.assertEqual(calls[0].getParam(0), conn)
#self.assertTrue(isinstance(calls[0].getParam(2), Queue))
testCache = testCache testCache = testCache
def test_load(self): def test_load(self):
...@@ -207,47 +159,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -207,47 +159,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertEqual(len(cache._oid_dict[oid]), 2) self.assertEqual(len(cache._oid_dict[oid]), 2)
def test_tpc_begin(self):
app = self.getApp()
tid = self.makeTID()
txn = Mock()
# first, tid is supplied
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
})
app.tpc_begin(transaction=txn, tid=tid)
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEqual(txn_context['ttid'], tid)
# next, the transaction already begin -> raise
self.assertRaises(StorageTransactionError, app.tpc_begin,
transaction=txn, tid=None)
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEqual(txn_context['ttid'], tid)
# start a transaction without tid
txn = Mock()
# no connection -> NEOStorageError (wait until connected to primary)
#self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None)
# ask a tid to pmn
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
})
app.tpc_begin(transaction=txn, tid=None)
self.checkAskNewTid(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn)
# check attributes
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEqual(txn_context['ttid'], tid)
def test_store1(self): def test_store1(self):
app = self.getApp() app = self.getApp()
oid = self.makeOID(11) oid = self.makeOID(11)
...@@ -265,111 +176,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -265,111 +176,6 @@ class ClientApplicationTests(NeoUnitTestBase):
calls = app.pt.mockGetNamedCalls('getCellList') calls = app.pt.mockGetNamedCalls('getCellList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
def test_store2(self):
app = self.getApp()
oid = self.makeOID(11)
tid = self.makeTID()
txn = self.makeTransactionObject()
# build conflicting state
txn_context = self._begin(app, txn, tid)
packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid)
packet.setId(0)
storage_address = ('127.0.0.1', 10020)
node, cell, conn = self.getNodeCellConn(address=storage_address)
app.pt = Mock()
app.cp = self.getConnectionPool([(node, conn)])
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
data_dict = txn_context['data_dict']
data_dict[oid] = 'BEFORE'
app.store(oid, tid, '', None, txn)
txn_context['queue'].put((conn, packet, {}))
self.assertRaises(ConflictError, app.waitStoreResponses, txn_context,
failing_tryToResolveConflict)
self.assertTrue(oid not in data_dict)
self.assertEqual(txn_context['object_stored_counter_dict'][oid], {})
self.checkAskStoreObject(conn)
def test_store3(self):
app = self.getApp()
uuid = self.getStorageUUID()
oid = self.makeOID(11)
tid = self.makeTID()
txn = self.makeTransactionObject()
# case with no conflict
txn_context = self._begin(app, txn, tid)
packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
packet.setId(0)
storage_address = ('127.0.0.1', 10020)
node, cell, conn = self.getNodeCellConn(address=storage_address,
uuid=uuid)
app.cp = self.getConnectionPool([(node, conn)])
app.pt = Mock()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn)
txn_context['queue'].put((conn, packet, {}))
app.waitStoreResponses(txn_context, None) # no conflict in this test
self.assertEqual(txn_context['object_stored_counter_dict'][oid],
{tid: {uuid}})
self.assertEqual(txn_context['cache_dict'][oid], 'DATA')
self.assertFalse(oid in txn_context['data_dict'])
self.assertFalse(oid in txn_context['conflict_serial_dict'])
def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """
app = self.getApp()
# three partitions/storages: one per object/transaction
app.num_partitions = num_partitions = 3
app.num_replicas = 0
tid = self.makeTID(num_partitions) # on partition 0
oid1 = self.makeOID(num_partitions + 1) # on partition 1, conflicting
oid2 = self.makeOID(num_partitions + 2) # on partition 2
# storage nodes
address1 = ('127.0.0.1', 10000); uuid1 = self.getMasterUUID()
address2 = ('127.0.0.1', 10001); uuid2 = self.getStorageUUID()
address3 = ('127.0.0.1', 10002); uuid3 = self.getStorageUUID()
app.nm.createMaster(address=address1, uuid=uuid1)
app.nm.createStorage(address=address2, uuid=uuid2)
app.nm.createStorage(address=address3, uuid=uuid3)
# answer packets
packet1 = Packets.AnswerStoreTransaction(tid=tid)
packet2 = Packets.AnswerStoreObject(conflicting=1, oid=oid1, serial=tid)
packet3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid)
[p.setId(i) for p, i in zip([packet1, packet2, packet3], range(3))]
conn1 = getConnection({'__repr__': 'conn1', 'getAddress': address1,
'fakeReceived': packet1, 'getUUID': uuid1})
conn2 = getConnection({'__repr__': 'conn2', 'getAddress': address2,
'fakeReceived': packet2, 'getUUID': uuid2})
conn3 = getConnection({'__repr__': 'conn3', 'getAddress': address3,
'fakeReceived': packet3, 'getUUID': uuid3})
node1 = Mock({'__repr__': 'node1', '__hash__': 1, 'getConnection': conn1})
node2 = Mock({'__repr__': 'node2', '__hash__': 2, 'getConnection': conn2})
node3 = Mock({'__repr__': 'node3', '__hash__': 3, 'getConnection': conn3})
# fake environment
app.cp = Mock({'getConnForCell': ReturnValues(conn2, conn3, conn1)})
app.cp = Mock({
'getConnForNode': ReturnValues(conn2, conn3, conn1),
'iterateForObject': [(node2, conn2), (node3, conn3), (node1, conn1)],
})
app.master_conn = Mock({'__hash__': 0})
txn = self.makeTransactionObject()
txn_context = self._begin(app, txn, tid)
app.dispatcher = Dispatcher()
# conflict occurs on storage 2
app.store(oid1, tid, 'DATA', None, txn)
app.store(oid2, tid, 'DATA', None, txn)
queue = txn_context['queue']
queue.put((conn2, packet2, {}))
queue.put((conn3, packet3, {}))
# vote fails as the conflict is not resolved, nothing is sent to storage 3
self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict)
# abort must be sent to storage 1 and 2
app.tpc_abort(txn)
self.checkAbortTransaction(conn2)
self.checkAbortTransaction(conn3)
def test_undo1(self): def test_undo1(self):
# invalid transaction # invalid transaction
app = self.getApp() app = self.getApp()
...@@ -383,169 +189,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -383,169 +189,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
def _getAppForUndoTests(self, oid0, tid0, tid1, tid2):
app = self.getApp()
cell = Mock({
'getAddress': 'FakeServer',
'getState': 'FakeState',
})
app.pt = Mock({'getCellList': [cell]})
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, ))
transaction_info.setId(1)
conn = getConnection({
'getNextId': 1,
'fakeReceived': transaction_info,
'getAddress': ('127.0.0.1', 10020),
})
node = app.nm.createStorage(address=conn.getAddress())
app.cp = Mock({
'iterateForObject': [(node, conn)],
'getConnForCell': conn,
})
app.dispatcher = Dispatcher()
def load(oid, tid=None, before_tid=None):
self.assertEqual(oid, oid0)
return ({tid0: 'dummy', tid2: 'cdummy'}[tid], None, None)
app.load = load
store_marker = []
def _store(txn_context, oid, serial, data, data_serial=None,
unlock=False):
store_marker.append((oid, serial, data, data_serial))
app._store = _store
app.last_tid = self.getNextTID()
return app, conn, store_marker
def test_undoWithResolutionSuccess(self):
"""
Try undoing transaction tid1, which contains object oid.
Object oid previous revision before tid1 is tid0.
Transaction tid2 modified oid (and contains its data).
Undo is accepted, because conflict resolution succeeds.
"""
oid0 = self.makeOID(1)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
conn.ask = lambda p, queue=None, **kw: \
isinstance(p, Packets.AskObjectUndoSerial) and \
queue.put((conn, undo_serial, kw))
undo_serial.setId(2)
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append((oid, conflict_serial, serial, data, committedData))
return 'solved'
# The undo
txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, tryToResolveConflict)
# Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mconflict_serial, tid2)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, 'dummy')
self.assertEqual(mcommittedData, 'cdummy')
moid, mserial, mdata, mdata_serial = store_marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mserial, tid2)
self.assertEqual(mdata, 'solved')
self.assertEqual(mdata_serial, None)
def test_undoWithResolutionFailure(self):
"""
Try undoing transaction tid1, which contains object oid.
Object oid previous revision before tid1 is tid0.
Transaction tid2 modified oid (and contains its data).
Undo is rejected with a raise, because conflict resolution fails.
"""
oid0 = self.makeOID(1)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
txn = self.beginTransaction(app, tid=tid3)
self.assertRaises(UndoError, app.undo, tid1, txn, tryToResolveConflict)
# Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mconflict_serial, tid2)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, 'dummy')
self.assertEqual(mcommittedData, 'cdummy')
self.assertEqual(len(store_marker), 0)
# Likewise, but conflict resolver raises a ConflictError.
# Still, exception raised by undo() must be UndoError.
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
self.assertRaises(UndoError, app.undo, tid1, txn, tryToResolveConflict)
# Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mconflict_serial, tid2)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, 'dummy')
self.assertEqual(mcommittedData, 'cdummy')
self.assertEqual(len(store_marker), 0)
def test_undo(self):
"""
Try undoing transaction tid1, which contains object oid.
Object oid previous revision before tid1 is tid0.
Undo is accepted, because tid1 is object's current revision.
"""
oid0 = self.makeOID(1)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, ))
transaction_info.setId(1)
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid1, tid0, True)})
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
# The undo
txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, None) # no conflict resolution in this test
# Checking what happened
moid, mserial, mdata, mdata_serial = store_marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, None)
self.assertEqual(mdata_serial, tid0)
def test_connectToPrimaryNode(self): def test_connectToPrimaryNode(self):
# here we have three master nodes : # here we have three master nodes :
# the connection to the first will fail # the connection to the first will fail
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time, unittest import time, unittest
from mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool from neo.client.app import ConnectionPool
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.client import pool from neo.client import pool
from neo.lib.util import p64
class ConnectionPoolTests(NeoUnitTestBase): class ConnectionPoolTests(NeoUnitTestBase):
...@@ -54,7 +55,7 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -54,7 +55,7 @@ class ConnectionPoolTests(NeoUnitTestBase):
def test_iterateForObject_noStorageAvailable(self): def test_iterateForObject_noStorageAvailable(self):
# no node available # no node available
oid = self.getOID(1) oid = p64(1)
app = Mock() app = Mock()
app.pt = Mock({'getCellList': []}) app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app) pool = ConnectionPool(app)
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.handlers.master import PrimaryAnswersHandler from neo.client.handlers.master import PrimaryAnswersHandler
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
......
#
# Copyright (C) 2009-2016 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.client.handlers.storage import StorageAnswersHandler
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
class StorageAnswerHandlerTests(NeoUnitTestBase):
def setUp(self):
super(StorageAnswerHandlerTests, self).setUp()
self.app = Mock()
self.handler = StorageAnswersHandler(self.app)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict):
app = Mock({
'getHandlerData': {
'object_stored_counter_dict': object_stored_counter_dict,
'conflict_serial_dict': conflict_serial_dict,
'resolved_conflict_serial_dict': resolved_conflict_serial_dict,
}
})
return StorageAnswersHandler(app)
def test_answerStoreObject_1(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid = self.getNextTID()
# conflict
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(conflict_serial_dict[oid], {tid})
self.assertEqual(object_stored_counter_dict[oid], {})
self.assertFalse(oid in resolved_conflict_serial_dict)
# object was already accepted by another storage, raise
handler = self._getAnswerStoreObjectHandler({oid: {tid: {1}}}, {}, {})
self.assertRaises(NEOStorageError, handler.answerStoreObject,
conn, 1, oid, tid)
def test_answerStoreObject_2(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid = self.getNextTID()
tid_2 = self.getNextTID()
# resolution-pending conflict
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {oid: {tid}}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(conflict_serial_dict[oid], {tid})
self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {})
# object was already accepted by another storage, raise
handler = self._getAnswerStoreObjectHandler({oid: {tid: {1}}},
{oid: {tid}}, {})
self.assertRaises(NEOStorageError, handler.answerStoreObject,
conn, 1, oid, tid)
# detected conflict is different, don't raise
self._getAnswerStoreObjectHandler({oid: {}}, {oid: {tid}}, {},
).answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_3(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid = self.getNextTID()
tid_2 = self.getNextTID()
# already-resolved conflict
# This case happens if a storage is answering a store action for which
# any other storage already answered (with same conflict) and any other
# storage accepted the resolved object.
object_stored_counter_dict = {oid: {tid_2: 1}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {oid: {tid}}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertFalse(oid in conflict_serial_dict)
self.assertEqual(resolved_conflict_serial_dict[oid], {tid})
self.assertEqual(object_stored_counter_dict[oid], {tid_2: 1})
# detected conflict is different, don't raise
self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {},
{oid: {tid}}).answerStoreObject(conn, 1, oid, tid_2)
def test_tidNotFound(self):
conn = self.getFakeConnection()
self.assertRaises(NEOStorageNotFoundError, self.handler.tidNotFound,
conn, 'message')
if __name__ == '__main__':
unittest.main()
# #
# Copyright (C) 2011-2016 Nexedi SA # Copyright (C) 2011-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2014-2016 Nexedi SA # Copyright (C) 2014-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -424,6 +424,18 @@ class NEOCluster(object): ...@@ -424,6 +424,18 @@ class NEOCluster(object):
if not pdb.wait(test, MAX_START_TIME): if not pdb.wait(test, MAX_START_TIME):
raise AssertionError('Timeout when starting cluster') raise AssertionError('Timeout when starting cluster')
def startCluster(self):
# Even if the storage nodes are in the expected state, there may still
# be activity between them and the master, preventing the cluster to
# start.
def start(last_try):
try:
self.neoctl.startCluster()
except (NotReadyException, RuntimeError), e:
return False, e
return True, None
self.expectCondition(start)
def stop(self, clients=True): def stop(self, clients=True):
# Suspend all processes to kill before actually killing them, so that # Suspend all processes to kill before actually killing them, so that
# nodes don't log errors because they get disconnected from other nodes: # nodes don't log errors because they get disconnected from other nodes:
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -267,8 +267,8 @@ class ClientTests(NEOFunctionalTest): ...@@ -267,8 +267,8 @@ class ClientTests(NEOFunctionalTest):
db2, conn2 = self.neo.getZODBConnection() db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage st1, st2 = conn1._storage, conn2._storage
t1, t2 = transaction.Transaction(), transaction.Transaction() t1, t2 = transaction.Transaction(), transaction.Transaction()
t1.user = t2.user = 'user' t1.user = t2.user = u'user'
t1.description = t2.description = 'desc' t1.description = t2.description = u'desc'
oid = st1.new_oid() oid = st1.new_oid()
rev = '\0' * 8 rev = '\0' * 8
data = zodb_pickle(PObject()) data = zodb_pickle(PObject())
...@@ -311,8 +311,8 @@ class ClientTests(NEOFunctionalTest): ...@@ -311,8 +311,8 @@ class ClientTests(NEOFunctionalTest):
db2, conn2 = self.neo.getZODBConnection() db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage st1, st2 = conn1._storage, conn2._storage
t1, t2 = transaction.Transaction(), transaction.Transaction() t1, t2 = transaction.Transaction(), transaction.Transaction()
t1.user = t2.user = 'user' t1.user = t2.user = u'user'
t1.description = t2.description = 'desc' t1.description = t2.description = u'desc'
oid = st1.new_oid() oid = st1.new_oid()
rev = '\0' * 8 rev = '\0' * 8
data = zodb_pickle(PObject()) data = zodb_pickle(PObject())
...@@ -330,8 +330,8 @@ class ClientTests(NEOFunctionalTest): ...@@ -330,8 +330,8 @@ class ClientTests(NEOFunctionalTest):
db3, conn3 = self.neo.getZODBConnection() db3, conn3 = self.neo.getZODBConnection()
st3 = conn3._storage st3 = conn3._storage
t3 = transaction.Transaction() t3 = transaction.Transaction()
t3.user = 'user' t3.user = u'user'
t3.description = 'desc' t3.description = u'desc'
st3.tpc_begin(t3) st3.tpc_begin(t3)
# retrieve the last revision # retrieve the last revision
data, serial = st3.load(oid) data, serial = st3.load(oid)
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -31,7 +31,6 @@ class ClusterTests(NEOFunctionalTest): ...@@ -31,7 +31,6 @@ class ClusterTests(NEOFunctionalTest):
def testClusterStartup(self): def testClusterStartup(self):
neo = self.neo = NEOCluster(['test_neo1', 'test_neo2'], replicas=1, neo = self.neo = NEOCluster(['test_neo1', 'test_neo2'], replicas=1,
temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
neoctl = neo.neoctl
neo.run() neo.run()
# Runing a new cluster doesn't exit Recovery state. # Runing a new cluster doesn't exit Recovery state.
s1, s2 = neo.getStorageProcessList() s1, s2 = neo.getStorageProcessList()
...@@ -40,7 +39,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -40,7 +39,7 @@ class ClusterTests(NEOFunctionalTest):
neo.expectClusterRecovering() neo.expectClusterRecovering()
# When allowing cluster to exit Recovery, it reaches Running state and # When allowing cluster to exit Recovery, it reaches Running state and
# all present storage nodes reach running state. # all present storage nodes reach running state.
neoctl.startCluster() neo.startCluster()
neo.expectRunning(s1) neo.expectRunning(s1)
neo.expectRunning(s2) neo.expectRunning(s2)
neo.expectClusterRunning() neo.expectClusterRunning()
...@@ -64,7 +63,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -64,7 +63,7 @@ class ClusterTests(NEOFunctionalTest):
neo.expectPending(s1) neo.expectPending(s1)
neo.expectUnknown(s2) neo.expectUnknown(s2)
neo.expectClusterRecovering() neo.expectClusterRecovering()
neoctl.startCluster() neo.startCluster()
neo.expectRunning(s1) neo.expectRunning(s1)
neo.expectUnknown(s2) neo.expectUnknown(s2)
neo.expectClusterRunning() neo.expectClusterRunning()
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,14 +14,12 @@ ...@@ -14,14 +14,12 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
import unittest import unittest
import transaction import transaction
from persistent import Persistent from persistent import Persistent
from . import NEOCluster, NEOFunctionalTest from . import NEOCluster, NEOFunctionalTest
from neo.lib.protocol import ClusterStates, NodeStates from neo.lib.protocol import ClusterStates, NodeStates
from ZODB.tests.StorageTestBase import zodb_pickle
class PObject(Persistent): class PObject(Persistent):
...@@ -421,47 +419,5 @@ class StorageTests(NEOFunctionalTest): ...@@ -421,47 +419,5 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
def testReplicationBlockedByUnfinished(self):
# start a cluster with 1 of 2 storages and a replica
(started, stopped) = self.__setup(storage_number=2, replicas=1,
pending_number=1, partitions=10)
self.neo.expectRunning(started[0])
self.neo.expectStorageNotKnown(stopped[0])
self.neo.expectOudatedCells(number=0)
self.neo.expectClusterRunning()
self.__populate()
self.neo.expectOudatedCells(number=0)
# start a transaction that will block the end of the replication
db, conn = self.neo.getZODBConnection()
st = conn._storage
t = transaction.Transaction()
t.user = 'user'
t.description = 'desc'
oid = st.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject(42))
st.tpc_begin(t)
st.store(oid, rev, data, '', t)
# start the outdated storage
stopped[0].start()
self.neo.expectPending(stopped[0])
self.neo.neoctl.enableStorageList([stopped[0].getUUID()])
self.neo.neoctl.tweakPartitionTable()
self.neo.expectRunning(stopped[0])
self.neo.expectClusterRunning()
self.neo.expectAssignedCells(started[0], 10)
self.neo.expectAssignedCells(stopped[0], 10)
# wait a bit, replication must not happen. This hack is required
# because we cannot gather informations directly from the storages
time.sleep(10)
self.neo.expectOudatedCells(number=10)
# finish the transaction, the replication must happen and finish
st.tpc_vote(t)
st.tpc_finish(t)
self.neo.expectOudatedCells(number=0, timeout=10)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -62,6 +63,9 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -62,6 +63,9 @@ class MasterClientHandlerTests(NeoUnitTestBase):
) )
return uuid return uuid
def checkAnswerBeginTransaction(self, conn):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction)
# Tests # Tests
def test_07_askBeginTransaction(self): def test_07_askBeginTransaction(self):
tid1 = self.getNextTID() tid1 = self.getNextTID()
...@@ -87,12 +91,12 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -87,12 +91,12 @@ class MasterClientHandlerTests(NeoUnitTestBase):
calls = tm.mockGetNamedCalls('begin') calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_node, None) calls[0].checkArgs(client_node, None)
args = self.checkAnswerBeginTransaction(conn, decode=True) packet = self.checkAnswerBeginTransaction(conn)
self.assertEqual(args, (tid1, )) self.assertEqual(packet.decode(), (tid1, ))
def test_08_askNewOIDs(self): def test_08_askNewOIDs(self):
service = self.service service = self.service
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = p64(1), p64(2)
self.app.tm.setLastOID(oid1) self.app.tm.setLastOID(oid1)
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
...@@ -136,7 +140,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -136,7 +140,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.setStorageReady(storage_uuid) self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid)) self.assertTrue(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ()) service.askFinishTransaction(conn, ttid, (), ())
self.checkAskLockInformation(storage_conn) self.checkAskPacket(storage_conn, Packets.AskLockInformation)
self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1) self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1)
txn = self.app.tm[ttid] txn = self.app.tm[ttid]
pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0] pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0]
...@@ -170,8 +174,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -170,8 +174,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack, ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0]
decode=True)[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -183,8 +186,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -183,8 +186,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack, status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0]
decode=True)[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from neo.lib import protocol from neo.lib import protocol
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.election import ClientElectionHandler, \ from neo.master.handlers.election import ClientElectionHandler, \
ServerElectionHandler ServerElectionHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -48,6 +48,9 @@ class MasterClientElectionTestBase(NeoUnitTestBase): ...@@ -48,6 +48,9 @@ class MasterClientElectionTestBase(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return (node, conn) return (node, conn)
def checkAcceptIdentification(self, conn):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification)
class MasterClientElectionTests(MasterClientElectionTestBase): class MasterClientElectionTests(MasterClientElectionTestBase):
def setUp(self): def setUp(self):
...@@ -91,7 +94,7 @@ class MasterClientElectionTests(MasterClientElectionTestBase): ...@@ -91,7 +94,7 @@ class MasterClientElectionTests(MasterClientElectionTestBase):
self.election.connectionCompleted(conn) self.election.connectionCompleted(conn)
self._checkUnconnected(node) self._checkUnconnected(node)
self.assertTrue(node.isUnknown()) self.assertTrue(node.isUnknown())
self.checkRequestIdentification(conn) self.checkAskPacket(conn, Packets.RequestIdentification)
def _setNegociating(self, node): def _setNegociating(self, node):
self._checkUnconnected(node) self._checkUnconnected(node)
...@@ -252,9 +255,8 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -252,9 +255,8 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.election.requestIdentification(conn, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
args = self.checkAcceptIdentification(conn, decode=True)
(node_type, uuid, partitions, replicas, new_uuid, primary_uuid, (node_type, uuid, partitions, replicas, new_uuid, primary_uuid,
master_list) = args master_list) = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node.getUUID(), new_uuid) self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid) self.assertNotEqual(node.getUUID(), uuid)
...@@ -290,7 +292,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -290,7 +292,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
None, None,
) )
node_type, uuid, partitions, replicas, _peer_uuid, primary, \ node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True) master_list = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node_type, NodeTypes.MASTER) self.assertEqual(node_type, NodeTypes.MASTER)
self.assertEqual(uuid, self.app.uuid) self.assertEqual(uuid, self.app.uuid)
self.assertEqual(partitions, self.app.pt.getPartitions()) self.assertEqual(partitions, self.app.pt.getPartitions())
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import unittest import unittest
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import Packets
from neo.master.app import Application from neo.master.app import Application
class MasterAppTests(NeoUnitTestBase): class MasterAppTests(NeoUnitTestBase):
...@@ -31,6 +32,9 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -31,6 +32,9 @@ class MasterAppTests(NeoUnitTestBase):
self.app.close() self.app.close()
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
def checkNotifyNodeInformation(self, conn):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation)
def test_06_broadcastNodeInformation(self): def test_06_broadcastNodeInformation(self):
# defined some nodes to which data will be send # defined some nodes to which data will be send
master_uuid = self.getMasterUUID() master_uuid = self.getMasterUUID()
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.master.handlers.storage import StorageServiceHandler from neo.master.handlers.storage import StorageServiceHandler
...@@ -71,10 +71,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -71,10 +71,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.checkNoPacketSent(client_conn) self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], {conn2.getUUID()}) self.assertEqual(self.app.packing[2], {conn2.getUUID()})
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
status = self.checkAnswerPacket(client_conn, Packets.AnswerPack, packet = self.checkAnswerPacket(client_conn, Packets.AnswerPack)
decode=True)[0]
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(status) self.assertTrue(packet.decode()[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from struct import pack from struct import pack
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock, ReturnValues from ..mock import Mock, ReturnValues
from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState from neo.lib.util import p64
from neo.lib.protocol import INVALID_TID, Packets, LockState
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
...@@ -30,11 +30,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -30,11 +30,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getStorageConfiguration(master_number=1) config = self.getStorageConfiguration(master_number=1)
self.app = Application(config) self.app = Application(config)
self.app.transaction_dict = {}
self.app.store_lock_dict = {}
self.app.load_lock_dict = {}
self.app.event_queue = deque()
self.app.event_queue_dict = {}
self.app.tm = Mock({'__contains__': True}) self.app.tm = Mock({'__contains__': True})
# handler # handler
self.operation = ClientOperationHandler(self.app) self.operation = ClientOperationHandler(self.app)
...@@ -59,19 +54,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -59,19 +54,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askTransactionInformation(conn, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_24_askObject1(self):
# delayed response
conn = self._getConnection()
self.app.dm = Mock()
self.app.tm = Mock({'loadLocked': True})
self.app.load_lock_dict[INVALID_OID] = object()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID,
serial=INVALID_TID, tid=INVALID_TID)
self.assertEqual(len(self.app.event_queue), 1)
self.checkNoPacketSent(conn)
self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
def test_25_askTIDs1(self): def test_25_askTIDs1(self):
# invalid offsets => error # invalid offsets => error
app = self.app app = self.app
...@@ -91,7 +73,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -91,7 +73,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, [1, ]) calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn) self.checkAnswerPacket(conn, Packets.AnswerTIDs)
def test_26_askObjectHistory1(self): def test_26_askObjectHistory1(self):
# invalid offsets => error # invalid offsets => error
...@@ -108,8 +90,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -108,8 +90,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
ltid = self.getNextTID() ltid = self.getNextTID()
undone_tid = self.getNextTID() undone_tid = self.getNextTID()
# Keep 2 entries here, so we check findUndoTID is called only once. # Keep 2 entries here, so we check findUndoTID is called only once.
oid_list = [self.getOID(1), self.getOID(2)] oid_list = map(p64, (1, 2))
obj2_data = [] # Marker
self.app.tm = Mock({ self.app.tm = Mock({
'getObjectFromTransaction': None, 'getObjectFromTransaction': None,
}) })
...@@ -134,7 +115,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -134,7 +115,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
conn = self._getConnection() conn = self._getConnection()
self.operation.askHasLock(conn, tid_1, oid) self.operation.askHasLock(conn, tid_1, oid)
p_oid, p_status = self.checkAnswerPacket(conn, p_oid, p_status = self.checkAnswerPacket(conn,
Packets.AnswerHasLock, decode=True) Packets.AnswerHasLock).decode()
self.assertEqual(oid, p_oid) self.assertEqual(oid, p_oid)
self.assertEqual(status, p_status) self.assertEqual(status, p_status)
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from collections import deque from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.lib.protocol import CellStates from neo.lib.protocol import CellStates
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -103,20 +103,19 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -103,20 +103,19 @@ class StorageDBTests(NeoUnitTestBase):
def test_15_PTID(self): def test_15_PTID(self):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1)) self.checkConfigEntry(db.getPTID, db.setPTID, 1)
def test_getPartitionTable(self): def test_getPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1)
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID() uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
cell1 = (0, uuid1, CellStates.OUT_OF_DATE) cell1 = (0, uuid1, CellStates.OUT_OF_DATE)
cell2 = (1, uuid1, CellStates.UP_TO_DATE) cell2 = (1, uuid1, CellStates.UP_TO_DATE)
db.changePartitionTable(ptid, [cell1, cell2], 1) db.changePartitionTable(1, [cell1, cell2], 1)
result = db.getPartitionTable() result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2}) self.assertEqual(set(result), {cell1, cell2})
def getOIDs(self, count): def getOIDs(self, count):
return map(self.getOID, xrange(count)) return map(p64, xrange(count))
def getTIDs(self, count): def getTIDs(self, count):
tid_list = [self.getNextTID()] tid_list = [self.getNextTID()]
...@@ -198,7 +197,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -198,7 +197,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_setPartitionTable(self): def test_setPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = 1
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE cell2 = 1, uuid, CellStates.UP_TO_DATE
...@@ -220,7 +219,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -220,7 +219,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_changePartitionTable(self): def test_changePartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = 1
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE cell2 = 1, uuid, CellStates.UP_TO_DATE
...@@ -301,7 +300,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -301,7 +300,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_deleteRange(self): def test_deleteRange(self):
np = 4 np = 4
self.setNumPartitions(np) self.setNumPartitions(np)
t1, t2, t3 = map(self.getOID, (1, 2, 3)) t1, t2, t3 = map(p64, (1, 2, 3))
oid_list = self.getOIDs(np * 2) oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3: for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list) txn, objs = self.getTransaction(oid_list)
...@@ -339,7 +338,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -339,7 +338,7 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getObjectHistory(self): def test_getObjectHistory(self):
oid = self.getOID(1) oid = p64(1)
tid1, tid2, tid3 = self.getTIDs(3) tid1, tid2, tid3 = self.getTIDs(3)
txn1, objs1 = self.getTransaction([oid]) txn1, objs1 = self.getTransaction([oid])
txn2, objs2 = self.getTransaction([oid]) txn2, objs2 = self.getTransaction([oid])
...@@ -362,7 +361,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -362,7 +361,7 @@ class StorageDBTests(NeoUnitTestBase):
def _storeTransactions(self, count): def _storeTransactions(self, count):
# use OID generator to know result of tid % N # use OID generator to know result of tid % N
tid_list = self.getOIDs(count) tid_list = self.getOIDs(count)
oid = self.getOID(1) oid = p64(1)
for tid in tid_list: for tid in tid_list:
txn, objs = self.getTransaction([oid]) txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn, False) self.db.storeTransaction(tid, objs, txn, False)
...@@ -446,7 +445,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -446,7 +445,7 @@ class StorageDBTests(NeoUnitTestBase):
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID() tid5 = self.getNextTID()
oid1 = self.getOID(1) oid1 = p64(1)
foo = db.holdData("3" * 20, 'foo', 0) foo = db.holdData("3" * 20, 'foo', 0)
bar = db.holdData("4" * 20, 'bar', 0) bar = db.holdData("4" * 20, 'bar', 0)
db.releaseData((foo, bar)) db.releaseData((foo, bar))
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from MySQLdb import OperationalError from MySQLdb import OperationalError
from mock import Mock from ..mock import Mock
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from neo.lib.util import p64 from neo.lib.util import p64
from .. import DB_PREFIX, DB_SOCKET, DB_USER from .. import DB_PREFIX, DB_SOCKET, DB_USER
...@@ -96,9 +96,11 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -96,9 +96,11 @@ class StorageMySQLdbTests(StorageDBTests):
assert len(x) + EXTRA == self.db._max_allowed_packet assert len(x) + EXTRA == self.db._max_allowed_packet
self.assertRaises(DatabaseFailure, self.db.query, x + ' ') self.assertRaises(DatabaseFailure, self.db.query, x + ' ')
self.db.query(x) self.db.query(x)
# Reconnection cleared the cache of the config table,
# so fill it again with required values before we patch query().
self.db.getNumPartitions()
# Check MySQLDatabaseManager._max_allowed_packet # Check MySQLDatabaseManager._max_allowed_packet
query_list = [] query_list = []
query = self.db.query
self.db.query = lambda query: query_list.append(EXTRA + len(query)) self.db.query = lambda query: query_list.append(EXTRA + len(query))
self.assertEqual(2, max(len(self.db.escape(chr(x))) self.assertEqual(2, max(len(self.db.escape(chr(x)))
for x in xrange(256))) for x in xrange(256)))
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2010-2016 Nexedi SA # Copyright (C) 2010-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,21 +15,12 @@ ...@@ -15,21 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.transactions import Transaction, TransactionManager from neo.lib.util import p64
from neo.storage.transactions import TransactionManager
class TransactionTests(NeoUnitTestBase):
def testLock(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
self.assertFalse(txn.isLocked())
txn.lock()
self.assertTrue(txn.isLocked())
# disallow lock more than once
self.assertRaises(AssertionError, txn.lock)
class TransactionManagerTests(NeoUnitTestBase): class TransactionManagerTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
...@@ -46,7 +37,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -46,7 +37,7 @@ class TransactionManagerTests(NeoUnitTestBase):
def test_updateObjectDataForPack(self): def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID() ram_serial = self.getNextTID()
oid = self.getOID(1) oid = p64(1)
orig_serial = self.getNextTID() orig_serial = self.getNextTID()
uuid = self.getClientUUID() uuid = self.getClientUUID()
locking_serial = self.getNextTID() locking_serial = self.getNextTID()
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes, Packets
class BootstrapManagerTests(NeoUnitTestBase): class BootstrapManagerTests(NeoUnitTestBase):
...@@ -46,7 +46,7 @@ class BootstrapManagerTests(NeoUnitTestBase): ...@@ -46,7 +46,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
conn = self.getFakeConnection(address=address) conn = self.getFakeConnection(address=address)
self.bootstrap.current = self.app.nm.createMaster(address=address) self.bootstrap.current = self.app.nm.createMaster(address=address)
self.bootstrap.connectionCompleted(conn) self.bootstrap.connectionCompleted(conn)
self.checkRequestIdentification(conn) self.checkAskPacket(conn, Packets.RequestIdentification)
def testHandleNotReady(self): def testHandleNotReady(self):
# the primary is not ready # the primary is not ready
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from time import time from time import time
from mock import Mock from .mock import Mock
from neo.lib import connection, logging from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \ from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, CRITICAL_TIMEOUT MTClientConnection, CRITICAL_TIMEOUT
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock from .mock import Mock
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import PacketMalformedError, UnexpectedPacketError, \ from neo.lib.protocol import PacketMalformedError, UnexpectedPacketError, \
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import shutil import shutil
import unittest import unittest
from mock import Mock from .mock import Mock
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.node import Node, MasterDB from neo.lib.node import Node, MasterDB
from . import NeoUnitTestBase, getTempDirectory from . import NeoUnitTestBase, getTempDirectory
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2006-2016 Nexedi SA # Copyright (C) 2006-2017 Nexedi SA
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2011-2016 Nexedi SA # Copyright (C) 2011-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -22,10 +22,9 @@ from collections import deque ...@@ -22,10 +22,9 @@ from collections import deque
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from contextlib import contextmanager from contextlib import contextmanager
from itertools import count from itertools import count
from functools import wraps from functools import partial, wraps
from thread import get_ident
from zlib import decompress from zlib import decompress
from mock import Mock from ..mock import Mock
import transaction, ZODB import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app import neo.client.app, neo.neoctl.app
...@@ -36,13 +35,14 @@ from neo.lib.connection import BaseConnection, \ ...@@ -36,13 +35,14 @@ from neo.lib.connection import BaseConnection, \
from neo.lib.connector import SocketConnector, ConnectorException from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
TIC_LOOP = xrange(1000)
# T1 T2 # T1 T2
...@@ -273,7 +273,7 @@ class TestSerialized(Serialized): # NOTE used only in .NeoCTL ...@@ -273,7 +273,7 @@ class TestSerialized(Serialized): # NOTE used only in .NeoCTL
def poll(self, timeout): def poll(self, timeout):
if timeout: if timeout:
for x in xrange(1000): for x in TIC_LOOP:
r = self._epoll.poll(0) r = self._epoll.poll(0)
if r: if r:
return r return r
...@@ -339,6 +339,7 @@ class ServerNode(Node): ...@@ -339,6 +339,7 @@ class ServerNode(Node):
master_nodes = kw.get('master_nodes', cluster.master_nodes) master_nodes = kw.get('master_nodes', cluster.master_nodes)
name = kw.get('name', cluster.name) name = kw.get('name', cluster.name)
port = address[1] port = address[1]
if address is not BIND:
self._node_list[port] = weakref.proxy(self) self._node_list[port] = weakref.proxy(self)
self._init_args = init_args = kw.copy() self._init_args = init_args = kw.copy()
init_args['cluster'] = cluster init_args['cluster'] = cluster
...@@ -386,7 +387,7 @@ class ServerNode(Node): ...@@ -386,7 +387,7 @@ class ServerNode(Node):
raise ConnectorException raise ConnectorException
def stop(self): def stop(self):
self.em.wakeup(True) self.em.wakeup(thread.exit)
class AdminApplication(ServerNode, neo.admin.app.Application): class AdminApplication(ServerNode, neo.admin.app.Application):
pass pass
...@@ -484,7 +485,7 @@ class ConnectionFilter(object): ...@@ -484,7 +485,7 @@ class ConnectionFilter(object):
filtered_count = 0 filtered_count = 0
filter_list = [] filter_list = []
filter_queue = weakref.WeakKeyDictionary() filter_queue = weakref.WeakKeyDictionary() # XXX: see the end of __new__
lock = threading.RLock() lock = threading.RLock()
_addPacket = Connection._addPacket _addPacket = Connection._addPacket
...@@ -500,13 +501,16 @@ class ConnectionFilter(object): ...@@ -500,13 +501,16 @@ class ConnectionFilter(object):
queue = cls.filter_queue[conn] queue = cls.filter_queue[conn]
except KeyError: except KeyError:
for self in cls.filter_list: for self in cls.filter_list:
if self(conn, packet): if self._test(conn, packet):
self.filtered_count += 1 self.filtered_count += 1
break break
else: else:
return cls._addPacket(conn, packet) return cls._addPacket(conn, packet)
cls.filter_queue[conn] = queue = deque() cls.filter_queue[conn] = queue = deque()
p = packet.__new__(packet.__class__) p = packet.__class__
logging.debug("queued %s#0x%04x for %s",
p.__name__, packet.getId(), conn)
p = packet.__new__(p)
p.__dict__.update(packet.__dict__) p.__dict__.update(packet.__dict__)
queue.append(p) queue.append(p)
Connection._addPacket = _addPacket Connection._addPacket = _addPacket
...@@ -517,10 +521,13 @@ class ConnectionFilter(object): ...@@ -517,10 +521,13 @@ class ConnectionFilter(object):
del cls.filter_list[-1:] del cls.filter_list[-1:]
if not cls.filter_list: if not cls.filter_list:
Connection._addPacket = cls._addPacket.im_func Connection._addPacket = cls._addPacket.im_func
# Retry even in case of exception, at least to avoid leaks in
# filter_queue. Sometimes, WeakKeyDictionary only does the job
# only an explicit call to gc.collect.
with cls.lock: with cls.lock:
cls._retry() cls._retry()
def __call__(self, conn, packet): def _test(self, conn, packet):
if not self.conn_list or conn in self.conn_list: if not self.conn_list or conn in self.conn_list:
for filter in self.filter_dict: for filter in self.filter_dict:
if filter(conn, packet): if filter(conn, packet):
...@@ -533,13 +540,16 @@ class ConnectionFilter(object): ...@@ -533,13 +540,16 @@ class ConnectionFilter(object):
while queue: while queue:
packet = queue.popleft() packet = queue.popleft()
for self in cls.filter_list: for self in cls.filter_list:
if self(conn, packet): if self._test(conn, packet):
queue.appendleft(packet) queue.appendleft(packet)
break break
else: else:
if conn.isClosed(): if conn.isClosed():
return return
cls._addPacket(conn, packet) # Use the thread that created the packet to reinject it,
# to avoid a race condition on Connector.queued.
conn.em.wakeup(lambda conn=conn, packet=packet:
conn.isClosed() or cls._addPacket(conn, packet))
continue continue
break break
else: else:
...@@ -566,6 +576,22 @@ class ConnectionFilter(object): ...@@ -566,6 +576,22 @@ class ConnectionFilter(object):
def __contains__(self, filter): def __contains__(self, filter):
return filter in self.filter_dict return filter in self.filter_dict
def byPacket(self, packet_type, *args):
patches = []
other = []
for x in args:
(patches if isinstance(x, Patch) else other).append(x)
def delay(conn, packet):
return isinstance(packet, packet_type) and False not in (
callback(conn) for callback in other)
self.add(delay, *patches)
return delay
def __getattr__(self, attr):
if attr.startswith('delay'):
return partial(self.byPacket, getattr(Packets, attr[5:]))
return self.__getattribute__(attr)
class NEOCluster(object): class NEOCluster(object):
SSL = None SSL = None
...@@ -576,9 +602,11 @@ class NEOCluster(object): ...@@ -576,9 +602,11 @@ class NEOCluster(object):
def _lock(blocking=True): def _lock(blocking=True):
if blocking: if blocking:
logging.info('<SimpleQueue>._lock.acquire()') logging.info('<SimpleQueue>._lock.acquire()')
while not lock(False): for i in TIC_LOOP:
Serialized.tic(step=1, quiet=True) if lock(False):
return True return True
Serialized.tic(step=1, quiet=True)
raise Exception("tic is looping forever")
return lock(False) return lock(False)
self._lock = _lock self._lock = _lock
_patches = ( _patches = (
...@@ -619,6 +647,8 @@ class NEOCluster(object): ...@@ -619,6 +647,8 @@ class NEOCluster(object):
patch.revert() patch.revert()
Serialized.stop() Serialized.stop()
started = False
def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None, def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
storage_count=None, db_list=None, clear_databases=True, storage_count=None, db_list=None, clear_databases=True,
...@@ -627,6 +657,7 @@ class NEOCluster(object): ...@@ -627,6 +657,7 @@ class NEOCluster(object):
self.name = 'neo_%s' % self._allocate('name', self.name = 'neo_%s' % self._allocate('name',
lambda: random.randint(0, 100)) lambda: random.randint(0, 100))
self.compress = compress self.compress = compress
self.num_partitions = partitions
master_list = [MasterApplication.newAddress() master_list = [MasterApplication.newAddress()
for _ in xrange(master_count)] for _ in xrange(master_count)]
self.master_nodes = ' '.join('%s:%s' % x for x in master_list) self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
...@@ -670,7 +701,6 @@ class NEOCluster(object): ...@@ -670,7 +701,6 @@ class NEOCluster(object):
self.storage_list = [StorageApplication(getDatabase=db % x, **kw) self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
for x in db_list] for x in db_list]
self.admin_list = [AdminApplication(**kw)] self.admin_list = [AdminApplication(**kw)]
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
def __repr__(self): def __repr__(self):
return "<%s(%s) at 0x%x>" % (self.__class__.__name__, return "<%s(%s) at 0x%x>" % (self.__class__.__name__,
...@@ -720,18 +750,16 @@ class NEOCluster(object): ...@@ -720,18 +750,16 @@ class NEOCluster(object):
return master return master
### ###
def reset(self, clear_database=False): def __enter__(self):
for node_type in 'master', 'storage', 'admin': return self
kw = {}
if node_type == 'storage': def __exit__(self, t, v, tb):
kw['clear_database'] = clear_database self.stop(None)
for node in getattr(self, node_type + '_list'):
node.resetNode(**kw)
self.neoctl.close()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
def start(self, storage_list=None, fast_startup=False): def start(self, storage_list=None, fast_startup=False):
self.started = True
self._patch() self._patch()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
for node_type in 'master', 'admin': for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'): for node in getattr(self, node_type + '_list'):
node.start() node.start()
...@@ -750,13 +778,63 @@ class NEOCluster(object): ...@@ -750,13 +778,63 @@ class NEOCluster(object):
assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state
self.enableStorageList(storage_list) self.enableStorageList(storage_list)
def newClient(self): def stop(self, clear_database=False, __print_exc=traceback.print_exc):
if self.started:
del self.started
logging.debug("stopping %s", self)
client = self.__dict__.get("client")
client is None or self.__dict__.pop("db", client).close()
node_list = self.admin_list + self.storage_list + self.master_list
for node in node_list:
node.stop()
try:
node_list.append(client.poll_thread)
except AttributeError: # client is None or thread is already stopped
pass
self.join(node_list)
self.neoctl.close()
del self.neoctl
logging.debug("stopped %s", self)
self._unpatch()
if clear_database is None:
try:
for node_type in 'admin', 'storage', 'master':
for node in getattr(self, node_type + '_list'):
node.close()
except:
__print_exc()
raise
else:
for node_type in 'master', 'storage', 'admin':
kw = {}
if node_type == 'storage':
kw['clear_database'] = clear_database
for node in getattr(self, node_type + '_list'):
node.resetNode(**kw)
def _newClient(self):
return ClientApplication(name=self.name, master_nodes=self.master_nodes, return ClientApplication(name=self.name, master_nodes=self.master_nodes,
compress=self.compress, ssl=self.SSL) compress=self.compress, ssl=self.SSL)
@contextmanager
def newClient(self, with_db=False):
x = self._newClient()
try:
t = x.poll_thread
closed = []
if with_db:
x = ZODB.DB(storage=self.getZODBStorage(client=x))
else:
# XXX: Do nothing if finally if the caller already closed it.
x.close = lambda: closed.append(x.__class__.close(x))
yield x
finally:
closed or x.close()
self.join((t,))
@cached_property @cached_property
def client(self): def client(self):
client = self.newClient() client = self._newClient()
# Make sure client won't be reused after it was closed. # Make sure client won't be reused after it was closed.
def close(): def close():
client = self.client client = self.client
...@@ -794,21 +872,6 @@ class NEOCluster(object): ...@@ -794,21 +872,6 @@ class NEOCluster(object):
Serialized.tic() Serialized.tic()
thread_list = [t for t in thread_list if t.is_alive()] thread_list = [t for t in thread_list if t.is_alive()]
def stop(self):
logging.debug("stopping %s", self)
client = self.__dict__.get("client")
client is None or self.__dict__.pop("db", client).close()
node_list = self.admin_list + self.storage_list + self.master_list
for node in node_list:
node.stop()
try:
node_list.append(client.poll_thread)
except AttributeError: # client is None or thread is already stopped
pass
self.join(node_list)
logging.debug("stopped %s", self)
self._unpatch()
def getNodeState(self, node): def getNodeState(self, node):
uuid = node.uuid uuid = node.uuid
for node in self.neoctl.getNodeList(node.node_type): for node in self.neoctl.getNodeList(node.node_type):
...@@ -850,19 +913,9 @@ class NEOCluster(object): ...@@ -850,19 +913,9 @@ class NEOCluster(object):
for o in oid_list: for o in oid_list:
tid_dict[o] = i tid_dict[o] = i
def getTransaction(self): def getTransaction(self, db=None):
txn = transaction.TransactionManager() txn = transaction.TransactionManager()
return txn, self.db.open(transaction_manager=txn) return txn, (self.db if db is None else db).open(txn)
def __del__(self, __print_exc=traceback.print_exc):
try:
self.neoctl.close()
for node_type in 'admin', 'storage', 'master':
for node in getattr(self, node_type + '_list'):
node.close()
except:
__print_exc()
raise
def extraCellSortKey(self, key): # XXX unused? def extraCellSortKey(self, key): # XXX unused?
return Patch(self.client.cp, getCellSortKey=lambda orig, cell: return Patch(self.client.cp, getCellSortKey=lambda orig, cell:
...@@ -897,9 +950,15 @@ class NEOCluster(object): ...@@ -897,9 +950,15 @@ class NEOCluster(object):
class NEOThreadedTest(NeoTestBase): class NEOThreadedTest(NeoTestBase):
__run_count = {}
def setupLog(self): def setupLog(self):
log_file = os.path.join(getTempDirectory(), self.id() + '.log') test_id = self.id()
logging.setup(log_file) i = self.__run_count.get(test_id, 0)
self.__run_count[test_id] = 1 + i
if i:
test_id += '-%s' % i
logging.setup(os.path.join(getTempDirectory(), test_id + '.log'))
return LoggerThreadName() return LoggerThreadName()
def _tearDown(self, success): def _tearDown(self, success):
...@@ -912,20 +971,17 @@ class NEOThreadedTest(NeoTestBase): ...@@ -912,20 +971,17 @@ class NEOThreadedTest(NeoTestBase):
tic = Serialized.tic tic = Serialized.tic
@contextmanager
def getLoopbackConnection(self): def getLoopbackConnection(self):
app = MasterApplication(getSSL=NEOCluster.SSL, app = MasterApplication(address=BIND,
getReplicas=0, getPartitions=1) getSSL=NEOCluster.SSL, getReplicas=0, getPartitions=1)
try:
handler = EventHandler(app) handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server) app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(), yield ClientConnection(app, handler, app.nm.createMaster(
uuid=app.uuid) address=app.listening_conn.getAddress(), uuid=app.uuid))
conn = ClientConnection.__new__(ClientConnection) finally:
def reset(): app.close()
conn.__dict__.clear()
conn.__init__(app, handler, node)
conn.reset = reset
reset()
return conn
def getUnpickler(self, conn): # XXX not used? def getUnpickler(self, conn): # XXX not used?
reader = conn._reader reader = conn._reader
...@@ -963,6 +1019,9 @@ class NEOThreadedTest(NeoTestBase): ...@@ -963,6 +1019,9 @@ class NEOThreadedTest(NeoTestBase):
with Patch(client, _getFinalTID=lambda *_: None): with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit) self.assertRaises(ConnectionClosed, txn.commit)
def assertPartitionTable(self, cluster, stats):
self.assertEqual(stats, '|'.join(cluster.admin.pt.formatRows()))
def predictable_random(seed=None): def predictable_random(seed=None):
# Because we have 2 running threads when client works, we can't # Because we have 2 running threads when client works, we can't
...@@ -984,3 +1043,13 @@ def predictable_random(seed=None): ...@@ -984,3 +1043,13 @@ def predictable_random(seed=None):
= random = random
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
return decorator return decorator
def with_cluster(start_cluster=True, **cluster_kw):
def decorator(wrapped):
def wrapper(self, *args, **kw):
with NEOCluster(**cluster_kw) as cluster:
if start_cluster:
cluster.start()
return wrapped(self, cluster, *args, **kw)
return wraps(wrapped)(wrapper)
return decorator
# #
# Copyright (C) 2011-2016 Nexedi SA # Copyright (C) 2011-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import os import os
import sys import sys
import threading import threading
import time
import transaction import transaction
import unittest import unittest
from thread import get_ident from thread import get_ident
...@@ -32,7 +33,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation ...@@ -32,7 +33,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch from .. import expectedFailure, Patch
from . import LockLock, NEOCluster, NEOThreadedTest from . import LockLock, NEOThreadedTest, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
...@@ -53,10 +54,9 @@ class PCounterWithResolution(PCounter): ...@@ -53,10 +54,9 @@ class PCounterWithResolution(PCounter):
class Test(NEOThreadedTest): class Test(NEOThreadedTest):
def testBasicStore(self): @with_cluster()
cluster = NEOCluster() def testBasicStore(self, cluster):
try: if 1:
cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
data_info = {} data_info = {}
compressible = 'x' * 20 compressible = 'x' * 20
...@@ -106,13 +106,10 @@ class Test(NEOThreadedTest): ...@@ -106,13 +106,10 @@ class Test(NEOThreadedTest):
if big: if big:
self.assertFalse(cluster.storage.sqlCount('bigdata')) self.assertFalse(cluster.storage.sqlCount('bigdata'))
self.assertFalse(cluster.storage.sqlCount('data')) self.assertFalse(cluster.storage.sqlCount('data'))
finally:
cluster.stop()
def testDeleteObject(self): @with_cluster()
cluster = NEOCluster() def testDeleteObject(self, cluster):
try: if 1:
cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
for clear_cache in 0, 1: for clear_cache in 0, 1:
for tst in 'a.', 'bcd.': for tst in 'a.', 'bcd.':
...@@ -131,13 +128,10 @@ class Test(NEOThreadedTest): ...@@ -131,13 +128,10 @@ class Test(NEOThreadedTest):
storage._cache.clear() storage._cache.clear()
self.assertRaises(POSException.POSKeyError, self.assertRaises(POSException.POSKeyError,
storage.load, oid, '') storage.load, oid, '')
finally:
cluster.stop()
def testCreationUndoneHistory(self): @with_cluster()
cluster = NEOCluster() def testCreationUndoneHistory(self, cluster):
try: if 1:
cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
oid = storage.new_oid() oid = storage.new_oid()
txn = transaction.Transaction() txn = transaction.Transaction()
...@@ -155,18 +149,15 @@ class Test(NEOThreadedTest): ...@@ -155,18 +149,15 @@ class Test(NEOThreadedTest):
for x in storage.history(oid, 10): for x in storage.history(oid, 10):
self.assertEqual((x['tid'], x['size']), expected.pop()) self.assertEqual((x['tid'], x['size']), expected.pop())
self.assertFalse(expected) self.assertFalse(expected)
finally:
cluster.stop()
def testUndoConflict(self, conflict_during_store=False): @with_cluster()
def testUndoConflict(self, cluster, conflict_during_store=False):
def waitResponses(orig, *args): def waitResponses(orig, *args):
orig(*args) orig(*args)
p.revert() p.revert()
ob.value += 3 ob.value += 3
t.commit() t.commit()
cluster = NEOCluster() if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = ob = PCounterWithResolution() c.root()[0] = ob = PCounterWithResolution()
t.commit() t.commit()
...@@ -186,17 +177,14 @@ class Test(NEOThreadedTest): ...@@ -186,17 +177,14 @@ class Test(NEOThreadedTest):
undo.tpc_finish(txn) undo.tpc_finish(txn)
t.begin() t.begin()
self.assertEqual(ob.value, 3) self.assertEqual(ob.value, 3)
finally:
cluster.stop()
@expectedFailure(POSException.ConflictError) # TODO recheck @expectedFailure(POSException.ConflictError) # TODO recheck
def testUndoConflictDuringStore(self): def testUndoConflictDuringStore(self):
self.testUndoConflict(True) self.testUndoConflict(True)
def testStorageDataLock(self): @with_cluster()
cluster = NEOCluster() def testStorageDataLock(self, cluster):
try: if 1:
cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
data_info = {} data_info = {}
...@@ -211,15 +199,21 @@ class Test(NEOThreadedTest): ...@@ -211,15 +199,21 @@ class Test(NEOThreadedTest):
data_info[key] = 0 data_info[key] = 0
storage.sync() storage.sync()
txn = [transaction.Transaction() for x in xrange(3)] txn = [transaction.Transaction() for x in xrange(4)]
for t in txn: for t in txn:
storage.tpc_begin(t) storage.tpc_begin(t)
storage.store(tid and oid or storage.new_oid(), storage.store(oid if tid else storage.new_oid(),
tid, data, '', t) tid, data, '', t)
tid = None tid = None
data_info[key] = 4
storage.sync()
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn.pop())
for t in txn: for t in txn:
storage.tpc_vote(t) storage.tpc_vote(t)
data_info[key] = 3 storage.sync()
data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[1]) storage.tpc_abort(txn[1])
...@@ -236,13 +230,10 @@ class Test(NEOThreadedTest): ...@@ -236,13 +230,10 @@ class Test(NEOThreadedTest):
storage.sync() storage.sync()
data_info[key] -= 1 data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally:
cluster.stop()
def testDelayedUnlockInformation(self): @with_cluster(storage_count=1)
def testDelayedUnlockInformation(self, cluster):
except_list = [] except_list = []
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreObject(orig, tm, ttid, serial, oid, *args): def onStoreObject(orig, tm, ttid, serial, oid, *args):
if oid == resume_oid and delayUnlockInformation in m2s: if oid == resume_oid and delayUnlockInformation in m2s:
m2s.remove(delayUnlockInformation) m2s.remove(delayUnlockInformation)
...@@ -251,25 +242,22 @@ class Test(NEOThreadedTest): ...@@ -251,25 +242,22 @@ class Test(NEOThreadedTest):
except Exception, e: except Exception, e:
except_list.append(e.__class__) except_list.append(e.__class__)
raise raise
cluster = NEOCluster(storage_count=1) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = ob = PCounter() c.root()[0] = ob = PCounter()
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.master.filterConnection(cluster.storage) as m2s:
resume_oid = None resume_oid = None
m2s.add(delayUnlockInformation, delayUnlockInformation = m2s.delayNotifyUnlockInformation(
Patch(TransactionManager, storeObject=onStoreObject)) Patch(TransactionManager, storeObject=onStoreObject))
t.commit() t.commit()
resume_oid = ob._p_oid resume_oid = ob._p_oid
ob._p_changed = 1 ob._p_changed = 1
t.commit() t.commit()
self.assertFalse(delayUnlockInformation in m2s) self.assertNotIn(delayUnlockInformation, m2s)
finally:
cluster.stop()
self.assertEqual(except_list, [DelayedError]) self.assertEqual(except_list, [DelayedError])
def _testDeadlockAvoidance(self, scenario): @with_cluster(storage_count=2, replicas=1)
def _testDeadlockAvoidance(self, cluster, scenario):
except_list = [] except_list = []
delay = threading.Event(), threading.Event() delay = threading.Event(), threading.Event()
ident = get_ident() ident = get_ident()
...@@ -301,9 +289,7 @@ class Test(NEOThreadedTest): ...@@ -301,9 +289,7 @@ class Test(NEOThreadedTest):
delay[c2].clear() delay[c2].clear()
delay[1-c2].set() delay[1-c2].set()
cluster = NEOCluster(storage_count=2, replicas=1) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = ob = PCounterWithResolution() c.root()[0] = ob = PCounterWithResolution()
t.commit() t.commit()
...@@ -326,8 +312,6 @@ class Test(NEOThreadedTest): ...@@ -326,8 +312,6 @@ class Test(NEOThreadedTest):
t2.begin() t2.begin()
self.assertEqual(o1.value, 3) self.assertEqual(o1.value, 3)
self.assertEqual(o2.value, 3) self.assertEqual(o2.value, 3)
finally:
cluster.stop()
return except_list return except_list
def testDelayedStore(self): def testDelayedStore(self):
...@@ -349,11 +333,10 @@ class Test(NEOThreadedTest): ...@@ -349,11 +333,10 @@ class Test(NEOThreadedTest):
self.assertEqual(self._testDeadlockAvoidance([1, 3]), self.assertEqual(self._testDeadlockAvoidance([1, 3]),
[DelayedError, ConflictError, "???" ]) [DelayedError, ConflictError, "???" ])
def testConflictResolutionTriggered2(self): @with_cluster()
def testConflictResolutionTriggered2(self, cluster):
""" Check that conflict resolution works """ """ Check that conflict resolution works """
cluster = NEOCluster() if 1:
try:
cluster.start()
# create the initial object # create the initial object
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()['with_resolution'] = ob = PCounterWithResolution() c.root()['with_resolution'] = ob = PCounterWithResolution()
...@@ -384,7 +367,7 @@ class Test(NEOThreadedTest): ...@@ -384,7 +367,7 @@ class Test(NEOThreadedTest):
tid1 = o1._p_serial tid1 = o1._p_serial
resolved = [] resolved = []
last = (t2.get(), t3.get()).index last = lambda txn: txn._extension['last'] # BBB
def _handleConflicts(orig, txn_context, *args): def _handleConflicts(orig, txn_context, *args):
resolved.append(last(txn_context['txn'])) resolved.append(last(txn_context['txn']))
return orig(txn_context, *args) return orig(txn_context, *args)
...@@ -395,7 +378,8 @@ class Test(NEOThreadedTest): ...@@ -395,7 +378,8 @@ class Test(NEOThreadedTest):
with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote): with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote):
with LockLock() as l2: with LockLock() as l2:
tt = [] tt = []
for t, l in (t2, l2), (t3, l3): for i, t, l in (0, t2, l2), (1, t3, l3):
t.get().setExtendedInfo('last', i)
tt.append(self.newThread(t.commit)) tt.append(self.newThread(t.commit))
l() l()
tt.pop(0).join() tt.pop(0).join()
...@@ -421,15 +405,45 @@ class Test(NEOThreadedTest): ...@@ -421,15 +405,45 @@ class Test(NEOThreadedTest):
# check history # check history
self.assertEqual([x['tid'] for x in c1.db().history(oid, size=10)], self.assertEqual([x['tid'] for x in c1.db().history(oid, size=10)],
[tid3, tid2, tid1, tid0]) [tid3, tid2, tid1, tid0])
finally:
cluster.stop()
def test_notifyNodeInformation(self): @with_cluster()
def testDelayedLoad(self, cluster):
"""
Check that a storage node delays reads from the database,
when the requested data may still be in a temporary place.
"""
l = threading.Lock()
l.acquire()
idle = []
def askObject(orig, *args):
orig(*args)
idle.append(cluster.storage.em.isIdle())
l.release()
if 1:
t, c = cluster.getTransaction()
r = c.root()
r[''] = ''
with Patch(ClientOperationHandler, askObject=askObject):
with cluster.master.filterConnection(cluster.storage) as m2s:
m2s.delayNotifyUnlockInformation()
t.commit()
c.cacheMinimize()
cluster.client._cache.clear()
load = self.newThread(r._p_activate)
l.acquire()
l.acquire()
# The request from the client is processed again
# (upon reception on unlock notification from the master),
# once exactly, and now with success.
load.join()
self.assertEqual(idle, [1, 0])
self.assertIn('', r)
@with_cluster(replicas=1)
def test_notifyNodeInformation(self, cluster):
# translated from MasterNotificationsHandlerTests # translated from MasterNotificationsHandlerTests
# (neo.tests.client.testMasterHandler) # (neo.tests.client.testMasterHandler)
cluster = NEOCluster(replicas=1) if 1:
try:
cluster.start()
cluster.db # open DB cluster.db # open DB
s0, s1 = cluster.client.nm.getStorageList() s0, s1 = cluster.client.nm.getStorageList()
conn = s0.getConnection() conn = s0.getConnection()
...@@ -444,27 +458,21 @@ class Test(NEOThreadedTest): ...@@ -444,27 +458,21 @@ class Test(NEOThreadedTest):
# was called (even if it's useless in this case), # was called (even if it's useless in this case),
# but we would need an API to do that easily. # but we would need an API to do that easily.
self.assertFalse(cluster.client.dispatcher.registered(conn)) self.assertFalse(cluster.client.dispatcher.registered(conn))
finally:
cluster.stop()
def testRestartWithMissingStorage(self): @with_cluster(replicas=1, partitions=10)
def testRestartWithMissingStorage(self, cluster):
# translated from neo.tests.functional.testStorage.StorageTest # translated from neo.tests.functional.testStorage.StorageTest
cluster = NEOCluster(replicas=1, partitions=10)
s1, s2 = cluster.storage_list s1, s2 = cluster.storage_list
try: if 1:
cluster.start()
self.assertEqual([], cluster.getOutdatedCells()) self.assertEqual([], cluster.getOutdatedCells())
finally:
cluster.stop() cluster.stop()
# restart it with one storage only # restart it with one storage only
cluster.reset() if 1:
try:
cluster.start(storage_list=(s1,)) cluster.start(storage_list=(s1,))
self.assertEqual(NodeStates.UNKNOWN, cluster.getNodeState(s2)) self.assertEqual(NodeStates.UNKNOWN, cluster.getNodeState(s2))
finally:
cluster.stop()
def testRestartStoragesWithReplicas(self): @with_cluster(storage_count=2, partitions=2, replicas=1)
def testRestartStoragesWithReplicas(self, cluster):
""" """
Check that the master must discard its partition table when the Check that the master must discard its partition table when the
cluster is not operational anymore. Which means that it must go back cluster is not operational anymore. Which means that it must go back
...@@ -480,8 +488,7 @@ class Test(NEOThreadedTest): ...@@ -480,8 +488,7 @@ class Test(NEOThreadedTest):
orig() orig()
def stop(): def stop():
with cluster.master.filterConnection(s0) as m2s0: with cluster.master.filterConnection(s0) as m2s0:
m2s0.add(lambda conn, packet: m2s0.delayNotifyPartitionChanges()
isinstance(packet, Packets.NotifyPartitionChanges))
s1.stop() s1.stop()
cluster.join((s1,)) cluster.join((s1,))
self.assertEqual(getClusterState(), ClusterStates.RUNNING) self.assertEqual(getClusterState(), ClusterStates.RUNNING)
...@@ -492,9 +499,7 @@ class Test(NEOThreadedTest): ...@@ -492,9 +499,7 @@ class Test(NEOThreadedTest):
self.assertNotEqual(getClusterState(), ClusterStates.RUNNING) self.assertNotEqual(getClusterState(), ClusterStates.RUNNING)
s0.resetNode() s0.resetNode()
s1.resetNode() s1.resetNode()
cluster = NEOCluster(storage_count=2, partitions=2, replicas=1) if 1:
try:
cluster.start()
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
getClusterState = cluster.neoctl.getClusterState getClusterState = cluster.neoctl.getClusterState
if 1: if 1:
...@@ -517,13 +522,10 @@ class Test(NEOThreadedTest): ...@@ -517,13 +522,10 @@ class Test(NEOThreadedTest):
self.assertEqual(getClusterState(), ClusterStates.RUNNING) self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(cluster.getOutdatedCells(), self.assertEqual(cluster.getOutdatedCells(),
[(0, s0.uuid), (1, s0.uuid)]) [(0, s0.uuid), (1, s0.uuid)])
finally:
cluster.stop()
def testVerificationCommitUnfinishedTransactions(self): @with_cluster(partitions=2, storage_count=2)
def testVerificationCommitUnfinishedTransactions(self, cluster):
""" Verification step should commit locked transactions """ """ Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onLockTransaction(storage, die=False): def onLockTransaction(storage, die=False):
def lock(orig, *args, **kw): def lock(orig, *args, **kw):
if die: if die:
...@@ -531,9 +533,7 @@ class Test(NEOThreadedTest): ...@@ -531,9 +533,7 @@ class Test(NEOThreadedTest):
orig(*args, **kw) orig(*args, **kw)
storage.master_conn.close() storage.master_conn.close()
return Patch(storage.tm, lock=lock) return Patch(storage.tm, lock=lock)
cluster = NEOCluster(partitions=2, storage_count=2) if 1:
try:
cluster.start()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
...@@ -564,7 +564,7 @@ class Test(NEOThreadedTest): ...@@ -564,7 +564,7 @@ class Test(NEOThreadedTest):
self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3)) self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok' r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s: with cluster.master.filterConnection(s0) as m2s:
m2s.add(delayUnlockInformation) m2s.delayNotifyUnlockInformation()
t.commit() t.commit()
x.value = 1 x.value = 1
# s0 will accept to store y (because it's not locked) but will # s0 will accept to store y (because it's not locked) but will
...@@ -574,9 +574,7 @@ class Test(NEOThreadedTest): ...@@ -574,9 +574,7 @@ class Test(NEOThreadedTest):
di0 = s0.getDataLockInfo() di0 = s0.getDataLockInfo()
with onLockTransaction(s1, die=True): with onLockTransaction(s1, die=True):
self.commitWithStorageFailure(cluster.client, t) self.commitWithStorageFailure(cluster.client, t)
finally:
cluster.stop() cluster.stop()
cluster.reset()
(k, v), = set(s0.getDataLockInfo().iteritems() (k, v), = set(s0.getDataLockInfo().iteritems()
).difference(di0.iteritems()) ).difference(di0.iteritems())
self.assertEqual(v, 1) self.assertEqual(v, 1)
...@@ -587,7 +585,7 @@ class Test(NEOThreadedTest): ...@@ -587,7 +585,7 @@ class Test(NEOThreadedTest):
k, = (k for k, v in di1.iteritems() if v == 1) k, = (k for k, v in di1.iteritems() if v == 1)
del di1[k] # x.value = 1 del di1[k] # x.value = 1
self.assertEqual(di1.values(), [0]) self.assertEqual(di1.values(), [0])
try: if 1:
cluster.start() cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
...@@ -596,19 +594,16 @@ class Test(NEOThreadedTest): ...@@ -596,19 +594,16 @@ class Test(NEOThreadedTest):
self.assertEqual(r[2], 'ok') self.assertEqual(r[2], 'ok')
self.assertEqual(di0, s0.getDataLockInfo()) self.assertEqual(di0, s0.getDataLockInfo())
self.assertEqual(di1, s1.getDataLockInfo()) self.assertEqual(di1, s1.getDataLockInfo())
finally:
cluster.stop()
def testVerificationWithNodesWithoutReadableCells(self): @with_cluster(replicas=1)
def testVerificationWithNodesWithoutReadableCells(self, cluster):
def onLockTransaction(storage, die_after): def onLockTransaction(storage, die_after):
def lock(orig, *args, **kw): def lock(orig, *args, **kw):
if die_after: if die_after:
orig(*args, **kw) orig(*args, **kw)
sys.exit() sys.exit()
return Patch(storage.tm, lock=lock) return Patch(storage.tm, lock=lock)
cluster = NEOCluster(replicas=1) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = None c.root()[0] = None
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
...@@ -634,10 +629,9 @@ class Test(NEOThreadedTest): ...@@ -634,10 +629,9 @@ class Test(NEOThreadedTest):
self.assertEqual(sorted(c.root()), [1]) self.assertEqual(sorted(c.root()), [1])
self.tic() self.tic()
t0, t1 = c.db().storage.iterator() t0, t1 = c.db().storage.iterator()
finally:
cluster.stop()
def testDropUnfinishedData(self): @with_cluster(partitions=2, storage_count=2, replicas=1)
def testDropUnfinishedData(self, cluster):
def lock(orig, *args, **kw): def lock(orig, *args, **kw):
orig(*args, **kw) orig(*args, **kw)
storage.master_conn.close() storage.master_conn.close()
...@@ -646,9 +640,7 @@ class Test(NEOThreadedTest): ...@@ -646,9 +640,7 @@ class Test(NEOThreadedTest):
r.append(len(orig.__self__.getUnfinishedTIDDict())) r.append(len(orig.__self__.getUnfinishedTIDDict()))
orig() orig()
r.append(len(orig.__self__.getUnfinishedTIDDict())) r.append(len(orig.__self__.getUnfinishedTIDDict()))
cluster = NEOCluster(partitions=2, storage_count=2, replicas=1) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()._p_changed = 1 c.root()._p_changed = 1
storage = cluster.storage_list[0] storage = cluster.storage_list[0]
...@@ -657,13 +649,10 @@ class Test(NEOThreadedTest): ...@@ -657,13 +649,10 @@ class Test(NEOThreadedTest):
t.commit() t.commit()
self.tic() self.tic()
self.assertEqual(r, [1, 0]) self.assertEqual(r, [1, 0])
finally:
cluster.stop()
def testStorageUpgrade1(self): @with_cluster()
cluster = NEOCluster() def testStorageUpgrade1(self, cluster):
try: if 1:
cluster.start()
storage = cluster.storage storage = cluster.storage
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None) storage.dm.setConfiguration("version", None)
...@@ -680,42 +669,32 @@ class Test(NEOThreadedTest): ...@@ -680,42 +669,32 @@ class Test(NEOThreadedTest):
with Patch(storage.tm, lock=lambda *_: sys.exit()): with Patch(storage.tm, lock=lambda *_: sys.exit()):
self.commitWithStorageFailure(cluster.client, t) self.commitWithStorageFailure(cluster.client, t)
self.assertRaises(DatabaseFailure, storage.resetNode) self.assertRaises(DatabaseFailure, storage.resetNode)
finally:
cluster.stop()
def testStorageReconnectDuringStore(self): @with_cluster(replicas=1)
cluster = NEOCluster(replicas=1) def testStorageReconnectDuringStore(self, cluster):
try: if 1:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = 'ok' c.root()[0] = 'ok'
cluster.client.cp.closeAll() cluster.client.cp.closeAll()
t.commit() # store request t.commit() # store request
finally:
cluster.stop()
def testStorageReconnectDuringTransactionLog(self): @with_cluster(storage_count=2, partitions=2)
cluster = NEOCluster(storage_count=2, partitions=2) def testStorageReconnectDuringTransactionLog(self, cluster):
try: if 1:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
cluster.client.cp.closeAll() cluster.client.cp.closeAll()
tid, (t1,) = cluster.client.transactionLog( tid, (t1,) = cluster.client.transactionLog(
ZERO_TID, c.db().lastTransaction(), 10) ZERO_TID, c.db().lastTransaction(), 10)
finally:
cluster.stop()
def testStorageReconnectDuringUndoLog(self): @with_cluster(storage_count=2, partitions=2)
cluster = NEOCluster(storage_count=2, partitions=2) def testStorageReconnectDuringUndoLog(self, cluster):
try: if 1:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
cluster.client.cp.closeAll() cluster.client.cp.closeAll()
t1, = cluster.client.undoLog(0, 10) t1, = cluster.client.undoLog(0, 10)
finally:
cluster.stop()
def testDropNodeThenRestartCluster(self): @with_cluster(storage_count=2, replicas=1)
def testDropNodeThenRestartCluster(self, cluster):
""" Start a cluster with more than one storage, down one, shutdown the """ Start a cluster with more than one storage, down one, shutdown the
cluster then restart it. The partition table recovered must not include cluster then restart it. The partition table recovered must not include
the dropped node """ the dropped node """
...@@ -724,10 +703,8 @@ class Test(NEOThreadedTest): ...@@ -724,10 +703,8 @@ class Test(NEOThreadedTest):
self.assertEqual(cluster.getNodeState(s2), NodeStates.RUNNING) self.assertEqual(cluster.getNodeState(s2), NodeStates.RUNNING)
# start with two storage / one replica # start with two storage / one replica
cluster = NEOCluster(storage_count=2, replicas=1)
s1, s2 = cluster.storage_list s1, s2 = cluster.storage_list
try: if 1:
cluster.start()
checkNodeState(NodeStates.RUNNING) checkNodeState(NodeStates.RUNNING)
self.assertEqual([], cluster.getOutdatedCells()) self.assertEqual([], cluster.getOutdatedCells())
# drop one # drop one
...@@ -737,39 +714,29 @@ class Test(NEOThreadedTest): ...@@ -737,39 +714,29 @@ class Test(NEOThreadedTest):
checkNodeState(None) checkNodeState(None)
self.assertEqual([], cluster.getOutdatedCells()) self.assertEqual([], cluster.getOutdatedCells())
# restart with s2 only # restart with s2 only
finally:
cluster.stop() cluster.stop()
cluster.reset() if 1:
try:
cluster.start(storage_list=[s2]) cluster.start(storage_list=[s2])
checkNodeState(None) checkNodeState(None)
# then restart it, it must be in pending state # then restart it, it must be in pending state
s1.start() s1.start()
self.tic() self.tic()
checkNodeState(NodeStates.PENDING) checkNodeState(NodeStates.PENDING)
finally:
cluster.stop()
def test2Clusters(self): # NOTE @with_cluster()
cluster1 = NEOCluster() @with_cluster()
cluster2 = NEOCluster() def test2Clusters(self, cluster1, cluster2): # NOTE
try: if 1:
cluster1.start()
cluster2.start()
t1, c1 = cluster1.getTransaction() t1, c1 = cluster1.getTransaction()
t2, c2 = cluster2.getTransaction() t2, c2 = cluster2.getTransaction()
c1.root()['1'] = c2.root()['2'] = '' c1.root()['1'] = c2.root()['2'] = ''
t1.commit() t1.commit()
t2.commit() t2.commit()
finally:
cluster1.stop()
cluster2.stop()
def testAbortStorage(self): @with_cluster(partitions=2, storage_count=2)
cluster = NEOCluster(partitions=2, storage_count=2) def testAbortStorage(self, cluster):
storage = cluster.storage_list[0] storage = cluster.storage_list[0]
try: if 1:
cluster.start()
# prevent storage to reconnect, in order to easily test # prevent storage to reconnect, in order to easily test
# that cluster becomes non-operational # that cluster becomes non-operational
with Patch(storage, connectToPrimary=sys.exit): with Patch(storage, connectToPrimary=sys.exit):
...@@ -783,17 +750,13 @@ class Test(NEOThreadedTest): ...@@ -783,17 +750,13 @@ class Test(NEOThreadedTest):
self.tic() self.tic()
self.assertEqual(cluster.neoctl.getClusterState(), self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
finally:
cluster.stop()
def testShutdown(self): @with_cluster(master_count=3, partitions=10, replicas=1, storage_count=3)
# BUG: Due to bugs in election, master nodes sometimes crash, or they # <- NOTE def testShutdown(self, cluster):
# NOTE vvv
# declare themselves primary too quickly. The consequence is # declare themselves primary too quickly. The consequence is
# often an endless tic loop. # often an endless tic loop.
cluster = NEOCluster(master_count=3, partitions=10, if 1:
replicas=1, storage_count=3)
try:
cluster.start()
# fill DB a little # fill DB a little
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[''] = '' c.root()[''] = ''
...@@ -804,9 +767,7 @@ class Test(NEOThreadedTest): ...@@ -804,9 +767,7 @@ class Test(NEOThreadedTest):
cluster.join(cluster.master_list cluster.join(cluster.master_list
+ cluster.storage_list + cluster.storage_list
+ cluster.admin_list) + cluster.admin_list)
finally: cluster.stop() # stop and reopen DB to check partition tables
cluster.stop()
cluster.reset() # reopen DB to check partition tables
dm = cluster.storage_list[0].dm dm = cluster.storage_list[0].dm
self.assertEqual(1, dm.getPTID()) self.assertEqual(1, dm.getPTID())
pt = list(dm.getPartitionTable()) pt = list(dm.getPartitionTable())
...@@ -817,14 +778,13 @@ class Test(NEOThreadedTest): ...@@ -817,14 +778,13 @@ class Test(NEOThreadedTest):
self.assertEqual(s.dm.getPTID(), 1) self.assertEqual(s.dm.getPTID(), 1)
self.assertEqual(list(s.dm.getPartitionTable()), pt) self.assertEqual(list(s.dm.getPartitionTable()), pt)
def testInternalInvalidation(self): @with_cluster()
def testInternalInvalidation(self, cluster):
def _handlePacket(orig, conn, packet, kw={}, handler=None): def _handlePacket(orig, conn, packet, kw={}, handler=None):
if type(packet) is Packets.AnswerTransactionFinished: if type(packet) is Packets.AnswerTransactionFinished:
ll() ll()
orig(conn, packet, kw, handler) orig(conn, packet, kw, handler)
cluster = NEOCluster() if 1:
try:
cluster.start()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter() c1.root()['x'] = x1 = PCounter()
t1.commit() t1.commit()
...@@ -839,14 +799,9 @@ class Test(NEOThreadedTest): ...@@ -839,14 +799,9 @@ class Test(NEOThreadedTest):
t2.begin() t2.begin()
t.join() t.join()
self.assertEqual(x2.value, 1) self.assertEqual(x2.value, 1)
finally:
cluster.stop()
def testExternalInvalidation(self): @with_cluster()
cluster = NEOCluster() def testExternalInvalidation(self, cluster):
try:
cluster.start()
cache = cluster.client._cache
# Initialize objects # Initialize objects
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter() c1.root()['x'] = x1 = PCounter()
...@@ -861,14 +816,14 @@ class Test(NEOThreadedTest): ...@@ -861,14 +816,14 @@ class Test(NEOThreadedTest):
# (at this time, we still have x=0 and y=1) # (at this time, we still have x=0 and y=1)
t2, c2 = cluster.getTransaction() t2, c2 = cluster.getTransaction()
# Copy y to x using a different Master-Client connection # Copy y to x using a different Master-Client connection
client = cluster.newClient() with cluster.newClient() as client:
cache = cluster.client._cache
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x # Delay invalidation for x
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet: m2c.delayInvalidateObjects()
isinstance(packet, Packets.InvalidateObjects))
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
# Change to x is committed. Testing connection must ask the # Change to x is committed. Testing connection must ask the
# storage node to return original value of x, even if we # storage node to return original value of x, even if we
...@@ -940,22 +895,15 @@ class Test(NEOThreadedTest): ...@@ -940,22 +895,15 @@ class Test(NEOThreadedTest):
self.assertFalse(invalidations(c1)) self.assertFalse(invalidations(c1))
self.assertEqual(x1.value, 1) self.assertEqual(x1.value, 1)
finally: @with_cluster(storage_count=2, partitions=2)
cluster.stop() def testReadVerifyingStorage(self, cluster):
if 1:
def testReadVerifyingStorage(self):
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounter() c1.root()['x'] = x = PCounter()
t1.commit() t1.commit()
# We need a second client for external invalidations. # We need a second client for external invalidations.
t2 = transaction.TransactionManager() with cluster.newClient(1) as db:
db = DB(storage=cluster.getZODBStorage(client=cluster.newClient())) t2, c2 = cluster.getTransaction(db)
try:
c2 = db.open(t2)
t2.begin()
r = c2.root() r = c2.root()
r['y'] = None r['y'] = None
r['x']._p_activate() r['x']._p_activate()
...@@ -968,8 +916,6 @@ class Test(NEOThreadedTest): ...@@ -968,8 +916,6 @@ class Test(NEOThreadedTest):
t2.commit() t2.commit()
for storage in cluster.storage_list: for storage in cluster.storage_list:
self.assertFalse(storage.tm._transaction_dict) self.assertFalse(storage.tm._transaction_dict)
finally:
db.close()
# Check we didn't get an invalidation, which would cause an # Check we didn't get an invalidation, which would cause an
# assertion failure in the cache. Connection does the same check in # assertion failure in the cache. Connection does the same check in
# _setstate_noncurrent so this could be also done by starting a # _setstate_noncurrent so this could be also done by starting a
...@@ -981,13 +927,10 @@ class Test(NEOThreadedTest): ...@@ -981,13 +927,10 @@ class Test(NEOThreadedTest):
self.assertEqual(map(u64, t1.oid_list), [0, 1]) self.assertEqual(map(u64, t1.oid_list), [0, 1])
# Check oid 1 is part of transaction metadata. # Check oid 1 is part of transaction metadata.
self.assertEqual(t2.oid_list, t1.oid_list) self.assertEqual(t2.oid_list, t1.oid_list)
finally:
cluster.stop()
def testClientReconnection(self): @with_cluster()
cluster = NEOCluster() def testClientReconnection(self, cluster):
try: if 1:
cluster.start()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter() c1.root()['x'] = x1 = PCounter()
c1.root()['y'] = y = PCounter() c1.root()['y'] = y = PCounter()
...@@ -1004,14 +947,11 @@ class Test(NEOThreadedTest): ...@@ -1004,14 +947,11 @@ class Test(NEOThreadedTest):
#self.tic() # NOTE works ok with tic() commented #self.tic() # NOTE works ok with tic() commented
# modify x with another client # modify x with another client
client = cluster.newClient() with cluster.newClient() as client:
try:
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
finally:
client.close()
#self.tic() # NOTE ----//---- #self.tic() # NOTE ----//----
# Check reconnection to the master and storage. # Check reconnection to the master and storage.
...@@ -1020,13 +960,10 @@ class Test(NEOThreadedTest): ...@@ -1020,13 +960,10 @@ class Test(NEOThreadedTest):
t1.begin() t1.begin()
self.assertEqual(x1._p_changed, None) self.assertEqual(x1._p_changed, None)
self.assertEqual(x1.value, 1) self.assertEqual(x1.value, 1)
finally:
cluster.stop()
def testInvalidTTID(self): @with_cluster()
cluster = NEOCluster() def testInvalidTTID(self, cluster):
try: if 1:
cluster.start()
client = cluster.client client = cluster.client
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(txn)
...@@ -1034,16 +971,13 @@ class Test(NEOThreadedTest): ...@@ -1034,16 +971,13 @@ class Test(NEOThreadedTest):
txn_context['ttid'] = add64(txn_context['ttid'], 1) txn_context['ttid'] = add64(txn_context['ttid'], 1)
self.assertRaises(POSException.StorageError, self.assertRaises(POSException.StorageError,
client.tpc_finish, txn, None) client.tpc_finish, txn, None)
finally:
cluster.stop()
def testStorageFailureDuringTpcFinish(self): @with_cluster()
def testStorageFailureDuringTpcFinish(self, cluster):
def answerTransactionFinished(conn, packet): def answerTransactionFinished(conn, packet):
if isinstance(packet, Packets.AnswerTransactionFinished): if isinstance(packet, Packets.AnswerTransactionFinished):
raise StoppedOperation raise StoppedOperation
cluster = NEOCluster() if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()['x'] = PCounter() c.root()['x'] = PCounter()
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
...@@ -1057,10 +991,9 @@ class Test(NEOThreadedTest): ...@@ -1057,10 +991,9 @@ class Test(NEOThreadedTest):
self.assertEqual(1, u64(c.root()['x']._p_oid)) self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list) self.assertFalse(cluster.client.new_oid_list)
self.assertEqual(2, u64(cluster.client.new_oid())) self.assertEqual(2, u64(cluster.client.new_oid()))
finally:
cluster.stop()
def testClientFailureDuringTpcFinish(self): @with_cluster()
def testClientFailureDuringTpcFinish(self, cluster):
""" """
Third scenario: Third scenario:
...@@ -1095,9 +1028,7 @@ class Test(NEOThreadedTest): ...@@ -1095,9 +1028,7 @@ class Test(NEOThreadedTest):
self.tic() self.tic()
s2m.remove(delayAnswerLockInformation) s2m.remove(delayAnswerLockInformation)
return conn return conn
cluster = NEOCluster() if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
r['x'] = PCounter() r['x'] = PCounter()
...@@ -1126,38 +1057,31 @@ class Test(NEOThreadedTest): ...@@ -1126,38 +1057,31 @@ class Test(NEOThreadedTest):
cluster.master.filterConnection(cluster.storage) as m2s: cluster.master.filterConnection(cluster.storage) as m2s:
s2m.add(delayAnswerLockInformation, Patch(cluster.client, s2m.add(delayAnswerLockInformation, Patch(cluster.client,
_connectToPrimaryNode=_connectToPrimaryNode)) _connectToPrimaryNode=_connectToPrimaryNode))
m2s.add(lambda conn, packet: m2s.delayNotifyUnlockInformation()
isinstance(packet, Packets.NotifyUnlockInformation))
t.commit() # the final TID is returned by the storage (tm) t.commit() # the final TID is returned by the storage (tm)
t.begin() t.begin()
self.assertEqual(r['x'].value, 2) self.assertEqual(r['x'].value, 2)
self.assertTrue(tid2 < r['x']._p_serial) self.assertTrue(tid2 < r['x']._p_serial)
finally:
cluster.stop()
def testMasterFailureBeforeVote(self): @with_cluster(storage_count=2, partitions=2)
def testMasterFailureBeforeVote(self, cluster):
def waitStoreResponses(orig, *args): def waitStoreResponses(orig, *args):
result = orig(*args) result = orig(*args)
m2c, = cluster.master.getConnectionList(orig.__self__) m2c, = cluster.master.getConnectionList(orig.__self__)
m2c.close() m2c.close()
self.tic() self.tic()
return result return result
cluster = NEOCluster(storage_count=2, partitions=2) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()['x'] = PCounter() # 1 store() to each storage c.root()['x'] = PCounter() # 1 store() to each storage
with Patch(cluster.client, waitStoreResponses=waitStoreResponses): with Patch(cluster.client, waitStoreResponses=waitStoreResponses):
self.assertRaises(POSException.StorageError, t.commit) self.assertRaises(POSException.StorageError, t.commit)
self.assertEqual(cluster.neoctl.getClusterState(), self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
finally:
cluster.stop()
def testEmptyTransaction(self): @with_cluster()
cluster = NEOCluster() def testEmptyTransaction(self, cluster):
try: if 1:
cluster.start()
txn = transaction.Transaction() txn = transaction.Transaction()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
storage.tpc_begin(txn) storage.tpc_begin(txn)
...@@ -1166,58 +1090,40 @@ class Test(NEOThreadedTest): ...@@ -1166,58 +1090,40 @@ class Test(NEOThreadedTest):
t, = storage.iterator() t, = storage.iterator()
self.assertEqual(t.tid, serial) self.assertEqual(t.tid, serial)
self.assertFalse(t.oid_list) self.assertFalse(t.oid_list)
finally:
cluster.stop()
def testRecycledClientUUID(self): @with_cluster()
def delayNotifyInformation(conn, packet): def testRecycledClientUUID(self, cluster):
return isinstance(packet, Packets.NotifyNodeInformation)
def notReady(orig, *args): def notReady(orig, *args):
m2s.discard(delayNotifyInformation) m2s.discard(delayNotifyInformation)
return orig(*args) return orig(*args)
cluster = NEOCluster() if 1:
try:
cluster.start()
cluster.getTransaction() cluster.getTransaction()
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.master.filterConnection(cluster.storage) as m2s:
m2s.add(delayNotifyInformation) delayNotifyInformation = m2s.delayNotifyNodeInformation()
cluster.client.master_conn.close() cluster.client.master_conn.close()
client = cluster.newClient() with cluster.newClient() as client, Patch(
p = Patch(client.storage_bootstrap_handler, notReady=notReady) client.storage_bootstrap_handler, notReady=notReady):
try:
p.apply()
x = client.load(ZERO_TID) x = client.load(ZERO_TID)
finally:
del p
client.close()
self.assertNotIn(delayNotifyInformation, m2s) self.assertNotIn(delayNotifyInformation, m2s)
finally:
cluster.stop()
def testAutostart(self): @with_cluster(start_cluster=0, storage_count=3, autostart=3)
def startCluster(): def testAutostart(self, cluster):
def startCluster(orig):
getClusterState = cluster.neoctl.getClusterState getClusterState = cluster.neoctl.getClusterState
self.assertEqual(ClusterStates.RECOVERING, getClusterState()) self.assertEqual(ClusterStates.RECOVERING, getClusterState())
cluster.storage_list[2].start() cluster.storage_list[2].start()
self.tic() with Patch(cluster, startCluster=startCluster):
self.assertEqual(ClusterStates.RUNNING, getClusterState()) self.assertEqual(ClusterStates.RUNNING, getClusterState())
cluster = NEOCluster(storage_count=3, autostart=3)
try:
cluster.startCluster = startCluster
cluster.start(cluster.storage_list[:2]) cluster.start(cluster.storage_list[:2])
finally:
cluster.stop()
del cluster.startCluster
def testAbortVotedTransaction(self): @with_cluster(storage_count=2, partitions=2)
def testAbortVotedTransaction(self, cluster):
r = [] r = []
def tpc_finish(*args, **kw): def tpc_finish(*args, **kw):
for storage in cluster.storage_list: for storage in cluster.storage_list:
r.append(len(storage.dm.getUnfinishedTIDDict())) r.append(len(storage.dm.getUnfinishedTIDDict()))
raise NEOStorageError raise NEOStorageError
cluster = NEOCluster(storage_count=2, partitions=2) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()['x'] = PCounter() c.root()['x'] = PCounter()
with Patch(cluster.client, tpc_finish=tpc_finish): with Patch(cluster.client, tpc_finish=tpc_finish):
...@@ -1228,17 +1134,11 @@ class Test(NEOThreadedTest): ...@@ -1228,17 +1134,11 @@ class Test(NEOThreadedTest):
self.assertFalse(storage.dm.getUnfinishedTIDDict()) self.assertFalse(storage.dm.getUnfinishedTIDDict())
t.begin() t.begin()
self.assertNotIn('x', c.root()) self.assertNotIn('x', c.root())
finally:
cluster.stop()
def testStorageLostDuringRecovery(self): @with_cluster(storage_count=2, partitions=2)
def testStorageLostDuringRecovery(self, cluster):
# Initialize a cluster. # Initialize a cluster.
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
finally:
cluster.stop() cluster.stop()
cluster.reset()
# Restart with a connection failure for the first AskPartitionTable. # Restart with a connection failure for the first AskPartitionTable.
# The master must not be stuck in RECOVERING state # The master must not be stuck in RECOVERING state
# or re-make the partition table. # or re-make the partition table.
...@@ -1247,16 +1147,15 @@ class Test(NEOThreadedTest): ...@@ -1247,16 +1147,15 @@ class Test(NEOThreadedTest):
def askPartitionTable(orig, self, conn): def askPartitionTable(orig, self, conn):
p.revert() p.revert()
conn.close() conn.close()
try: if 1:
with Patch(cluster.master.pt, make=make), \ with Patch(cluster.master.pt, make=make), \
Patch(InitializationHandler, Patch(InitializationHandler,
askPartitionTable=askPartitionTable) as p: askPartitionTable=askPartitionTable) as p:
cluster.start() cluster.start()
self.assertFalse(p.applied) self.assertFalse(p.applied)
finally:
cluster.stop()
def testTruncate(self): @with_cluster(replicas=1)
def testTruncate(self, cluster):
calls = [0, 0] calls = [0, 0]
def dieFirst(i): def dieFirst(i):
def f(orig, *args, **kw): def f(orig, *args, **kw):
...@@ -1265,9 +1164,7 @@ class Test(NEOThreadedTest): ...@@ -1265,9 +1164,7 @@ class Test(NEOThreadedTest):
sys.exit() sys.exit()
return orig(*args, **kw) return orig(*args, **kw)
return f return f
cluster = NEOCluster(replicas=1) if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
tids = [] tids = []
...@@ -1311,30 +1208,27 @@ class Test(NEOThreadedTest): ...@@ -1311,30 +1208,27 @@ class Test(NEOThreadedTest):
self.assertEqual(1, u64(c._storage.new_oid())) self.assertEqual(1, u64(c._storage.new_oid()))
for s in cluster.storage_list: for s in cluster.storage_list:
self.assertEqual(s.dm.getLastIDs()[0], truncate_tid) self.assertEqual(s.dm.getLastIDs()[0], truncate_tid)
finally:
cluster.stop()
def testConnectionTimeout(self): def testConnectionTimeout(self):
conn = self.getLoopbackConnection() with self.getLoopbackConnection() as conn:
conn.KEEP_ALIVE conn.KEEP_ALIVE
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
def onTimeout(orig): def onTimeout(orig):
conn.idle() conn.idle()
orig() orig()
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
with Patch(conn, onTimeout=onTimeout): with Patch(conn, onTimeout=onTimeout):
conn.em.poll(1) conn.em.poll(1)
self.assertFalse(conn.isClosed()) self.assertFalse(conn.isClosed())
def testClientDisconnectedFromMaster(self): @with_cluster()
def testClientDisconnectedFromMaster(self, cluster):
def disconnect(conn, packet): def disconnect(conn, packet):
if isinstance(packet, Packets.AskObject): if isinstance(packet, Packets.AskObject):
m2c.close() m2c.close()
#return True #return True
cluster = NEOCluster() if 1:
try:
cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
m2c, = cluster.master.getConnectionList(cluster.client) m2c, = cluster.master.getConnectionList(cluster.client)
cluster.client._cache.clear() cluster.client._cache.clear()
...@@ -1350,8 +1244,7 @@ class Test(NEOThreadedTest): ...@@ -1350,8 +1244,7 @@ class Test(NEOThreadedTest):
self.assertRaises(TransientError, getattr, c, "root") self.assertRaises(TransientError, getattr, c, "root")
uuid = cluster.client.uuid uuid = cluster.client.uuid
# Let's use a second client to steal the node id of the first one. # Let's use a second client to steal the node id of the first one.
client = cluster.newClient() with cluster.newClient() as client:
try:
client.sync() client.sync()
self.assertEqual(uuid, client.uuid) self.assertEqual(uuid, client.uuid)
# The client reconnects successfully to the master and storage, # The client reconnects successfully to the master and storage,
...@@ -1363,12 +1256,9 @@ class Test(NEOThreadedTest): ...@@ -1363,12 +1256,9 @@ class Test(NEOThreadedTest):
self.assertNotEqual(uuid, cluster.client.uuid) self.assertNotEqual(uuid, cluster.client.uuid)
# Second reconnection, for a successful load. # Second reconnection, for a successful load.
c.root c.root
finally:
client.close()
finally:
cluster.stop()
def testIdTimestamp(self): @with_cluster()
def testIdTimestamp(self, cluster):
""" """
Given a master M, a storage S, and 2 clients Ca and Cb. Given a master M, a storage S, and 2 clients Ca and Cb.
...@@ -1394,9 +1284,7 @@ class Test(NEOThreadedTest): ...@@ -1394,9 +1284,7 @@ class Test(NEOThreadedTest):
ll() ll()
def connectToStorage(client): def connectToStorage(client):
next(client.cp.iterateForObject(0)) next(client.cp.iterateForObject(0))
cluster = NEOCluster() if 1:
try:
cluster.start()
Ca = cluster.client Ca = cluster.client
Ca.pt # only connect to the master Ca.pt # only connect to the master
# In a separate thread, connect to the storage but suspend the # In a separate thread, connect to the storage but suspend the
...@@ -1408,18 +1296,72 @@ class Test(NEOThreadedTest): ...@@ -1408,18 +1296,72 @@ class Test(NEOThreadedTest):
s2c, = s2c s2c, = s2c
m2c, = cluster.master.getConnectionList(cluster.client) m2c, = cluster.master.getConnectionList(cluster.client)
m2c.close() m2c.close()
Cb = cluster.newClient() with cluster.newClient() as Cb:
try:
Cb.pt # only connect to the master Cb.pt # only connect to the master
del s2c.readable del s2c.readable
self.assertRaises(NEOPrimaryMasterLost, t.join) self.assertRaises(NEOPrimaryMasterLost, t.join)
self.assertTrue(s2c.isClosed()) self.assertTrue(s2c.isClosed())
connectToStorage(Cb) connectToStorage(Cb)
finally:
Cb.close()
finally:
cluster.stop()
@with_cluster(storage_count=2, partitions=2)
def testPruneOrphan(self, cluster):
if 1:
cluster.importZODB()(3)
bad = []
ok = []
def data_args(value):
return makeChecksum(value), value, 0
node_list = []
for i, s in enumerate(cluster.storage_list):
node_list.append(s.uuid)
if i:
s.dm.holdData(*data_args('boo'))
ok.append(s.getDataLockInfo())
for i in xrange(3 - i):
s.dm.storeData(*data_args('!' * i))
bad.append(s.getDataLockInfo())
s.dm.commit()
def check(dry_run, expected):
cluster.neoctl.repair(node_list, dry_run)
for e, s in zip(expected, cluster.storage_list):
while 1:
self.tic()
if s.dm._repairing is None:
break
time.sleep(.1)
self.assertEqual(e, s.getDataLockInfo())
check(1, bad)
check(0, ok)
check(1, ok)
@with_cluster(replicas=1)
def testLateConflictOnReplica(self, cluster):
"""
Already resolved conflict: check the case of a storage node that
reports a conflict after that this conflict was fully resolved with
another node.
"""
def answerStoreObject(orig, conn, conflicting, *args):
if not conflicting:
p.revert()
ll()
orig(conn, conflicting, *args)
if 1:
s0, s1 = cluster.storage_list
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounterWithResolution()
t1.commit()
x.value += 1
t2, c2 = cluster.getTransaction()
c2.root()['x'].value += 2
t2.commit()
with LockLock() as ll, s1.filterConnection(cluster.client) as f, \
Patch(cluster.client.storage_handler,
answerStoreObject=answerStoreObject) as p:
f.delayAnswerStoreObject()
t = self.newThread(t1.commit)
ll()
t.join()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# #
# Copyright (C) 2014-2016 Nexedi SA # Copyright (C) 2014-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import deque
from cPickle import Pickler, Unpickler from cPickle import Pickler, Unpickler
from cStringIO import StringIO from cStringIO import StringIO
from itertools import islice, izip_longest from itertools import islice, izip_longest
import os, unittest import os, shutil, unittest
import neo, transaction, ZODB import neo, transaction, ZODB
from neo.lib import logging from neo.lib import logging
from neo.lib.util import u64 from neo.lib.util import u64
...@@ -129,18 +128,28 @@ class ImporterTests(NEOThreadedTest): ...@@ -129,18 +128,28 @@ class ImporterTests(NEOThreadedTest):
self.assertDictEqual(state, load()) self.assertDictEqual(state, load())
def test(self): def test(self):
# XXX: Using NEO source files as test data was a bad idea because
# the test breaks easily in case of massive changes in the code,
# or if there are many untracked files.
importer = [] importer = []
fs_dir = os.path.join(getTempDirectory(), self.id()) fs_dir = os.path.join(getTempDirectory(), self.id())
shutil.rmtree(fs_dir, 1) # for --loop
os.mkdir(fs_dir) os.mkdir(fs_dir)
src_root, = neo.__path__ src_root, = neo.__path__
fs_list = "root", "client", "master", "tests" fs_list = "root", "client", "master", "tests"
def not_pyc(name):
return not name.endswith(".pyc")
# We use 'hash' to skip roughly half of files.
# They'll be added after the migration has started.
def root_filter(name): def root_filter(name):
if not name.endswith(".pyc"): if not_pyc(name):
i = name.find(os.sep) i = name.find(os.sep)
return i < 0 or name[:i] not in fs_list return (i < 0 or name[:i] not in fs_list) and (
'.' not in name or hash(name) & 1)
def sub_filter(name): def sub_filter(name):
return lambda n: n[-4:] != '.pyc' and \ return lambda n: not_pyc(n) and (
n.split(os.sep, 1)[0] in (name, "scripts") hash(n) & 1 if '.' in n else
os.sep in n or n in (name, "scripts"))
conn_list = [] conn_list = []
iter_list = [] iter_list = []
# Setup several FileStorage databases. # Setup several FileStorage databases.
...@@ -172,8 +181,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -172,8 +181,7 @@ class ImporterTests(NEOThreadedTest):
c.db().close() c.db().close()
#del importer[0][1][importer.pop()[0]] #del importer[0][1][importer.pop()[0]]
# Start NEO cluster with transparent import of a multi-base ZODB. # Start NEO cluster with transparent import of a multi-base ZODB.
cluster = NEOCluster(compress=False, importer=importer) with NEOCluster(compress=False, importer=importer) as cluster:
try:
# Suspend import for a while, so that import # Suspend import for a while, so that import
# is finished in the middle of the below 'for' loop. # is finished in the middle of the below 'for' loop.
# Use a slightly different main loop for storage so that it # Use a slightly different main loop for storage so that it
...@@ -193,7 +201,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -193,7 +201,7 @@ class ImporterTests(NEOThreadedTest):
cluster.start() cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root()["neo"] r = c.root()["neo"]
# Test retrieving of an object from ZODB when next serial in NEO. # Test retrieving of an object from ZODB when next serial is in NEO.
r._p_changed = 1 r._p_changed = 1
t.commit() t.commit()
t.begin() t.begin()
...@@ -204,24 +212,25 @@ class ImporterTests(NEOThreadedTest): ...@@ -204,24 +212,25 @@ class ImporterTests(NEOThreadedTest):
self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$", self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$",
c.db().history, r._p_oid) c.db().history, r._p_oid)
i = r.walk() i = r.walk()
next(islice(i, 9, None)) next(islice(i, 4, None))
logging.info("start migration") logging.info("start migration")
dm.doOperation(cluster.storage) dm.doOperation(cluster.storage)
deque(i, maxlen=0) # Adjust if needed. Must remain > 0.
last_import = None assert 14 == sum(1 for i in i)
for i, r in enumerate(r.treeFromFs(src_root, 10)): last_import = -1
for i, r in enumerate(r.treeFromFs(src_root, 6, not_pyc)):
t.commit() t.commit()
if cluster.storage.dm._import: if cluster.storage.dm._import:
last_import = i last_import = i
self.tic() self.tic()
self.assertTrue(last_import and not cluster.storage.dm._import) # Same as above. We want last_import smaller enough compared to i
assert i / 3 < last_import < i - 2, (last_import, i)
self.assertFalse(cluster.storage.dm._import)
i = len(src_root) + 1 i = len(src_root) + 1
self.assertEqual(sorted(r.walk()), sorted( self.assertEqual(sorted(r.walk()), sorted(
(x[i:] or '.', sorted(y), sorted(z)) (x[i:] or '.', sorted(y), sorted(filter(not_pyc, z)))
for x, y, z in os.walk(src_root))) for x, y, z in os.walk(src_root)))
t.commit() t.commit()
finally:
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
......
# #
# Copyright (C) 2012-2016 Nexedi SA # Copyright (C) 2012-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -36,7 +36,8 @@ from neo.lib.protocol import CellStates, ClusterStates, Packets, \ ...@@ -36,7 +36,8 @@ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64 from neo.lib.util import p64
from .. import expectedFailure, Patch from .. import expectedFailure, Patch
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, predictable_random from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \
predictable_random, with_cluster
# dump log to stderr # dump log to stderr
""" """
...@@ -48,19 +49,14 @@ getLogger().setLevel(INFO) ...@@ -48,19 +49,14 @@ getLogger().setLevel(INFO)
def backup_test(partitions=1, upstream_kw={}, backup_kw={}): def backup_test(partitions=1, upstream_kw={}, backup_kw={}):
def decorator(wrapped): def decorator(wrapped):
def wrapper(self): def wrapper(self):
upstream = NEOCluster(partitions, **upstream_kw) with NEOCluster(partitions, **upstream_kw) as upstream:
try:
upstream.start() upstream.start()
backup = NEOCluster(partitions, upstream=upstream, **backup_kw) with NEOCluster(partitions, upstream=upstream,
try: **backup_kw) as backup:
backup.start() backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic() self.tic()
wrapped(self, backup) wrapped(self, backup)
finally:
backup.stop()
finally:
upstream.stop()
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
return decorator return decorator
...@@ -99,15 +95,18 @@ class ReplicationTests(NEOThreadedTest): ...@@ -99,15 +95,18 @@ class ReplicationTests(NEOThreadedTest):
np = 7 np = 7
nr = 2 nr = 2
check_dict = dict.fromkeys(xrange(np)) check_dict = dict.fromkeys(xrange(np))
upstream = NEOCluster(partitions=np, replicas=nr-1, storage_count=3) with NEOCluster(partitions=np, replicas=nr-1, storage_count=3
try: ) as upstream:
upstream.start() upstream.start()
importZODB = upstream.importZODB() importZODB = upstream.importZODB()
importZODB(3) importZODB(3)
backup = NEOCluster(partitions=np, replicas=nr-1, storage_count=5, def delaySecondary(conn, packet):
upstream=upstream) if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode()
return not upstream_name and all(source_dict.itervalues())
# U -> B propagation # U -> B propagation
try: with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream) as backup:
backup.start() backup.start()
# Initialize & catch up. # Initialize & catch up.
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
...@@ -122,10 +121,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -122,10 +121,7 @@ class ReplicationTests(NEOThreadedTest):
# Check that a backup cluster can be restarted. # Check that a backup cluster can be restarted.
# (U -> B propagation after restart) # (U -> B propagation after restart)
finally:
backup.stop() backup.stop()
backup.reset()
try:
backup.start() backup.start()
self.assertEqual(backup.neoctl.getClusterState(), self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.BACKINGUP) ClusterStates.BACKINGUP)
...@@ -144,7 +140,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -144,7 +140,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(np*nr, self.checkBackup(backup)) self.assertEqual(np*nr, self.checkBackup(backup))
self.assertEqual(backup.neoctl.getClusterState(), self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
finally:
backup.stop() backup.stop()
# U -> B propagation with Mb -> Sb' (secondary, Replicate from primary Sb) delayed # U -> B propagation with Mb -> Sb' (secondary, Replicate from primary Sb) delayed
...@@ -153,16 +149,10 @@ class ReplicationTests(NEOThreadedTest): ...@@ -153,16 +149,10 @@ class ReplicationTests(NEOThreadedTest):
from neo.master import handlers as mhandler from neo.master import handlers as mhandler
#dbmanager.X = 1 #dbmanager.X = 1
#mhandler.X = 1 #mhandler.X = 1
def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode()
#print 'REPLICATE tid: %r, upstream_name: %r, source_dict: %r' % \
# (tid, upstream_name, source_dict) # (tid, upstream_name, source_dict)
#return True #return True
#return not upstream_name and all(source_dict.itervalues()) #return not upstream_name and all(source_dict.itervalues())
return upstream_name != "" return upstream_name != ""
backup.reset()
try:
backup.start() backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic() self.tic()
...@@ -189,16 +179,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -189,16 +179,13 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.cluster_state, ClusterStates.RUNNING) self.assertEqual(backup.cluster_state, ClusterStates.RUNNING)
self.assertEqual(np*nr, self.checkBackup(backup, self.assertEqual(np*nr, self.checkBackup(backup,
max_tid=backup.last_tid)) max_tid=backup.last_tid))
self.assertEqual(backup.last_tid, u_last_tid0) # truncated after recovery
self.assertEqual(np*nr, self.checkBackup(backup, max_tid=backup.last_tid)) self.assertEqual(np*nr, self.checkBackup(backup, max_tid=backup.last_tid))
finally:
backup.stop() backup.stop()
dbmanager.X = 0 dbmanager.X = 0
mhandler.X = 0 mhandler.X = 0
# S -> Sb (AddObject) delayed XXX not only S -> Sb: also Sb -> Sb' # S -> Sb (AddObject) delayed XXX not only S -> Sb: also Sb -> Sb'
backup.reset()
try:
backup.start() backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic() self.tic()
...@@ -206,8 +193,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -206,8 +193,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.backup_tid, upstream.last_tid) # B caught-up with U self.assertEqual(backup.backup_tid, upstream.last_tid) # B caught-up with U
u_last_tid1 = upstream.last_tid u_last_tid1 = upstream.last_tid
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: conn.getUUID() is None and f.delayAddObject(lambda conn: conn.getUUID() is None)
isinstance(packet, Packets.AddObject))
while not f.filtered_count: while not f.filtered_count:
importZODB(1) importZODB(1)
self.tic() self.tic()
...@@ -222,10 +208,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -222,10 +208,6 @@ class ReplicationTests(NEOThreadedTest):
max_tid=backup.last_tid)) max_tid=backup.last_tid))
self.assertEqual(backup.last_tid, u_last_tid1) # truncated after recovery self.assertEqual(backup.last_tid, u_last_tid1) # truncated after recovery
self.assertEqual(np*nr, self.checkBackup(backup, max_tid=backup.last_tid)) self.assertEqual(np*nr, self.checkBackup(backup, max_tid=backup.last_tid))
finally:
backup.stop()
finally:
upstream.stop()
@predictable_random() @predictable_random()
def testBackupNodeLost(self): def testBackupNodeLost(self):
...@@ -251,16 +233,14 @@ class ReplicationTests(NEOThreadedTest): ...@@ -251,16 +233,14 @@ class ReplicationTests(NEOThreadedTest):
node_list.remove(txn.getNode()) node_list.remove(txn.getNode())
node_list[0].getConnection().close() # disconnect Mb from M node_list[0].getConnection().close() # disconnect Mb from M
return orig(txn) return orig(txn)
upstream = NEOCluster(partitions=np, replicas=0, storage_count=1) with NEOCluster(partitions=np, replicas=0, storage_count=1) as upstream:
try:
upstream.start() upstream.start()
importZODB = upstream.importZODB(random=random) importZODB = upstream.importZODB(random=random)
# Do not start with an empty DB so that 'primary_dict' below is not # Do not start with an empty DB so that 'primary_dict' below is not
# empty on the first iteration. # empty on the first iteration.
importZODB(1) importZODB(1)
backup = NEOCluster(partitions=np, replicas=2, storage_count=4, with NEOCluster(partitions=np, replicas=2, storage_count=4,
upstream=upstream) upstream=upstream) as backup:
try:
backup.start() backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic() self.tic()
...@@ -293,10 +273,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -293,10 +273,6 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.backup_tid, upstream.last_tid) self.assertEqual(backup.backup_tid, upstream.last_tid)
self.assertEqual(backup.last_tid, upstream.last_tid) self.assertEqual(backup.last_tid, upstream.last_tid)
self.assertEqual(np*3, self.checkBackup(backup)) self.assertEqual(np*3, self.checkBackup(backup))
finally:
backup.stop()
finally:
upstream.stop()
@backup_test() @backup_test()
def testBackupUpstreamMasterDead(self, backup): def testBackupUpstreamMasterDead(self, backup):
...@@ -331,8 +307,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -331,8 +307,7 @@ class ReplicationTests(NEOThreadedTest):
def testBackupUpstreamStorageDead(self, backup): def testBackupUpstreamStorageDead(self, backup):
upstream = backup.upstream upstream = backup.upstream
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.delayInvalidateObjects() # delay M -> Mb
isinstance(packet, Packets.InvalidateObjects)) # delay M -> Mb
upstream.importZODB()(1) upstream.importZODB()(1)
count = [0] count = [0]
def _connect(orig, conn): def _connect(orig, conn):
...@@ -361,38 +336,29 @@ class ReplicationTests(NEOThreadedTest): ...@@ -361,38 +336,29 @@ class ReplicationTests(NEOThreadedTest):
""" """
upstream = backup.upstream upstream = backup.upstream
with upstream.master.filterConnection(upstream.storage) as f: with upstream.master.filterConnection(upstream.storage) as f:
f.add(lambda conn, packet: f.delayNotifyUnlockInformation()
isinstance(packet, Packets.NotifyUnlockInformation))
upstream.importZODB()(1) upstream.importZODB()(1)
self.tic() self.tic()
self.tic() self.tic()
# TODO check tids # TODO check tids
self.assertEqual(1, self.checkBackup(backup)) self.assertEqual(1, self.checkBackup(backup))
def testBackupEarlyInvalidation(self): @with_cluster()
def testBackupEarlyInvalidation(self, upstream):
""" """
The backup master must ignore notification before being fully The backup master must ignore notifications before being fully
initialized. initialized.
""" """
upstream = NEOCluster() with NEOCluster(upstream=upstream) as backup:
try:
upstream.start()
backup = NEOCluster(upstream=upstream)
try:
backup.start() backup.start()
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.delayAskPartitionTable(lambda conn:
isinstance(packet, Packets.AskPartitionTable) and
isinstance(conn.getHandler(), BackupHandler)) isinstance(conn.getHandler(), BackupHandler))
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
upstream.importZODB()(1) upstream.importZODB()(1)
self.tic() self.tic()
self.tic() self.tic()
self.assertTrue(backup.master.isAlive()) self.assertTrue(backup.master.is_alive())
finally:
backup.stop()
finally:
upstream.stop()
@backup_test() @backup_test()
def testBackupTid(self, backup): def testBackupTid(self, backup):
...@@ -408,16 +374,15 @@ class ReplicationTests(NEOThreadedTest): ...@@ -408,16 +374,15 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(last_tid, backup.backup_tid) self.assertEqual(last_tid, backup.backup_tid)
backup.stop() backup.stop()
importZODB(1) importZODB(1)
backup.reset()
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.delayAskFetchTransactions()
isinstance(packet, Packets.AskFetchTransactions))
backup.start() backup.start()
self.assertEqual(last_tid, backup.backup_tid) self.assertEqual(last_tid, backup.backup_tid)
self.tic() self.tic()
self.assertEqual(1, self.checkBackup(backup)) self.assertEqual(1, self.checkBackup(backup))
def testSafeTweak(self): @with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
def testSafeTweak(self, cluster):
""" """
Check that tweak always tries to keep a minimum of (replicas + 1) Check that tweak always tries to keep a minimum of (replicas + 1)
readable cells, otherwise we have less/no redundancy as long as readable cells, otherwise we have less/no redundancy as long as
...@@ -426,9 +391,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -426,9 +391,8 @@ class ReplicationTests(NEOThreadedTest):
def changePartitionTable(orig, *args): def changePartitionTable(orig, *args):
orig(*args) orig(*args)
sys.exit() sys.exit()
cluster = NEOCluster(partitions=3, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
try: if 1:
cluster.start([s0, s1]) cluster.start([s0, s1])
s2.start() s2.start()
self.tic() self.tic()
...@@ -441,10 +405,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -441,10 +405,9 @@ class ReplicationTests(NEOThreadedTest):
self.tic() self.tic()
expectedFailure(self.assertEqual)(cluster.neoctl.getClusterState(), expectedFailure(self.assertEqual)(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
finally:
cluster.stop()
def testReplicationAbortedBySource(self): @with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
def testReplicationAbortedBySource(self, cluster):
""" """
Check that a feeding node aborts replication when its partition is Check that a feeding node aborts replication when its partition is
dropped, and that the out-of-date node finishes to replicate from dropped, and that the out-of-date node finishes to replicate from
...@@ -469,11 +432,12 @@ class ReplicationTests(NEOThreadedTest): ...@@ -469,11 +432,12 @@ class ReplicationTests(NEOThreadedTest):
# default for performance reason # default for performance reason
orig.im_self.dropPartitions((offset,)) orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list) return orig(ptid, cell_list)
np = 3 np = cluster.num_partitions
cluster = NEOCluster(partitions=np, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects: for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
try: if cluster.started:
cluster.stop(1)
if 1:
cluster.start([s0]) cluster.start([s0])
cluster.populate([range(np*2)] * np) cluster.populate([range(np*2)] * np)
s1.start() s1.start()
...@@ -491,20 +455,17 @@ class ReplicationTests(NEOThreadedTest): ...@@ -491,20 +455,17 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(1, connection_filter.filtered_count) self.assertEqual(1, connection_filter.filtered_count)
self.tic() self.tic()
self.checkPartitionReplicated(s1, s2, offset) self.checkPartitionReplicated(s1, s2, offset)
finally:
cluster.stop()
cluster.reset(True)
def testClientReadingDuringTweak(self): @with_cluster(start_cluster=0, partitions=2, storage_count=2)
def testClientReadingDuringTweak(self, cluster):
# XXX: Currently, the test passes because data of dropped cells are not # XXX: Currently, the test passes because data of dropped cells are not
# deleted while the cluster is operational: this is only done # deleted while the cluster is operational: this is only done
# during the RECOVERING phase. But we'll want to be able to free # during the RECOVERING phase. But we'll want to be able to free
# disk space without service interruption, and for this the client # disk space without service interruption, and for this the client
# may have to retry reading data from the new cells. If s0 deleted # may have to retry reading data from the new cells. If s0 deleted
# all data for partition 1, the test would fail with a POSKeyError. # all data for partition 1, the test would fail with a POSKeyError.
cluster = NEOCluster(partitions=2, storage_count=2)
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
try: if 1:
cluster.start([s0]) cluster.start([s0])
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
oid = p64(1) oid = p64(1)
...@@ -518,16 +479,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -518,16 +479,13 @@ class ReplicationTests(NEOThreadedTest):
cluster.neoctl.enableStorageList([s1.uuid]) cluster.neoctl.enableStorageList([s1.uuid])
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet: m2c.delayNotifyPartitionChanges()
isinstance(packet, Packets.NotifyPartitionChanges))
self.tic() self.tic()
self.assertEqual('foo', storage.load(oid)[0]) self.assertEqual('foo', storage.load(oid)[0])
finally:
cluster.stop()
def testResumingReplication(self): @with_cluster(start_cluster=0, replicas=1)
cluster = NEOCluster(replicas=1) def testResumingReplication(self, cluster):
try: if 1:
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,)) cluster.start(storage_list=(s0,))
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
...@@ -552,10 +510,43 @@ class ReplicationTests(NEOThreadedTest): ...@@ -552,10 +510,43 @@ class ReplicationTests(NEOThreadedTest):
s0.stop() s0.stop()
cluster.join((s0,)) cluster.join((s0,))
t0, t1, t2 = c.db().storage.iterator() t0, t1, t2 = c.db().storage.iterator()
finally:
cluster.stop()
def testCheckReplicas(self): @with_cluster(start_cluster=0, replicas=1)
def testReplicationBlockedByUnfinished(self, cluster):
if 1:
s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,))
storage = cluster.getZODBStorage()
oid = storage.new_oid()
tid = None
expected = 'UO'
for n in 1, 0:
# On first iteration, the transaction will block replication
# until tpc_finish.
# We do a second iteration as a quick check that the cluster
# remains functional after such a scenario.
txn = transaction.Transaction()
storage.tpc_begin(txn)
tid = storage.store(oid, tid, 'foo', '', txn)
if n:
# Start the outdated storage.
s1.start()
self.tic()
cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable()
self.tic()
self.assertPartitionTable(cluster, expected)
storage.tpc_vote(txn)
self.assertPartitionTable(cluster, expected)
tid = storage.tpc_finish(txn)
self.tic() # replication resumes and ends
expected = 'UU'
self.assertPartitionTable(cluster, expected)
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
@with_cluster(partitions=5, replicas=2, storage_count=3)
def testCheckReplicas(self, cluster):
from neo.storage import checker from neo.storage import checker
def corrupt(offset): def corrupt(offset):
s0, s1, s2 = (storage_dict[cell.getUUID()] s0, s1, s2 = (storage_dict[cell.getUUID()]
...@@ -570,14 +561,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -570,14 +561,11 @@ class ReplicationTests(NEOThreadedTest):
for cell in row[1] for cell in row[1]
if cell[1] == CellStates.CORRUPTED])) if cell[1] == CellStates.CORRUPTED]))
self.assertEqual(expected_state, cluster.neoctl.getClusterState()) self.assertEqual(expected_state, cluster.neoctl.getClusterState())
np = 5 np = cluster.num_partitions
tid_count = np * 3 tid_count = np * 3
corrupt_tid = tid_count // 2 corrupt_tid = tid_count // 2
check_dict = dict.fromkeys(xrange(np)) check_dict = dict.fromkeys(xrange(np))
cluster = NEOCluster(partitions=np, replicas=2, storage_count=3) with Patch(checker, CHECK_COUNT=2):
try:
checker.CHECK_COUNT = 2
cluster.start()
cluster.populate([range(np*2)] * tid_count) cluster.populate([range(np*2)] * tid_count)
storage_dict = {x.uuid: x for x in cluster.storage_list} storage_dict = {x.uuid: x for x in cluster.storage_list}
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None) cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
...@@ -597,9 +585,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -597,9 +585,6 @@ class ReplicationTests(NEOThreadedTest):
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None) cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
self.tic() self.tic()
check(ClusterStates.RECOVERING, 4) check(ClusterStates.RECOVERING, 4)
finally:
checker.CHECK_COUNT = CHECK_COUNT
cluster.stop()
@backup_test() @backup_test()
def testBackupReadOnlyAccess(self, backup): def testBackupReadOnlyAccess(self, backup):
...@@ -630,7 +615,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -630,7 +615,7 @@ class ReplicationTests(NEOThreadedTest):
# commit new data to U # commit new data to U
txn = transaction.Transaction() txn = transaction.Transaction()
txn.note('test transaction %i' % i) txn.note(u'test transaction %s' % i)
Z.tpc_begin(txn) Z.tpc_begin(txn)
oid = Z.new_oid() oid = Z.new_oid()
Z.store(oid, None, '%s-%i' % (oid, i), '', txn) Z.store(oid, None, '%s-%i' % (oid, i), '', txn)
......
# #
# Copyright (C) 2015-2016 Nexedi SA # Copyright (C) 2015-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -36,12 +36,8 @@ class SSLTests(SSLMixin, test.Test): ...@@ -36,12 +36,8 @@ class SSLTests(SSLMixin, test.Test):
testDeadlockAvoidance = None # XXX why this fails? testDeadlockAvoidance = None # XXX why this fails?
testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails? testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails?
def testAbortConnection(self): def testAbortConnection(self, after_handshake=1):
for after_handshake in 1, 0: with self.getLoopbackConnection() as conn:
try:
conn.reset()
except UnboundLocalError:
conn = self.getLoopbackConnection()
conn.ask(Packets.Ping()) conn.ask(Packets.Ping())
connector = conn.getConnector() connector = conn.getConnector()
del connector.connect_limit[connector.addr] del connector.connect_limit[connector.addr]
...@@ -58,6 +54,9 @@ class SSLTests(SSLMixin, test.Test): ...@@ -58,6 +54,9 @@ class SSLTests(SSLMixin, test.Test):
conn.em.poll(1) conn.em.poll(1)
self.assertIs(conn.getConnector(), None) self.assertIs(conn.getConnector(), None)
def testAbortConnectionBeforeHandshake(self):
self.testAbortConnection(0)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests): class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
# do not repeat slowest tests with SSL # do not repeat slowest tests with SSL
testBackupNodeLost = testBackupNormalCase = None # TODO recheck testBackupNodeLost = testBackupNormalCase = None # TODO recheck
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -42,9 +42,11 @@ class ZODBTestCase(TestCase): ...@@ -42,9 +42,11 @@ class ZODBTestCase(TestCase):
self.open() self.open()
def _tearDown(self, success): def _tearDown(self, success):
self._storage.cleanup()
try: try:
if functional:
self.neo.stop() self.neo.stop()
else:
self.neo.stop(None)
except Exception: except Exception:
if success: if success:
raise raise
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -14,14 +14,23 @@ ...@@ -14,14 +14,23 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
import unittest import unittest
from ZODB.tests.BasicStorage import BasicStorage from ZODB.tests.BasicStorage import BasicStorage
from ZODB.tests.StorageTestBase import StorageTestBase from ZODB.tests.StorageTestBase import StorageTestBase
from . import ZODBTestCase from . import ZODBTestCase
from .. import Patch, threaded
class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage): class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage):
pass
def check_checkCurrentSerialInTransaction(self):
x = time.time() + 10
def tic_loop():
while time.time() < x:
yield
with Patch(threaded, TIC_LOOP=tic_loop()):
super(BasicTests, self).check_checkCurrentSerialInTransaction()
if __name__ == "__main__": if __name__ == "__main__":
suite = unittest.makeSuite(BasicTests, 'check') suite = unittest.makeSuite(BasicTests, 'check')
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
......
# #
# Copyright (C) 2009-2016 Nexedi SA # Copyright (C) 2009-2017 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License # modify it under the terms of the GNU General Public License
...@@ -30,19 +30,6 @@ class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests): ...@@ -30,19 +30,6 @@ class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests):
self._db.close() self._db.close()
super(NEOZODBTests, self)._tearDown(success) super(NEOZODBTests, self)._tearDown(success)
def checkMultipleUndoInOneTransaction(self):
# XXX: Upstream test accesses a persistent object outside a transaction
# (it should call transaction.begin() after the last commit)
# so disable our Connection.afterCompletion optimization.
# This should really be discussed on zodb-dev ML.
from ZODB.Connection import Connection
afterCompletion = Connection.__dict__['afterCompletion']
try:
Connection.afterCompletion = Connection.__dict__['newTransaction']
super(NEOZODBTests, self).checkMultipleUndoInOneTransaction()
finally:
Connection.afterCompletion = afterCompletion
if __name__ == "__main__": if __name__ == "__main__":
suite = unittest.makeSuite(NEOZODBTests, 'check') suite = unittest.makeSuite(NEOZODBTests, 'check')
unittest.main(defaultTest='suite') unittest.main(defaultTest='suite')
...@@ -14,7 +14,8 @@ Topic :: Database ...@@ -14,7 +14,8 @@ Topic :: Database
Topic :: Software Development :: Libraries :: Python Modules Topic :: Software Development :: Libraries :: Python Modules
""" """
if not os.path.exists('mock.py'): mock = 'neo/tests/mock.py'
if not os.path.exists(mock):
import cStringIO, hashlib,subprocess, urllib, zipfile import cStringIO, hashlib,subprocess, urllib, zipfile
x = 'pythonmock-0.1.0.zip' x = 'pythonmock-0.1.0.zip'
try: try:
...@@ -25,7 +26,7 @@ if not os.path.exists('mock.py'): ...@@ -25,7 +26,7 @@ if not os.path.exists('mock.py'):
mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py') mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py')
if hashlib.md5(mock_py).hexdigest() != '79f42f390678e5195d9ce4ae43bd18ec': if hashlib.md5(mock_py).hexdigest() != '79f42f390678e5195d9ce4ae43bd18ec':
raise EnvironmentError("MD5 checksum mismatch downloading 'mock.py'") raise EnvironmentError("MD5 checksum mismatch downloading 'mock.py'")
open('mock.py', 'w').write(mock_py) open(mock, 'w').write(mock_py)
zodb_require = ['ZODB3>=3.10dev'] zodb_require = ['ZODB3>=3.10dev']
...@@ -58,7 +59,7 @@ else: ...@@ -58,7 +59,7 @@ else:
setup( setup(
name = 'neoppod', name = 'neoppod',
version = '1.7.0', version = '1.7.1',
description = __doc__.strip(), description = __doc__.strip(),
author = 'Nexedi SA', author = 'Nexedi SA',
author_email = 'neo-dev@erp5.org', author_email = 'neo-dev@erp5.org',
...@@ -69,7 +70,6 @@ setup( ...@@ -69,7 +70,6 @@ setup(
long_description = ".. contents::\n\n" + open('README.rst').read() long_description = ".. contents::\n\n" + open('README.rst').read()
+ "\n" + open('CHANGELOG.rst').read(), + "\n" + open('CHANGELOG.rst').read(),
packages = find_packages(), packages = find_packages(),
py_modules = ['mock'],
entry_points = { entry_points = {
'console_scripts': [ 'console_scripts': [
# XXX: we'd like not to generate scripts for unwanted features # XXX: we'd like not to generate scripts for unwanted features
......
#!/bin/sh -e #!/usr/bin/env python
for COV in coverage python-coverage import os, re, sys, shutil
do type $COV && break from coverage.cmdline import main, CmdOptionParser
done >/dev/null 2>&1 || exit
sys.argv.insert(1, 'html')
del CmdOptionParser.get_prog_name
$COV html "$@"
# https://bitbucket.org/ned/coveragepy/issues/474/javascript-in-html-captures-all-keys # https://bitbucket.org/ned/coveragepy/issues/474/javascript-in-html-captures-all-keys
sed -i " shutil_copyfile = shutil.copyfile
/assign_shortkeys *=/s/$/return;/ def copyfile(src, dst):
/^ *\.bind('keydown',/s,^,//, if os.path.basename(dst) == 'coverage_html.js':
" htmlcov/coverage_html.js with open(src) as f:
js = f.read()
js = re.sub(r"(assign_shortkeys.*\{)", r"\1return;", js)
js = re.sub(r"^( *\.bind\('keydown',)", r"//\1", js, flags=re.M)
with open(dst, 'w') as f:
f.write(js)
else:
shutil_copyfile(src, dst)
shutil.copyfile = copyfile
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment