Commit d61a34c0 authored by Kirill Smelkov's avatar Kirill Smelkov

.

parents f5cf9484 42fd89bc
......@@ -6,5 +6,5 @@
/build/
/dist/
/htmlcov/
/mock.py
/neo/tests/mock.py
/neoppod.egg-info/
Change History
==============
1.7.1 (2017-01-18)
------------------
- Replication:
- Fixed possibly wrong knowledge of cells' backup_tid when resuming backup.
In such case, 'neoctl print ids' gave false impression that the backup
cluster was up-to-date. This also resulted in an inconsistent database
when leaving backup mode before that the issue resolved by itself.
- Storage nodes now select the partition which is furthest behind. Previous
criterion was such that in case of high upstream activity, the backup could
even be stuck looping on a subset of partitions.
- Fixed replication of unfinished imported transactions.
- Fixed abort before vote, to free the storage space used by the transaction.
A new 'prune_orphan' neoctl command was added to delete unreferenced raw data
in the database.
- Removed short storage option -R to reset the db.
Help is reworded to clarify that --reset exits once done.
- The application receiving buffer size has been increased.
This speeds up transfer of big packets.
- The master raised AttributeError at exit during recovery.
- At startup, the importer storage backend connected twice to the destination
database.
1.7.0 (2016-12-19)
------------------
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -71,6 +71,7 @@ class AdminEventHandler(EventHandler):
setNodeState = forward_ask(Packets.SetNodeState)
checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate)
repair = forward_ask(Packets.Repair)
class MasterEventHandler(EventHandler):
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -191,11 +191,6 @@ class Storage(BaseStorage.BaseStorage,
# seems used only by FileStorage
raise NotImplementedError
def cleanup(self):
# Used in unit tests to remove local database files.
# We have no such thing, so make this method a no-op.
pass
def close(self):
# WARNING: This does not handle the case where an app is shared by
# several Storage instances, but this is something that only
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -544,6 +544,8 @@ class Application(ThreadedApplication):
# A later serial has already been resolved, skip.
resolved_serial_set.update(conflict_serial_set)
continue
if self.last_tid < conflict_serial:
self.sync() # possible late invalidation (very rare)
try:
new_data = tryToResolveConflict(oid, conflict_serial,
serial, data)
......
#
# Copyright (C) 2011-2016 Nexedi SA
# Copyright (C) 2011-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2015-2016 Nexedi SA
# Copyright (C) 2015-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -145,7 +145,7 @@ class SocketConnector(object):
def receive(self, read_buf):
try:
data = self.socket.recv(4096)
data = self.socket.recv(65536)
except socket.error, e:
self._error('recv', e)
if data:
......@@ -155,6 +155,7 @@ class SocketConnector(object):
raise ConnectorException
def send(self):
# XXX: unefficient for big packets
msg = ''.join(self.queued)
if msg:
try:
......
#
# Copyright (C) 2010-2016 Nexedi SA
# Copyright (C) 2010-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, thread
import os
from time import time
from select import epoll, EPOLLIN, EPOLLOUT, EPOLLERR, EPOLLHUP
from errno import EAGAIN, EEXIST, EINTR, ENOENT
......@@ -35,7 +35,6 @@ class EpollEventManager(object):
"""This class manages connections and events based on epoll(5)."""
_timeout = None
_trigger_exit = False
def __init__(self):
self.connection_dict = {}
......@@ -43,6 +42,7 @@ class EpollEventManager(object):
self.writer_set = set()
self.epoll = epoll()
self._pending_processing = []
self._trigger_list = []
self._trigger_fd, w = os.pipe()
os.close(w)
self._trigger_lock = Lock()
......@@ -231,9 +231,12 @@ class EpollEventManager(object):
if fd == self._trigger_fd:
with self._trigger_lock:
self.epoll.unregister(fd)
if self._trigger_exit:
del self._trigger_exit
thread.exit()
action_list = self._trigger_list
try:
while action_list:
action_list.pop(0)()
finally:
del action_list[:]
continue
if conn.readable():
self._addPendingConnection(conn)
......@@ -253,9 +256,9 @@ class EpollEventManager(object):
def setTimeout(self, *args):
self._timeout, self._on_timeout = args
def wakeup(self, exit=False):
def wakeup(self, *actions):
with self._trigger_lock:
self._trigger_exit |= exit
self._trigger_list += actions
try:
self.epoll.register(self._trigger_fd)
except IOError, e:
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2015-2016 Nexedi SA
# Copyright (C) 2015-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2015-2016 Nexedi SA
# Copyright (C) 2015-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO
from struct import Struct
PROTOCOL_VERSION = 8
PROTOCOL_VERSION = 9
# Size restrictions.
MIN_PACKET_SIZE = 10
......@@ -237,14 +237,10 @@ class Packet(object):
_id = None
poll_thread = False
def __init__(self, *args, **kw):
def __init__(self, *args):
assert self._code is not None, "Packet class not registered"
if args or kw:
args = list(args)
if args:
buf = StringIO()
# load named arguments
for item in self._fmt._items[len(args):]:
args.append(kw.get(item._name))
self._fmt.encode(buf.write, args)
self._body = buf.getvalue()
else:
......@@ -1176,6 +1172,25 @@ class SetClusterState(Packet):
_answer = Error
class Repair(Packet):
"""
Ask storage nodes to repair their databases. ctl -> A -> M
"""
_flags = map(PBoolean, ('dry_run',
# 'prune_orphan' (commented because it's the only option for the moment)
))
_fmt = PStruct('repair',
PFUUIDList,
*_flags)
_answer = Error
class RepairOne(Packet):
"""
See Repair. M -> S
"""
_fmt = PStruct('repair', *Repair._flags)
class ClusterInformation(Packet):
"""
Notify information about the cluster
......@@ -1685,6 +1700,10 @@ class Packets(dict):
TweakPartitionTable, ignore_when_closed=False)
SetClusterState = register(
SetClusterState, ignore_when_closed=False)
Repair = register(
Repair)
NotifyRepair = register(
RepairOne)
NotifyClusterInformation = register(
ClusterInformation)
AskClusterState, AnswerClusterState = register(
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -258,30 +258,34 @@ class PartitionTable(object):
partition on the line (here, line length is 11 to keep the docstring
width under 80 column).
"""
node_list = sorted(self.count_dict)
result = ['pt: node %u: %s, %s' % (i, uuid_str(node.getUUID()),
protocol.node_state_prefix_dict[node.getState()])
for i, node in enumerate(node_list)]
for i, node in enumerate(sorted(self.count_dict))]
append = result.append
line = []
max_line_len = 20 # XXX: hardcoded number of partitions per line
cell_state_dict = protocol.cell_state_prefix_dict
prefix = 0
prefix_len = int(math.ceil(math.log10(self.np)))
for offset, row in enumerate(self.partition_list):
for offset, row in enumerate(self.formatRows()):
if len(line) == max_line_len:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
line = []
prefix = offset
line.append(row)
if line:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
return result
def formatRows(self):
node_list = sorted(self.count_dict)
cell_state_dict = protocol.cell_state_prefix_dict
for row in self.partition_list:
if row is None:
line.append('X' * len(node_list))
yield 'X' * len(node_list)
else:
cell_dict = {x.getNode(): cell_state_dict[x.getState()]
for x in row}
line.append(''.join(cell_dict.get(x, '.') for x in node_list))
if line:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
return result
yield ''.join(cell_dict.get(x, '.') for x in node_list)
def operational(self):
if not self.filled():
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading, weakref
import thread, threading, weakref
from . import logging
from .app import BaseApplication
from .connection import ConnectionClosed
......@@ -69,7 +69,7 @@ class ThreadedApplication(BaseApplication):
conn.close()
# Stop polling thread
logging.debug('Stopping %s', self.poll_thread)
self.em.wakeup(True)
self.em.wakeup(thread.exit)
else:
super(ThreadedApplication, self).close()
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
# -*- coding: utf-8 -*-
#
# Copyright (C) 2012-2016 Nexedi SA
# Copyright (C) 2012-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -64,6 +64,9 @@ class AdministrationHandler(MasterHandler):
for node in storage_list:
assert node.isPending(), node
if node.getConnection().isPending():
# XXX: It's wrong to use ProtocolError here. We must reply
# less aggressively because the admin has no way to
# know that there's still pending activity.
raise ProtocolError('Cannot exit recovery now: node %r is '
'entering cluster' % (node, ))
app._startup_allowed = True
......@@ -147,6 +150,19 @@ class AdministrationHandler(MasterHandler):
logging.warning('No node added')
conn.answer(Errors.Ack('No node added'))
def repair(self, conn, uuid_list, *args):
getByUUID = self.app.nm.getByUUID
node_list = []
for uuid in uuid_list:
node = getByUUID(uuid)
if node is None or not (node.isStorage() and node.isIdentified()):
raise ProtocolError("invalid storage node %s" % uuid_str(uuid))
node_list.append(node)
repair = Packets.NotifyRepair(*args)
for node in node_list:
node.notify(repair)
conn.answer(Errors.Ack(''))
def tweakPartitionTable(self, conn, uuid_list):
app = self.app
state = app.getClusterState()
......
#
# Copyright (C) 2012-2016 Nexedi SA
# Copyright (C) 2012-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -53,7 +53,10 @@ class StorageServiceHandler(BaseServiceHandler):
last_tid = app.pt.getBackupTid(min)
pending_list = ()
else:
last_tid = app.tm.getLastTID()
# This can't be app.tm.getLastTID() for imported transactions,
# because outdated cells must at least wait that they're locked
# at source side. For normal transactions, it would not matter.
last_tid = app.getLastTransaction()
pending_list = app.tm.registerForNotification(conn.getUUID())
p = Packets.AnswerUnfinishedTransactions(last_tid, pending_list)
conn.answer(p)
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -137,7 +137,7 @@ class RecoveryManager(MasterHandler):
logging.warning("Waiting for %r to come back."
" No other node has version %s of the partition table.",
node, self.target_ptid)
if node.getState() == new_state:
if node is None or node.getState() == new_state:
return
node.setState(new_state)
# broadcast to all so that admin nodes gets informed
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -36,6 +36,7 @@ action_dict = {
'tweak': 'tweakPartitionTable',
'drop': 'dropNode',
'kill': 'killNode',
'prune_orphan': 'pruneOrphan',
'truncate': 'truncate',
}
......@@ -146,20 +147,20 @@ class TerminalNeoCTL(object):
assert len(params) == 0
return self.neoctl.startCluster()
def _getStorageList(self, params):
if len(params) == 1 and params[0] == 'all':
node_list = self.neoctl.getNodeList(NodeTypes.STORAGE)
return [node[2] for node in node_list]
return map(self.asNode, params)
def enableStorageList(self, params):
"""
Enable cluster to make use of pending storages.
Parameters: all
node [node [...]]
node: if "all", add all pending storage nodes.
Parameters: node [node [...]]
node: if "all", add all pending storage nodes,
otherwise, the list of storage nodes to enable.
"""
if len(params) == 1 and params[0] == 'all':
node_list = self.neoctl.getNodeList(NodeTypes.STORAGE)
uuid_list = [node[2] for node in node_list]
else:
uuid_list = map(self.asNode, params)
return self.neoctl.enableStorageList(uuid_list)
return self.neoctl.enableStorageList(self._getStorageList(params))
def tweakPartitionTable(self, params):
"""
......@@ -189,6 +190,20 @@ class TerminalNeoCTL(object):
"""
return uuid_str(self.neoctl.getPrimary())
def pruneOrphan(self, params):
"""
Fix database by deleting unreferenced raw data
This can take a long time.
Parameters: dry_run node [node [...]]
dry_run: 0 or 1
node: if "all", ask all connected storage nodes to repair,
otherwise, only the given list of storage nodes.
"""
dry_run = "01".index(params.pop(0))
return self.neoctl.repair(self._getStorageList(params), dry_run)
def truncate(self, params):
"""
Truncate the database at the given tid.
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -172,6 +172,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response)
return response[1]
def repair(self, *args):
response = self.__ask(Packets.Repair(*args))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
def truncate(self, tid):
response = self.__ask(Packets.Truncate(tid))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
......
......@@ -2,7 +2,7 @@
#
# neoadmin - run an administrator node of NEO
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
......@@ -2,7 +2,7 @@
#
# neoadmin - run an administrator node of NEO
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
......@@ -2,7 +2,7 @@
#
# neolog - read a NEO log
#
# Copyright (C) 2012-2016 Nexedi SA
# Copyright (C) 2012-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
......@@ -2,7 +2,7 @@
#
# neomaster - run a master node of NEO
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
......@@ -2,7 +2,7 @@
#
# neomaster - run a master node of NEO
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
......@@ -2,7 +2,7 @@
#
# neostorage - run a storage node of NEO
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -24,14 +24,14 @@ from neo.lib.config import getServerOptionParser, ConfigurationManager
parser = getServerOptionParser()
parser.add_option('-u', '--uuid', help='specify an UUID to use for this ' \
'process. Previously assigned UUID takes precedence (ie ' \
'you should always use -R with this switch)')
parser.add_option('-R', '--reset', action = 'store_true',
help = 'remove an existing database if any')
'you should always use --reset with this switch)')
parser.add_option('-a', '--adapter', help = 'database adapter to use')
parser.add_option('-d', '--database', help = 'database connections string')
parser.add_option('-e', '--engine', help = 'database engine')
parser.add_option('-w', '--wait', help='seconds to wait for backend to be '
'available, before erroring-out (-1 = infinite)', type='float', default=0)
parser.add_option('--reset', action='store_true',
help='remove an existing database if any, and exit')
defaults = dict(
bind = '127.0.0.1',
......
#!/usr/bin/env python
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -23,6 +23,7 @@ import os
import re
from collections import Counter, defaultdict
from cStringIO import StringIO
from fnmatch import fnmatchcase
from unittest.runner import _WritelnDecorator
if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
......@@ -32,7 +33,8 @@ if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
coverage.neotestrunner = []
coverage.start()
from neo.tests import getTempDirectory, __dict__ as neo_tests__dict__
from neo.tests import getTempDirectory, NeoTestBase, Patch, \
__dict__ as neo_tests__dict__
from neo.tests.benchmark import BenchmarkRunner
# list of test modules
......@@ -64,7 +66,6 @@ UNIT_TEST_MODULES = [
# client application
'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler',
'neo.tests.client.testStorageHandler',
'neo.tests.client.testConnectionPool',
# light functional tests
'neo.tests.threaded.test',
......@@ -114,17 +115,33 @@ class NeoTestRunner(unittest.TextTestResult):
def wasSuccessful(self):
return not (self.failures or self.errors or self.unexpectedSuccesses)
def run(self, name, modules):
print '\n', name
def run(self, name, modules, only):
suite = unittest.TestSuite()
loader = unittest.defaultTestLoader
loader = unittest.TestLoader()
if only:
exclude = only[0] == '!'
test_only = only[exclude + 1:]
only = only[exclude]
if test_only:
def getTestCaseNames(testCaseClass):
tests = loader.__class__.getTestCaseNames(
loader, testCaseClass)
x = testCaseClass.__name__ + '.'
return [t for t in tests
if exclude != any(fnmatchcase(x + t, o)
for o in test_only)]
loader.getTestCaseNames = getTestCaseNames
if not only:
only = '*'
else:
print '\n', name
for test_module in modules:
# load prefix if supplied
if isinstance(test_module, tuple):
test_module, prefix = test_module
loader.testMethodPrefix = prefix
else:
loader.testMethodPrefix = 'test'
test_module, loader.testMethodPrefix = test_module
if only and not (exclude and test_only or
exclude != fnmatchcase(test_module, only)):
continue
try:
test_module = __import__(test_module, globals(), locals(), ['*'])
except ImportError, err:
......@@ -135,7 +152,11 @@ class NeoTestRunner(unittest.TextTestResult):
# NOTE it is also possible to run individual tests via `python -m unittest ...`
if 1 or test_module.__name__ == 'neo.tests.functional.testStorage':
suite.addTests(loader.loadTestsFromModule(test_module))
try:
suite.run(self)
finally:
# Workaround weird behaviour of Python.
self._previousTestClass = None
def startTest(self, test):
super(NeoTestRunner, self).startTest(test)
......@@ -203,7 +224,8 @@ class NeoTestRunner(unittest.TextTestResult):
for test in self.unexpectedSuccesses:
body.write("UNEXPECTED SUCCESS: %s\n" % self.getDescription(test))
self.stream = _WritelnDecorator(body)
self.printErrors()
self.printErrorList('ERROR', self.errors)
self.printErrorList('FAIL', self.failures)
return subject, body.getvalue()
class TestRunner(BenchmarkRunner):
......@@ -211,6 +233,11 @@ class TestRunner(BenchmarkRunner):
def add_options(self, parser):
parser.add_option('-c', '--coverage', action='store_true',
help='Enable coverage')
parser.add_option('-C', '--cov-unit', action='store_true',
help='Same as -c but output 1 file per test,'
' in the temporary test directory')
parser.add_option('-l', '--loop', type='int', default=1,
help='Repeat tests several times')
parser.add_option('-f', '--functional', action='store_true',
help='Functional tests')
parser.add_option('-u', '--unit', action='store_true',
......@@ -219,7 +246,12 @@ class TestRunner(BenchmarkRunner):
help='ZODB test suite running on a NEO')
parser.add_option('-v', '--verbose', action='store_true',
help='Verbose output')
parser.usage += " [[!] module [test...]]"
parser.format_epilog = lambda _: """
Positional:
Filter by given module/test. These arguments are shell patterns.
This implies -ufz if none of this option is passed.
Environment Variables:
NEO_TESTS_ADAPTER Default is SQLite for threaded clusters,
MySQL otherwise.
......@@ -241,27 +273,51 @@ Environment Variables:
""" % neo_tests__dict__
def load_options(self, options, args):
if not (options.unit or options.functional or options.zodb or args):
if options.coverage and options.cov_unit:
sys.exit('-c conflicts with -C')
if not (options.unit or options.functional or options.zodb):
if not args:
sys.exit('Nothing to run, please give one of -f, -u, -z')
options.unit = options.functional = options.zodb = True
return dict(
loop = options.loop,
unit = options.unit,
functional = options.functional,
zodb = options.zodb,
verbosity = 2 if options.verbose else 1,
coverage = options.coverage,
cov_unit = options.cov_unit,
only = args,
)
def start(self):
config = self._config
only = config.only
# run requested tests
runner = NeoTestRunner(config.title or 'Neo', config.verbosity)
if config.cov_unit:
from coverage import Coverage
cov_dir = runner.temp_directory + '/coverage'
os.mkdir(cov_dir)
@Patch(NeoTestBase)
def setUp(orig, self):
orig(self)
self.__coverage = Coverage('%s/%s' % (cov_dir, self.id()))
self.__coverage.start()
@Patch(NeoTestBase)
def _tearDown(orig, self, success):
self.__coverage.stop()
self.__coverage.save()
del self.__coverage
orig(self, success)
try:
for _ in xrange(config.loop):
if config.unit:
runner.run('Unit tests', UNIT_TEST_MODULES)
runner.run('Unit tests', UNIT_TEST_MODULES, only)
if config.functional:
runner.run('Functional tests', FUNC_TEST_MODULES)
runner.run('Functional tests', FUNC_TEST_MODULES, only)
if config.zodb:
runner.run('ZODB tests', ZODB_TEST_MODULES)
runner.run('ZODB tests', ZODB_TEST_MODULES, only)
except KeyboardInterrupt:
config['mail_to'] = None
traceback.print_exc()
......@@ -270,7 +326,13 @@ Environment Variables:
if coverage.neotestrunner:
coverage.combine(coverage.neotestrunner)
coverage.save()
if runner.dots:
print
# build report
if only and not config.mail_to:
runner._buildSummary = lambda *args: (
runner.__class__._buildSummary(runner, *args)[0], '')
self.build_report = str
self._successful = runner.wasSuccessful()
return runner.buildReport(self.add_status)
......
#!/usr/bin/env python
#
# Copyright (C) 2011-2016 Nexedi SA
# Copyright (C) 2011-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2012-2016 Nexedi SA
# Copyright (C) 2012-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2014-2016 Nexedi SA
# Copyright (C) 2014-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -281,7 +281,6 @@ class ImporterDatabaseManager(DatabaseManager):
def __init__(self, *args, **kw):
super(ImporterDatabaseManager, self).__init__(*args, **kw)
self.db._connect()
implements(self, """_getNextTID checkSerialRange checkTIDRange
deleteObject deleteTransaction dropPartitions getLastTID
getReplicationObjectList getTIDList nonempty""".split())
......@@ -305,10 +304,13 @@ class ImporterDatabaseManager(DatabaseManager):
getPartitionTable changePartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction
storeData _pruneData deferCommit
storeData getOrphanList _pruneData deferCommit
""".split():
setattr(self, x, getattr(self.db, x))
def _connect(self):
pass
def commit(self):
self.db.commit()
self._last_commit = time.time()
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,7 +14,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
......@@ -54,6 +56,9 @@ class DatabaseManager(object):
ENGINES = ()
_deferred = 0
_duplicating = _repairing = None
def __init__(self, database, engine=None, wait=0):
"""
Initialize the object.
......@@ -64,22 +69,42 @@ class DatabaseManager(object):
% (engine, self.ENGINES))
self._engine = engine
self._wait = wait
self._deferred = 0
self._parse(database)
self._connect()
def __getattr__(self, attr):
if attr == "_getPartition":
np = self.getNumPartitions()
value = lambda x: x % np
else:
elif self._duplicating is None:
return self.__getattribute__(attr)
else:
value = getattr(self._duplicating, attr)
setattr(self, attr, value)
return value
@contextmanager
def _duplicate(self):
cls = self.__class__
db = cls.__new__(cls)
db._duplicating = self
try:
db._connect()
finally:
del db._duplicating
try:
yield db
finally:
db.close()
@abstract
def _parse(self, database):
"""Called during instantiation, to process database parameter."""
@abstract
def _connect(self):
"""Connect to the database"""
def setup(self, reset=0):
"""Set up a database, discarding existing data first if reset is True
"""
......@@ -415,6 +440,15 @@ class DatabaseManager(object):
is always the case at tpc_vote.
"""
@abstract
def getOrphanList(self):
"""Return the list of data id that is not referenced by the obj table
This is a repair method, and it's usually expensive.
There was a bug that did not free data of transactions that were
aborted before vote. This method is used to reclaim the wasted space.
"""
@abstract
def _pruneData(self, data_id_list):
"""To be overridden by the backend to delete any unreferenced data
......@@ -423,6 +457,8 @@ class DatabaseManager(object):
- not in self._uncommitted_data
- and not referenced by a fully-committed object (storage should have
an index or a refcount of all data ids of all objects)
The returned value is the number of deleted rows from the data table.
"""
@abstract
......@@ -588,6 +624,37 @@ class DatabaseManager(object):
self._setTruncateTID(None)
self.commit()
def repair(self, weak_app, dry_run):
t = self._repairing
if t and t.is_alive():
logging.error('already repairing')
return
def repair():
l = threading.Lock()
l.acquire()
def finalize():
try:
if data_id_list and not dry_run:
self.commit()
logging.info("repair: deleted %s orphan records",
self._pruneData(data_id_list))
self.commit()
finally:
l.release()
try:
with self._duplicate() as db:
data_id_list = db.getOrphanList()
logging.info("repair: found %s records that may be orphan",
len(data_id_list))
weak_app().em.wakeup(finalize)
l.acquire()
finally:
del self._repairing
logging.info("repair: done")
t = self._repairing = threading.Thread(target=repair)
t.daemon = 1
t.start()
@abstract
def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information,
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -56,12 +56,6 @@ class MySQLDatabaseManager(DatabaseManager):
_max_allowed_packet = 32769 * 1024
def __init__(self, *args, **kw):
super(MySQLDatabaseManager, self).__init__(*args, **kw)
self.conn = None
self._config = {}
self._connect()
def _parse(self, database):
""" Get the database credentials (username, password, database) """
# expected pattern : [user[:password]@]database[(~|.|/)unix_socket]
......@@ -93,6 +87,7 @@ class MySQLDatabaseManager(DatabaseManager):
logging.exception('Connection to MySQL failed, retrying.')
time.sleep(1)
self._active = 0
self._config = {}
conn = self.conn
conn.autocommit(False)
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
......@@ -475,6 +470,11 @@ class MySQLDatabaseManager(DatabaseManager):
_structLL = struct.Struct(">LL")
_unpackLL = _structLL.unpack
def getOrphanList(self):
return [x for x, in self.query(
"SELECT id FROM data LEFT JOIN obj ON (id=data_id)"
" WHERE data_id IS NULL")]
def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list:
......@@ -495,6 +495,8 @@ class MySQLDatabaseManager(DatabaseManager):
if bigid_list:
q("DELETE FROM bigdata WHERE id IN (%s)"
% ",".join(map(str, bigid_list)))
return len(id_list)
return 0
def _bigData(self, value):
bigdata_id, length = self._unpackLL(value)
......@@ -582,11 +584,8 @@ class MySQLDatabaseManager(DatabaseManager):
def abortTransaction(self, ttid):
ttid = util.u64(ttid)
q = self.query
sql = " FROM tobj WHERE tid=%s" % ttid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("DELETE" + sql)
q("DELETE FROM tobj WHERE tid=%s" % ttid)
q("DELETE FROM ttrans WHERE ttid=%s" % ttid)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
......
#
# Copyright (C) 2012-2016 Nexedi SA
# Copyright (C) 2012-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -69,11 +69,6 @@ class SQLiteDatabaseManager(DatabaseManager):
VERSION = 1
def __init__(self, *args, **kw):
super(SQLiteDatabaseManager, self).__init__(*args, **kw)
self._config = {}
self._connect()
def _parse(self, database):
self.db = os.path.expanduser(database)
......@@ -83,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self):
logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False)
self._config = {}
def _commit(self):
retry_if_locked(self.conn.commit)
......@@ -376,6 +372,11 @@ class SQLiteDatabaseManager(DatabaseManager):
packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid)))
def getOrphanList(self):
return [x for x, in self.query(
"SELECT id FROM data LEFT JOIN obj ON (id=data_id)"
" WHERE data_id IS NULL")]
def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list:
......@@ -385,6 +386,8 @@ class SQLiteDatabaseManager(DatabaseManager):
% ",".join(map(str, data_id_list))))
q("DELETE FROM data WHERE id IN (%s)"
% ",".join(map(str, data_id_list)))
return len(data_id_list)
return 0
def storeData(self, checksum, data, compression,
_dup=unique_constraint_message("data", "hash", "compression")):
......@@ -439,11 +442,8 @@ class SQLiteDatabaseManager(DatabaseManager):
def abortTransaction(self, ttid):
args = util.u64(ttid),
q = self.query
sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("DELETE" + sql, args)
q("DELETE FROM tobj WHERE tid=?", args)
q("DELETE FROM ttrans WHERE ttid=?", args)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
......
#
# Copyright (C) 2010-2016 Nexedi SA
# Copyright (C) 2010-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import weakref
from neo.lib import logging
from neo.lib.handler import EventHandler
from neo.lib.exception import PrimaryFailure, StoppedOperation
......@@ -59,3 +60,7 @@ class BaseMasterHandler(EventHandler):
def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
def notifyRepair(self, conn, *args):
app = self.app
app.dm.repair(weakref.ref(app), *args)
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -77,7 +77,7 @@ class ClientOperationHandler(EventHandler):
checksum, data, data_serial, unlock)
except ConflictError, err:
# resolvable or not
conn.answer(Packets.AnswerStoreObject(1, oid, err.getTID()))
conn.answer(Packets.AnswerStoreObject(1, oid, err.tid))
except DelayedError:
# locked by a previous transaction, retry later
# If we are unlocking, we want queueEvent to raise
......@@ -194,8 +194,7 @@ class ClientOperationHandler(EventHandler):
self.app.tm.checkCurrentSerial(ttid, serial, oid)
except ConflictError, err:
# resolvable or not
conn.answer(Packets.AnswerCheckCurrentSerial(1, oid,
err.getTID()))
conn.answer(Packets.AnswerCheckCurrentSerial(1, oid, err.tid))
except DelayedError:
# locked by a previous transaction, retry later
try:
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -184,6 +184,10 @@ class StorageOperationHandler(EventHandler):
if app.tm.isLockedTid(max_tid):
# Wow, backup cluster is fast. Requested transactions are still in
# ttrans/ttobj so wait a little.
# This can also happen for internal replication, when
# NotifyTransactionFinished(M->S) + AskFetchTransactions(S->S)
# is faster than
# NotifyUnlockInformation(M->S)
app.queueEvent(self.askFetchTransactions, conn,
(partition, length, min_tid, max_tid, tid_list))
return
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -92,6 +92,7 @@ class Replicator(object):
def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list):
"""This is a callback from MasterOperationHandler."""
assert self.ttid_set.issubset(ttid_list), (self.ttid_set, ttid_list)
if ttid_list:
self.ttid_set.update(ttid_list)
max_ttid = max(ttid_list)
......
#
# Copyright (C) 2010-2016 Nexedi SA
# Copyright (C) 2010-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -27,10 +27,7 @@ class ConflictError(Exception):
def __init__(self, tid):
Exception.__init__(self)
self._tid = tid
def getTID(self):
return self._tid
self.tid = tid
class DelayedError(Exception):
......@@ -47,76 +44,41 @@ class Transaction(object):
"""
Container for a pending transaction
"""
_tid = None
tid = None
has_trans = False
def __init__(self, uuid, ttid):
self._uuid = uuid
self._ttid = ttid
self._object_dict = {}
self._locked = False
self._birth = time()
self._checked_set = set()
self.uuid = uuid
# Consider using lists.
self.store_dict = {}
self.checked_set = set()
def __repr__(self):
return "<%s(ttid=%r, tid=%r, uuid=%r, locked=%r, age=%.2fs) at 0x%x>" \
return "<%s(tid=%r, uuid=%r, age=%.2fs) at 0x%x>" \
% (self.__class__.__name__,
dump(self._ttid),
dump(self._tid),
uuid_str(self._uuid),
self.isLocked(),
dump(self.tid),
uuid_str(self.uuid),
time() - self._birth,
id(self))
def addCheckedObject(self, oid):
assert oid not in self._object_dict, dump(oid)
self._checked_set.add(oid)
def getTTID(self):
return self._ttid
def setTID(self, tid):
assert self._tid is None, dump(self._tid)
assert tid is not None
self._tid = tid
def getTID(self):
return self._tid
def getUUID(self):
return self._uuid
def lock(self):
assert not self._locked
self._locked = True
def check(self, oid):
assert oid not in self.store_dict, dump(oid)
assert oid not in self.checked_set, dump(oid)
self.checked_set.add(oid)
def isLocked(self):
return self._locked
def addObject(self, oid, data_id, value_serial):
def store(self, oid, data_id, value_serial):
"""
Add an object to the transaction
"""
assert oid not in self._checked_set, dump(oid)
self._object_dict[oid] = oid, data_id, value_serial
assert oid not in self.checked_set, dump(oid)
self.store_dict[oid] = oid, data_id, value_serial
def delObject(self, oid):
def cancel(self, oid):
try:
return self._object_dict.pop(oid)[1]
return self.store_dict.pop(oid)[1]
except KeyError:
self._checked_set.remove(oid)
def getObject(self, oid):
return self._object_dict[oid]
def getObjectList(self):
return self._object_dict.values()
def getOIDList(self):
return self._object_dict.keys()
def getLockedOIDList(self):
return self._object_dict.keys() + list(self._checked_set)
self.checked_set.remove(oid)
class TransactionManager(object):
......@@ -145,7 +107,7 @@ class TransactionManager(object):
Return None if not found.
"""
try:
return self._transaction_dict[ttid].getObject(oid)
return self._transaction_dict[ttid].store_dict[oid]
except KeyError:
return None
......@@ -166,7 +128,7 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid]
except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid))
object_list = transaction.getObjectList()
object_list = transaction.store_dict.itervalues()
if txn_info:
user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid
......@@ -185,21 +147,20 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid]
except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid))
# remember that the transaction has been locked
transaction.lock()
assert transaction.tid is None, dump(transaction.tid)
assert ttid <= tid, (ttid, tid)
transaction.tid = tid
self._load_lock_dict.update(
dict.fromkeys(transaction.getOIDList(), ttid))
# commit transaction and remember its definitive TID
dict.fromkeys(transaction.store_dict, ttid))
if transaction.has_trans:
self._app.dm.lockTransaction(tid, ttid)
transaction.setTID(tid)
def unlock(self, ttid):
"""
Unlock transaction
"""
try:
tid = self._transaction_dict[ttid].getTID()
tid = self._transaction_dict[ttid].tid
except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid))
logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
......@@ -210,7 +171,7 @@ class TransactionManager(object):
def getFinalTID(self, ttid):
try:
return self._transaction_dict[ttid].getTID()
return self._transaction_dict[ttid].tid
except KeyError:
return self._app.dm.getFinalTID(ttid)
......@@ -233,7 +194,7 @@ class TransactionManager(object):
# drop the lock it held on this object, and drop object data for
# consistency.
del self._store_lock_dict[oid]
data_id = self._transaction_dict[ttid].delObject(oid)
data_id = self._transaction_dict[ttid].cancel(oid)
if data_id:
self._app.dm.pruneData((data_id,))
# Give a chance to pending events to take that lock now.
......@@ -245,7 +206,7 @@ class TransactionManager(object):
elif locking_tid == ttid:
# If previous store was an undo, next store must be based on
# undo target.
previous_serial = self._transaction_dict[ttid].getObject(oid)[2]
previous_serial = self._transaction_dict[ttid].store_dict[oid][2]
if previous_serial is None:
# XXX: use some special serial when previous store was not
# an undo ? Maybe it should just not happen.
......@@ -290,7 +251,7 @@ class TransactionManager(object):
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True)
transaction.addCheckedObject(oid)
transaction.check(oid)
def storeObject(self, ttid, serial, oid, compression, checksum, data,
value_serial, unlock=False):
......@@ -307,7 +268,7 @@ class TransactionManager(object):
data_id = None
else:
data_id = self._app.dm.holdData(checksum, data, compression)
transaction.addObject(oid, data_id, value_serial)
transaction.store(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False):
"""
......@@ -323,24 +284,28 @@ class TransactionManager(object):
return
logging.debug('Abort TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid]
has_load_lock = transaction.isLocked()
locked = transaction.tid
# if the transaction is locked, ensure we can drop it
if has_load_lock:
if locked:
if not even_if_locked:
return
else:
self._app.dm.abortTransaction(ttid)
dm = self._app.dm
dm.abortTransaction(ttid)
dm.releaseData([x[1] for x in transaction.store_dict.itervalues()],
True)
# unlock any object
for oid in transaction.getLockedOIDList():
if has_load_lock:
for oid in transaction.store_dict, transaction.checked_set:
for oid in oid:
if locked:
lock_ttid = self._load_lock_dict.pop(oid, None)
assert lock_ttid in (ttid, None), 'Transaction %s tried to ' \
'release the lock on oid %s, but it was held by %s' % (
dump(ttid), dump(oid), dump(lock_ttid))
assert lock_ttid in (ttid, None), ('Transaction %s tried'
' to release the lock on oid %s, but it was held by %s'
% (dump(ttid), dump(oid), dump(lock_ttid)))
write_locking_tid = self._store_lock_dict.pop(oid)
assert write_locking_tid == ttid, 'Inconsistent locking state: ' \
'aborting %s:%s but %s has the lock.' % (dump(ttid), dump(oid),
dump(write_locking_tid))
assert write_locking_tid == ttid, ('Inconsistent locking'
' state: aborting %s:%s but %s has the lock.'
% (dump(ttid), dump(oid), dump(write_locking_tid)))
# remove the transaction
del self._transaction_dict[ttid]
# some locks were released, some pending locks may now succeed
......@@ -352,37 +317,35 @@ class TransactionManager(object):
"""
logging.debug('Abort for %s', uuid_str(uuid))
# abort any non-locked transaction of this node
for transaction in self._transaction_dict.values():
if transaction.getUUID() == uuid:
self.abort(transaction.getTTID())
for ttid, transaction in self._transaction_dict.items():
if transaction.uuid == uuid:
self.abort(ttid)
def isLockedTid(self, tid):
for t in self._transaction_dict.itervalues():
if t.isLocked() and t.getTID() <= tid:
return True
return False
return any(None is not t.tid <= tid
for t in self._transaction_dict.itervalues())
def loadLocked(self, oid):
return oid in self._load_lock_dict
def log(self):
logging.info("Transactions:")
for txn in self._transaction_dict.values():
logging.info(' %r', txn)
for ttid, txn in self._transaction_dict.iteritems():
logging.info(' %s %r', dump(ttid), txn)
logging.info(' Read locks:')
for oid, ttid in self._load_lock_dict.items():
for oid, ttid in self._load_lock_dict.iteritems():
logging.info(' %r by %r', dump(oid), dump(ttid))
logging.info(' Write locks:')
for oid, ttid in self._store_lock_dict.items():
for oid, ttid in self._store_lock_dict.iteritems():
logging.info(' %r by %r', dump(oid), dump(ttid))
def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id):
lock_tid = self.getLockingTID(oid)
if lock_tid is not None:
transaction = self._transaction_dict[lock_tid]
if transaction.getObject(oid)[2] == orig_serial:
if transaction.store_dict[oid][2] == orig_serial:
if new_serial:
data_id = None
else:
self._app.dm.holdData(data_id)
transaction.addObject(oid, data_id, new_serial)
transaction.store(oid, data_id, new_serial)
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -29,7 +29,7 @@ import MySQLdb
import transaction
from functools import wraps
from mock import Mock
from .mock import Mock
from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property
......@@ -281,18 +281,6 @@ class NeoUnitTestBase(NeoTestBase):
def getNextTID(self, ltid=None):
return newTid(ltid)
def getPTID(self, i=None):
""" Return an integer PTID """
if i is None:
return random.randint(1, 2**64)
return i
def getOID(self, i=None):
""" Return a 8-bytes OID """
if i is None:
return os.urandom(8)
return pack('!Q', i)
def getFakeConnector(self, descriptor=None):
return Mock({
'__repr__': 'FakeConnector',
......@@ -321,18 +309,6 @@ class NeoUnitTestBase(NeoTestBase):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception was raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception was raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception was raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
......@@ -341,36 +317,19 @@ class NeoUnitTestBase(NeoTestBase):
""" Ensure the connection was aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
def checkNotAborted(self, conn):
""" Ensure the connection was not aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 0)
def checkClosed(self, conn):
""" Ensure the connection was closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 1)
def checkNotClosed(self, conn):
""" Ensure the connection was not closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 0)
def _checkNoPacketSend(self, conn, method_id):
call_list = conn.mockGetNamedCalls(method_id)
self.assertEqual(len(call_list), 0, call_list)
self.assertEqual([], conn.mockGetNamedCalls(method_id))
def checkNoPacketSent(self, conn, check_notify=True, check_answer=True,
check_ask=True):
def checkNoPacketSent(self, conn):
""" check if no packet were sent """
if check_notify:
self._checkNoPacketSend(conn, 'notify')
if check_answer:
self._checkNoPacketSend(conn, 'answer')
if check_ask:
self._checkNoPacketSend(conn, 'ask')
def checkNoUUIDSet(self, conn):
""" ensure no UUID was set on the connection """
self.assertEqual(len(conn.mockGetNamedCalls('setUUID')), 0)
def checkUUIDSet(self, conn, uuid=None, check_intermediate=True):
""" ensure UUID was set on the connection """
calls = conn.mockGetNamedCalls('setUUID')
......@@ -384,151 +343,41 @@ class NeoUnitTestBase(NeoTestBase):
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurate checks are required
def checkErrorPacket(self, conn, decode=False):
def checkErrorPacket(self, conn):
""" Check if an error packet was answered """
calls = conn.mockGetNamedCalls("answer")
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), Packets.Error)
if decode:
return packet.decode()
return packet
def checkAskPacket(self, conn, packet_type, decode=False):
def checkAskPacket(self, conn, packet_type):
""" Check if an ask-packet with the right type is sent """
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
def checkAnswerPacket(self, conn, packet_type, decode=False):
def checkAnswerPacket(self, conn, packet_type):
""" Check if an answer-packet with the right type is sent """
calls = conn.mockGetNamedCalls('answer')
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False):
def checkNotifyPacket(self, conn, packet_type, packet_number=0):
""" Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('notify')
packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
def checkNotify(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.Notify, **kw)
def checkNotifyNodeInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw)
def checkSendPartitionTable(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw)
def checkStartOperation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.StartOperation, **kw)
def checkInvalidateObjects(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.InvalidateObjects, **kw)
def checkAbortTransaction(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.AbortTransaction, **kw)
def checkNotifyLastOID(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyLastOID, **kw)
def checkAnswerTransactionFinished(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionFinished, **kw)
def checkAnswerInformationLocked(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerInformationLocked, **kw)
def checkAskLockInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLockInformation, **kw)
def checkNotifyUnlockInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw)
def checkNotifyTransactionFinished(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)
def checkRequestIdentification(self, conn, **kw):
return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
def checkAskPrimary(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskPrimary)
def checkAskUnfinishedTransactions(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions)
def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw)
def checkAskStoreObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreObject, **kw)
def checkAskStoreTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw)
def checkAskFinishTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskFinishTransaction, **kw)
def checkAskNewTid(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw)
def checkAskLastIDs(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLastIDs, **kw)
def checkAcceptIdentification(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw)
def checkAnswerPrimary(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw)
def checkAnswerLastIDs(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw)
def checkAnswerUnfinishedTransactions(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw)
def checkAnswerObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw)
def checkAnswerTransactionInformation(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw)
def checkAnswerBeginTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction, **kw)
def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerTidsFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
class Patch(object):
"""
......@@ -536,7 +385,7 @@ class Patch(object):
Usage:
with Patch(someObject, attrToPatch=newValue, [otherAttr=...]) as patch:
with Patch(someObject, attrToPatch=newValue) as patch:
[... code that runs with patches ...]
[... code that runs without patch ...]
......@@ -552,10 +401,29 @@ class Patch(object):
For patched callables, the new one receives the original value as first
argument.
Alternative usage:
@Patch(someObject)
def funcToPatch(orig, ...):
...
...
funcToPatch.revert()
The decorator applies the patch immediately.
"""
applied = False
def __new__(cls, patched, **patch):
if patch:
return object.__new__(cls)
def patch(func):
self = cls(patched, **{func.__name__: func})
self.apply()
return self
return patch
def __init__(self, patched, **patch):
(name, patch), = patch.iteritems()
self._patched = patched
......
from __future__ import print_function
import sys
import smtplib
import optparse
......@@ -34,13 +34,13 @@ class BenchmarkRunner(object):
parser.add_option('', '--repeat', type='int', default=1)
self.add_options(parser)
# check common arguments
options, self._args = parser.parse_args()
options, args = parser.parse_args()
if bool(options.mail_to) ^ bool(options.mail_from):
sys.exit('Need a sender and recipients to mail report')
mail_server = options.mail_server or MAIL_SERVER
# check specifics arguments
self._config = AttributeDict()
self._config.update(self.load_options(options, self._args))
self._config.update(self.load_options(options, args))
self._config.update(
title = options.title or self.__class__.__name__,
mail_from = options.mail_from,
......@@ -87,7 +87,7 @@ class BenchmarkRunner(object):
try:
s.sendmail(self._config.mail_from, recipient, mail)
except smtplib.SMTPRecipientsRefused:
print "Mail for %s fails" % recipient
print("Mail for %s fails" % recipient)
s.close()
def run(self):
......@@ -95,9 +95,10 @@ class BenchmarkRunner(object):
report = self.build_report(report)
if self._config.mail_to:
self.send_report(subject, report)
print subject
print
print report
print(subject)
if report:
print()
print(report, end='')
def was_successful(self):
return self._successful
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,10 +14,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
import unittest
from mock import Mock, ReturnValues
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from ..mock import Mock
from ZODB.POSException import StorageTransactionError, ConflictError
from .. import NeoUnitTestBase, buildUrlFromString
from neo.client.app import Application
from neo.client.cache import test as testCache
......@@ -25,14 +24,6 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.lib.protocol import NodeTypes, Packets, Errors, UUID_NAMESPACES
from neo.lib.util import makeChecksum
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue, flush_queue=True):
pass
def _getMasterConnection(self):
if self.master_conn is None:
self.last_tid = None
......@@ -43,11 +34,6 @@ def _getMasterConnection(self):
self.master_conn = Mock()
return self.master_conn
def getConnection(kw):
conn = Mock(kw)
conn.lock = threading.RLock()
return conn
def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None)
conn.ask(packet, **kw)
......@@ -82,6 +68,9 @@ class ClientApplicationTests(NeoUnitTestBase):
# some helpers
def checkAskObject(self, conn):
return self.checkAskPacket(conn, Packets.AskObject)
def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn)
txn_context['ttid'] = tid
......@@ -95,11 +84,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.dispatcher = Mock({ })
return app
def getConnectionPool(self, conn_list):
return Mock({
'iterateForObject': conn_list,
})
def makeOID(self, value=None):
from random import randint
if value is None:
......@@ -107,24 +91,6 @@ class ClientApplicationTests(NeoUnitTestBase):
return '\00' * 7 + chr(value)
makeTID = makeOID
def getNodeCellConn(self, index=1, address=('127.0.0.1', 10000), uuid=None):
conn = getConnection({
'getAddress': address,
'__repr__': 'connection mock',
'getUUID': uuid,
})
node = Mock({
'__repr__': 'node%s' % index,
'__hash__': index,
'getConnection': conn,
})
cell = Mock({
'getAddress': 'FakeServer',
'getState': 'FakeState',
'getNode': node,
})
return (node, cell, conn)
def makeTransactionObject(self, user='u', description='d', _extension='e'):
class Transaction(object):
pass
......@@ -134,22 +100,8 @@ class ClientApplicationTests(NeoUnitTestBase):
txn._extension = _extension
return txn
def beginTransaction(self, app, tid):
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({ 'fakeReceived': packet, })
txn = self.makeTransactionObject()
app.tpc_begin(txn, tid=tid)
return txn
# common checks
def checkDispatcherRegisterCalled(self, app, conn):
calls = app.dispatcher.mockGetNamedCalls('register')
#self.assertEqual(len(calls), 1)
#self.assertEqual(calls[0].getParam(0), conn)
#self.assertTrue(isinstance(calls[0].getParam(2), Queue))
testCache = testCache
def test_load(self):
......@@ -207,47 +159,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAskObject(conn)
self.assertEqual(len(cache._oid_dict[oid]), 2)
def test_tpc_begin(self):
app = self.getApp()
tid = self.makeTID()
txn = Mock()
# first, tid is supplied
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
})
app.tpc_begin(transaction=txn, tid=tid)
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEqual(txn_context['ttid'], tid)
# next, the transaction already begin -> raise
self.assertRaises(StorageTransactionError, app.tpc_begin,
transaction=txn, tid=None)
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEqual(txn_context['ttid'], tid)
# start a transaction without tid
txn = Mock()
# no connection -> NEOStorageError (wait until connected to primary)
#self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None)
# ask a tid to pmn
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
})
app.tpc_begin(transaction=txn, tid=None)
self.checkAskNewTid(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn)
# check attributes
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEqual(txn_context['ttid'], tid)
def test_store1(self):
app = self.getApp()
oid = self.makeOID(11)
......@@ -265,111 +176,6 @@ class ClientApplicationTests(NeoUnitTestBase):
calls = app.pt.mockGetNamedCalls('getCellList')
self.assertEqual(len(calls), 1)
def test_store2(self):
app = self.getApp()
oid = self.makeOID(11)
tid = self.makeTID()
txn = self.makeTransactionObject()
# build conflicting state
txn_context = self._begin(app, txn, tid)
packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid)
packet.setId(0)
storage_address = ('127.0.0.1', 10020)
node, cell, conn = self.getNodeCellConn(address=storage_address)
app.pt = Mock()
app.cp = self.getConnectionPool([(node, conn)])
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
data_dict = txn_context['data_dict']
data_dict[oid] = 'BEFORE'
app.store(oid, tid, '', None, txn)
txn_context['queue'].put((conn, packet, {}))
self.assertRaises(ConflictError, app.waitStoreResponses, txn_context,
failing_tryToResolveConflict)
self.assertTrue(oid not in data_dict)
self.assertEqual(txn_context['object_stored_counter_dict'][oid], {})
self.checkAskStoreObject(conn)
def test_store3(self):
app = self.getApp()
uuid = self.getStorageUUID()
oid = self.makeOID(11)
tid = self.makeTID()
txn = self.makeTransactionObject()
# case with no conflict
txn_context = self._begin(app, txn, tid)
packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
packet.setId(0)
storage_address = ('127.0.0.1', 10020)
node, cell, conn = self.getNodeCellConn(address=storage_address,
uuid=uuid)
app.cp = self.getConnectionPool([(node, conn)])
app.pt = Mock()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn)
txn_context['queue'].put((conn, packet, {}))
app.waitStoreResponses(txn_context, None) # no conflict in this test
self.assertEqual(txn_context['object_stored_counter_dict'][oid],
{tid: {uuid}})
self.assertEqual(txn_context['cache_dict'][oid], 'DATA')
self.assertFalse(oid in txn_context['data_dict'])
self.assertFalse(oid in txn_context['conflict_serial_dict'])
def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """
app = self.getApp()
# three partitions/storages: one per object/transaction
app.num_partitions = num_partitions = 3
app.num_replicas = 0
tid = self.makeTID(num_partitions) # on partition 0
oid1 = self.makeOID(num_partitions + 1) # on partition 1, conflicting
oid2 = self.makeOID(num_partitions + 2) # on partition 2
# storage nodes
address1 = ('127.0.0.1', 10000); uuid1 = self.getMasterUUID()
address2 = ('127.0.0.1', 10001); uuid2 = self.getStorageUUID()
address3 = ('127.0.0.1', 10002); uuid3 = self.getStorageUUID()
app.nm.createMaster(address=address1, uuid=uuid1)
app.nm.createStorage(address=address2, uuid=uuid2)
app.nm.createStorage(address=address3, uuid=uuid3)
# answer packets
packet1 = Packets.AnswerStoreTransaction(tid=tid)
packet2 = Packets.AnswerStoreObject(conflicting=1, oid=oid1, serial=tid)
packet3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid)
[p.setId(i) for p, i in zip([packet1, packet2, packet3], range(3))]
conn1 = getConnection({'__repr__': 'conn1', 'getAddress': address1,
'fakeReceived': packet1, 'getUUID': uuid1})
conn2 = getConnection({'__repr__': 'conn2', 'getAddress': address2,
'fakeReceived': packet2, 'getUUID': uuid2})
conn3 = getConnection({'__repr__': 'conn3', 'getAddress': address3,
'fakeReceived': packet3, 'getUUID': uuid3})
node1 = Mock({'__repr__': 'node1', '__hash__': 1, 'getConnection': conn1})
node2 = Mock({'__repr__': 'node2', '__hash__': 2, 'getConnection': conn2})
node3 = Mock({'__repr__': 'node3', '__hash__': 3, 'getConnection': conn3})
# fake environment
app.cp = Mock({'getConnForCell': ReturnValues(conn2, conn3, conn1)})
app.cp = Mock({
'getConnForNode': ReturnValues(conn2, conn3, conn1),
'iterateForObject': [(node2, conn2), (node3, conn3), (node1, conn1)],
})
app.master_conn = Mock({'__hash__': 0})
txn = self.makeTransactionObject()
txn_context = self._begin(app, txn, tid)
app.dispatcher = Dispatcher()
# conflict occurs on storage 2
app.store(oid1, tid, 'DATA', None, txn)
app.store(oid2, tid, 'DATA', None, txn)
queue = txn_context['queue']
queue.put((conn2, packet2, {}))
queue.put((conn3, packet3, {}))
# vote fails as the conflict is not resolved, nothing is sent to storage 3
self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict)
# abort must be sent to storage 1 and 2
app.tpc_abort(txn)
self.checkAbortTransaction(conn2)
self.checkAbortTransaction(conn3)
def test_undo1(self):
# invalid transaction
app = self.getApp()
......@@ -383,169 +189,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn)
def _getAppForUndoTests(self, oid0, tid0, tid1, tid2):
app = self.getApp()
cell = Mock({
'getAddress': 'FakeServer',
'getState': 'FakeState',
})
app.pt = Mock({'getCellList': [cell]})
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, ))
transaction_info.setId(1)
conn = getConnection({
'getNextId': 1,
'fakeReceived': transaction_info,
'getAddress': ('127.0.0.1', 10020),
})
node = app.nm.createStorage(address=conn.getAddress())
app.cp = Mock({
'iterateForObject': [(node, conn)],
'getConnForCell': conn,
})
app.dispatcher = Dispatcher()
def load(oid, tid=None, before_tid=None):
self.assertEqual(oid, oid0)
return ({tid0: 'dummy', tid2: 'cdummy'}[tid], None, None)
app.load = load
store_marker = []
def _store(txn_context, oid, serial, data, data_serial=None,
unlock=False):
store_marker.append((oid, serial, data, data_serial))
app._store = _store
app.last_tid = self.getNextTID()
return app, conn, store_marker
def test_undoWithResolutionSuccess(self):
"""
Try undoing transaction tid1, which contains object oid.
Object oid previous revision before tid1 is tid0.
Transaction tid2 modified oid (and contains its data).
Undo is accepted, because conflict resolution succeeds.
"""
oid0 = self.makeOID(1)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
conn.ask = lambda p, queue=None, **kw: \
isinstance(p, Packets.AskObjectUndoSerial) and \
queue.put((conn, undo_serial, kw))
undo_serial.setId(2)
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append((oid, conflict_serial, serial, data, committedData))
return 'solved'
# The undo
txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, tryToResolveConflict)
# Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mconflict_serial, tid2)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, 'dummy')
self.assertEqual(mcommittedData, 'cdummy')
moid, mserial, mdata, mdata_serial = store_marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mserial, tid2)
self.assertEqual(mdata, 'solved')
self.assertEqual(mdata_serial, None)
def test_undoWithResolutionFailure(self):
"""
Try undoing transaction tid1, which contains object oid.
Object oid previous revision before tid1 is tid0.
Transaction tid2 modified oid (and contains its data).
Undo is rejected with a raise, because conflict resolution fails.
"""
oid0 = self.makeOID(1)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
txn = self.beginTransaction(app, tid=tid3)
self.assertRaises(UndoError, app.undo, tid1, txn, tryToResolveConflict)
# Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mconflict_serial, tid2)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, 'dummy')
self.assertEqual(mcommittedData, 'cdummy')
self.assertEqual(len(store_marker), 0)
# Likewise, but conflict resolver raises a ConflictError.
# Still, exception raised by undo() must be UndoError.
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
self.assertRaises(UndoError, app.undo, tid1, txn, tryToResolveConflict)
# Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mconflict_serial, tid2)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, 'dummy')
self.assertEqual(mcommittedData, 'cdummy')
self.assertEqual(len(store_marker), 0)
def test_undo(self):
"""
Try undoing transaction tid1, which contains object oid.
Object oid previous revision before tid1 is tid0.
Undo is accepted, because tid1 is object's current revision.
"""
oid0 = self.makeOID(1)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, ))
transaction_info.setId(1)
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid1, tid0, True)})
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
# The undo
txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, None) # no conflict resolution in this test
# Checking what happened
moid, mserial, mdata, mdata_serial = store_marker[0]
self.assertEqual(moid, oid0)
self.assertEqual(mserial, tid1)
self.assertEqual(mdata, None)
self.assertEqual(mdata_serial, tid0)
def test_connectToPrimaryNode(self):
# here we have three master nodes :
# the connection to the first will fail
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,12 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import time, unittest
from mock import Mock
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool
from neo.client.exception import NEOStorageError
from neo.client import pool
from neo.lib.util import p64
class ConnectionPoolTests(NeoUnitTestBase):
......@@ -54,7 +55,7 @@ class ConnectionPoolTests(NeoUnitTestBase):
def test_iterateForObject_noStorageAvailable(self):
# no node available
oid = self.getOID(1)
oid = p64(1)
app = Mock()
app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app)
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.client.handlers.master import PrimaryAnswersHandler
from neo.client.exception import NEOStorageError
......
#
# Copyright (C) 2009-2016 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.client.handlers.storage import StorageAnswersHandler
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
class StorageAnswerHandlerTests(NeoUnitTestBase):
def setUp(self):
super(StorageAnswerHandlerTests, self).setUp()
self.app = Mock()
self.handler = StorageAnswersHandler(self.app)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict):
app = Mock({
'getHandlerData': {
'object_stored_counter_dict': object_stored_counter_dict,
'conflict_serial_dict': conflict_serial_dict,
'resolved_conflict_serial_dict': resolved_conflict_serial_dict,
}
})
return StorageAnswersHandler(app)
def test_answerStoreObject_1(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid = self.getNextTID()
# conflict
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(conflict_serial_dict[oid], {tid})
self.assertEqual(object_stored_counter_dict[oid], {})
self.assertFalse(oid in resolved_conflict_serial_dict)
# object was already accepted by another storage, raise
handler = self._getAnswerStoreObjectHandler({oid: {tid: {1}}}, {}, {})
self.assertRaises(NEOStorageError, handler.answerStoreObject,
conn, 1, oid, tid)
def test_answerStoreObject_2(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid = self.getNextTID()
tid_2 = self.getNextTID()
# resolution-pending conflict
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {oid: {tid}}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(conflict_serial_dict[oid], {tid})
self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {})
# object was already accepted by another storage, raise
handler = self._getAnswerStoreObjectHandler({oid: {tid: {1}}},
{oid: {tid}}, {})
self.assertRaises(NEOStorageError, handler.answerStoreObject,
conn, 1, oid, tid)
# detected conflict is different, don't raise
self._getAnswerStoreObjectHandler({oid: {}}, {oid: {tid}}, {},
).answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_3(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid = self.getNextTID()
tid_2 = self.getNextTID()
# already-resolved conflict
# This case happens if a storage is answering a store action for which
# any other storage already answered (with same conflict) and any other
# storage accepted the resolved object.
object_stored_counter_dict = {oid: {tid_2: 1}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {oid: {tid}}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertFalse(oid in conflict_serial_dict)
self.assertEqual(resolved_conflict_serial_dict[oid], {tid})
self.assertEqual(object_stored_counter_dict[oid], {tid_2: 1})
# detected conflict is different, don't raise
self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {},
{oid: {tid}}).answerStoreObject(conn, 1, oid, tid_2)
def test_tidNotFound(self):
conn = self.getFakeConnection()
self.assertRaises(NEOStorageNotFoundError, self.handler.tidNotFound,
conn, 'message')
if __name__ == '__main__':
unittest.main()
#
# Copyright (C) 2011-2016 Nexedi SA
# Copyright (C) 2011-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2014-2016 Nexedi SA
# Copyright (C) 2014-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -424,6 +424,18 @@ class NEOCluster(object):
if not pdb.wait(test, MAX_START_TIME):
raise AssertionError('Timeout when starting cluster')
def startCluster(self):
# Even if the storage nodes are in the expected state, there may still
# be activity between them and the master, preventing the cluster to
# start.
def start(last_try):
try:
self.neoctl.startCluster()
except (NotReadyException, RuntimeError), e:
return False, e
return True, None
self.expectCondition(start)
def stop(self, clients=True):
# Suspend all processes to kill before actually killing them, so that
# nodes don't log errors because they get disconnected from other nodes:
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -267,8 +267,8 @@ class ClientTests(NEOFunctionalTest):
db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage
t1, t2 = transaction.Transaction(), transaction.Transaction()
t1.user = t2.user = 'user'
t1.description = t2.description = 'desc'
t1.user = t2.user = u'user'
t1.description = t2.description = u'desc'
oid = st1.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject())
......@@ -311,8 +311,8 @@ class ClientTests(NEOFunctionalTest):
db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage
t1, t2 = transaction.Transaction(), transaction.Transaction()
t1.user = t2.user = 'user'
t1.description = t2.description = 'desc'
t1.user = t2.user = u'user'
t1.description = t2.description = u'desc'
oid = st1.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject())
......@@ -330,8 +330,8 @@ class ClientTests(NEOFunctionalTest):
db3, conn3 = self.neo.getZODBConnection()
st3 = conn3._storage
t3 = transaction.Transaction()
t3.user = 'user'
t3.description = 'desc'
t3.user = u'user'
t3.description = u'desc'
st3.tpc_begin(t3)
# retrieve the last revision
data, serial = st3.load(oid)
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -31,7 +31,6 @@ class ClusterTests(NEOFunctionalTest):
def testClusterStartup(self):
neo = self.neo = NEOCluster(['test_neo1', 'test_neo2'], replicas=1,
temp_dir=self.getTempDirectory())
neoctl = neo.neoctl
neo.run()
# Runing a new cluster doesn't exit Recovery state.
s1, s2 = neo.getStorageProcessList()
......@@ -40,7 +39,7 @@ class ClusterTests(NEOFunctionalTest):
neo.expectClusterRecovering()
# When allowing cluster to exit Recovery, it reaches Running state and
# all present storage nodes reach running state.
neoctl.startCluster()
neo.startCluster()
neo.expectRunning(s1)
neo.expectRunning(s2)
neo.expectClusterRunning()
......@@ -64,7 +63,7 @@ class ClusterTests(NEOFunctionalTest):
neo.expectPending(s1)
neo.expectUnknown(s2)
neo.expectClusterRecovering()
neoctl.startCluster()
neo.startCluster()
neo.expectRunning(s1)
neo.expectUnknown(s2)
neo.expectClusterRunning()
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,14 +14,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
import unittest
import transaction
from persistent import Persistent
from . import NEOCluster, NEOFunctionalTest
from neo.lib.protocol import ClusterStates, NodeStates
from ZODB.tests.StorageTestBase import zodb_pickle
class PObject(Persistent):
......@@ -421,47 +419,5 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0)
def testReplicationBlockedByUnfinished(self):
# start a cluster with 1 of 2 storages and a replica
(started, stopped) = self.__setup(storage_number=2, replicas=1,
pending_number=1, partitions=10)
self.neo.expectRunning(started[0])
self.neo.expectStorageNotKnown(stopped[0])
self.neo.expectOudatedCells(number=0)
self.neo.expectClusterRunning()
self.__populate()
self.neo.expectOudatedCells(number=0)
# start a transaction that will block the end of the replication
db, conn = self.neo.getZODBConnection()
st = conn._storage
t = transaction.Transaction()
t.user = 'user'
t.description = 'desc'
oid = st.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject(42))
st.tpc_begin(t)
st.store(oid, rev, data, '', t)
# start the outdated storage
stopped[0].start()
self.neo.expectPending(stopped[0])
self.neo.neoctl.enableStorageList([stopped[0].getUUID()])
self.neo.neoctl.tweakPartitionTable()
self.neo.expectRunning(stopped[0])
self.neo.expectClusterRunning()
self.neo.expectAssignedCells(started[0], 10)
self.neo.expectAssignedCells(stopped[0], 10)
# wait a bit, replication must not happen. This hack is required
# because we cannot gather informations directly from the storages
time.sleep(10)
self.neo.expectOudatedCells(number=10)
# finish the transaction, the replication must happen and finish
st.tpc_vote(t)
st.tpc_finish(t)
self.neo.expectOudatedCells(number=0, timeout=10)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,8 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application
......@@ -62,6 +63,9 @@ class MasterClientHandlerTests(NeoUnitTestBase):
)
return uuid
def checkAnswerBeginTransaction(self, conn):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction)
# Tests
def test_07_askBeginTransaction(self):
tid1 = self.getNextTID()
......@@ -87,12 +91,12 @@ class MasterClientHandlerTests(NeoUnitTestBase):
calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_node, None)
args = self.checkAnswerBeginTransaction(conn, decode=True)
self.assertEqual(args, (tid1, ))
packet = self.checkAnswerBeginTransaction(conn)
self.assertEqual(packet.decode(), (tid1, ))
def test_08_askNewOIDs(self):
service = self.service
oid1, oid2 = self.getOID(1), self.getOID(2)
oid1, oid2 = p64(1), p64(2)
self.app.tm.setLastOID(oid1)
# client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
......@@ -136,7 +140,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ())
self.checkAskLockInformation(storage_conn)
self.checkAskPacket(storage_conn, Packets.AskLockInformation)
self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1)
txn = self.app.tm[ttid]
pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0]
......@@ -170,8 +174,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack,
decode=True)[0]
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0]
self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id)
......@@ -183,8 +186,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack,
decode=True)[0]
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0]
self.assertFalse(status)
if __name__ == '__main__':
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,10 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from neo.lib import protocol
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.election import ClientElectionHandler, \
ServerElectionHandler
from neo.master.app import Application
......@@ -48,6 +48,9 @@ class MasterClientElectionTestBase(NeoUnitTestBase):
node.setConnection(conn)
return (node, conn)
def checkAcceptIdentification(self, conn):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification)
class MasterClientElectionTests(MasterClientElectionTestBase):
def setUp(self):
......@@ -91,7 +94,7 @@ class MasterClientElectionTests(MasterClientElectionTestBase):
self.election.connectionCompleted(conn)
self._checkUnconnected(node)
self.assertTrue(node.isUnknown())
self.checkRequestIdentification(conn)
self.checkAskPacket(conn, Packets.RequestIdentification)
def _setNegociating(self, node):
self._checkUnconnected(node)
......@@ -252,9 +255,8 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.election.requestIdentification(conn,
NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID())
args = self.checkAcceptIdentification(conn, decode=True)
(node_type, uuid, partitions, replicas, new_uuid, primary_uuid,
master_list) = args
master_list) = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid)
......@@ -290,7 +292,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
None,
)
node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True)
master_list = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node_type, NodeTypes.MASTER)
self.assertEqual(uuid, self.app.uuid)
self.assertEqual(partitions, self.app.pt.getPartitions())
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -16,6 +16,7 @@
import unittest
from .. import NeoUnitTestBase
from neo.lib.protocol import Packets
from neo.master.app import Application
class MasterAppTests(NeoUnitTestBase):
......@@ -31,6 +32,9 @@ class MasterAppTests(NeoUnitTestBase):
self.app.close()
NeoUnitTestBase._tearDown(self, success)
def checkNotifyNodeInformation(self, conn):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation)
def test_06_broadcastNodeInformation(self):
# defined some nodes to which data will be send
master_uuid = self.getMasterUUID()
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, Packets
from neo.master.handlers.storage import StorageServiceHandler
......@@ -71,10 +71,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], {conn2.getUUID()})
self.service.answerPack(conn2, False)
status = self.checkAnswerPacket(client_conn, Packets.AnswerPack,
decode=True)[0]
packet = self.checkAnswerPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id
self.assertTrue(status)
self.assertTrue(packet.decode()[0])
self.assertEqual(self.app.packing, None)
if __name__ == '__main__':
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from struct import pack
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,12 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock, ReturnValues
from collections import deque
from ..mock import Mock, ReturnValues
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState
from neo.lib.util import p64
from neo.lib.protocol import INVALID_TID, Packets, LockState
class StorageClientHandlerTests(NeoUnitTestBase):
......@@ -30,11 +30,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.app.transaction_dict = {}
self.app.store_lock_dict = {}
self.app.load_lock_dict = {}
self.app.event_queue = deque()
self.app.event_queue_dict = {}
self.app.tm = Mock({'__contains__': True})
# handler
self.operation = ClientOperationHandler(self.app)
......@@ -59,19 +54,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn)
def test_24_askObject1(self):
# delayed response
conn = self._getConnection()
self.app.dm = Mock()
self.app.tm = Mock({'loadLocked': True})
self.app.load_lock_dict[INVALID_OID] = object()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID,
serial=INVALID_TID, tid=INVALID_TID)
self.assertEqual(len(self.app.event_queue), 1)
self.checkNoPacketSent(conn)
self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
def test_25_askTIDs1(self):
# invalid offsets => error
app = self.app
......@@ -91,7 +73,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn)
self.checkAnswerPacket(conn, Packets.AnswerTIDs)
def test_26_askObjectHistory1(self):
# invalid offsets => error
......@@ -108,8 +90,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
ltid = self.getNextTID()
undone_tid = self.getNextTID()
# Keep 2 entries here, so we check findUndoTID is called only once.
oid_list = [self.getOID(1), self.getOID(2)]
obj2_data = [] # Marker
oid_list = map(p64, (1, 2))
self.app.tm = Mock({
'getObjectFromTransaction': None,
})
......@@ -134,7 +115,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
conn = self._getConnection()
self.operation.askHasLock(conn, tid_1, oid)
p_oid, p_status = self.checkAnswerPacket(conn,
Packets.AnswerHasLock, decode=True)
Packets.AnswerHasLock).decode()
self.assertEqual(oid, p_oid)
self.assertEqual(status, p_status)
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from collections import deque
from .. import NeoUnitTestBase
from neo.storage.app import Application
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.lib.protocol import CellStates
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -103,20 +103,19 @@ class StorageDBTests(NeoUnitTestBase):
def test_15_PTID(self):
db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1))
self.checkConfigEntry(db.getPTID, db.setPTID, 1)
def test_getPartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
cell1 = (0, uuid1, CellStates.OUT_OF_DATE)
cell2 = (1, uuid1, CellStates.UP_TO_DATE)
db.changePartitionTable(ptid, [cell1, cell2], 1)
db.changePartitionTable(1, [cell1, cell2], 1)
result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2})
def getOIDs(self, count):
return map(self.getOID, xrange(count))
return map(p64, xrange(count))
def getTIDs(self, count):
tid_list = [self.getNextTID()]
......@@ -198,7 +197,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_setPartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
......@@ -220,7 +219,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_changePartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
......@@ -301,7 +300,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_deleteRange(self):
np = 4
self.setNumPartitions(np)
t1, t2, t3 = map(self.getOID, (1, 2, 3))
t1, t2, t3 = map(p64, (1, 2, 3))
oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list)
......@@ -339,7 +338,7 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getObjectHistory(self):
oid = self.getOID(1)
oid = p64(1)
tid1, tid2, tid3 = self.getTIDs(3)
txn1, objs1 = self.getTransaction([oid])
txn2, objs2 = self.getTransaction([oid])
......@@ -362,7 +361,7 @@ class StorageDBTests(NeoUnitTestBase):
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid_list = self.getOIDs(count)
oid = self.getOID(1)
oid = p64(1)
for tid in tid_list:
txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn, False)
......@@ -446,7 +445,7 @@ class StorageDBTests(NeoUnitTestBase):
tid3 = self.getNextTID()
tid4 = self.getNextTID()
tid5 = self.getNextTID()
oid1 = self.getOID(1)
oid1 = p64(1)
foo = db.holdData("3" * 20, 'foo', 0)
bar = db.holdData("4" * 20, 'bar', 0)
db.releaseData((foo, bar))
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -16,7 +16,7 @@
import unittest
from MySQLdb import OperationalError
from mock import Mock
from ..mock import Mock
from neo.lib.exception import DatabaseFailure
from neo.lib.util import p64
from .. import DB_PREFIX, DB_SOCKET, DB_USER
......@@ -96,9 +96,11 @@ class StorageMySQLdbTests(StorageDBTests):
assert len(x) + EXTRA == self.db._max_allowed_packet
self.assertRaises(DatabaseFailure, self.db.query, x + ' ')
self.db.query(x)
# Reconnection cleared the cache of the config table,
# so fill it again with required values before we patch query().
self.db.getNumPartitions()
# Check MySQLDatabaseManager._max_allowed_packet
query_list = []
query = self.db.query
self.db.query = lambda query: query_list.append(EXTRA + len(query))
self.assertEqual(2, max(len(self.db.escape(chr(x)))
for x in xrange(256)))
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2010-2016 Nexedi SA
# Copyright (C) 2010-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,21 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.storage.transactions import Transaction, TransactionManager
from neo.lib.util import p64
from neo.storage.transactions import TransactionManager
class TransactionTests(NeoUnitTestBase):
def testLock(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
self.assertFalse(txn.isLocked())
txn.lock()
self.assertTrue(txn.isLocked())
# disallow lock more than once
self.assertRaises(AssertionError, txn.lock)
class TransactionManagerTests(NeoUnitTestBase):
def setUp(self):
......@@ -46,7 +37,7 @@ class TransactionManagerTests(NeoUnitTestBase):
def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID()
oid = self.getOID(1)
oid = p64(1)
orig_serial = self.getNextTID()
uuid = self.getClientUUID()
locking_serial = self.getNextTID()
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -18,7 +18,7 @@ import unittest
from . import NeoUnitTestBase
from neo.storage.app import Application
from neo.lib.bootstrap import BootstrapManager
from neo.lib.protocol import NodeTypes
from neo.lib.protocol import NodeTypes, Packets
class BootstrapManagerTests(NeoUnitTestBase):
......@@ -46,7 +46,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
conn = self.getFakeConnection(address=address)
self.bootstrap.current = self.app.nm.createMaster(address=address)
self.bootstrap.connectionCompleted(conn)
self.checkRequestIdentification(conn)
self.checkAskPacket(conn, Packets.RequestIdentification)
def testHandleNotReady(self):
# the primary is not ready
......
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -16,7 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from time import time
from mock import Mock
from .mock import Mock
from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, CRITICAL_TIMEOUT
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .mock import Mock
from . import NeoUnitTestBase
from neo.lib.handler import EventHandler
from neo.lib.protocol import PacketMalformedError, UnexpectedPacketError, \
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -16,7 +16,7 @@
import shutil
import unittest
from mock import Mock
from .mock import Mock
from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.node import Node, MasterDB
from . import NeoUnitTestBase, getTempDirectory
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2006-2016 Nexedi SA
# Copyright (C) 2006-2017 Nexedi SA
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2011-2016 Nexedi SA
# Copyright (C) 2011-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -22,10 +22,9 @@ from collections import deque
from ConfigParser import SafeConfigParser
from contextlib import contextmanager
from itertools import count
from functools import wraps
from thread import get_ident
from functools import partial, wraps
from zlib import decompress
from mock import Mock
from ..mock import Mock
import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
......@@ -36,13 +35,14 @@ from neo.lib.connection import BaseConnection, \
from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
TIC_LOOP = xrange(1000)
# T1 T2
......@@ -273,7 +273,7 @@ class TestSerialized(Serialized): # NOTE used only in .NeoCTL
def poll(self, timeout):
if timeout:
for x in xrange(1000):
for x in TIC_LOOP:
r = self._epoll.poll(0)
if r:
return r
......@@ -339,6 +339,7 @@ class ServerNode(Node):
master_nodes = kw.get('master_nodes', cluster.master_nodes)
name = kw.get('name', cluster.name)
port = address[1]
if address is not BIND:
self._node_list[port] = weakref.proxy(self)
self._init_args = init_args = kw.copy()
init_args['cluster'] = cluster
......@@ -386,7 +387,7 @@ class ServerNode(Node):
raise ConnectorException
def stop(self):
self.em.wakeup(True)
self.em.wakeup(thread.exit)
class AdminApplication(ServerNode, neo.admin.app.Application):
pass
......@@ -484,7 +485,7 @@ class ConnectionFilter(object):
filtered_count = 0
filter_list = []
filter_queue = weakref.WeakKeyDictionary()
filter_queue = weakref.WeakKeyDictionary() # XXX: see the end of __new__
lock = threading.RLock()
_addPacket = Connection._addPacket
......@@ -500,13 +501,16 @@ class ConnectionFilter(object):
queue = cls.filter_queue[conn]
except KeyError:
for self in cls.filter_list:
if self(conn, packet):
if self._test(conn, packet):
self.filtered_count += 1
break
else:
return cls._addPacket(conn, packet)
cls.filter_queue[conn] = queue = deque()
p = packet.__new__(packet.__class__)
p = packet.__class__
logging.debug("queued %s#0x%04x for %s",
p.__name__, packet.getId(), conn)
p = packet.__new__(p)
p.__dict__.update(packet.__dict__)
queue.append(p)
Connection._addPacket = _addPacket
......@@ -517,10 +521,13 @@ class ConnectionFilter(object):
del cls.filter_list[-1:]
if not cls.filter_list:
Connection._addPacket = cls._addPacket.im_func
# Retry even in case of exception, at least to avoid leaks in
# filter_queue. Sometimes, WeakKeyDictionary only does the job
# only an explicit call to gc.collect.
with cls.lock:
cls._retry()
def __call__(self, conn, packet):
def _test(self, conn, packet):
if not self.conn_list or conn in self.conn_list:
for filter in self.filter_dict:
if filter(conn, packet):
......@@ -533,13 +540,16 @@ class ConnectionFilter(object):
while queue:
packet = queue.popleft()
for self in cls.filter_list:
if self(conn, packet):
if self._test(conn, packet):
queue.appendleft(packet)
break
else:
if conn.isClosed():
return
cls._addPacket(conn, packet)
# Use the thread that created the packet to reinject it,
# to avoid a race condition on Connector.queued.
conn.em.wakeup(lambda conn=conn, packet=packet:
conn.isClosed() or cls._addPacket(conn, packet))
continue
break
else:
......@@ -566,6 +576,22 @@ class ConnectionFilter(object):
def __contains__(self, filter):
return filter in self.filter_dict
def byPacket(self, packet_type, *args):
patches = []
other = []
for x in args:
(patches if isinstance(x, Patch) else other).append(x)
def delay(conn, packet):
return isinstance(packet, packet_type) and False not in (
callback(conn) for callback in other)
self.add(delay, *patches)
return delay
def __getattr__(self, attr):
if attr.startswith('delay'):
return partial(self.byPacket, getattr(Packets, attr[5:]))
return self.__getattribute__(attr)
class NEOCluster(object):
SSL = None
......@@ -576,9 +602,11 @@ class NEOCluster(object):
def _lock(blocking=True):
if blocking:
logging.info('<SimpleQueue>._lock.acquire()')
while not lock(False):
Serialized.tic(step=1, quiet=True)
for i in TIC_LOOP:
if lock(False):
return True
Serialized.tic(step=1, quiet=True)
raise Exception("tic is looping forever")
return lock(False)
self._lock = _lock
_patches = (
......@@ -619,6 +647,8 @@ class NEOCluster(object):
patch.revert()
Serialized.stop()
started = False
def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
storage_count=None, db_list=None, clear_databases=True,
......@@ -627,6 +657,7 @@ class NEOCluster(object):
self.name = 'neo_%s' % self._allocate('name',
lambda: random.randint(0, 100))
self.compress = compress
self.num_partitions = partitions
master_list = [MasterApplication.newAddress()
for _ in xrange(master_count)]
self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
......@@ -670,7 +701,6 @@ class NEOCluster(object):
self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
for x in db_list]
self.admin_list = [AdminApplication(**kw)]
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
def __repr__(self):
return "<%s(%s) at 0x%x>" % (self.__class__.__name__,
......@@ -720,18 +750,16 @@ class NEOCluster(object):
return master
###
def reset(self, clear_database=False):
for node_type in 'master', 'storage', 'admin':
kw = {}
if node_type == 'storage':
kw['clear_database'] = clear_database
for node in getattr(self, node_type + '_list'):
node.resetNode(**kw)
self.neoctl.close()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
def __enter__(self):
return self
def __exit__(self, t, v, tb):
self.stop(None)
def start(self, storage_list=None, fast_startup=False):
self.started = True
self._patch()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'):
node.start()
......@@ -750,13 +778,63 @@ class NEOCluster(object):
assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state
self.enableStorageList(storage_list)
def newClient(self):
def stop(self, clear_database=False, __print_exc=traceback.print_exc):
if self.started:
del self.started
logging.debug("stopping %s", self)
client = self.__dict__.get("client")
client is None or self.__dict__.pop("db", client).close()
node_list = self.admin_list + self.storage_list + self.master_list
for node in node_list:
node.stop()
try:
node_list.append(client.poll_thread)
except AttributeError: # client is None or thread is already stopped
pass
self.join(node_list)
self.neoctl.close()
del self.neoctl
logging.debug("stopped %s", self)
self._unpatch()
if clear_database is None:
try:
for node_type in 'admin', 'storage', 'master':
for node in getattr(self, node_type + '_list'):
node.close()
except:
__print_exc()
raise
else:
for node_type in 'master', 'storage', 'admin':
kw = {}
if node_type == 'storage':
kw['clear_database'] = clear_database
for node in getattr(self, node_type + '_list'):
node.resetNode(**kw)
def _newClient(self):
return ClientApplication(name=self.name, master_nodes=self.master_nodes,
compress=self.compress, ssl=self.SSL)
@contextmanager
def newClient(self, with_db=False):
x = self._newClient()
try:
t = x.poll_thread
closed = []
if with_db:
x = ZODB.DB(storage=self.getZODBStorage(client=x))
else:
# XXX: Do nothing if finally if the caller already closed it.
x.close = lambda: closed.append(x.__class__.close(x))
yield x
finally:
closed or x.close()
self.join((t,))
@cached_property
def client(self):
client = self.newClient()
client = self._newClient()
# Make sure client won't be reused after it was closed.
def close():
client = self.client
......@@ -794,21 +872,6 @@ class NEOCluster(object):
Serialized.tic()
thread_list = [t for t in thread_list if t.is_alive()]
def stop(self):
logging.debug("stopping %s", self)
client = self.__dict__.get("client")
client is None or self.__dict__.pop("db", client).close()
node_list = self.admin_list + self.storage_list + self.master_list
for node in node_list:
node.stop()
try:
node_list.append(client.poll_thread)
except AttributeError: # client is None or thread is already stopped
pass
self.join(node_list)
logging.debug("stopped %s", self)
self._unpatch()
def getNodeState(self, node):
uuid = node.uuid
for node in self.neoctl.getNodeList(node.node_type):
......@@ -850,19 +913,9 @@ class NEOCluster(object):
for o in oid_list:
tid_dict[o] = i
def getTransaction(self):
def getTransaction(self, db=None):
txn = transaction.TransactionManager()
return txn, self.db.open(transaction_manager=txn)
def __del__(self, __print_exc=traceback.print_exc):
try:
self.neoctl.close()
for node_type in 'admin', 'storage', 'master':
for node in getattr(self, node_type + '_list'):
node.close()
except:
__print_exc()
raise
return txn, (self.db if db is None else db).open(txn)
def extraCellSortKey(self, key): # XXX unused?
return Patch(self.client.cp, getCellSortKey=lambda orig, cell:
......@@ -897,9 +950,15 @@ class NEOCluster(object):
class NEOThreadedTest(NeoTestBase):
__run_count = {}
def setupLog(self):
log_file = os.path.join(getTempDirectory(), self.id() + '.log')
logging.setup(log_file)
test_id = self.id()
i = self.__run_count.get(test_id, 0)
self.__run_count[test_id] = 1 + i
if i:
test_id += '-%s' % i
logging.setup(os.path.join(getTempDirectory(), test_id + '.log'))
return LoggerThreadName()
def _tearDown(self, success):
......@@ -912,20 +971,17 @@ class NEOThreadedTest(NeoTestBase):
tic = Serialized.tic
@contextmanager
def getLoopbackConnection(self):
app = MasterApplication(getSSL=NEOCluster.SSL,
getReplicas=0, getPartitions=1)
app = MasterApplication(address=BIND,
getSSL=NEOCluster.SSL, getReplicas=0, getPartitions=1)
try:
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
conn = ClientConnection.__new__(ClientConnection)
def reset():
conn.__dict__.clear()
conn.__init__(app, handler, node)
conn.reset = reset
reset()
return conn
yield ClientConnection(app, handler, app.nm.createMaster(
address=app.listening_conn.getAddress(), uuid=app.uuid))
finally:
app.close()
def getUnpickler(self, conn): # XXX not used?
reader = conn._reader
......@@ -963,6 +1019,9 @@ class NEOThreadedTest(NeoTestBase):
with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit)
def assertPartitionTable(self, cluster, stats):
self.assertEqual(stats, '|'.join(cluster.admin.pt.formatRows()))
def predictable_random(seed=None):
# Because we have 2 running threads when client works, we can't
......@@ -984,3 +1043,13 @@ def predictable_random(seed=None):
= random
return wraps(wrapped)(wrapper)
return decorator
def with_cluster(start_cluster=True, **cluster_kw):
def decorator(wrapped):
def wrapper(self, *args, **kw):
with NEOCluster(**cluster_kw) as cluster:
if start_cluster:
cluster.start()
return wrapped(self, cluster, *args, **kw)
return wraps(wrapped)(wrapper)
return decorator
#
# Copyright (C) 2011-2016 Nexedi SA
# Copyright (C) 2011-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -17,6 +17,7 @@
import os
import sys
import threading
import time
import transaction
import unittest
from thread import get_ident
......@@ -32,7 +33,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch
from . import LockLock, NEOCluster, NEOThreadedTest
from . import LockLock, NEOThreadedTest, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
......@@ -53,10 +54,9 @@ class PCounterWithResolution(PCounter):
class Test(NEOThreadedTest):
def testBasicStore(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testBasicStore(self, cluster):
if 1:
storage = cluster.getZODBStorage()
data_info = {}
compressible = 'x' * 20
......@@ -106,13 +106,10 @@ class Test(NEOThreadedTest):
if big:
self.assertFalse(cluster.storage.sqlCount('bigdata'))
self.assertFalse(cluster.storage.sqlCount('data'))
finally:
cluster.stop()
def testDeleteObject(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testDeleteObject(self, cluster):
if 1:
storage = cluster.getZODBStorage()
for clear_cache in 0, 1:
for tst in 'a.', 'bcd.':
......@@ -131,13 +128,10 @@ class Test(NEOThreadedTest):
storage._cache.clear()
self.assertRaises(POSException.POSKeyError,
storage.load, oid, '')
finally:
cluster.stop()
def testCreationUndoneHistory(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testCreationUndoneHistory(self, cluster):
if 1:
storage = cluster.getZODBStorage()
oid = storage.new_oid()
txn = transaction.Transaction()
......@@ -155,18 +149,15 @@ class Test(NEOThreadedTest):
for x in storage.history(oid, 10):
self.assertEqual((x['tid'], x['size']), expected.pop())
self.assertFalse(expected)
finally:
cluster.stop()
def testUndoConflict(self, conflict_during_store=False):
@with_cluster()
def testUndoConflict(self, cluster, conflict_during_store=False):
def waitResponses(orig, *args):
orig(*args)
p.revert()
ob.value += 3
t.commit()
cluster = NEOCluster()
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()[0] = ob = PCounterWithResolution()
t.commit()
......@@ -186,17 +177,14 @@ class Test(NEOThreadedTest):
undo.tpc_finish(txn)
t.begin()
self.assertEqual(ob.value, 3)
finally:
cluster.stop()
@expectedFailure(POSException.ConflictError) # TODO recheck
def testUndoConflictDuringStore(self):
self.testUndoConflict(True)
def testStorageDataLock(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testStorageDataLock(self, cluster):
if 1:
storage = cluster.getZODBStorage()
data_info = {}
......@@ -211,15 +199,21 @@ class Test(NEOThreadedTest):
data_info[key] = 0
storage.sync()
txn = [transaction.Transaction() for x in xrange(3)]
txn = [transaction.Transaction() for x in xrange(4)]
for t in txn:
storage.tpc_begin(t)
storage.store(tid and oid or storage.new_oid(),
storage.store(oid if tid else storage.new_oid(),
tid, data, '', t)
tid = None
data_info[key] = 4
storage.sync()
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn.pop())
for t in txn:
storage.tpc_vote(t)
data_info[key] = 3
storage.sync()
data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[1])
......@@ -236,13 +230,10 @@ class Test(NEOThreadedTest):
storage.sync()
data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally:
cluster.stop()
def testDelayedUnlockInformation(self):
@with_cluster(storage_count=1)
def testDelayedUnlockInformation(self, cluster):
except_list = []
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreObject(orig, tm, ttid, serial, oid, *args):
if oid == resume_oid and delayUnlockInformation in m2s:
m2s.remove(delayUnlockInformation)
......@@ -251,25 +242,22 @@ class Test(NEOThreadedTest):
except Exception, e:
except_list.append(e.__class__)
raise
cluster = NEOCluster(storage_count=1)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()[0] = ob = PCounter()
with cluster.master.filterConnection(cluster.storage) as m2s:
resume_oid = None
m2s.add(delayUnlockInformation,
delayUnlockInformation = m2s.delayNotifyUnlockInformation(
Patch(TransactionManager, storeObject=onStoreObject))
t.commit()
resume_oid = ob._p_oid
ob._p_changed = 1
t.commit()
self.assertFalse(delayUnlockInformation in m2s)
finally:
cluster.stop()
self.assertNotIn(delayUnlockInformation, m2s)
self.assertEqual(except_list, [DelayedError])
def _testDeadlockAvoidance(self, scenario):
@with_cluster(storage_count=2, replicas=1)
def _testDeadlockAvoidance(self, cluster, scenario):
except_list = []
delay = threading.Event(), threading.Event()
ident = get_ident()
......@@ -301,9 +289,7 @@ class Test(NEOThreadedTest):
delay[c2].clear()
delay[1-c2].set()
cluster = NEOCluster(storage_count=2, replicas=1)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()[0] = ob = PCounterWithResolution()
t.commit()
......@@ -326,8 +312,6 @@ class Test(NEOThreadedTest):
t2.begin()
self.assertEqual(o1.value, 3)
self.assertEqual(o2.value, 3)
finally:
cluster.stop()
return except_list
def testDelayedStore(self):
......@@ -349,11 +333,10 @@ class Test(NEOThreadedTest):
self.assertEqual(self._testDeadlockAvoidance([1, 3]),
[DelayedError, ConflictError, "???" ])
def testConflictResolutionTriggered2(self):
@with_cluster()
def testConflictResolutionTriggered2(self, cluster):
""" Check that conflict resolution works """
cluster = NEOCluster()
try:
cluster.start()
if 1:
# create the initial object
t, c = cluster.getTransaction()
c.root()['with_resolution'] = ob = PCounterWithResolution()
......@@ -384,7 +367,7 @@ class Test(NEOThreadedTest):
tid1 = o1._p_serial
resolved = []
last = (t2.get(), t3.get()).index
last = lambda txn: txn._extension['last'] # BBB
def _handleConflicts(orig, txn_context, *args):
resolved.append(last(txn_context['txn']))
return orig(txn_context, *args)
......@@ -395,7 +378,8 @@ class Test(NEOThreadedTest):
with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote):
with LockLock() as l2:
tt = []
for t, l in (t2, l2), (t3, l3):
for i, t, l in (0, t2, l2), (1, t3, l3):
t.get().setExtendedInfo('last', i)
tt.append(self.newThread(t.commit))
l()
tt.pop(0).join()
......@@ -421,15 +405,45 @@ class Test(NEOThreadedTest):
# check history
self.assertEqual([x['tid'] for x in c1.db().history(oid, size=10)],
[tid3, tid2, tid1, tid0])
finally:
cluster.stop()
def test_notifyNodeInformation(self):
@with_cluster()
def testDelayedLoad(self, cluster):
"""
Check that a storage node delays reads from the database,
when the requested data may still be in a temporary place.
"""
l = threading.Lock()
l.acquire()
idle = []
def askObject(orig, *args):
orig(*args)
idle.append(cluster.storage.em.isIdle())
l.release()
if 1:
t, c = cluster.getTransaction()
r = c.root()
r[''] = ''
with Patch(ClientOperationHandler, askObject=askObject):
with cluster.master.filterConnection(cluster.storage) as m2s:
m2s.delayNotifyUnlockInformation()
t.commit()
c.cacheMinimize()
cluster.client._cache.clear()
load = self.newThread(r._p_activate)
l.acquire()
l.acquire()
# The request from the client is processed again
# (upon reception on unlock notification from the master),
# once exactly, and now with success.
load.join()
self.assertEqual(idle, [1, 0])
self.assertIn('', r)
@with_cluster(replicas=1)
def test_notifyNodeInformation(self, cluster):
# translated from MasterNotificationsHandlerTests
# (neo.tests.client.testMasterHandler)
cluster = NEOCluster(replicas=1)
try:
cluster.start()
if 1:
cluster.db # open DB
s0, s1 = cluster.client.nm.getStorageList()
conn = s0.getConnection()
......@@ -444,27 +458,21 @@ class Test(NEOThreadedTest):
# was called (even if it's useless in this case),
# but we would need an API to do that easily.
self.assertFalse(cluster.client.dispatcher.registered(conn))
finally:
cluster.stop()
def testRestartWithMissingStorage(self):
@with_cluster(replicas=1, partitions=10)
def testRestartWithMissingStorage(self, cluster):
# translated from neo.tests.functional.testStorage.StorageTest
cluster = NEOCluster(replicas=1, partitions=10)
s1, s2 = cluster.storage_list
try:
cluster.start()
if 1:
self.assertEqual([], cluster.getOutdatedCells())
finally:
cluster.stop()
# restart it with one storage only
cluster.reset()
try:
if 1:
cluster.start(storage_list=(s1,))
self.assertEqual(NodeStates.UNKNOWN, cluster.getNodeState(s2))
finally:
cluster.stop()
def testRestartStoragesWithReplicas(self):
@with_cluster(storage_count=2, partitions=2, replicas=1)
def testRestartStoragesWithReplicas(self, cluster):
"""
Check that the master must discard its partition table when the
cluster is not operational anymore. Which means that it must go back
......@@ -480,8 +488,7 @@ class Test(NEOThreadedTest):
orig()
def stop():
with cluster.master.filterConnection(s0) as m2s0:
m2s0.add(lambda conn, packet:
isinstance(packet, Packets.NotifyPartitionChanges))
m2s0.delayNotifyPartitionChanges()
s1.stop()
cluster.join((s1,))
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
......@@ -492,9 +499,7 @@ class Test(NEOThreadedTest):
self.assertNotEqual(getClusterState(), ClusterStates.RUNNING)
s0.resetNode()
s1.resetNode()
cluster = NEOCluster(storage_count=2, partitions=2, replicas=1)
try:
cluster.start()
if 1:
s0, s1 = cluster.storage_list
getClusterState = cluster.neoctl.getClusterState
if 1:
......@@ -517,13 +522,10 @@ class Test(NEOThreadedTest):
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(cluster.getOutdatedCells(),
[(0, s0.uuid), (1, s0.uuid)])
finally:
cluster.stop()
def testVerificationCommitUnfinishedTransactions(self):
@with_cluster(partitions=2, storage_count=2)
def testVerificationCommitUnfinishedTransactions(self, cluster):
""" Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onLockTransaction(storage, die=False):
def lock(orig, *args, **kw):
if die:
......@@ -531,9 +533,7 @@ class Test(NEOThreadedTest):
orig(*args, **kw)
storage.master_conn.close()
return Patch(storage.tm, lock=lock)
cluster = NEOCluster(partitions=2, storage_count=2)
try:
cluster.start()
if 1:
s0, s1 = cluster.sortStorageList()
t, c = cluster.getTransaction()
r = c.root()
......@@ -564,7 +564,7 @@ class Test(NEOThreadedTest):
self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s:
m2s.add(delayUnlockInformation)
m2s.delayNotifyUnlockInformation()
t.commit()
x.value = 1
# s0 will accept to store y (because it's not locked) but will
......@@ -574,9 +574,7 @@ class Test(NEOThreadedTest):
di0 = s0.getDataLockInfo()
with onLockTransaction(s1, die=True):
self.commitWithStorageFailure(cluster.client, t)
finally:
cluster.stop()
cluster.reset()
(k, v), = set(s0.getDataLockInfo().iteritems()
).difference(di0.iteritems())
self.assertEqual(v, 1)
......@@ -587,7 +585,7 @@ class Test(NEOThreadedTest):
k, = (k for k, v in di1.iteritems() if v == 1)
del di1[k] # x.value = 1
self.assertEqual(di1.values(), [0])
try:
if 1:
cluster.start()
t, c = cluster.getTransaction()
r = c.root()
......@@ -596,19 +594,16 @@ class Test(NEOThreadedTest):
self.assertEqual(r[2], 'ok')
self.assertEqual(di0, s0.getDataLockInfo())
self.assertEqual(di1, s1.getDataLockInfo())
finally:
cluster.stop()
def testVerificationWithNodesWithoutReadableCells(self):
@with_cluster(replicas=1)
def testVerificationWithNodesWithoutReadableCells(self, cluster):
def onLockTransaction(storage, die_after):
def lock(orig, *args, **kw):
if die_after:
orig(*args, **kw)
sys.exit()
return Patch(storage.tm, lock=lock)
cluster = NEOCluster(replicas=1)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()[0] = None
s0, s1 = cluster.storage_list
......@@ -634,10 +629,9 @@ class Test(NEOThreadedTest):
self.assertEqual(sorted(c.root()), [1])
self.tic()
t0, t1 = c.db().storage.iterator()
finally:
cluster.stop()
def testDropUnfinishedData(self):
@with_cluster(partitions=2, storage_count=2, replicas=1)
def testDropUnfinishedData(self, cluster):
def lock(orig, *args, **kw):
orig(*args, **kw)
storage.master_conn.close()
......@@ -646,9 +640,7 @@ class Test(NEOThreadedTest):
r.append(len(orig.__self__.getUnfinishedTIDDict()))
orig()
r.append(len(orig.__self__.getUnfinishedTIDDict()))
cluster = NEOCluster(partitions=2, storage_count=2, replicas=1)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()._p_changed = 1
storage = cluster.storage_list[0]
......@@ -657,13 +649,10 @@ class Test(NEOThreadedTest):
t.commit()
self.tic()
self.assertEqual(r, [1, 0])
finally:
cluster.stop()
def testStorageUpgrade1(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testStorageUpgrade1(self, cluster):
if 1:
storage = cluster.storage
t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None)
......@@ -680,42 +669,32 @@ class Test(NEOThreadedTest):
with Patch(storage.tm, lock=lambda *_: sys.exit()):
self.commitWithStorageFailure(cluster.client, t)
self.assertRaises(DatabaseFailure, storage.resetNode)
finally:
cluster.stop()
def testStorageReconnectDuringStore(self):
cluster = NEOCluster(replicas=1)
try:
cluster.start()
@with_cluster(replicas=1)
def testStorageReconnectDuringStore(self, cluster):
if 1:
t, c = cluster.getTransaction()
c.root()[0] = 'ok'
cluster.client.cp.closeAll()
t.commit() # store request
finally:
cluster.stop()
def testStorageReconnectDuringTransactionLog(self):
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
@with_cluster(storage_count=2, partitions=2)
def testStorageReconnectDuringTransactionLog(self, cluster):
if 1:
t, c = cluster.getTransaction()
cluster.client.cp.closeAll()
tid, (t1,) = cluster.client.transactionLog(
ZERO_TID, c.db().lastTransaction(), 10)
finally:
cluster.stop()
def testStorageReconnectDuringUndoLog(self):
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
@with_cluster(storage_count=2, partitions=2)
def testStorageReconnectDuringUndoLog(self, cluster):
if 1:
t, c = cluster.getTransaction()
cluster.client.cp.closeAll()
t1, = cluster.client.undoLog(0, 10)
finally:
cluster.stop()
def testDropNodeThenRestartCluster(self):
@with_cluster(storage_count=2, replicas=1)
def testDropNodeThenRestartCluster(self, cluster):
""" Start a cluster with more than one storage, down one, shutdown the
cluster then restart it. The partition table recovered must not include
the dropped node """
......@@ -724,10 +703,8 @@ class Test(NEOThreadedTest):
self.assertEqual(cluster.getNodeState(s2), NodeStates.RUNNING)
# start with two storage / one replica
cluster = NEOCluster(storage_count=2, replicas=1)
s1, s2 = cluster.storage_list
try:
cluster.start()
if 1:
checkNodeState(NodeStates.RUNNING)
self.assertEqual([], cluster.getOutdatedCells())
# drop one
......@@ -737,39 +714,29 @@ class Test(NEOThreadedTest):
checkNodeState(None)
self.assertEqual([], cluster.getOutdatedCells())
# restart with s2 only
finally:
cluster.stop()
cluster.reset()
try:
if 1:
cluster.start(storage_list=[s2])
checkNodeState(None)
# then restart it, it must be in pending state
s1.start()
self.tic()
checkNodeState(NodeStates.PENDING)
finally:
cluster.stop()
def test2Clusters(self): # NOTE
cluster1 = NEOCluster()
cluster2 = NEOCluster()
try:
cluster1.start()
cluster2.start()
@with_cluster()
@with_cluster()
def test2Clusters(self, cluster1, cluster2): # NOTE
if 1:
t1, c1 = cluster1.getTransaction()
t2, c2 = cluster2.getTransaction()
c1.root()['1'] = c2.root()['2'] = ''
t1.commit()
t2.commit()
finally:
cluster1.stop()
cluster2.stop()
def testAbortStorage(self):
cluster = NEOCluster(partitions=2, storage_count=2)
@with_cluster(partitions=2, storage_count=2)
def testAbortStorage(self, cluster):
storage = cluster.storage_list[0]
try:
cluster.start()
if 1:
# prevent storage to reconnect, in order to easily test
# that cluster becomes non-operational
with Patch(storage, connectToPrimary=sys.exit):
......@@ -783,17 +750,13 @@ class Test(NEOThreadedTest):
self.tic()
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
cluster.stop()
def testShutdown(self):
# BUG: Due to bugs in election, master nodes sometimes crash, or they # <- NOTE
@with_cluster(master_count=3, partitions=10, replicas=1, storage_count=3)
def testShutdown(self, cluster):
# NOTE vvv
# declare themselves primary too quickly. The consequence is
# often an endless tic loop.
cluster = NEOCluster(master_count=3, partitions=10,
replicas=1, storage_count=3)
try:
cluster.start()
if 1:
# fill DB a little
t, c = cluster.getTransaction()
c.root()[''] = ''
......@@ -804,9 +767,7 @@ class Test(NEOThreadedTest):
cluster.join(cluster.master_list
+ cluster.storage_list
+ cluster.admin_list)
finally:
cluster.stop()
cluster.reset() # reopen DB to check partition tables
cluster.stop() # stop and reopen DB to check partition tables
dm = cluster.storage_list[0].dm
self.assertEqual(1, dm.getPTID())
pt = list(dm.getPartitionTable())
......@@ -817,14 +778,13 @@ class Test(NEOThreadedTest):
self.assertEqual(s.dm.getPTID(), 1)
self.assertEqual(list(s.dm.getPartitionTable()), pt)
def testInternalInvalidation(self):
@with_cluster()
def testInternalInvalidation(self, cluster):
def _handlePacket(orig, conn, packet, kw={}, handler=None):
if type(packet) is Packets.AnswerTransactionFinished:
ll()
orig(conn, packet, kw, handler)
cluster = NEOCluster()
try:
cluster.start()
if 1:
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter()
t1.commit()
......@@ -839,14 +799,9 @@ class Test(NEOThreadedTest):
t2.begin()
t.join()
self.assertEqual(x2.value, 1)
finally:
cluster.stop()
def testExternalInvalidation(self):
cluster = NEOCluster()
try:
cluster.start()
cache = cluster.client._cache
@with_cluster()
def testExternalInvalidation(self, cluster):
# Initialize objects
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter()
......@@ -861,14 +816,14 @@ class Test(NEOThreadedTest):
# (at this time, we still have x=0 and y=1)
t2, c2 = cluster.getTransaction()
# Copy y to x using a different Master-Client connection
client = cluster.newClient()
with cluster.newClient() as client:
cache = cluster.client._cache
txn = transaction.Transaction()
client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x
with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet:
isinstance(packet, Packets.InvalidateObjects))
m2c.delayInvalidateObjects()
tid = client.tpc_finish(txn, None)
# Change to x is committed. Testing connection must ask the
# storage node to return original value of x, even if we
......@@ -940,22 +895,15 @@ class Test(NEOThreadedTest):
self.assertFalse(invalidations(c1))
self.assertEqual(x1.value, 1)
finally:
cluster.stop()
def testReadVerifyingStorage(self):
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
@with_cluster(storage_count=2, partitions=2)
def testReadVerifyingStorage(self, cluster):
if 1:
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounter()
t1.commit()
# We need a second client for external invalidations.
t2 = transaction.TransactionManager()
db = DB(storage=cluster.getZODBStorage(client=cluster.newClient()))
try:
c2 = db.open(t2)
t2.begin()
with cluster.newClient(1) as db:
t2, c2 = cluster.getTransaction(db)
r = c2.root()
r['y'] = None
r['x']._p_activate()
......@@ -968,8 +916,6 @@ class Test(NEOThreadedTest):
t2.commit()
for storage in cluster.storage_list:
self.assertFalse(storage.tm._transaction_dict)
finally:
db.close()
# Check we didn't get an invalidation, which would cause an
# assertion failure in the cache. Connection does the same check in
# _setstate_noncurrent so this could be also done by starting a
......@@ -981,13 +927,10 @@ class Test(NEOThreadedTest):
self.assertEqual(map(u64, t1.oid_list), [0, 1])
# Check oid 1 is part of transaction metadata.
self.assertEqual(t2.oid_list, t1.oid_list)
finally:
cluster.stop()
def testClientReconnection(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testClientReconnection(self, cluster):
if 1:
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter()
c1.root()['y'] = y = PCounter()
......@@ -1004,14 +947,11 @@ class Test(NEOThreadedTest):
#self.tic() # NOTE works ok with tic() commented
# modify x with another client
client = cluster.newClient()
try:
with cluster.newClient() as client:
txn = transaction.Transaction()
client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn)
tid = client.tpc_finish(txn, None)
finally:
client.close()
#self.tic() # NOTE ----//----
# Check reconnection to the master and storage.
......@@ -1020,13 +960,10 @@ class Test(NEOThreadedTest):
t1.begin()
self.assertEqual(x1._p_changed, None)
self.assertEqual(x1.value, 1)
finally:
cluster.stop()
def testInvalidTTID(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testInvalidTTID(self, cluster):
if 1:
client = cluster.client
txn = transaction.Transaction()
client.tpc_begin(txn)
......@@ -1034,16 +971,13 @@ class Test(NEOThreadedTest):
txn_context['ttid'] = add64(txn_context['ttid'], 1)
self.assertRaises(POSException.StorageError,
client.tpc_finish, txn, None)
finally:
cluster.stop()
def testStorageFailureDuringTpcFinish(self):
@with_cluster()
def testStorageFailureDuringTpcFinish(self, cluster):
def answerTransactionFinished(conn, packet):
if isinstance(packet, Packets.AnswerTransactionFinished):
raise StoppedOperation
cluster = NEOCluster()
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()['x'] = PCounter()
with cluster.master.filterConnection(cluster.client) as m2c:
......@@ -1057,10 +991,9 @@ class Test(NEOThreadedTest):
self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list)
self.assertEqual(2, u64(cluster.client.new_oid()))
finally:
cluster.stop()
def testClientFailureDuringTpcFinish(self):
@with_cluster()
def testClientFailureDuringTpcFinish(self, cluster):
"""
Third scenario:
......@@ -1095,9 +1028,7 @@ class Test(NEOThreadedTest):
self.tic()
s2m.remove(delayAnswerLockInformation)
return conn
cluster = NEOCluster()
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
r = c.root()
r['x'] = PCounter()
......@@ -1126,38 +1057,31 @@ class Test(NEOThreadedTest):
cluster.master.filterConnection(cluster.storage) as m2s:
s2m.add(delayAnswerLockInformation, Patch(cluster.client,
_connectToPrimaryNode=_connectToPrimaryNode))
m2s.add(lambda conn, packet:
isinstance(packet, Packets.NotifyUnlockInformation))
m2s.delayNotifyUnlockInformation()
t.commit() # the final TID is returned by the storage (tm)
t.begin()
self.assertEqual(r['x'].value, 2)
self.assertTrue(tid2 < r['x']._p_serial)
finally:
cluster.stop()
def testMasterFailureBeforeVote(self):
@with_cluster(storage_count=2, partitions=2)
def testMasterFailureBeforeVote(self, cluster):
def waitStoreResponses(orig, *args):
result = orig(*args)
m2c, = cluster.master.getConnectionList(orig.__self__)
m2c.close()
self.tic()
return result
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()['x'] = PCounter() # 1 store() to each storage
with Patch(cluster.client, waitStoreResponses=waitStoreResponses):
self.assertRaises(POSException.StorageError, t.commit)
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
cluster.stop()
def testEmptyTransaction(self):
cluster = NEOCluster()
try:
cluster.start()
@with_cluster()
def testEmptyTransaction(self, cluster):
if 1:
txn = transaction.Transaction()
storage = cluster.getZODBStorage()
storage.tpc_begin(txn)
......@@ -1166,58 +1090,40 @@ class Test(NEOThreadedTest):
t, = storage.iterator()
self.assertEqual(t.tid, serial)
self.assertFalse(t.oid_list)
finally:
cluster.stop()
def testRecycledClientUUID(self):
def delayNotifyInformation(conn, packet):
return isinstance(packet, Packets.NotifyNodeInformation)
@with_cluster()
def testRecycledClientUUID(self, cluster):
def notReady(orig, *args):
m2s.discard(delayNotifyInformation)
return orig(*args)
cluster = NEOCluster()
try:
cluster.start()
if 1:
cluster.getTransaction()
with cluster.master.filterConnection(cluster.storage) as m2s:
m2s.add(delayNotifyInformation)
delayNotifyInformation = m2s.delayNotifyNodeInformation()
cluster.client.master_conn.close()
client = cluster.newClient()
p = Patch(client.storage_bootstrap_handler, notReady=notReady)
try:
p.apply()
with cluster.newClient() as client, Patch(
client.storage_bootstrap_handler, notReady=notReady):
x = client.load(ZERO_TID)
finally:
del p
client.close()
self.assertNotIn(delayNotifyInformation, m2s)
finally:
cluster.stop()
def testAutostart(self):
def startCluster():
@with_cluster(start_cluster=0, storage_count=3, autostart=3)
def testAutostart(self, cluster):
def startCluster(orig):
getClusterState = cluster.neoctl.getClusterState
self.assertEqual(ClusterStates.RECOVERING, getClusterState())
cluster.storage_list[2].start()
self.tic()
with Patch(cluster, startCluster=startCluster):
self.assertEqual(ClusterStates.RUNNING, getClusterState())
cluster = NEOCluster(storage_count=3, autostart=3)
try:
cluster.startCluster = startCluster
cluster.start(cluster.storage_list[:2])
finally:
cluster.stop()
del cluster.startCluster
def testAbortVotedTransaction(self):
@with_cluster(storage_count=2, partitions=2)
def testAbortVotedTransaction(self, cluster):
r = []
def tpc_finish(*args, **kw):
for storage in cluster.storage_list:
r.append(len(storage.dm.getUnfinishedTIDDict()))
raise NEOStorageError
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
c.root()['x'] = PCounter()
with Patch(cluster.client, tpc_finish=tpc_finish):
......@@ -1228,17 +1134,11 @@ class Test(NEOThreadedTest):
self.assertFalse(storage.dm.getUnfinishedTIDDict())
t.begin()
self.assertNotIn('x', c.root())
finally:
cluster.stop()
def testStorageLostDuringRecovery(self):
@with_cluster(storage_count=2, partitions=2)
def testStorageLostDuringRecovery(self, cluster):
# Initialize a cluster.
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
finally:
cluster.stop()
cluster.reset()
# Restart with a connection failure for the first AskPartitionTable.
# The master must not be stuck in RECOVERING state
# or re-make the partition table.
......@@ -1247,16 +1147,15 @@ class Test(NEOThreadedTest):
def askPartitionTable(orig, self, conn):
p.revert()
conn.close()
try:
if 1:
with Patch(cluster.master.pt, make=make), \
Patch(InitializationHandler,
askPartitionTable=askPartitionTable) as p:
cluster.start()
self.assertFalse(p.applied)
finally:
cluster.stop()
def testTruncate(self):
@with_cluster(replicas=1)
def testTruncate(self, cluster):
calls = [0, 0]
def dieFirst(i):
def f(orig, *args, **kw):
......@@ -1265,9 +1164,7 @@ class Test(NEOThreadedTest):
sys.exit()
return orig(*args, **kw)
return f
cluster = NEOCluster(replicas=1)
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
r = c.root()
tids = []
......@@ -1311,30 +1208,27 @@ class Test(NEOThreadedTest):
self.assertEqual(1, u64(c._storage.new_oid()))
for s in cluster.storage_list:
self.assertEqual(s.dm.getLastIDs()[0], truncate_tid)
finally:
cluster.stop()
def testConnectionTimeout(self):
conn = self.getLoopbackConnection()
with self.getLoopbackConnection() as conn:
conn.KEEP_ALIVE
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
def onTimeout(orig):
conn.idle()
orig()
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
with Patch(conn, onTimeout=onTimeout):
conn.em.poll(1)
self.assertFalse(conn.isClosed())
def testClientDisconnectedFromMaster(self):
@with_cluster()
def testClientDisconnectedFromMaster(self, cluster):
def disconnect(conn, packet):
if isinstance(packet, Packets.AskObject):
m2c.close()
#return True
cluster = NEOCluster()
try:
cluster.start()
if 1:
t, c = cluster.getTransaction()
m2c, = cluster.master.getConnectionList(cluster.client)
cluster.client._cache.clear()
......@@ -1350,8 +1244,7 @@ class Test(NEOThreadedTest):
self.assertRaises(TransientError, getattr, c, "root")
uuid = cluster.client.uuid
# Let's use a second client to steal the node id of the first one.
client = cluster.newClient()
try:
with cluster.newClient() as client:
client.sync()
self.assertEqual(uuid, client.uuid)
# The client reconnects successfully to the master and storage,
......@@ -1363,12 +1256,9 @@ class Test(NEOThreadedTest):
self.assertNotEqual(uuid, cluster.client.uuid)
# Second reconnection, for a successful load.
c.root
finally:
client.close()
finally:
cluster.stop()
def testIdTimestamp(self):
@with_cluster()
def testIdTimestamp(self, cluster):
"""
Given a master M, a storage S, and 2 clients Ca and Cb.
......@@ -1394,9 +1284,7 @@ class Test(NEOThreadedTest):
ll()
def connectToStorage(client):
next(client.cp.iterateForObject(0))
cluster = NEOCluster()
try:
cluster.start()
if 1:
Ca = cluster.client
Ca.pt # only connect to the master
# In a separate thread, connect to the storage but suspend the
......@@ -1408,18 +1296,72 @@ class Test(NEOThreadedTest):
s2c, = s2c
m2c, = cluster.master.getConnectionList(cluster.client)
m2c.close()
Cb = cluster.newClient()
try:
with cluster.newClient() as Cb:
Cb.pt # only connect to the master
del s2c.readable
self.assertRaises(NEOPrimaryMasterLost, t.join)
self.assertTrue(s2c.isClosed())
connectToStorage(Cb)
finally:
Cb.close()
finally:
cluster.stop()
@with_cluster(storage_count=2, partitions=2)
def testPruneOrphan(self, cluster):
if 1:
cluster.importZODB()(3)
bad = []
ok = []
def data_args(value):
return makeChecksum(value), value, 0
node_list = []
for i, s in enumerate(cluster.storage_list):
node_list.append(s.uuid)
if i:
s.dm.holdData(*data_args('boo'))
ok.append(s.getDataLockInfo())
for i in xrange(3 - i):
s.dm.storeData(*data_args('!' * i))
bad.append(s.getDataLockInfo())
s.dm.commit()
def check(dry_run, expected):
cluster.neoctl.repair(node_list, dry_run)
for e, s in zip(expected, cluster.storage_list):
while 1:
self.tic()
if s.dm._repairing is None:
break
time.sleep(.1)
self.assertEqual(e, s.getDataLockInfo())
check(1, bad)
check(0, ok)
check(1, ok)
@with_cluster(replicas=1)
def testLateConflictOnReplica(self, cluster):
"""
Already resolved conflict: check the case of a storage node that
reports a conflict after that this conflict was fully resolved with
another node.
"""
def answerStoreObject(orig, conn, conflicting, *args):
if not conflicting:
p.revert()
ll()
orig(conn, conflicting, *args)
if 1:
s0, s1 = cluster.storage_list
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounterWithResolution()
t1.commit()
x.value += 1
t2, c2 = cluster.getTransaction()
c2.root()['x'].value += 2
t2.commit()
with LockLock() as ll, s1.filterConnection(cluster.client) as f, \
Patch(cluster.client.storage_handler,
answerStoreObject=answerStoreObject) as p:
f.delayAnswerStoreObject()
t = self.newThread(t1.commit)
ll()
t.join()
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2014-2016 Nexedi SA
# Copyright (C) 2014-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,11 +14,10 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import deque
from cPickle import Pickler, Unpickler
from cStringIO import StringIO
from itertools import islice, izip_longest
import os, unittest
import os, shutil, unittest
import neo, transaction, ZODB
from neo.lib import logging
from neo.lib.util import u64
......@@ -129,18 +128,28 @@ class ImporterTests(NEOThreadedTest):
self.assertDictEqual(state, load())
def test(self):
# XXX: Using NEO source files as test data was a bad idea because
# the test breaks easily in case of massive changes in the code,
# or if there are many untracked files.
importer = []
fs_dir = os.path.join(getTempDirectory(), self.id())
shutil.rmtree(fs_dir, 1) # for --loop
os.mkdir(fs_dir)
src_root, = neo.__path__
fs_list = "root", "client", "master", "tests"
def not_pyc(name):
return not name.endswith(".pyc")
# We use 'hash' to skip roughly half of files.
# They'll be added after the migration has started.
def root_filter(name):
if not name.endswith(".pyc"):
if not_pyc(name):
i = name.find(os.sep)
return i < 0 or name[:i] not in fs_list
return (i < 0 or name[:i] not in fs_list) and (
'.' not in name or hash(name) & 1)
def sub_filter(name):
return lambda n: n[-4:] != '.pyc' and \
n.split(os.sep, 1)[0] in (name, "scripts")
return lambda n: not_pyc(n) and (
hash(n) & 1 if '.' in n else
os.sep in n or n in (name, "scripts"))
conn_list = []
iter_list = []
# Setup several FileStorage databases.
......@@ -172,8 +181,7 @@ class ImporterTests(NEOThreadedTest):
c.db().close()
#del importer[0][1][importer.pop()[0]]
# Start NEO cluster with transparent import of a multi-base ZODB.
cluster = NEOCluster(compress=False, importer=importer)
try:
with NEOCluster(compress=False, importer=importer) as cluster:
# Suspend import for a while, so that import
# is finished in the middle of the below 'for' loop.
# Use a slightly different main loop for storage so that it
......@@ -193,7 +201,7 @@ class ImporterTests(NEOThreadedTest):
cluster.start()
t, c = cluster.getTransaction()
r = c.root()["neo"]
# Test retrieving of an object from ZODB when next serial in NEO.
# Test retrieving of an object from ZODB when next serial is in NEO.
r._p_changed = 1
t.commit()
t.begin()
......@@ -204,24 +212,25 @@ class ImporterTests(NEOThreadedTest):
self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$",
c.db().history, r._p_oid)
i = r.walk()
next(islice(i, 9, None))
next(islice(i, 4, None))
logging.info("start migration")
dm.doOperation(cluster.storage)
deque(i, maxlen=0)
last_import = None
for i, r in enumerate(r.treeFromFs(src_root, 10)):
# Adjust if needed. Must remain > 0.
assert 14 == sum(1 for i in i)
last_import = -1
for i, r in enumerate(r.treeFromFs(src_root, 6, not_pyc)):
t.commit()
if cluster.storage.dm._import:
last_import = i
self.tic()
self.assertTrue(last_import and not cluster.storage.dm._import)
# Same as above. We want last_import smaller enough compared to i
assert i / 3 < last_import < i - 2, (last_import, i)
self.assertFalse(cluster.storage.dm._import)
i = len(src_root) + 1
self.assertEqual(sorted(r.walk()), sorted(
(x[i:] or '.', sorted(y), sorted(z))
(x[i:] or '.', sorted(y), sorted(filter(not_pyc, z)))
for x, y, z in os.walk(src_root)))
t.commit()
finally:
cluster.stop()
if __name__ == "__main__":
......
#
# Copyright (C) 2012-2016 Nexedi SA
# Copyright (C) 2012-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -36,7 +36,8 @@ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64
from .. import expectedFailure, Patch
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, predictable_random
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \
predictable_random, with_cluster
# dump log to stderr
"""
......@@ -48,19 +49,14 @@ getLogger().setLevel(INFO)
def backup_test(partitions=1, upstream_kw={}, backup_kw={}):
def decorator(wrapped):
def wrapper(self):
upstream = NEOCluster(partitions, **upstream_kw)
try:
with NEOCluster(partitions, **upstream_kw) as upstream:
upstream.start()
backup = NEOCluster(partitions, upstream=upstream, **backup_kw)
try:
with NEOCluster(partitions, upstream=upstream,
**backup_kw) as backup:
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic()
wrapped(self, backup)
finally:
backup.stop()
finally:
upstream.stop()
return wraps(wrapped)(wrapper)
return decorator
......@@ -99,15 +95,18 @@ class ReplicationTests(NEOThreadedTest):
np = 7
nr = 2
check_dict = dict.fromkeys(xrange(np))
upstream = NEOCluster(partitions=np, replicas=nr-1, storage_count=3)
try:
with NEOCluster(partitions=np, replicas=nr-1, storage_count=3
) as upstream:
upstream.start()
importZODB = upstream.importZODB()
importZODB(3)
backup = NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream)
def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode()
return not upstream_name and all(source_dict.itervalues())
# U -> B propagation
try:
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream) as backup:
backup.start()
# Initialize & catch up.
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
......@@ -122,10 +121,7 @@ class ReplicationTests(NEOThreadedTest):
# Check that a backup cluster can be restarted.
# (U -> B propagation after restart)
finally:
backup.stop()
backup.reset()
try:
backup.start()
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.BACKINGUP)
......@@ -144,7 +140,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(np*nr, self.checkBackup(backup))
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
backup.stop()
# U -> B propagation with Mb -> Sb' (secondary, Replicate from primary Sb) delayed
......@@ -153,16 +149,10 @@ class ReplicationTests(NEOThreadedTest):
from neo.master import handlers as mhandler
#dbmanager.X = 1
#mhandler.X = 1
def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode()
#print 'REPLICATE tid: %r, upstream_name: %r, source_dict: %r' % \
# (tid, upstream_name, source_dict)
#return True
#return not upstream_name and all(source_dict.itervalues())
return upstream_name != ""
backup.reset()
try:
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic()
......@@ -189,16 +179,13 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.cluster_state, ClusterStates.RUNNING)
self.assertEqual(np*nr, self.checkBackup(backup,
max_tid=backup.last_tid))
self.assertEqual(backup.last_tid, u_last_tid0) # truncated after recovery
self.assertEqual(np*nr, self.checkBackup(backup, max_tid=backup.last_tid))
finally:
backup.stop()
dbmanager.X = 0
mhandler.X = 0
# S -> Sb (AddObject) delayed XXX not only S -> Sb: also Sb -> Sb'
backup.reset()
try:
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic()
......@@ -206,8 +193,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.backup_tid, upstream.last_tid) # B caught-up with U
u_last_tid1 = upstream.last_tid
with ConnectionFilter() as f:
f.add(lambda conn, packet: conn.getUUID() is None and
isinstance(packet, Packets.AddObject))
f.delayAddObject(lambda conn: conn.getUUID() is None)
while not f.filtered_count:
importZODB(1)
self.tic()
......@@ -222,10 +208,6 @@ class ReplicationTests(NEOThreadedTest):
max_tid=backup.last_tid))
self.assertEqual(backup.last_tid, u_last_tid1) # truncated after recovery
self.assertEqual(np*nr, self.checkBackup(backup, max_tid=backup.last_tid))
finally:
backup.stop()
finally:
upstream.stop()
@predictable_random()
def testBackupNodeLost(self):
......@@ -251,16 +233,14 @@ class ReplicationTests(NEOThreadedTest):
node_list.remove(txn.getNode())
node_list[0].getConnection().close() # disconnect Mb from M
return orig(txn)
upstream = NEOCluster(partitions=np, replicas=0, storage_count=1)
try:
with NEOCluster(partitions=np, replicas=0, storage_count=1) as upstream:
upstream.start()
importZODB = upstream.importZODB(random=random)
# Do not start with an empty DB so that 'primary_dict' below is not
# empty on the first iteration.
importZODB(1)
backup = NEOCluster(partitions=np, replicas=2, storage_count=4,
upstream=upstream)
try:
with NEOCluster(partitions=np, replicas=2, storage_count=4,
upstream=upstream) as backup:
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic()
......@@ -293,10 +273,6 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.backup_tid, upstream.last_tid)
self.assertEqual(backup.last_tid, upstream.last_tid)
self.assertEqual(np*3, self.checkBackup(backup))
finally:
backup.stop()
finally:
upstream.stop()
@backup_test()
def testBackupUpstreamMasterDead(self, backup):
......@@ -331,8 +307,7 @@ class ReplicationTests(NEOThreadedTest):
def testBackupUpstreamStorageDead(self, backup):
upstream = backup.upstream
with ConnectionFilter() as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.InvalidateObjects)) # delay M -> Mb
f.delayInvalidateObjects() # delay M -> Mb
upstream.importZODB()(1)
count = [0]
def _connect(orig, conn):
......@@ -361,38 +336,29 @@ class ReplicationTests(NEOThreadedTest):
"""
upstream = backup.upstream
with upstream.master.filterConnection(upstream.storage) as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.NotifyUnlockInformation))
f.delayNotifyUnlockInformation()
upstream.importZODB()(1)
self.tic()
self.tic()
# TODO check tids
self.assertEqual(1, self.checkBackup(backup))
def testBackupEarlyInvalidation(self):
@with_cluster()
def testBackupEarlyInvalidation(self, upstream):
"""
The backup master must ignore notification before being fully
The backup master must ignore notifications before being fully
initialized.
"""
upstream = NEOCluster()
try:
upstream.start()
backup = NEOCluster(upstream=upstream)
try:
with NEOCluster(upstream=upstream) as backup:
backup.start()
with ConnectionFilter() as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.AskPartitionTable) and
f.delayAskPartitionTable(lambda conn:
isinstance(conn.getHandler(), BackupHandler))
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
upstream.importZODB()(1)
self.tic()
self.tic()
self.assertTrue(backup.master.isAlive())
finally:
backup.stop()
finally:
upstream.stop()
self.assertTrue(backup.master.is_alive())
@backup_test()
def testBackupTid(self, backup):
......@@ -408,16 +374,15 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(last_tid, backup.backup_tid)
backup.stop()
importZODB(1)
backup.reset()
with ConnectionFilter() as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.AskFetchTransactions))
f.delayAskFetchTransactions()
backup.start()
self.assertEqual(last_tid, backup.backup_tid)
self.tic()
self.assertEqual(1, self.checkBackup(backup))
def testSafeTweak(self):
@with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
def testSafeTweak(self, cluster):
"""
Check that tweak always tries to keep a minimum of (replicas + 1)
readable cells, otherwise we have less/no redundancy as long as
......@@ -426,9 +391,8 @@ class ReplicationTests(NEOThreadedTest):
def changePartitionTable(orig, *args):
orig(*args)
sys.exit()
cluster = NEOCluster(partitions=3, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list
try:
if 1:
cluster.start([s0, s1])
s2.start()
self.tic()
......@@ -441,10 +405,9 @@ class ReplicationTests(NEOThreadedTest):
self.tic()
expectedFailure(self.assertEqual)(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
cluster.stop()
def testReplicationAbortedBySource(self):
@with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
def testReplicationAbortedBySource(self, cluster):
"""
Check that a feeding node aborts replication when its partition is
dropped, and that the out-of-date node finishes to replicate from
......@@ -469,11 +432,12 @@ class ReplicationTests(NEOThreadedTest):
# default for performance reason
orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list)
np = 3
cluster = NEOCluster(partitions=np, replicas=1, storage_count=3)
np = cluster.num_partitions
s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
try:
if cluster.started:
cluster.stop(1)
if 1:
cluster.start([s0])
cluster.populate([range(np*2)] * np)
s1.start()
......@@ -491,20 +455,17 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(1, connection_filter.filtered_count)
self.tic()
self.checkPartitionReplicated(s1, s2, offset)
finally:
cluster.stop()
cluster.reset(True)
def testClientReadingDuringTweak(self):
@with_cluster(start_cluster=0, partitions=2, storage_count=2)
def testClientReadingDuringTweak(self, cluster):
# XXX: Currently, the test passes because data of dropped cells are not
# deleted while the cluster is operational: this is only done
# during the RECOVERING phase. But we'll want to be able to free
# disk space without service interruption, and for this the client
# may have to retry reading data from the new cells. If s0 deleted
# all data for partition 1, the test would fail with a POSKeyError.
cluster = NEOCluster(partitions=2, storage_count=2)
s0, s1 = cluster.storage_list
try:
if 1:
cluster.start([s0])
storage = cluster.getZODBStorage()
oid = p64(1)
......@@ -518,16 +479,13 @@ class ReplicationTests(NEOThreadedTest):
cluster.neoctl.enableStorageList([s1.uuid])
cluster.neoctl.tweakPartitionTable()
with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet:
isinstance(packet, Packets.NotifyPartitionChanges))
m2c.delayNotifyPartitionChanges()
self.tic()
self.assertEqual('foo', storage.load(oid)[0])
finally:
cluster.stop()
def testResumingReplication(self):
cluster = NEOCluster(replicas=1)
try:
@with_cluster(start_cluster=0, replicas=1)
def testResumingReplication(self, cluster):
if 1:
s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,))
t, c = cluster.getTransaction()
......@@ -552,10 +510,43 @@ class ReplicationTests(NEOThreadedTest):
s0.stop()
cluster.join((s0,))
t0, t1, t2 = c.db().storage.iterator()
finally:
cluster.stop()
def testCheckReplicas(self):
@with_cluster(start_cluster=0, replicas=1)
def testReplicationBlockedByUnfinished(self, cluster):
if 1:
s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,))
storage = cluster.getZODBStorage()
oid = storage.new_oid()
tid = None
expected = 'UO'
for n in 1, 0:
# On first iteration, the transaction will block replication
# until tpc_finish.
# We do a second iteration as a quick check that the cluster
# remains functional after such a scenario.
txn = transaction.Transaction()
storage.tpc_begin(txn)
tid = storage.store(oid, tid, 'foo', '', txn)
if n:
# Start the outdated storage.
s1.start()
self.tic()
cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable()
self.tic()
self.assertPartitionTable(cluster, expected)
storage.tpc_vote(txn)
self.assertPartitionTable(cluster, expected)
tid = storage.tpc_finish(txn)
self.tic() # replication resumes and ends
expected = 'UU'
self.assertPartitionTable(cluster, expected)
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
@with_cluster(partitions=5, replicas=2, storage_count=3)
def testCheckReplicas(self, cluster):
from neo.storage import checker
def corrupt(offset):
s0, s1, s2 = (storage_dict[cell.getUUID()]
......@@ -570,14 +561,11 @@ class ReplicationTests(NEOThreadedTest):
for cell in row[1]
if cell[1] == CellStates.CORRUPTED]))
self.assertEqual(expected_state, cluster.neoctl.getClusterState())
np = 5
np = cluster.num_partitions
tid_count = np * 3
corrupt_tid = tid_count // 2
check_dict = dict.fromkeys(xrange(np))
cluster = NEOCluster(partitions=np, replicas=2, storage_count=3)
try:
checker.CHECK_COUNT = 2
cluster.start()
with Patch(checker, CHECK_COUNT=2):
cluster.populate([range(np*2)] * tid_count)
storage_dict = {x.uuid: x for x in cluster.storage_list}
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
......@@ -597,9 +585,6 @@ class ReplicationTests(NEOThreadedTest):
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
self.tic()
check(ClusterStates.RECOVERING, 4)
finally:
checker.CHECK_COUNT = CHECK_COUNT
cluster.stop()
@backup_test()
def testBackupReadOnlyAccess(self, backup):
......@@ -630,7 +615,7 @@ class ReplicationTests(NEOThreadedTest):
# commit new data to U
txn = transaction.Transaction()
txn.note('test transaction %i' % i)
txn.note(u'test transaction %s' % i)
Z.tpc_begin(txn)
oid = Z.new_oid()
Z.store(oid, None, '%s-%i' % (oid, i), '', txn)
......
#
# Copyright (C) 2015-2016 Nexedi SA
# Copyright (C) 2015-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -36,12 +36,8 @@ class SSLTests(SSLMixin, test.Test):
testDeadlockAvoidance = None # XXX why this fails?
testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails?
def testAbortConnection(self):
for after_handshake in 1, 0:
try:
conn.reset()
except UnboundLocalError:
conn = self.getLoopbackConnection()
def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn:
conn.ask(Packets.Ping())
connector = conn.getConnector()
del connector.connect_limit[connector.addr]
......@@ -58,6 +54,9 @@ class SSLTests(SSLMixin, test.Test):
conn.em.poll(1)
self.assertIs(conn.getConnector(), None)
def testAbortConnectionBeforeHandshake(self):
self.testAbortConnection(0)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
# do not repeat slowest tests with SSL
testBackupNodeLost = testBackupNormalCase = None # TODO recheck
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -42,9 +42,11 @@ class ZODBTestCase(TestCase):
self.open()
def _tearDown(self, success):
self._storage.cleanup()
try:
if functional:
self.neo.stop()
else:
self.neo.stop(None)
except Exception:
if success:
raise
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -14,14 +14,23 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
import unittest
from ZODB.tests.BasicStorage import BasicStorage
from ZODB.tests.StorageTestBase import StorageTestBase
from . import ZODBTestCase
from .. import Patch, threaded
class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage):
pass
def check_checkCurrentSerialInTransaction(self):
x = time.time() + 10
def tic_loop():
while time.time() < x:
yield
with Patch(threaded, TIC_LOOP=tic_loop()):
super(BasicTests, self).check_checkCurrentSerialInTransaction()
if __name__ == "__main__":
suite = unittest.makeSuite(BasicTests, 'check')
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......
#
# Copyright (C) 2009-2016 Nexedi SA
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -30,19 +30,6 @@ class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests):
self._db.close()
super(NEOZODBTests, self)._tearDown(success)
def checkMultipleUndoInOneTransaction(self):
# XXX: Upstream test accesses a persistent object outside a transaction
# (it should call transaction.begin() after the last commit)
# so disable our Connection.afterCompletion optimization.
# This should really be discussed on zodb-dev ML.
from ZODB.Connection import Connection
afterCompletion = Connection.__dict__['afterCompletion']
try:
Connection.afterCompletion = Connection.__dict__['newTransaction']
super(NEOZODBTests, self).checkMultipleUndoInOneTransaction()
finally:
Connection.afterCompletion = afterCompletion
if __name__ == "__main__":
suite = unittest.makeSuite(NEOZODBTests, 'check')
unittest.main(defaultTest='suite')
......@@ -14,7 +14,8 @@ Topic :: Database
Topic :: Software Development :: Libraries :: Python Modules
"""
if not os.path.exists('mock.py'):
mock = 'neo/tests/mock.py'
if not os.path.exists(mock):
import cStringIO, hashlib,subprocess, urllib, zipfile
x = 'pythonmock-0.1.0.zip'
try:
......@@ -25,7 +26,7 @@ if not os.path.exists('mock.py'):
mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py')
if hashlib.md5(mock_py).hexdigest() != '79f42f390678e5195d9ce4ae43bd18ec':
raise EnvironmentError("MD5 checksum mismatch downloading 'mock.py'")
open('mock.py', 'w').write(mock_py)
open(mock, 'w').write(mock_py)
zodb_require = ['ZODB3>=3.10dev']
......@@ -58,7 +59,7 @@ else:
setup(
name = 'neoppod',
version = '1.7.0',
version = '1.7.1',
description = __doc__.strip(),
author = 'Nexedi SA',
author_email = 'neo-dev@erp5.org',
......@@ -69,7 +70,6 @@ setup(
long_description = ".. contents::\n\n" + open('README.rst').read()
+ "\n" + open('CHANGELOG.rst').read(),
packages = find_packages(),
py_modules = ['mock'],
entry_points = {
'console_scripts': [
# XXX: we'd like not to generate scripts for unwanted features
......
#!/bin/sh -e
for COV in coverage python-coverage
do type $COV && break
done >/dev/null 2>&1 || exit
#!/usr/bin/env python
import os, re, sys, shutil
from coverage.cmdline import main, CmdOptionParser
sys.argv.insert(1, 'html')
del CmdOptionParser.get_prog_name
$COV html "$@"
# https://bitbucket.org/ned/coveragepy/issues/474/javascript-in-html-captures-all-keys
sed -i "
/assign_shortkeys *=/s/$/return;/
/^ *\.bind('keydown',/s,^,//,
" htmlcov/coverage_html.js
shutil_copyfile = shutil.copyfile
def copyfile(src, dst):
if os.path.basename(dst) == 'coverage_html.js':
with open(src) as f:
js = f.read()
js = re.sub(r"(assign_shortkeys.*\{)", r"\1return;", js)
js = re.sub(r"^( *\.bind\('keydown',)", r"//\1", js, flags=re.M)
with open(dst, 'w') as f:
f.write(js)
else:
shutil_copyfile(src, dst)
shutil.copyfile = copyfile
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment