Commit f873e151 authored by Kirill Smelkov's avatar Kirill Smelkov

Merge branch 'master' into t

* master: (23 commits)
  Make the identification of the primary master easier with 'neoctl print node'
  client: fix AttributeError when connected to a master that happens to be secondary
  qa: speed up SQlite tests by accessing DBs in unsafe mode (e.g. no sync)
  master: fix identification of unknown masters
  Better logging of connector errors
  qa: relax assertion in testDeadlockAvoidanceBeforeInvolvingAnotherNode
  client: fix possible data corruption after conflict resolutions with replicas
  qa: new --readable-tid runner option
  qa: make Patch work on functions
  qa: new --continue-on-success and --stop-on-success runner options
  doc: add advice about the number of master nodes to set up
  Improvements to --dynamic-master-list
  Make NodeManager.remove stricter
  Clean up neo.lib.protocol
  Use ProtocolError instead of Notify for unexpected answers, and drop Notify
  Rename node states: DOWN -> UNKNOWN, TEMPORARILY_DOWN -> DOWN
  Remove UNKNOWN node state
  Reimplement election (of the primary master)
  Use existing generic way to ignore AcceptIdentification on closed connections
  Remove BROKEN node state
  ...
parents 421fda44 09bc404f
Documentation Documentation
- Clarify node state signification, and consider renaming them in the code.
Ideas:
TEMPORARILY_DOWN becomes UNAVAILABLE
BROKEN is removed ?
- Clarify the use of each error codes: - Clarify the use of each error codes:
- NOT_READY removed (connection kept opened until ready) - NOT_READY removed (connection kept opened until ready)
- Split PROTOCOL_ERROR (BAD IDENTIFICATION, ...) - Split PROTOCOL_ERROR (BAD IDENTIFICATION, ...)
...@@ -25,8 +21,6 @@ ...@@ -25,8 +21,6 @@
This is mainly the case for : This is mainly the case for :
- Client rejected before the cluster is operational - Client rejected before the cluster is operational
- Empty storages rejected during recovery process - Empty storages rejected during recovery process
Masters implies in the election process should still reject any connection
as the primary master is still unknown.
- Implement transaction garbage collection API (FEATURE) - Implement transaction garbage collection API (FEATURE)
NEO packing implementation does not update transaction metadata when NEO packing implementation does not update transaction metadata when
deleting object revisions. This inconsistency must be made possible to deleting object revisions. This inconsistency must be made possible to
...@@ -36,9 +30,9 @@ ...@@ -36,9 +30,9 @@
The same code to ask/receive node list and partition table exists in too The same code to ask/receive node list and partition table exists in too
many places. many places.
- Clarify handler methods to call when a connection is accepted from a - Clarify handler methods to call when a connection is accepted from a
listening conenction and when remote node is identified listening connection and when remote node is identified
(cf. neo/lib/bootstrap.py). (cf. neo/lib/bootstrap.py).
- Review PENDING/HIDDEN/SHUTDOWN states, don't use notifyNodeInformation() - Review PENDING/SHUTDOWN states, don't use notifyNodeInformation()
to do a state-switch, use a exception-based mechanism ? (CODE) to do a state-switch, use a exception-based mechanism ? (CODE)
- Review handler split (CODE) - Review handler split (CODE)
The current handler split is the result of small incremental changes. A The current handler split is the result of small incremental changes. A
...@@ -71,7 +65,7 @@ ...@@ -71,7 +65,7 @@
which is an issue on big databases. (SPEED) which is an issue on big databases. (SPEED)
- Pack segmentation & throttling (HIGH AVAILABILITY) - Pack segmentation & throttling (HIGH AVAILABILITY)
In its current implementation, pack runs in one call on all storage nodes In its current implementation, pack runs in one call on all storage nodes
at the same time, which lcoks down the whole cluster. This task should at the same time, which locks down the whole cluster. This task should
be split in chunks and processed in "background" on storage nodes. be split in chunks and processed in "background" on storage nodes.
Packing throttling should probably be at the lowest possible priority Packing throttling should probably be at the lowest possible priority
(below interactive use and below replication). (below interactive use and below replication).
...@@ -122,9 +116,9 @@ ...@@ -122,9 +116,9 @@
Later Later
- Consider auto-generating cluster name upon initial startup (it might - Consider auto-generating cluster name upon initial startup (it might
actualy be a partition property). actually be a partition property).
- Consider ways to centralise the configuration file, or make the - Consider ways to centralise the configuration file, or make the
configuration updatable automaticaly on all nodes. configuration updatable automatically on all nodes.
- Consider storing some metadata on master nodes (partition table [version], - Consider storing some metadata on master nodes (partition table [version],
...). This data should be treated non-authoritatively, as a way to lower ...). This data should be treated non-authoritatively, as a way to lower
the probability to use an outdated partition table. the probability to use an outdated partition table.
......
# Note: Unless otherwise noted, all parameters in this configuration file # Note: Unless otherwise noted, all parameters in this configuration file
# must be identical for all nodes in a given cluster. # must be identical for all nodes in a given cluster.
# This file is optionnal: parameters can be given at the command line. # This file is optional: parameters can be given at the command line.
# See SafeConfigParser at https://docs.python.org/2/library/configparser.html # See SafeConfigParser at https://docs.python.org/2/library/configparser.html
# for more information about the syntax. # for more information about the syntax.
...@@ -17,9 +17,10 @@ ...@@ -17,9 +17,10 @@
cluster: test cluster: test
# The list of master nodes # The list of master nodes
# Master nodes not not in this list will be rejected by the cluster.
# This list should be identical for all nodes in a given cluster for # This list should be identical for all nodes in a given cluster for
# maximum availability. # maximum availability.
# With replicas, it is recommended to have 1 master node per machine
# (physical or not). Otherwise, 1 is enough (but more do not harm).
masters: 127.0.0.1:10000 masters: 127.0.0.1:10000
# Partition table configuration # Partition table configuration
...@@ -58,7 +59,7 @@ partitions: 12 ...@@ -58,7 +59,7 @@ partitions: 12
# - MySQL: [user[:password]@]database[unix_socket] # - MySQL: [user[:password]@]database[unix_socket]
# Database must be created manually. # Database must be created manually.
# - SQLite: path # - SQLite: path
# engine: Optionnal parameter for MySQL. # engine: Optional parameter for MySQL.
# Can be InnoDB (default), RocksDB or TokuDB. # Can be InnoDB (default), RocksDB or TokuDB.
# Admin node # Admin node
......
...@@ -194,17 +194,17 @@ class Application(ThreadedApplication): ...@@ -194,17 +194,17 @@ class Application(ThreadedApplication):
self.nm.reset() self.nm.reset()
if self.primary_master_node is not None: if self.primary_master_node is not None:
# If I know a primary master node, pinpoint it. # If I know a primary master node, pinpoint it.
self.trying_master_node = self.primary_master_node node = self.primary_master_node
self.primary_master_node = None self.primary_master_node = None
else: else:
# Otherwise, check one by one. # Otherwise, check one by one.
master_list = self.nm.getMasterList() master_list = self.nm.getMasterList()
index = (index + 1) % len(master_list) index = (index + 1) % len(master_list)
self.trying_master_node = master_list[index] node = master_list[index]
# Connect to master # Connect to master
conn = MTClientConnection(self, conn = MTClientConnection(self,
self.notifications_handler, self.notifications_handler,
node=self.trying_master_node, node=node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification( p = Packets.RequestIdentification(
NodeTypes.CLIENT, self.uuid, None, self.name, None) NodeTypes.CLIENT, self.uuid, None, self.name, None)
...@@ -212,10 +212,8 @@ class Application(ThreadedApplication): ...@@ -212,10 +212,8 @@ class Application(ThreadedApplication):
ask(conn, p, handler=handler) ask(conn, p, handler=handler)
except ConnectionClosed: except ConnectionClosed:
fail_count += 1 fail_count += 1
continue else:
# If we reached the primary master node, mark as connected self.primary_master_node = node
if self.primary_master_node is not None and \
self.primary_master_node is self.trying_master_node:
break break
else: else:
raise NEOPrimaryMasterLost( raise NEOPrimaryMasterLost(
...@@ -444,8 +442,8 @@ class Application(ThreadedApplication): ...@@ -444,8 +442,8 @@ class Application(ThreadedApplication):
# Store object in tmp cache # Store object in tmp cache
packet = Packets.AskStoreObject(oid, serial, compression, packet = Packets.AskStoreObject(oid, serial, compression,
checksum, compressed_data, data_serial, ttid) checksum, compressed_data, data_serial, ttid)
txn_context.data_dict[oid] = data, txn_context.write( txn_context.data_dict[oid] = data, serial, txn_context.write(
self, packet, oid, oid=oid, serial=serial) self, packet, oid, oid=oid)
while txn_context.data_size >= self._cache._max_size: while txn_context.data_size >= self._cache._max_size:
self._waitAnyTransactionMessage(txn_context) self._waitAnyTransactionMessage(txn_context)
...@@ -462,13 +460,13 @@ class Application(ThreadedApplication): ...@@ -462,13 +460,13 @@ class Application(ThreadedApplication):
# This is also done atomically, to avoid race conditions # This is also done atomically, to avoid race conditions
# with PrimaryNotificationsHandler.notifyDeadlock # with PrimaryNotificationsHandler.notifyDeadlock
try: try:
oid, (serial, conflict_serial) = pop_conflict() oid, serial = pop_conflict()
except KeyError: except KeyError:
return return
try: try:
data = data_dict.pop(oid)[0] data, old_serial, _ = data_dict.pop(oid)
except KeyError: except KeyError:
assert oid is conflict_serial is None, (oid, conflict_serial) assert oid is None, (oid, serial)
# Storage refused us from taking object lock, to avoid a # Storage refused us from taking object lock, to avoid a
# possible deadlock. TID is actually used for some kind of # possible deadlock. TID is actually used for some kind of
# "locking priority": when a higher value has the lock, # "locking priority": when a higher value has the lock,
...@@ -487,33 +485,32 @@ class Application(ThreadedApplication): ...@@ -487,33 +485,32 @@ class Application(ThreadedApplication):
self._askStorageForWrite(txn_context, uuid, packet) self._askStorageForWrite(txn_context, uuid, packet)
else: else:
if data is CHECKED_SERIAL: if data is CHECKED_SERIAL:
raise ReadConflictError(oid=oid, serials=(conflict_serial, raise ReadConflictError(oid=oid,
serial)) serials=(serial, old_serial))
# TODO: data can be None if a conflict happens during undo # TODO: data can be None if a conflict happens during undo
if data: if data:
txn_context.data_size -= len(data) txn_context.data_size -= len(data)
if self.last_tid < conflict_serial: if self.last_tid < serial:
self.sync() # possible late invalidation (very rare) self.sync() # possible late invalidation (very rare)
try: try:
data = tryToResolveConflict(oid, conflict_serial, data = tryToResolveConflict(oid, serial, old_serial, data)
serial, data)
except ConflictError: except ConflictError:
logging.info( logging.info(
'Conflict resolution failed for %s@%s with %s', 'Conflict resolution failed for %s@%s with %s',
dump(oid), dump(serial), dump(conflict_serial)) dump(oid), dump(old_serial), dump(serial))
# With recent ZODB, get_pickle_metadata (from ZODB.utils) # With recent ZODB, get_pickle_metadata (from ZODB.utils)
# does not support empty values, so do not pass 'data' # does not support empty values, so do not pass 'data'
# in this case. # in this case.
raise ConflictError(oid=oid, serials=(conflict_serial, raise ConflictError(oid=oid, serials=(serial, old_serial),
serial), data=data or None) data=data or None)
else: else:
logging.info( logging.info(
'Conflict resolution succeeded for %s@%s with %s', 'Conflict resolution succeeded for %s@%s with %s',
dump(oid), dump(serial), dump(conflict_serial)) dump(oid), dump(old_serial), dump(serial))
# Mark this conflict as resolved # Mark this conflict as resolved
resolved_dict[oid] = conflict_serial resolved_dict[oid] = serial
# Try to store again # Try to store again
self._store(txn_context, oid, conflict_serial, data) self._store(txn_context, oid, serial, data)
def _askStorageForWrite(self, txn_context, uuid, packet): def _askStorageForWrite(self, txn_context, uuid, packet):
node = self.nm.getByUUID(uuid) node = self.nm.getByUUID(uuid)
...@@ -929,7 +926,7 @@ class Application(ThreadedApplication): ...@@ -929,7 +926,7 @@ class Application(ThreadedApplication):
assert oid not in txn_context.cache_dict, oid assert oid not in txn_context.cache_dict, oid
assert oid not in txn_context.data_dict, oid assert oid not in txn_context.data_dict, oid
packet = Packets.AskCheckCurrentSerial(ttid, oid, serial) packet = Packets.AskCheckCurrentSerial(ttid, oid, serial)
txn_context.data_dict[oid] = CHECKED_SERIAL, txn_context.write( txn_context.data_dict[oid] = CHECKED_SERIAL, serial, txn_context.write(
self, packet, oid, 0, oid=oid, serial=serial) self, packet, oid, 0, oid=oid)
self._waitAnyTransactionMessage(txn_context, False) self._waitAnyTransactionMessage(txn_context, False)
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import PrimaryElected
from neo.lib.handler import MTEventHandler from neo.lib.handler import MTEventHandler
from neo.lib.pt import MTPartitionTable as PartitionTable from neo.lib.pt import MTPartitionTable as PartitionTable
from neo.lib.protocol import NodeStates, ProtocolError from neo.lib.protocol import NodeStates
from neo.lib.util import dump
from . import AnswerBaseHandler from . import AnswerBaseHandler
from ..exception import NEOStorageError from ..exception import NEOStorageError
...@@ -26,10 +26,6 @@ from ..exception import NEOStorageError ...@@ -26,10 +26,6 @@ from ..exception import NEOStorageError
class PrimaryBootstrapHandler(AnswerBaseHandler): class PrimaryBootstrapHandler(AnswerBaseHandler):
""" Bootstrap handler used when looking for the primary master """ """ Bootstrap handler used when looking for the primary master """
def notReady(self, conn, message):
self.app.trying_master_node = None
conn.close()
def answerPartitionTable(self, conn, ptid, row_list): def answerPartitionTable(self, conn, ptid, row_list):
assert row_list assert row_list
self.app.pt.load(ptid, row_list, self.app.nm) self.app.pt.load(ptid, row_list, self.app.nm)
...@@ -40,57 +36,14 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -40,57 +36,14 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
class PrimaryNotificationsHandler(MTEventHandler): class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """ """ Handler that process the notifications from the primary master """
def _acceptIdentification(self, node, uuid, num_partitions, def notPrimaryMaster(self, *args):
num_replicas, your_uuid, primary, known_master_list): try:
app = self.app super(PrimaryNotificationsHandler, self).notPrimaryMaster(*args)
except PrimaryElected, e:
# Register new master nodes. app.primary_master_node, = e.args
found = False
conn_address = node.getAddress()
for node_address, node_uuid in known_master_list:
if node_address == conn_address:
assert uuid == node_uuid, (dump(uuid), dump(node_uuid))
found = True
n = app.nm.getByAddress(node_address)
if n is None:
n = app.nm.createMaster(address=node_address)
if node_uuid is not None and n.getUUID() != node_uuid:
n.setUUID(node_uuid)
assert found, (node, dump(uuid), known_master_list)
conn = node.getConnection()
if primary is not None:
primary_node = app.nm.getByAddress(primary)
if primary_node is None:
# I don't know such a node. Probably this information
# is old. So ignore it.
logging.warning('Unknown primary master: %s. Ignoring.',
primary)
return
else:
if app.trying_master_node is not primary_node:
app.trying_master_node = None
conn.close()
app.primary_master_node = primary_node
else:
if app.primary_master_node is not None:
# The primary master node is not a primary master node
# any longer.
app.primary_master_node = None
app.trying_master_node = None
conn.close()
return
# the master must give an UUID
if your_uuid is None:
raise ProtocolError('No UUID supplied')
app.uuid = your_uuid
logging.info('Got an UUID: %s', dump(app.uuid))
app.id_timestamp = None
# Always create partition table def _acceptIdentification(self, node, num_partitions, num_replicas):
app.pt = PartitionTable(num_partitions, num_replicas) self.app.pt = PartitionTable(num_partitions, num_replicas)
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
app = self.app app = self.app
...@@ -189,7 +142,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -189,7 +142,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
def notifyNodeInformation(self, conn, timestamp, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
super(PrimaryNotificationsHandler, self).notifyNodeInformation( super(PrimaryNotificationsHandler, self).notifyNodeInformation(
conn, timestamp, node_list) conn, timestamp, node_list)
# XXX: 'update' automatically closes DOWN nodes. Do we really want # XXX: 'update' automatically closes UNKNOWN nodes. Do we really want
# to do the same thing for nodes in other non-running states ? # to do the same thing for nodes in other non-running states ?
getByUUID = self.app.nm.getByUUID getByUUID = self.app.nm.getByUUID
for node in node_list: for node in node_list:
...@@ -201,7 +154,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -201,7 +154,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
def notifyDeadlock(self, conn, ttid, locking_tid): def notifyDeadlock(self, conn, ttid, locking_tid):
for txn_context in self.app.txn_contexts(): for txn_context in self.app.txn_contexts():
if txn_context.ttid == ttid: if txn_context.ttid == ttid:
txn_context.conflict_dict[None] = locking_tid, None txn_context.conflict_dict[None] = locking_tid
txn_context.wakeup(conn) txn_context.wakeup(conn)
break break
......
...@@ -42,13 +42,8 @@ class StorageEventHandler(MTEventHandler): ...@@ -42,13 +42,8 @@ class StorageEventHandler(MTEventHandler):
self.app.cp.removeConnection(node) self.app.cp.removeConnection(node)
super(StorageEventHandler, self).connectionFailed(conn) super(StorageEventHandler, self).connectionFailed(conn)
def _acceptIdentification(self, node, def _acceptIdentification(*args):
uuid, num_partitions, num_replicas, your_uuid, primary, pass
master_list):
assert self.app.master_conn is None or \
primary == self.app.master_conn.getAddress(), (
primary, self.app.master_conn)
assert uuid == node.getUUID(), (uuid, node.getUUID())
class StorageBootstrapHandler(AnswerBaseHandler): class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """ """ Handler used when connecting to a storage node """
...@@ -63,7 +58,7 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -63,7 +58,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerObject(self, conn, oid, *args): def answerObject(self, conn, oid, *args):
self.app.setHandlerData(args) self.app.setHandlerData(args)
def answerStoreObject(self, conn, conflict, oid, serial): def answerStoreObject(self, conn, conflict, oid):
txn_context = self.app.getHandlerData() txn_context = self.app.getHandlerData()
if conflict: if conflict:
# Conflicts can not be resolved now because 'conn' is locked. # Conflicts can not be resolved now because 'conn' is locked.
...@@ -80,7 +75,7 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -80,7 +75,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
# If this conflict is not already resolved, mark it for # If this conflict is not already resolved, mark it for
# resolution. # resolution.
if txn_context.resolved_dict.get(oid, '') < conflict: if txn_context.resolved_dict.get(oid, '') < conflict:
txn_context.conflict_dict[oid] = serial, conflict txn_context.conflict_dict[oid] = conflict
else: else:
txn_context.written(self.app, conn.getUUID(), oid) txn_context.written(self.app, conn.getUUID(), oid)
...@@ -112,10 +107,10 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -112,10 +107,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
except KeyError: except KeyError:
if resolved: if resolved:
# We should still be waiting for an answer from this node. # We should still be waiting for an answer from this node.
assert conn.uuid in txn_context.data_dict[oid][1] assert conn.uuid in txn_context.data_dict[oid][2]
return return
assert oid in txn_context.data_dict assert oid in txn_context.data_dict
if serial <= txn_context.conflict_dict.get(oid, ('',))[0]: if serial <= txn_context.conflict_dict.get(oid, ''):
# Another node already reported this conflict or a newer, # Another node already reported this conflict or a newer,
# by answering to this rebase or to the previous store. # by answering to this rebase or to the previous store.
return return
...@@ -141,8 +136,8 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -141,8 +136,8 @@ class StorageAnswersHandler(AnswerBaseHandler):
if cached: if cached:
assert cached == data assert cached == data
txn_context.cache_size -= size txn_context.cache_size -= size
txn_context.data_dict[oid] = data, None txn_context.data_dict[oid] = data, serial, None
txn_context.conflict_dict[oid] = serial, conflict txn_context.conflict_dict[oid] = conflict
def answerStoreTransaction(self, conn): def answerStoreTransaction(self, conn):
pass pass
......
...@@ -99,7 +99,6 @@ class ConnectionPool(object): ...@@ -99,7 +99,6 @@ class ConnectionPool(object):
return conn return conn
def removeConnection(self, node): def removeConnection(self, node):
"""Explicitly remove connection when a node is broken."""
self.connection_dict.pop(node.getUUID(), None) self.connection_dict.pop(node.getUUID(), None)
def closeAll(self): def closeAll(self):
......
...@@ -40,11 +40,11 @@ class Transaction(object): ...@@ -40,11 +40,11 @@ class Transaction(object):
self.queue = SimpleQueue() self.queue = SimpleQueue()
self.txn = txn self.txn = txn
# data being stored # data being stored
self.data_dict = {} # {oid: (value, [node_id])} self.data_dict = {} # {oid: (value, serial, [node_id])}
# data stored: this will go to the cache on tpc_finish # data stored: this will go to the cache on tpc_finish
self.cache_dict = {} # {oid: value} self.cache_dict = {} # {oid: value}
# conflicts to resolve # conflicts to resolve
self.conflict_dict = {} # {oid: (base_serial, serial)} self.conflict_dict = {} # {oid: serial}
# resolved conflicts # resolved conflicts
self.resolved_dict = {} # {oid: serial} self.resolved_dict = {} # {oid: serial}
# Keys are node ids instead of Node objects because a node may # Keys are node ids instead of Node objects because a node may
...@@ -98,7 +98,7 @@ class Transaction(object): ...@@ -98,7 +98,7 @@ class Transaction(object):
# the data in self.data_dict until all nodes have answered so we remain # the data in self.data_dict until all nodes have answered so we remain
# able to resolve conflicts. # able to resolve conflicts.
try: try:
data, uuid_list = self.data_dict[oid] data, serial, uuid_list = self.data_dict[oid]
uuid_list.remove(uuid) uuid_list.remove(uuid)
except KeyError: except KeyError:
# 1. store to S1 and S2 # 1. store to S1 and S2
......
...@@ -21,6 +21,7 @@ from .node import NodeManager ...@@ -21,6 +21,7 @@ from .node import NodeManager
class BaseApplication(object): class BaseApplication(object):
server = None
ssl = None ssl = None
def __init__(self, ssl=None, dynamic_master_list=None): def __init__(self, ssl=None, dynamic_master_list=None):
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import logging from . import logging
from .exception import PrimaryElected
from .handler import EventHandler from .handler import EventHandler
from .protocol import uuid_str, Packets from .protocol import Packets
from .connection import ClientConnection from .connection import ClientConnection
...@@ -24,7 +25,6 @@ class BootstrapManager(EventHandler): ...@@ -24,7 +25,6 @@ class BootstrapManager(EventHandler):
""" """
Manage the bootstrap stage, lookup for the primary master then connect to it Manage the bootstrap stage, lookup for the primary master then connect to it
""" """
accepted = False
def __init__(self, app, node_type, server=None): def __init__(self, app, node_type, server=None):
""" """
...@@ -32,85 +32,30 @@ class BootstrapManager(EventHandler): ...@@ -32,85 +32,30 @@ class BootstrapManager(EventHandler):
primary master node, connect to it then returns when the master node primary master node, connect to it then returns when the master node
is ready. is ready.
""" """
self.primary = None
self.server = server self.server = server
self.node_type = node_type self.node_type = node_type
self.num_replicas = None self.num_replicas = None
self.num_partitions = None self.num_partitions = None
self.current = None
app.nm.reset() app.nm.reset()
uuid = property(lambda self: self.app.uuid) uuid = property(lambda self: self.app.uuid)
def announcePrimary(self, conn):
# We found the primary master early enough to be notified of election
# end. Lucky. Anyway, we must carry on with identification request, so
# nothing to do here.
pass
def connectionCompleted(self, conn): def connectionCompleted(self, conn):
"""
Triggered when the network connection is successful.
Now ask who's the primary.
"""
EventHandler.connectionCompleted(self, conn) EventHandler.connectionCompleted(self, conn)
self.current.setRunning()
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid, conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.app.name, None)) self.server, self.app.name, None))
def connectionFailed(self, conn): def connectionFailed(self, conn):
"""
Triggered when the network connection failed.
Restart bootstrap.
"""
EventHandler.connectionFailed(self, conn) EventHandler.connectionFailed(self, conn)
self.current = None self.current = None
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
"""
Triggered when an established network connection is lost.
Restart bootstrap.
"""
self.current.setTemporarilyDown()
self.current = None self.current = None
def notReady(self, conn, message): def _acceptIdentification(self, node, num_partitions, num_replicas):
""" assert self.current is node, (self.current, node)
The primary master send this message when it is still not ready to
handle the client node.
Close connection and restart.
"""
conn.close()
def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list):
nm = self.app.nm
# Register new master nodes.
for address, uuid in known_master_list:
master_node = nm.getByAddress(address)
if master_node is None:
master_node = nm.createMaster(address=address)
master_node.setUUID(uuid)
self.primary = nm.getByAddress(primary)
if self.primary is None or self.current is not self.primary:
# three cases here:
# - something goes wrong (unknown UUID)
# - this master doesn't know who's the primary
# - got the primary's uuid, so cut here
node.getConnection().close()
return
logging.info('connected to a primary master node')
self.num_partitions = num_partitions self.num_partitions = num_partitions
self.num_replicas = num_replicas self.num_replicas = num_replicas
if self.uuid != your_uuid:
# got an uuid from the primary master
self.app.uuid = your_uuid
logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.app.id_timestamp = None
self.accepted = True
def getPrimaryConnection(self): def getPrimaryConnection(self):
""" """
...@@ -122,25 +67,26 @@ class BootstrapManager(EventHandler): ...@@ -122,25 +67,26 @@ class BootstrapManager(EventHandler):
poll = app.em.poll poll = app.em.poll
index = 0 index = 0
self.current = None self.current = None
conn = None
# retry until identified to the primary # retry until identified to the primary
while not self.accepted: while True:
if self.current is None: try:
# conn closed while self.current:
conn = None if self.current.isIdentified():
return (self.current, self.current.getConnection(),
self.num_partitions, self.num_replicas)
poll(1)
except PrimaryElected, e:
if self.current:
self.current.getConnection().close()
self.current, = e.args
index = app.nm.getMasterList().index(self.current)
else:
# select a master # select a master
master_list = app.nm.getMasterList() master_list = app.nm.getMasterList()
index = (index + 1) % len(master_list) index = (index + 1) % len(master_list)
self.current = master_list[index] self.current = master_list[index]
if conn is None: ClientConnection(app, self, self.current)
# open the connection # Note that the connection may be already closed. This happens when
conn = ClientConnection(app, self, self.current)
# Yes, the connection may be already closed. This happens when
# the kernel reacts so quickly to a closed port that 'connect' # the kernel reacts so quickly to a closed port that 'connect'
# fails on the first call. In such case, poll(1) would deadlock # fails on the first call. In such case, poll(1) would deadlock
# if there's no other connection to timeout. # if there's no other connection to timeout.
if conn.isClosed():
continue
# still processing
poll(1)
return self.current, conn, self.num_partitions, self.num_replicas
...@@ -90,9 +90,7 @@ class ConfigurationManager(object): ...@@ -90,9 +90,7 @@ class ConfigurationManager(object):
def getMasters(self): def getMasters(self):
""" Get the master node list except itself """ """ Get the master node list except itself """
masters = self.__get('masters') return util.parseMasterList(self.__get('masters'))
# load master node list except itself
return util.parseMasterList(masters, except_node=self.getBind())
def getBind(self): def getBind(self):
""" Get the address to bind to """ """ Get the address to bind to """
......
...@@ -20,8 +20,7 @@ from time import time ...@@ -20,8 +20,7 @@ from time import time
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, \ from .protocol import uuid_str, Errors, PacketMalformedError, Packets
PacketMalformedError, Packets, ParserState
from .util import dummy_read_buffer, ReadBuffer from .util import dummy_read_buffer, ReadBuffer
CRITICAL_TIMEOUT = 30 CRITICAL_TIMEOUT = 30
...@@ -113,7 +112,8 @@ class HandlerSwitcher(object): ...@@ -113,7 +112,8 @@ class HandlerSwitcher(object):
self._is_handling = False self._is_handling = False
def _handle(self, connection, packet): # NOTE incoming packet -> handle -> dispatch ... def _handle(self, connection, packet): # NOTE incoming packet -> handle -> dispatch ...
assert len(self._pending) == 1 or self._pending[0][0] pending = self._pending
assert len(pending) == 1 or pending[0][0], pending
logging.packet(connection, packet, False) logging.packet(connection, packet, False)
if connection.isClosed() and (connection.isAborted() or if connection.isClosed() and (connection.isAborted() or
packet.ignoreOnClosedConnection()): packet.ignoreOnClosedConnection()):
...@@ -122,29 +122,30 @@ class HandlerSwitcher(object): ...@@ -122,29 +122,30 @@ class HandlerSwitcher(object):
return return
if not packet.isResponse(): # notification if not packet.isResponse(): # notification
# XXX: If there are several handlers, which one to use ? # XXX: If there are several handlers, which one to use ?
self._pending[0][1].packetReceived(connection, packet) pending[0][1].packetReceived(connection, packet)
return return
msg_id = packet.getId() msg_id = packet.getId()
request_dict, handler = self._pending[0] request_dict, handler = pending[0]
# checkout the expected answer class # checkout the expected answer class
try: try:
klass, _, _, kw = request_dict.pop(msg_id) klass, _, _, kw = request_dict.pop(msg_id)
except KeyError: except KeyError:
klass = None klass = None
kw = {} kw = {}
try:
if klass and isinstance(packet, klass) or packet.isError(): if klass and isinstance(packet, klass) or packet.isError():
handler.packetReceived(connection, packet, kw) handler.packetReceived(connection, packet, kw)
else: else:
logging.error('Unexpected answer %r in %r', packet, connection) logging.error('Unexpected answer %r in %r', packet, connection)
if not connection.isClosed(): if not connection.isClosed():
notification = Packets.Notify('Unexpected answer: %r' % packet) connection.answer(Errors.ProtocolError(
connection.send(notification) 'Unexpected answer: %r' % packet))
connection.abort() connection.abort()
# handler.peerBroken(connection) finally:
# apply a pending handler if no more answers are pending # apply a pending handler if no more answers are pending
while len(self._pending) > 1 and not self._pending[0][0]: while len(pending) > 1 and not pending[0][0]:
del self._pending[0] del pending[0]
logging.debug('Apply handler %r on %r', self._pending[0][1], logging.debug('Apply handler %r on %r', pending[0][1],
connection) connection)
if msg_id == self._next_timeout_msg_id: if msg_id == self._next_timeout_msg_id:
self._updateNextTimeout() self._updateNextTimeout()
...@@ -258,10 +259,12 @@ class BaseConnection(object): ...@@ -258,10 +259,12 @@ class BaseConnection(object):
) )
def setHandler(self, handler): def setHandler(self, handler):
if self._handlers.setHandler(handler): changed = self._handlers.setHandler(handler)
logging.debug('Set handler %r on %r', handler, self) if changed:
logging.debug('Handler changed on %r', self)
else: else:
logging.debug('Delay handler %r on %r', handler, self) logging.debug('Delay handler %r on %r', handler, self)
return changed
def getUUID(self): def getUUID(self):
return None return None
...@@ -315,9 +318,11 @@ class ListeningConnection(BaseConnection): ...@@ -315,9 +318,11 @@ class ListeningConnection(BaseConnection):
if self._ssl: if self._ssl:
conn.connecting = True conn.connecting = True
connector.ssl(self._ssl, conn._connected) connector.ssl(self._ssl, conn._connected)
self.em.addWriter(conn) # Nothing to send as long as we haven't received a ClientHello
# message.
else: else:
conn._connected() conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION
def getAddress(self): def getAddress(self):
return self.connector.getAddress() return self.connector.getAddress()
...@@ -336,6 +341,7 @@ class Connection(BaseConnection): ...@@ -336,6 +341,7 @@ class Connection(BaseConnection):
server = False server = False
peer_id = None peer_id = None
_next_timeout = None _next_timeout = None
_parser_state = None
_timeout = 0 _timeout = 0
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
...@@ -347,7 +353,6 @@ class Connection(BaseConnection): ...@@ -347,7 +353,6 @@ class Connection(BaseConnection):
self.uuid = None self.uuid = None
self._queue = [] self._queue = []
self._on_close = None self._on_close = None
self._parser_state = ParserState()
def _getReprInfo(self): def _getReprInfo(self):
r, flags = super(Connection, self)._getReprInfo() r, flags = super(Connection, self)._getReprInfo()
...@@ -466,20 +471,59 @@ class Connection(BaseConnection): ...@@ -466,20 +471,59 @@ class Connection(BaseConnection):
except ConnectorException: except ConnectorException:
self._closure() self._closure()
def _parse(self):
read = self.read_buf.read
version = read(4)
if version is None:
return
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE,
PACKET_HEADER_FORMAT, Packets)
if version != ENCODED_VERSION:
logging.warning('Protocol version mismatch with %r', self)
raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size
unpack = PACKET_HEADER_FORMAT.unpack
def parse():
state = self._parser_state
if state is None:
header = read(header_size)
if header is None:
return
msg_id, msg_type, msg_len = unpack(header)
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
else:
msg_id, packet_klass, msg_len = state
data = read(msg_len)
if data is None:
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
self._parse = parse
return parse()
def readable(self): def readable(self):
"""Called when self is readable.""" """Called when self is readable."""
# last known remote activity # last known remote activity
self._next_timeout = time() + self._timeout self._next_timeout = time() + self._timeout
read_buf = self.read_buf
try: try:
try: try:
if self.connector.receive(read_buf): if self.connector.receive(self.read_buf):
self.em.addWriter(self) self.em.addWriter(self)
finally: finally:
# A connector may read some data # A connector may read some data
# before raising ConnectorException # before raising ConnectorException
while 1: while 1:
packet = Packets.parse(read_buf, self._parser_state) packet = self._parse()
if packet is None: if packet is None:
break break
self._queue.append(packet) self._queue.append(packet)
...@@ -501,7 +545,9 @@ class Connection(BaseConnection): ...@@ -501,7 +545,9 @@ class Connection(BaseConnection):
Process a pending packet. Process a pending packet.
""" """
# check out packet and process it with current handler # check out packet and process it with current handler
try:
self._handlers.handle(self, self._queue.pop(0)) self._handlers.handle(self, self._queue.pop(0))
finally:
self.updateTimeout() self.updateTimeout()
def pending(self): def pending(self):
...@@ -625,9 +671,7 @@ class ClientConnection(Connection): ...@@ -625,9 +671,7 @@ class ClientConnection(Connection):
self.em.register(self) self.em.register(self)
if connected: if connected:
self._maybeConnected() self._maybeConnected()
# A client connection usually has a pending packet to send # There's always the protocol version to send.
# from the beginning. It would be too smart to detect when
# it's not required to poll for writing.
self.em.addWriter(self) self.em.addWriter(self)
def _delayedConnect(self): def _delayedConnect(self):
......
...@@ -19,6 +19,7 @@ import ssl ...@@ -19,6 +19,7 @@ import ssl
import errno import errno
from time import time from time import time
from . import logging from . import logging
from .protocol import ENCODED_VERSION
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
...@@ -58,7 +59,7 @@ class SocketConnector(object): ...@@ -58,7 +59,7 @@ class SocketConnector(object):
s.setblocking(0) s.setblocking(0)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [] self.queued = [ENCODED_VERSION]
return self return self
def queue(self, data): def queue(self, data):
...@@ -66,7 +67,10 @@ class SocketConnector(object): ...@@ -66,7 +67,10 @@ class SocketConnector(object):
self.queued += data self.queued += data
return was_empty return was_empty
def _error(self, op, exc): def _error(self, op, exc=None):
if exc is None:
logging.debug('%r closed in %s', self, op)
else:
logging.debug("%s failed for %s: %s (%s)", logging.debug("%s failed for %s: %s (%s)",
op, self, errno.errorcode[exc.errno], exc.strerror) op, self, errno.errorcode[exc.errno], exc.strerror)
raise ConnectorException raise ConnectorException
...@@ -151,8 +155,7 @@ class SocketConnector(object): ...@@ -151,8 +155,7 @@ class SocketConnector(object):
if data: if data:
read_buf.append(data) read_buf.append(data)
return return
logging.debug('%r closed in recv', self) self._error('recv')
raise ConnectorException
def send(self): def send(self):
# XXX: unefficient for big packets # XXX: unefficient for big packets
...@@ -240,10 +243,12 @@ def overlay_connector_class(cls): ...@@ -240,10 +243,12 @@ def overlay_connector_class(cls):
@overlay_connector_class @overlay_connector_class
class _SSL: class _SSL:
def _error(self, op, exc): def _error(self, op, exc=None):
if isinstance(exc, ssl.SSLError): if isinstance(exc, ssl.SSLError):
if not isinstance(exc, ssl.SSLEOFError):
logging.debug("%s failed for %s: %s", op, self, exc) logging.debug("%s failed for %s: %s", op, self, exc)
raise ConnectorException raise ConnectorException
exc = None
SocketConnector._error(self, op, exc) SocketConnector._error(self, op, exc)
def receive(self, read_buf): def receive(self, read_buf):
...@@ -258,7 +263,27 @@ class _SSL: ...@@ -258,7 +263,27 @@ class _SSL:
@overlay_connector_class @overlay_connector_class
class _SSLHandshake(_SSL): class _SSLHandshake(_SSL):
def receive(self, read_buf=None): # WKRD: Unfortunately, SSL_do_handshake(3SSL) does not try to reject
# non-SSL connections as soon as possible, by checking the first
# byte. It even does nothing before receiving a full TLSPlaintext
# frame (5 bytes).
# The NEO protocol is such that a client connection is always the
# first to send a packet, as soon as the connection is established,
# and without waiting that the protocol versions are checked.
# So in practice, non-SSL connection to SSL would never hang, but
# there's another issue: such case results in WRONG_VERSION_NUMBER
# instead of something like UNEXPECTED_RECORD, because the SSL
# version is checked first.
# For better logging, we try to detect non-SSL connections with
# MSG_PEEK. This only works reliably on server side.
# For SSL client connections, 2 things may prevent the workaround to
# log that the remote node has not enabled SSL:
# - non-SSL data received (or connection closed) before the first
# call to 'recv' in 'do_handshake'
# - the server connection detects a wrong protocol version before it
# sent its one
def _handshake(self, read_buf=None):
# ???Writer | send | receive # ???Writer | send | receive
# -----------+--------+-------- # -----------+--------+--------
# want read | remove | - # want read | remove | -
...@@ -270,9 +295,10 @@ class _SSLHandshake(_SSL): ...@@ -270,9 +295,10 @@ class _SSLHandshake(_SSL):
except ssl.SSLWantWriteError: except ssl.SSLWantWriteError:
return read_buf is not None return read_buf is not None
except socket.error, e: except socket.error, e:
self._error('SSL handshake', e) self._error('send' if read_buf is None else 'recv', e)
if not self.queued[0]: if not self.queued[0]:
del self.queued[0] del self.queued[0]
del self.receive, self.send
self.__class__ = self.SSLConnectorClass self.__class__ = self.SSLConnectorClass
cipher, proto, bits = self.socket.cipher() cipher, proto, bits = self.socket.cipher()
logging.debug("SSL handshake done for %s: %s %s", self, cipher, bits) logging.debug("SSL handshake done for %s: %s %s", self, cipher, bits)
...@@ -284,7 +310,21 @@ class _SSLHandshake(_SSL): ...@@ -284,7 +310,21 @@ class _SSLHandshake(_SSL):
self.receive(read_buf) self.receive(read_buf)
return self.queued return self.queued
send = receive def send(self, read_buf=None):
handshake = self.receive = self.send = self._handshake
return handshake(read_buf)
def receive(self, read_buf):
try:
content_type = self.socket._sock.recv(1, socket.MSG_PEEK)
except socket.error, e:
self._error('recv', e)
if content_type == '\26': # handshake
return self.send(read_buf)
if content_type:
logging.debug('Rejecting non-SSL %r', self)
raise ConnectorException
self._error('recv')
class ConnectorException(Exception): class ConnectorException(Exception):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
class NeoException(Exception): class NeoException(Exception):
pass pass
class ElectionFailure(NeoException): class PrimaryElected(NeoException):
pass pass
class PrimaryFailure(NeoException): class PrimaryFailure(NeoException):
......
...@@ -19,9 +19,9 @@ from collections import deque ...@@ -19,9 +19,9 @@ from collections import deque
from operator import itemgetter from operator import itemgetter
from . import logging from . import logging
from .connection import ConnectionClosed from .connection import ConnectionClosed
from .protocol import ( from .exception import PrimaryElected
NodeStates, Packets, Errors, BackendNotImplemented, from .protocol import (NodeStates, NodeTypes, Packets, uuid_str,
BrokenNodeDisallowedError, NonReadableCell, NotReadyError, Errors, BackendNotImplemented, NonReadableCell, NotReadyError,
PacketMalformedError, ProtocolError, UnexpectedPacketError) PacketMalformedError, ProtocolError, UnexpectedPacketError)
from .util import cached_property from .util import cached_property
...@@ -59,7 +59,6 @@ class EventHandler(object): ...@@ -59,7 +59,6 @@ class EventHandler(object):
logging.error(message) logging.error(message)
conn.answer(Errors.ProtocolError(message)) conn.answer(Errors.ProtocolError(message))
conn.abort() conn.abort()
# self.peerBroken(conn)
def dispatch(self, conn, packet, kw={}): def dispatch(self, conn, packet, kw={}):
"""This is a helper method to handle various packet types.""" """This is a helper method to handle various packet types."""
...@@ -80,11 +79,6 @@ class EventHandler(object): ...@@ -80,11 +79,6 @@ class EventHandler(object):
except PacketMalformedError, e: except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e) logging.error('malformed packet from %r: %s', conn, e)
conn.close() conn.close()
# self.peerBroken(conn)
except BrokenNodeDisallowedError:
if not conn.isClosed():
conn.answer(Errors.BrokenNode('go away'))
conn.abort()
except NotReadyError, message: except NotReadyError, message:
if not conn.isClosed(): if not conn.isClosed():
if not message.args: if not message.args:
...@@ -144,12 +138,7 @@ class EventHandler(object): ...@@ -144,12 +138,7 @@ class EventHandler(object):
def connectionClosed(self, conn): def connectionClosed(self, conn):
"""Called when a connection is closed by the peer.""" """Called when a connection is closed by the peer."""
logging.debug('connection closed for %r', conn) logging.debug('connection closed for %r', conn)
self.connectionLost(conn, NodeStates.TEMPORARILY_DOWN) self.connectionLost(conn, NodeStates.DOWN)
#def peerBroken(self, conn):
# """Called when a peer is broken."""
# logging.error('%r is broken', conn)
# # NodeStates.BROKEN
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
""" this is a method to override in sub-handlers when there is no need """ this is a method to override in sub-handlers when there is no need
...@@ -159,21 +148,41 @@ class EventHandler(object): ...@@ -159,21 +148,41 @@ class EventHandler(object):
# Packet handlers. # Packet handlers.
def acceptIdentification(self, conn, node_type, *args): def notPrimaryMaster(self, conn, primary, known_master_list):
try: nm = self.app.nm
acceptIdentification = self._acceptIdentification for address in known_master_list:
except AttributeError: nm.createMaster(address=address)
raise UnexpectedPacketError('no handler found') if primary is not None:
if conn.isClosed(): primary = known_master_list[primary]
# acceptIdentification received on a closed (probably aborted, assert primary != self.app.server
# actually) connection. Reject any further packet as unexpected. raise PrimaryElected(nm.getByAddress(primary))
conn.setHandler(EventHandler(self.app))
return def _acceptIdentification(*args):
node = self.app.nm.getByAddress(conn.getAddress()) pass
def acceptIdentification(self, conn, node_type, uuid,
num_partitions, num_replicas, your_uuid):
app = self.app
node = app.nm.getByAddress(conn.getAddress())
assert node.getConnection() is conn, (node.getConnection(), conn) assert node.getConnection() is conn, (node.getConnection(), conn)
if node.getType() == node_type: if node.getType() == node_type:
if node_type == NodeTypes.MASTER:
other = app.nm.getByUUID(uuid)
if other is not None:
other.setUUID(None)
node.setUUID(uuid)
node.setRunning()
if your_uuid is None:
raise ProtocolError('No UUID supplied')
logging.info('connected to a primary master node')
if app.uuid != your_uuid:
app.uuid = your_uuid
logging.info('Got a new UUID: %s', uuid_str(your_uuid))
app.id_timestamp = None
elif node.getUUID() != uuid or app.uuid != your_uuid != None:
raise ProtocolError('invalid uuids')
node.setIdentified() node.setIdentified()
acceptIdentification(node, *args) self._acceptIdentification(node, num_partitions, num_replicas)
return return
conn.close() conn.close()
...@@ -189,9 +198,6 @@ class EventHandler(object): ...@@ -189,9 +198,6 @@ class EventHandler(object):
# to test/maintain underlying connection. # to test/maintain underlying connection.
pass pass
def notify(self, conn, message):
logging.warning('notification from %r: %s', conn, message)
def closeClient(self, conn): def closeClient(self, conn):
conn.server = False conn.server = False
if not conn.client: if not conn.client:
...@@ -216,9 +222,6 @@ class EventHandler(object): ...@@ -216,9 +222,6 @@ class EventHandler(object):
def timeoutError(self, conn, message): def timeoutError(self, conn, message):
logging.error('timeout error: %s', message) logging.error('timeout error: %s', message)
def brokenNodeDisallowedError(self, conn, message):
raise RuntimeError, 'broken node disallowed error: %s' % (message,)
def ack(self, conn, message): def ack(self, conn, message):
logging.debug("no error message: %s", message) logging.debug("no error message: %s", message)
...@@ -268,7 +271,6 @@ class AnswerBaseHandler(EventHandler): ...@@ -268,7 +271,6 @@ class AnswerBaseHandler(EventHandler):
timeoutExpired = unexpectedInAnswerHandler timeoutExpired = unexpectedInAnswerHandler
connectionClosed = unexpectedInAnswerHandler connectionClosed = unexpectedInAnswerHandler
packetReceived = unexpectedInAnswerHandler packetReceived = unexpectedInAnswerHandler
peerBroken = unexpectedInAnswerHandler
protocolError = unexpectedInAnswerHandler protocolError = unexpectedInAnswerHandler
def acceptIdentification(*args): def acceptIdentification(*args):
...@@ -318,6 +320,10 @@ class EventQueue(object): ...@@ -318,6 +320,10 @@ class EventQueue(object):
self._event_queue = [] self._event_queue = []
self._executing_event = -1 self._executing_event = -1
# Stable sort when 2 keys are equal.
# XXX: Is it really useful to keep events with same key ordered
# chronologically ? The caller could use more specific keys. For
# write-locks (by the storage node), the locking tid seems enough.
sortQueuedEvents = (lambda key=itemgetter(0): lambda self: sortQueuedEvents = (lambda key=itemgetter(0): lambda self:
self._event_queue.sort(key=key))() self._event_queue.sort(key=key))()
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import errno, json, os
from time import time from time import time
from os.path import exists, getsize
import json
from . import attributeTracker, logging from . import attributeTracker, logging
from .handler import DelayEvent, EventQueue from .handler import DelayEvent, EventQueue
...@@ -31,8 +30,7 @@ class Node(object): ...@@ -31,8 +30,7 @@ class Node(object):
_identified = False _identified = False
id_timestamp = None id_timestamp = None
def __init__(self, manager, address=None, uuid=None, def __init__(self, manager, address=None, uuid=None, state=NodeStates.DOWN):
state=NodeStates.UNKNOWN):
self._state = state self._state = state
self._address = address self._address = address
self._uuid = uuid self._uuid = uuid
...@@ -64,7 +62,7 @@ class Node(object): ...@@ -64,7 +62,7 @@ class Node(object):
def setState(self, new_state): def setState(self, new_state):
if self._state == new_state: if self._state == new_state:
return return
if new_state == NodeStates.DOWN: if new_state == NodeStates.UNKNOWN:
self._manager.remove(self) self._manager.remove(self)
self._state = new_state self._state = new_state
else: else:
...@@ -205,34 +203,51 @@ class MasterDB(object): ...@@ -205,34 +203,51 @@ class MasterDB(object):
""" """
def __init__(self, path): def __init__(self, path):
self._path = path self._path = path
try_load = exists(path) and getsize(path)
if try_load:
db = open(path, 'r')
init_set = map(tuple, json.load(db))
else:
db = open(path, 'w+')
init_set = []
self._set = set(init_set)
db.close()
def _save(self):
try: try:
db = open(self._path, 'w') with open(path) as db:
except IOError: self._set = set(map(tuple, json.load(db)))
logging.warning('failed opening master database at %r ' except IOError, e:
'for writing, update skipped', self._path) if e.errno != errno.ENOENT:
else: raise
self._set = set()
self._save(True)
def _save(self, raise_on_error=False):
tmp = self._path + '#neo#'
try:
with open(tmp, 'w') as db:
json.dump(list(self._set), db) json.dump(list(self._set), db)
db.close() os.rename(tmp, self._path)
except EnvironmentError:
if raise_on_error:
raise
logging.exception('failed saving list of master nodes to %r',
self._path)
finally:
try:
os.remove(tmp)
except OSError:
pass
def add(self, addr): def remove(self, addr):
self._set.add(addr) if addr in self._set:
self._set.remove(addr)
self._save() self._save()
def discard(self, addr): def addremove(self, old, new):
self._set.discard(addr) assert old != new
if None is not new not in self._set:
self._set.add(new)
elif old not in self._set:
return
self._set.discard(old)
self._save() self._save()
def __repr__(self):
return '<%s@%s: %s>' % (self.__class__.__name__, self._path,
', '.join(sorted(('[%s]:%s' if ':' in x[0] else '%s:%s') % x
for x in self._set)))
def __iter__(self): def __iter__(self):
return iter(self._set) return iter(self._set)
...@@ -271,19 +286,14 @@ class NodeManager(EventQueue): ...@@ -271,19 +286,14 @@ class NodeManager(EventQueue):
if node in self._node_set: if node in self._node_set:
logging.warning('adding a known node %r, ignoring', node) logging.warning('adding a known node %r, ignoring', node)
return return
assert not node.isDown(), node assert not node.isUnknown(), node
self._node_set.add(node) self._node_set.add(node)
self._updateAddress(node, None) self._updateAddress(node, None)
self._updateUUID(node, None) self._updateUUID(node, None)
self.__updateSet(self._type_dict, None, node.getType(), node) self.__updateSet(self._type_dict, None, node.getType(), node)
self.__updateSet(self._state_dict, None, node.getState(), node) self.__updateSet(self._state_dict, None, node.getState(), node)
if node.isMaster() and self._master_db is not None:
self._master_db.add(node.getAddress())
def remove(self, node): def remove(self, node):
if node not in self._node_set:
logging.warning('removing unknown node %r, ignoring', node)
return
self._node_set.remove(node) self._node_set.remove(node)
# a node may have not be indexed by uuid or address, eg.: # a node may have not be indexed by uuid or address, eg.:
# - a client or admin node that don't have listening address # - a client or admin node that don't have listening address
...@@ -292,9 +302,8 @@ class NodeManager(EventQueue): ...@@ -292,9 +302,8 @@ class NodeManager(EventQueue):
self._uuid_dict.pop(node.getUUID(), None) self._uuid_dict.pop(node.getUUID(), None)
self._state_dict[node.getState()].remove(node) self._state_dict[node.getState()].remove(node)
self._type_dict[node.getType()].remove(node) self._type_dict[node.getType()].remove(node)
uuid = node.getUUID()
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.discard(node.getAddress()) self._master_db.remove(node.getAddress())
def __update(self, index_dict, old_key, new_key, node): def __update(self, index_dict, old_key, new_key, node):
""" Update an index from old to new key """ """ Update an index from old to new key """
...@@ -309,7 +318,10 @@ class NodeManager(EventQueue): ...@@ -309,7 +318,10 @@ class NodeManager(EventQueue):
index_dict[new_key] = node index_dict[new_key] = node
def _updateAddress(self, node, old_address): def _updateAddress(self, node, old_address):
self.__update(self._address_dict, old_address, node.getAddress(), node) address = node.getAddress()
self.__update(self._address_dict, old_address, address, node)
if node.isMaster() and self._master_db is not None:
self._master_db.addremove(old_address, address)
def _updateUUID(self, node, old_uuid): def _updateUUID(self, node, old_uuid):
self.__update(self._uuid_dict, old_uuid, node.getUUID(), node) self.__update(self._uuid_dict, old_uuid, node.getUUID(), node)
...@@ -321,7 +333,7 @@ class NodeManager(EventQueue): ...@@ -321,7 +333,7 @@ class NodeManager(EventQueue):
set_dict.setdefault(new_key, set()).add(node) set_dict.setdefault(new_key, set()).add(node)
def _updateState(self, node, old_state): def _updateState(self, node, old_state):
assert not node.isDown(), node assert not node.isUnknown(), node
self.__updateSet(self._state_dict, old_state, node.getState(), node) self.__updateSet(self._state_dict, old_state, node.getState(), node)
def getList(self, node_filter=None): def getList(self, node_filter=None):
...@@ -414,7 +426,7 @@ class NodeManager(EventQueue): ...@@ -414,7 +426,7 @@ class NodeManager(EventQueue):
def update(self, app, timestamp, node_list): def update(self, app, timestamp, node_list):
assert self._timestamp < timestamp, (self._timestamp, timestamp) assert self._timestamp < timestamp, (self._timestamp, timestamp)
self._timestamp = timestamp self._timestamp = timestamp
node_set = self._node_set.copy() if app.id_timestamp is None else None added_list = [] if app.id_timestamp is None else None
for node_type, addr, uuid, state, id_timestamp in node_list: for node_type, addr, uuid, state, id_timestamp in node_list:
# This should be done here (although klass might not be used in this # This should be done here (although klass might not be used in this
# iteration), as it raises if type is not valid. # iteration), as it raises if type is not valid.
...@@ -423,24 +435,23 @@ class NodeManager(EventQueue): ...@@ -423,24 +435,23 @@ class NodeManager(EventQueue):
# lookup in current table # lookup in current table
node_by_uuid = self.getByUUID(uuid) node_by_uuid = self.getByUUID(uuid)
node_by_addr = self.getByAddress(addr) node_by_addr = self.getByAddress(addr)
node = node_by_uuid or node_by_addr node = node_by_addr or node_by_uuid
log_args = node_type, uuid_str(uuid), addr, state, id_timestamp log_args = node_type, uuid_str(uuid), addr, state, id_timestamp
if node is None: if node is None:
if state == NodeStates.DOWN: assert state != NodeStates.UNKNOWN, (self._node_set,) + log_args
logging.debug('NOT creating node %s %s %s %s %s', *log_args)
continue
node = self._createNode(klass, address=addr, uuid=uuid, node = self._createNode(klass, address=addr, uuid=uuid,
state=state) state=state)
logging.debug('creating node %r', node) logging.debug('creating node %r', node)
else: else:
assert isinstance(node, klass), 'node %r is not ' \ assert isinstance(node, klass), 'node %r is not ' \
'of expected type: %r' % (node, klass) 'of expected type: %r' % (node, klass)
assert None in (node_by_uuid, node_by_addr) or \ if None is not node_by_uuid is not node_by_addr is not None:
node_by_uuid is node_by_addr, \ assert added_list is not None, \
'Discrepancy between node_by_uuid (%r) and ' \ 'Discrepancy between node_by_uuid (%r) and ' \
'node_by_addr (%r)' % (node_by_uuid, node_by_addr) 'node_by_addr (%r)' % (node_by_uuid, node_by_addr)
if state == NodeStates.DOWN: node_by_uuid.setUUID(None)
if state == NodeStates.UNKNOWN:
logging.debug('dropping node %r (%r), found with %s ' logging.debug('dropping node %r (%r), found with %s '
'%s %s %s %s', node, node.isConnected(), *log_args) '%s %s %s %s', node, node.isConnected(), *log_args)
if node.isConnected(): if node.isConnected():
...@@ -451,8 +462,9 @@ class NodeManager(EventQueue): ...@@ -451,8 +462,9 @@ class NodeManager(EventQueue):
# reconnect to the master because they cleared their # reconnect to the master because they cleared their
# partition table upon disconnection. # partition table upon disconnection.
node.getConnection().close() node.getConnection().close()
if app.uuid != uuid: if app.uuid != uuid: # XXX
app.pt.dropNode(node) dropped = app.pt.dropNode(node)
assert dropped, node
self.remove(node) self.remove(node)
continue continue
logging.debug('updating node %r to %s %s %s %s %s', logging.debug('updating node %r to %s %s %s %s %s',
...@@ -463,11 +475,14 @@ class NodeManager(EventQueue): ...@@ -463,11 +475,14 @@ class NodeManager(EventQueue):
node.id_timestamp = id_timestamp node.id_timestamp = id_timestamp
if app.uuid == uuid: if app.uuid == uuid:
app.id_timestamp = id_timestamp app.id_timestamp = id_timestamp
if node_set: if added_list is not None:
added_list.append(node)
if added_list is not None:
assert app.id_timestamp is not None
# For the first notification, we receive a full list of nodes from # For the first notification, we receive a full list of nodes from
# the master. Remove all unknown nodes from a previous connection. # the master. Remove all unknown nodes from a previous connection.
for node in node_set - self._node_set: for node in self._node_set.difference(added_list):
app.pt.dropNode(node) if app.pt.dropNode(node):
self.remove(node) self.remove(node)
self.log() self.log()
self.executeQueuedEvents() self.executeQueuedEvents()
......
...@@ -14,24 +14,21 @@ ...@@ -14,24 +14,21 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket
import sys import sys
import traceback import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
PROTOCOL_VERSION = 12 # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and
# the high order byte 0 is different from TLS Handshake (0x16).
PROTOCOL_VERSION = 1
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION)
# Size restrictions. # Avoid memory errors on corrupted data.
MIN_PACKET_SIZE = 10
MAX_PACKET_SIZE = 0x4000000 MAX_PACKET_SIZE = 0x4000000
PACKET_HEADER_FORMAT = Struct('!LHL') PACKET_HEADER_FORMAT = Struct('!LHL')
# Check that header size is the expected value.
# If it is not, it means that struct module result is incompatible with
# "reference" platform (python 2.4 on x86-64).
assert PACKET_HEADER_FORMAT.size == 10, \
'Unsupported platform, packet header length = %i' % \
(PACKET_HEADER_FORMAT.size, )
RESPONSE_MASK = 0x8000 RESPONSE_MASK = 0x8000
class Enum(tuple): class Enum(tuple):
...@@ -70,7 +67,6 @@ def ErrorCodes(): ...@@ -70,7 +67,6 @@ def ErrorCodes():
TID_NOT_FOUND TID_NOT_FOUND
OID_DOES_NOT_EXIST OID_DOES_NOT_EXIST
PROTOCOL_ERROR PROTOCOL_ERROR
BROKEN_NODE
REPLICATION_ERROR REPLICATION_ERROR
CHECKING_ERROR CHECKING_ERROR
BACKEND_NOT_IMPLEMENTED BACKEND_NOT_IMPLEMENTED
...@@ -119,13 +115,10 @@ def NodeTypes(): ...@@ -119,13 +115,10 @@ def NodeTypes():
@Enum @Enum
def NodeStates(): def NodeStates():
RUNNING UNKNOWN
TEMPORARILY_DOWN
DOWN DOWN
BROKEN RUNNING
HIDDEN
PENDING PENDING
UNKNOWN
@Enum @Enum
def CellStates(): def CellStates():
...@@ -149,12 +142,9 @@ def CellStates(): ...@@ -149,12 +142,9 @@ def CellStates():
# used for logging # used for logging
node_state_prefix_dict = { node_state_prefix_dict = {
NodeStates.RUNNING: 'R', NodeStates.RUNNING: 'R',
NodeStates.TEMPORARILY_DOWN: 'T',
NodeStates.DOWN: 'D', NodeStates.DOWN: 'D',
NodeStates.BROKEN: 'B',
NodeStates.HIDDEN: 'H',
NodeStates.PENDING: 'P',
NodeStates.UNKNOWN: 'U', NodeStates.UNKNOWN: 'U',
NodeStates.PENDING: 'P',
} }
# used for logging # used for logging
...@@ -167,16 +157,12 @@ cell_state_prefix_dict = { ...@@ -167,16 +157,12 @@ cell_state_prefix_dict = {
} }
# Other constants. # Other constants.
INVALID_UUID = 0 INVALID_TID = \
INVALID_TID = '\xff' * 8
INVALID_OID = '\xff' * 8 INVALID_OID = '\xff' * 8
INVALID_PARTITION = 0xffffffff INVALID_PARTITION = 0xffffffff
INVALID_ADDRESS_TYPE = socket.AF_UNSPEC
ZERO_HASH = '\0' * 20 ZERO_HASH = '\0' * 20
ZERO_TID = '\0' * 8 ZERO_TID = \
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID)
MAX_TID = '\x7f' + '\xff' * 7 # SQLite does not accept numbers above 2^63-1 MAX_TID = '\x7f' + '\xff' * 7 # SQLite does not accept numbers above 2^63-1
# High-order byte: # High-order byte:
...@@ -203,17 +189,14 @@ class ProtocolError(Exception): ...@@ -203,17 +189,14 @@ class ProtocolError(Exception):
""" Base class for protocol errors, close the connection """ """ Base class for protocol errors, close the connection """
class PacketMalformedError(ProtocolError): class PacketMalformedError(ProtocolError):
""" Close the connection and set the node as broken""" """Close the connection"""
class UnexpectedPacketError(ProtocolError): class UnexpectedPacketError(ProtocolError):
""" Close the connection and set the node as broken""" """Close the connection"""
class NotReadyError(ProtocolError): class NotReadyError(ProtocolError):
""" Just close the connection """ """ Just close the connection """
class BrokenNodeDisallowedError(ProtocolError):
""" Just close the connection """
class BackendNotImplemented(Exception): class BackendNotImplemented(Exception):
""" Method not implemented by backend storage """ """ Method not implemented by backend storage """
...@@ -279,8 +262,8 @@ class Packet(object): ...@@ -279,8 +262,8 @@ class Packet(object):
def encode(self): def encode(self):
""" Encode a packet as a string to send it over the network """ """ Encode a packet as a string to send it over the network """
content = self._body content = self._body
length = PACKET_HEADER_FORMAT.size + len(content) return (PACKET_HEADER_FORMAT.pack(self._id, self._code, len(content)),
return (PACKET_HEADER_FORMAT.pack(self._id, self._code, length), content) content)
def __len__(self): def __len__(self):
return PACKET_HEADER_FORMAT.size + len(self._body) return PACKET_HEADER_FORMAT.size + len(self._body)
...@@ -562,19 +545,6 @@ class PPTID(PStructItemOrNone): ...@@ -562,19 +545,6 @@ class PPTID(PStructItemOrNone):
_fmt = '!Q' _fmt = '!Q'
_None = Struct(_fmt).pack(0) _None = Struct(_fmt).pack(0)
class PProtocol(PNumber):
"""
The protocol version definition
"""
def _encode(self, writer, version):
writer(self.pack(version))
def _decode(self, reader):
version = self.unpack(reader(self.size))
if version != (PROTOCOL_VERSION,):
raise ProtocolError('protocol version mismatch')
return version
class PChecksum(PItem): class PChecksum(PItem):
""" """
A hash (SHA1) A hash (SHA1)
...@@ -586,12 +556,14 @@ class PChecksum(PItem): ...@@ -586,12 +556,14 @@ class PChecksum(PItem):
def _decode(self, reader): def _decode(self, reader):
return reader(20) return reader(20)
class PUUID(PStructItemOrNone): class PSignedNull(PStructItemOrNone):
_fmt = '!l'
_None = Struct(_fmt).pack(0)
class PUUID(PSignedNull):
""" """
An UUID (node identifier, 4-bytes signed integer) An UUID (node identifier, 4-bytes signed integer)
""" """
_fmt = '!l'
_None = Struct(_fmt).pack(0)
class PTID(PItem): class PTID(PItem):
""" """
...@@ -671,14 +643,6 @@ PFOidList = PList('oid_list', ...@@ -671,14 +643,6 @@ PFOidList = PList('oid_list',
# packets definition # packets definition
class Notify(Packet):
"""
General purpose notification (remote logging)
"""
_fmt = PStruct('notify',
PString('message'),
)
class Error(Packet): class Error(Packet):
""" """
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
...@@ -709,7 +673,6 @@ class RequestIdentification(Packet): ...@@ -709,7 +673,6 @@ class RequestIdentification(Packet):
poll_thread = True poll_thread = True
_fmt = PStruct('request_identification', _fmt = PStruct('request_identification',
PProtocol('protocol_version'),
PFNodeType, PFNodeType,
PUUID('uuid'), PUUID('uuid'),
PAddress('address'), PAddress('address'),
...@@ -723,25 +686,8 @@ class RequestIdentification(Packet): ...@@ -723,25 +686,8 @@ class RequestIdentification(Packet):
PNumber('num_partitions'), PNumber('num_partitions'),
PNumber('num_replicas'), PNumber('num_replicas'),
PUUID('your_uuid'), PUUID('your_uuid'),
PAddress('primary'),
PList('known_master_list',
PStruct('master',
PAddress('address'),
PUUID('uuid'),
),
),
) )
def __init__(self, *args, **kw):
if args or kw:
# always announce current protocol version
args = list(args)
args.insert(0, PROTOCOL_VERSION)
super(RequestIdentification, self).__init__(*args, **kw)
def decode(self):
return super(RequestIdentification, self).decode()[1:]
class PrimaryMaster(Packet): class PrimaryMaster(Packet):
""" """
Ask current primary master's uuid. CTL -> A. Ask current primary master's uuid. CTL -> A.
...@@ -750,15 +696,16 @@ class PrimaryMaster(Packet): ...@@ -750,15 +696,16 @@ class PrimaryMaster(Packet):
PUUID('primary_uuid'), PUUID('primary_uuid'),
) )
class AnnouncePrimary(Packet): class NotPrimaryMaster(Packet):
""" """
Announce a primary master node election. PM -> SM. Send list of known master nodes. SM -> Any.
"""
class ReelectPrimary(Packet):
"""
Force a re-election of a primary master node. M -> M.
""" """
_fmt = PStruct('not_primary_master',
PSignedNull('primary'),
PList('known_master_list',
PAddress('address'),
),
)
class Recovery(Packet): class Recovery(Packet):
""" """
...@@ -1620,22 +1567,6 @@ def register(request, ignore_when_closed=None): ...@@ -1620,22 +1567,6 @@ def register(request, ignore_when_closed=None):
StaticRegistry[code] = answer StaticRegistry[code] = answer
return (request, answer) return (request, answer)
class ParserState(object):
"""
Parser internal state.
To be considered opaque datatype outside of PacketRegistry.parse .
"""
payload = None
def set(self, payload):
self.payload = payload
def get(self):
return self.payload
def clear(self):
self.payload = None
class Packets(dict): class Packets(dict):
""" """
Packet registry that checks packet code uniqueness and provides an index Packet registry that checks packet code uniqueness and provides an index
...@@ -1647,58 +1578,19 @@ class Packets(dict): ...@@ -1647,58 +1578,19 @@ class Packets(dict):
# this builds a "singleton" # this builds a "singleton"
return type('PacketRegistry', base, d)(StaticRegistry) return type('PacketRegistry', base, d)(StaticRegistry)
def parse(self, buf, state_container):
state = state_container.get()
if state is None:
header = buf.read(PACKET_HEADER_FORMAT.size)
if header is None:
return None
msg_id, msg_type, msg_len = PACKET_HEADER_FORMAT.unpack(header)
try:
packet_klass = self[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len)
msg_len -= PACKET_HEADER_FORMAT.size
else:
msg_id, packet_klass, msg_len = state
data = buf.read(msg_len)
if data is None:
# Not enough.
if state is None:
state_container.set((msg_id, packet_klass, msg_len))
return None
if state:
state_container.clear()
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
# notifications # notifications
Error = register( Error = register(
Error) Error)
RequestIdentification, AcceptIdentification = register( RequestIdentification, AcceptIdentification = register(
RequestIdentification) RequestIdentification, ignore_when_closed=True)
# Code of RequestIdentification packet must never change so that 2
# incompatible nodes can reject themselves gracefully (i.e. comparing
# protocol versions) instead of raising PacketMalformedError.
assert RequestIdentification._code == 1
Ping, Pong = register( Ping, Pong = register(
Ping) Ping)
CloseClient = register( CloseClient = register(
CloseClient) CloseClient)
Notify = register(
Notify)
AskPrimary, AnswerPrimary = register( AskPrimary, AnswerPrimary = register(
PrimaryMaster) PrimaryMaster)
AnnouncePrimary = register( NotPrimaryMaster = register(
AnnouncePrimary) NotPrimaryMaster)
ReelectPrimary = register(
ReelectPrimary)
NotifyNodeInformation = register( NotifyNodeInformation = register(
NotifyNodeInformation) NotifyNodeInformation)
AskRecovery, AnswerRecovery = register( AskRecovery, AnswerRecovery = register(
......
...@@ -168,7 +168,7 @@ class PartitionTable(object): ...@@ -168,7 +168,7 @@ class PartitionTable(object):
def _setCell(self, offset, node, state): def _setCell(self, offset, node, state):
if state == CellStates.DISCARDED: if state == CellStates.DISCARDED:
return self.removeCell(offset, node) return self.removeCell(offset, node)
if node.isBroken() or node.isDown(): if node.isUnknown():
raise PartitionTableException('Invalid node state') raise PartitionTableException('Invalid node state')
self.count_dict.setdefault(node, 0) self.count_dict.setdefault(node, 0)
...@@ -196,8 +196,10 @@ class PartitionTable(object): ...@@ -196,8 +196,10 @@ class PartitionTable(object):
break break
def dropNode(self, node): def dropNode(self, node):
count = self.count_dict.pop(node, None) count = self.count_dict.get(node)
assert not count, (node, count) if count == 0:
del self.count_dict[node]
return not count
def load(self, ptid, row_list, nm): def load(self, ptid, row_list, nm):
""" """
......
...@@ -148,14 +148,9 @@ def parseNodeAddress(address, port_opt=None): ...@@ -148,14 +148,9 @@ def parseNodeAddress(address, port_opt=None):
# or return either raw host & port or getaddrinfo return value. # or return either raw host & port or getaddrinfo return value.
return socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0][4][:2] return socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0][4][:2]
def parseMasterList(masters, except_node=None): def parseMasterList(masters):
assert masters, 'At least one master must be defined' assert masters, 'At least one master must be defined'
master_node_list = [] return map(parseNodeAddress, masters.split())
for node in masters.split():
address = parseNodeAddress(node)
if address != except_node:
master_node_list.append(address)
return master_node_list
class ReadBuffer(object): class ReadBuffer(object):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys, weakref import sys
from collections import defaultdict from collections import defaultdict
from time import time from time import time
...@@ -25,7 +25,7 @@ from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID ...@@ -25,7 +25,7 @@ from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.connection import ListeningConnection, ClientConnection from neo.lib.connection import ListeningConnection, ClientConnection
from neo.lib.exception import ElectionFailure, PrimaryFailure, StoppedOperation from neo.lib.exception import PrimaryElected, PrimaryFailure, StoppedOperation
class StateChangedException(Exception): pass class StateChangedException(Exception): pass
...@@ -40,8 +40,7 @@ def monotonic_time(): ...@@ -40,8 +40,7 @@ def monotonic_time():
return now return now
from .backup_app import BackupApplication from .backup_app import BackupApplication
from .handlers import election, identification, secondary from .handlers import identification, administration, client, master, storage
from .handlers import administration, client, storage
from .pt import PartitionTable from .pt import PartitionTable
from .recovery import RecoveryManager from .recovery import RecoveryManager
from .transactions import TransactionManager from .transactions import TransactionManager
...@@ -58,6 +57,21 @@ class Application(BaseApplication): ...@@ -58,6 +57,21 @@ class Application(BaseApplication):
backup_app = None backup_app = None
truncate_tid = None truncate_tid = None
def uuid(self, uuid):
node = self.nm.getByUUID(uuid)
if node is not self._node:
if node:
node.setUUID(None)
if node.isConnected(True):
node.getConnection().close()
self._node.setUUID(uuid)
uuid = property(lambda self: self._node.getUUID(), uuid)
@property
def election(self):
if self.primary and self.cluster_state == ClusterStates.RECOVERING:
return self.primary
def __init__(self, config): def __init__(self, config):
super(Application, self).__init__( super(Application, self).__init__(
config.getSSL(), config.getDynamicMasterList()) config.getSSL(), config.getDynamicMasterList())
...@@ -71,6 +85,8 @@ class Application(BaseApplication): ...@@ -71,6 +85,8 @@ class Application(BaseApplication):
self.storage_starting_set = set() self.storage_starting_set = set()
for master_address in config.getMasters(): for master_address in config.getMasters():
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
self._node = self.nm.createMaster(address=self.server,
uuid=config.getUUID())
logging.debug('IP address is %s, port is %d', *self.server) logging.debug('IP address is %s, port is %d', *self.server)
...@@ -87,17 +103,7 @@ class Application(BaseApplication): ...@@ -87,17 +103,7 @@ class Application(BaseApplication):
logging.info('Name : %s', self.name) logging.info('Name : %s', self.name)
self.listening_conn = None self.listening_conn = None
self.primary = None
self.primary_master_node = None
self.cluster_state = None self.cluster_state = None
self.uuid = config.getUUID()
# election related data
self.unconnected_master_node_set = set()
self.negotiating_master_node_set = set()
self.master_address_dict = weakref.WeakKeyDictionary()
self._current_manager = None self._current_manager = None
# backup # backup
...@@ -111,7 +117,8 @@ class Application(BaseApplication): ...@@ -111,7 +117,8 @@ class Application(BaseApplication):
self.administration_handler = administration.AdministrationHandler( self.administration_handler = administration.AdministrationHandler(
self) self)
self.secondary_master_handler = secondary.SecondaryMasterHandler(self) self.election_handler = master.ElectionHandler(self)
self.secondary_handler = master.SecondaryHandler(self)
self.client_service_handler = client.ClientServiceHandler(self) self.client_service_handler = client.ClientServiceHandler(self)
self.client_ro_service_handler = client.ClientReadOnlyServiceHandler(self) self.client_ro_service_handler = client.ClientReadOnlyServiceHandler(self)
self.storage_service_handler = storage.StorageServiceHandler(self) self.storage_service_handler = storage.StorageServiceHandler(self)
...@@ -143,100 +150,12 @@ class Application(BaseApplication): ...@@ -143,100 +150,12 @@ class Application(BaseApplication):
raise raise
def _run(self): def _run(self):
"""Make sure that the status is sane and start a loop."""
# Make a listening port.
self.listening_conn = ListeningConnection(self, None, self.server) self.listening_conn = ListeningConnection(self, None, self.server)
while True:
# Start a normal operation.
while self.cluster_state != ClusterStates.STOPPING:
# (Re)elect a new primary master.
self.primary = not self.nm.getMasterList()
if not self.primary:
self.electPrimary()
try:
if self.primary:
self.playPrimaryRole() self.playPrimaryRole()
else:
self.playSecondaryRole() self.playSecondaryRole()
raise RuntimeError, 'should not reach here'
except (ElectionFailure, PrimaryFailure):
# Forget all connections.
for conn in self.em.getClientList():
conn.close()
def electPrimary(self):
"""Elect a primary master node.
The difficulty is that a master node must accept connections from
others while attempting to connect to other master nodes at the
same time. Note that storage nodes and client nodes may connect
to self as well as master nodes."""
logging.info('begin the election of a primary master')
client_handler = election.ClientElectionHandler(self)
self.unconnected_master_node_set.clear()
self.negotiating_master_node_set.clear()
self.master_address_dict.clear()
self.listening_conn.setHandler(election.ServerElectionHandler(self))
getByAddress = self.nm.getByAddress
while True:
# handle new connected masters
for node in self.nm.getMasterList():
node.setUnknown()
self.unconnected_master_node_set.add(node.getAddress())
# start the election process
self.primary = None
self.primary_master_node = None
try:
while (self.unconnected_master_node_set or
self.negotiating_master_node_set):
for addr in self.unconnected_master_node_set:
self.negotiating_master_node_set.add(addr)
ClientConnection(self, client_handler,
# XXX: Ugly, but the whole election code will be
# replaced soon
getByAddress(addr))
self.unconnected_master_node_set.clear()
self.em.poll(1)
except ElectionFailure, m:
# something goes wrong, clean then restart
logging.error('election failed: %s', m)
# Ask all connected nodes to reelect a single primary master.
for conn in self.em.getClientList():
conn.send(Packets.ReelectPrimary())
conn.abort()
# Wait until the connections are closed. def getNodeInformationDict(self, node_list):
self.primary = None
self.primary_master_node = None
# XXX: Since poll does not wake up anymore every second,
# the following time condition should be reviewed.
# See also playSecondaryRole.
t = time() + 10
while self.em.getClientList() and time() < t:
try:
self.em.poll(1)
except ElectionFailure:
pass
# Close all connections.
for conn in self.em.getClientList() + self.em.getServerList():
conn.close()
else:
# election succeed, stop the process
self.primary = self.primary is None
break
def broadcastNodesInformation(self, node_list, exclude=None):
"""
Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth
"""
node_dict = defaultdict(list) node_dict = defaultdict(list)
# group modified nodes by destination node type # group modified nodes by destination node type
for node in node_list: for node in node_list:
...@@ -251,7 +170,14 @@ class Application(BaseApplication): ...@@ -251,7 +170,14 @@ class Application(BaseApplication):
if node.isStorage(): if node.isStorage():
continue continue
node_dict[NodeTypes.MASTER].append(node_info) node_dict[NodeTypes.MASTER].append(node_info)
return node_dict
def broadcastNodesInformation(self, node_list, exclude=None):
"""
Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth
"""
node_dict = self.getNodeInformationDict(node_list)
now = monotonic_time() now = monotonic_time()
# send at most one non-empty notification packet per node # send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
...@@ -302,52 +228,26 @@ class Application(BaseApplication): ...@@ -302,52 +228,26 @@ class Application(BaseApplication):
def playPrimaryRole(self): def playPrimaryRole(self):
logging.info('play the primary role with %r', self.listening_conn) logging.info('play the primary role with %r', self.listening_conn)
self.master_address_dict.clear() self.primary_master = None
em = self.em for conn in self.em.getConnectionList():
packet = Packets.AnnouncePrimary()
for conn in em.getConnectionList():
if conn.isListening(): if conn.isListening():
conn.setHandler(identification.IdentificationHandler(self)) conn.setHandler(identification.IdentificationHandler(self))
else: else:
conn.send(packet) conn.close()
# Primary master should rather establish connections to all
# secondaries, rather than the other way around. This requires
# a bit more work when a new master joins a cluster but makes
# it easier to resolve UUID conflicts with minimal cluster
# impact, and ensure primary master uniqueness (primary masters
# become noisy, in that they actively try to maintain
# connections to all other master nodes, so duplicate
# primaries will eventually get in touch with each other and
# resolve the situation with a duel).
# TODO: only abort client connections, don't close server
# connections as we want to have them in the end. Secondary
# masters will reconnect nevertheless, but it's dirty.
# Currently, it's not trivial to preserve connected nodes,
# because of poor node status tracking during election.
# XXX: The above comment is partially wrong in that the primary
# master is now responsible of allocating node ids, and all
# other nodes must only create/update/remove nodes when
# processing node notification. We probably want to keep the
# current behaviour: having only server connections.
conn.abort()
# If I know any storage node, make sure that they are not in the # If I know any storage node, make sure that they are not in the
# running state, because they are not connected at this stage. # running state, because they are not connected at this stage.
for node in self.nm.getStorageList(): for node in self.nm.getStorageList():
if node.isRunning(): assert node.isDown(), node
node.setTemporarilyDown()
if self.uuid is None: if self.uuid is None:
self.uuid = self.getNewUUID(None, self.server, NodeTypes.MASTER) self.uuid = self.getNewUUID(None, self.server, NodeTypes.MASTER)
logging.info('My UUID: ' + uuid_str(self.uuid)) logging.info('My UUID: ' + uuid_str(self.uuid))
else: self._node.setRunning()
in_conflict = self.nm.getByUUID(self.uuid) self._node.id_timestamp = None
if in_conflict is not None: self.primary = monotonic_time()
logging.warning('UUID conflict at election exit with %r',
in_conflict)
in_conflict.setUUID(None)
# Do not restart automatically if ElectionFailure is raised, in order # Do not restart automatically if an election happens, in order
# to avoid a split of the database. For example, with 2 machines with # to avoid a split of the database. For example, with 2 machines with
# a master and a storage on each one and replicas=1, the secondary # a master and a storage on each one and replicas=1, the secondary
# master becomes primary in case of network failure between the 2 # master becomes primary in case of network failure between the 2
...@@ -393,41 +293,91 @@ class Application(BaseApplication): ...@@ -393,41 +293,91 @@ class Application(BaseApplication):
except StateChangedException, e: except StateChangedException, e:
assert e.args[0] == ClusterStates.STOPPING assert e.args[0] == ClusterStates.STOPPING
self.shutdown() self.shutdown()
except PrimaryElected, e:
self.primary_master, = e.args
def playSecondaryRole(self): def playSecondaryRole(self):
""" """
I play a secondary role, thus only wait for a primary master to fail. A master play the secondary role when it is unlikely to win the
election (it lost against against another master during identification
or it was notified that another is the primary master).
Its only task is to try again to become the primary master when the
later fail. When connected to the cluster, the only communication is
with the primary master, to stay informed about removed/added master
nodes, and exit if requested.
""" """
logging.info('play the secondary role with %r', self.listening_conn) logging.info('play the secondary role with %r', self.listening_conn)
self.primary = None
# Wait for an announcement. If this is too long, probably handler = master.PrimaryHandler(self)
# the primary master is down. # The connection to the probably-primary master can be in any state
# XXX: Same remark as in electPrimary. # depending on how we were informed. The only case in which it can not
t = time() + 10 # be reused in when we have pending requests.
while self.primary_master_node is None: if self.primary_master.isConnected(True):
self.em.poll(1) master_conn = self.primary_master.getConnection()
if t < time(): # When we find the primary during identification, we don't attach
# election timeout # the connection (a server one) to any node, and it will be closed
raise ElectionFailure("Election timeout") # in the below 'for' loop.
self.master_address_dict.clear() assert master_conn.isClient(), master_conn
try:
# Restart completely. Non-optimized # We want the handler to be effective immediately.
# but lower level code needs to be stabilized first. # If it's not possible, let's just reconnect.
if not master_conn.setHandler(handler):
master_conn.close()
assert False
except PrimaryFailure:
master_conn = None
else:
master_conn = None
for conn in self.em.getConnectionList(): for conn in self.em.getConnectionList():
if not conn.isListening(): if conn.isListening():
conn.close() conn.setHandler(
# Reconnect to primary master node.
self.nm.reset()
primary_handler = secondary.PrimaryHandler(self)
ClientConnection(self, primary_handler, self.primary_master_node)
# and another for the future incoming connections
self.listening_conn.setHandler(
identification.SecondaryIdentificationHandler(self)) identification.SecondaryIdentificationHandler(self))
elif conn is not master_conn:
conn.close()
failed = {self.server}
poll = self.em.poll
while True: while True:
self.em.poll(1) try:
if master_conn is None:
for node in self.nm.getMasterList():
node.setDown()
node = self.primary_master
failed.add(node.getAddress())
if not node.isConnected(True):
# On immediate connection failure,
# PrimaryFailure is raised.
ClientConnection(self, handler, node)
else:
master_conn = None
while True:
poll(1)
except PrimaryFailure:
if self.primary_master.isRunning():
# XXX: What's the best to do here ? Another option is to
# choose the RUNNING master node with the lowest
# election key (i.e. (id_timestamp, address) as
# defined in IdentificationHandler), and return if we
# have the lowest one.
failed = {self.server}
else:
# Since the last primary failure (or since we play the
# secondary role), do not try any node more than once.
for self.primary_master in self.nm.getMasterList():
if self.primary_master.getAddress() not in failed:
break
else:
# All known master nodes are either down or secondary.
# Let's play the primary role again.
break
except PrimaryElected, e:
node = self.primary_master
self.primary_master, = e.args
assert node is not self.primary_master, node
try:
node.getConnection().close()
except PrimaryFailure:
pass
def runManager(self, manager_klass): def runManager(self, manager_klass):
self._current_manager = manager_klass(self) self._current_manager = manager_klass(self)
...@@ -456,9 +406,14 @@ class Application(BaseApplication): ...@@ -456,9 +406,14 @@ class Application(BaseApplication):
# change handlers # change handlers
notification_packet = Packets.NotifyClusterInformation(state) notification_packet = Packets.NotifyClusterInformation(state)
for node in self.nm.getIdentifiedList(): for node in self.nm.getList():
if not node.isConnected(True):
continue
conn = node.getConnection() conn = node.getConnection()
if node.isIdentified():
conn.send(notification_packet) conn.send(notification_packet)
elif conn.isServer():
continue
if node.isClient(): if node.isClient():
if state == ClusterStates.RUNNING: if state == ClusterStates.RUNNING:
handler = self.client_service_handler handler = self.client_service_handler
...@@ -468,6 +423,11 @@ class Application(BaseApplication): ...@@ -468,6 +423,11 @@ class Application(BaseApplication):
if state != ClusterStates.STOPPING: if state != ClusterStates.STOPPING:
conn.abort() conn.abort()
continue continue
elif node.isMaster():
if state == ClusterStates.RECOVERING:
handler = self.election_handler
else:
handler = self.secondary_handler
elif node.isStorage() and storage_handler: elif node.isStorage() and storage_handler:
handler = storage_handler handler = storage_handler
else: else:
...@@ -485,7 +445,9 @@ class Application(BaseApplication): ...@@ -485,7 +445,9 @@ class Application(BaseApplication):
return uuid return uuid
hob = UUID_NAMESPACES[node_type] hob = UUID_NAMESPACES[node_type]
for uuid in xrange((hob << 24) + 1, hob + 0x10 << 24): for uuid in xrange((hob << 24) + 1, hob + 0x10 << 24):
if uuid != self.uuid and getByUUID(uuid) is None: node = getByUUID(uuid)
if node is None or None is not address == node.getAddress():
assert uuid != self.uuid
return uuid return uuid
raise RuntimeError raise RuntimeError
...@@ -517,17 +479,19 @@ class Application(BaseApplication): ...@@ -517,17 +479,19 @@ class Application(BaseApplication):
logging.info("asking remaining nodes to shutdown") logging.info("asking remaining nodes to shutdown")
self.listening_conn.close() self.listening_conn.close()
handler = EventHandler(self) handler = EventHandler(self)
for node in self.nm.getConnectedList(): for node in self.nm.getList():
if not node.isConnected(True):
continue
conn = node.getConnection() conn = node.getConnection()
if node.isStorage():
conn.setHandler(handler) conn.setHandler(handler)
if not conn.connecting:
if node.isStorage():
conn.send(Packets.NotifyNodeInformation(monotonic_time(), (( conn.send(Packets.NotifyNodeInformation(monotonic_time(), ((
node.getType(), node.getAddress(), node.getUUID(), node.getType(), node.getAddress(), node.getUUID(),
NodeStates.TEMPORARILY_DOWN, None),))) NodeStates.DOWN, None),)))
conn.abort() if conn.pending():
elif conn.pending():
conn.abort() conn.abort()
else: continue
conn.close() conn.close()
while self.em.connection_dict: while self.em.connection_dict:
......
...@@ -18,9 +18,7 @@ from ..app import monotonic_time ...@@ -18,9 +18,7 @@ from ..app import monotonic_time
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets, from neo.lib.protocol import Packets
BrokenNodeDisallowedError, ProtocolError,
)
X = 0 X = 0
...@@ -30,41 +28,10 @@ class MasterHandler(EventHandler): ...@@ -30,41 +28,10 @@ class MasterHandler(EventHandler):
def connectionCompleted(self, conn, new=None): def connectionCompleted(self, conn, new=None):
if new is None: if new is None:
super(MasterHandler, self).connectionCompleted(conn) super(MasterHandler, self).connectionCompleted(conn)
elif new:
self._notifyNodeInformation(conn)
def requestIdentification(self, conn, node_type, uuid, address, name, _): def connectionLost(self, conn, new_state=None):
self.checkClusterName(name) if self.app.listening_conn: # if running
app = self.app self._connectionLost(conn)
node = app.nm.getByUUID(uuid)
if node:
if node_type is NodeTypes.MASTER and not (
None != address == node.getAddress()):
raise ProtocolError
if node.isBroken():
raise BrokenNodeDisallowedError
peer_uuid = self._setupNode(conn, node_type, uuid, address, node)
if app.primary:
primary_address = app.server
elif app.primary_master_node is not None:
primary_address = app.primary_master_node.getAddress()
else:
primary_address = None
known_master_list = [(app.server, app.uuid)]
for n in app.nm.getMasterList():
if n.isBroken():
continue
known_master_list.append((n.getAddress(), n.getUUID()))
conn.answer(Packets.AcceptIdentification(
NodeTypes.MASTER,
app.uuid,
app.pt.getPartitions(),
app.pt.getReplicas(),
peer_uuid,
primary_address,
known_master_list),
)
def askClusterState(self, conn): def askClusterState(self, conn):
state = self.app.getClusterState() state = self.app.getClusterState()
...@@ -94,11 +61,12 @@ class MasterHandler(EventHandler): ...@@ -94,11 +61,12 @@ class MasterHandler(EventHandler):
self.app.getLastTransaction())) self.app.getLastTransaction()))
def _notifyNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
nm = self.app.nm app = self.app
node_list = [] node = app.nm.getByUUID(conn.getUUID())
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list = app.nm.getList()
node_list.extend(n.asTuple() for n in nm.getClientList()) node_list.remove(node)
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list = ([node.asTuple()] # for id_timestamp
+ app.getNodeInformationDict(node_list)[node.getType()])
conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list)) conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
...@@ -106,15 +74,10 @@ class MasterHandler(EventHandler): ...@@ -106,15 +74,10 @@ class MasterHandler(EventHandler):
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
DISCONNECTED_STATE_DICT = {
NodeTypes.STORAGE: NodeStates.TEMPORARILY_DOWN,
}
class BaseServiceHandler(MasterHandler): class BaseServiceHandler(MasterHandler):
"""This class deals with events for a service phase.""" """This class deals with events for a service phase."""
def connectionCompleted(self, conn, new): def connectionCompleted(self, conn, new):
self._notifyNodeInformation(conn)
pt = self.app.pt pt = self.app.pt
conn.send(Packets.SendPartitionTable(pt.getID(), pt.getRowList())) conn.send(Packets.SendPartitionTable(pt.getID(), pt.getRowList()))
...@@ -125,21 +88,16 @@ class BaseServiceHandler(MasterHandler): ...@@ -125,21 +88,16 @@ class BaseServiceHandler(MasterHandler):
return # for example, when a storage is removed by an admin return # for example, when a storage is removed by an admin
assert node.isStorage(), node assert node.isStorage(), node
logging.info('storage node lost') logging.info('storage node lost')
if new_state != NodeStates.BROKEN: if node.isPending():
new_state = DISCONNECTED_STATE_DICT.get(node.getType(),
NodeStates.DOWN)
assert new_state in (NodeStates.TEMPORARILY_DOWN, NodeStates.DOWN,
NodeStates.BROKEN), new_state
assert node.getState() not in (NodeStates.TEMPORARILY_DOWN,
NodeStates.DOWN, NodeStates.BROKEN), (uuid_str(self.app.uuid),
node.whoSetState(), new_state)
was_pending = node.isPending()
node.setState(new_state)
if new_state != NodeStates.BROKEN and was_pending:
# was in pending state, so drop it from the node manager to forget # was in pending state, so drop it from the node manager to forget
# it and do not set in running state when it comes back # it and do not set in running state when it comes back
logging.info('drop a pending node from the node manager') logging.info('drop a pending node from the node manager')
app.nm.remove(node) node.setUnknown()
elif node.isDown():
# Already put in DOWN state by AdministrationHandler.setNodeState
return
else:
node.setDown()
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
if app.truncate_tid: if app.truncate_tid:
raise StoppedOperation raise StoppedOperation
......
...@@ -34,8 +34,8 @@ CLUSTER_STATE_WORKFLOW = { ...@@ -34,8 +34,8 @@ CLUSTER_STATE_WORKFLOW = {
ClusterStates.STARTING_BACKUP), ClusterStates.STARTING_BACKUP),
} }
NODE_STATE_WORKFLOW = { NODE_STATE_WORKFLOW = {
NodeTypes.MASTER: (NodeStates.UNKNOWN,), NodeTypes.MASTER: (NodeStates.DOWN,),
NodeTypes.STORAGE: (NodeStates.UNKNOWN, NodeStates.DOWN), NodeTypes.STORAGE: (NodeStates.DOWN, NodeStates.UNKNOWN),
} }
class AdministrationHandler(MasterHandler): class AdministrationHandler(MasterHandler):
...@@ -43,6 +43,7 @@ class AdministrationHandler(MasterHandler): ...@@ -43,6 +43,7 @@ class AdministrationHandler(MasterHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID()) node = self.app.nm.getByUUID(conn.getUUID())
if node is not None:
self.app.nm.remove(node) self.app.nm.remove(node)
def setClusterState(self, conn, state): def setClusterState(self, conn, state):
...@@ -95,7 +96,7 @@ class AdministrationHandler(MasterHandler): ...@@ -95,7 +96,7 @@ class AdministrationHandler(MasterHandler):
message = ('state changed' if state_changed else message = ('state changed' if state_changed else
'node already in %s state' % state) 'node already in %s state' % state)
if node.isStorage(): if node.isStorage():
keep = state == NodeStates.UNKNOWN keep = state == NodeStates.DOWN
try: try:
cell_list = app.pt.dropNodeList([node], keep) cell_list = app.pt.dropNodeList([node], keep)
except PartitionTableException, e: except PartitionTableException, e:
......
...@@ -15,30 +15,21 @@ ...@@ -15,30 +15,21 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib.protocol import NodeStates, Packets, ProtocolError, MAX_TID, Errors from neo.lib.protocol import Packets, ProtocolError, MAX_TID, Errors
from ..app import monotonic_time from ..app import monotonic_time
from . import MasterHandler from . import MasterHandler
class ClientServiceHandler(MasterHandler): class ClientServiceHandler(MasterHandler):
""" Handler dedicated to client during service state """ """ Handler dedicated to client during service state """
def connectionLost(self, conn, new_state): def _connectionLost(self, conn):
# cancel its transactions and forgot the node # cancel its transactions and forgot the node
app = self.app app = self.app
if app.listening_conn: # if running
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
assert node is not None assert node is not None, conn
app.tm.clientLost(node) app.tm.clientLost(node)
node.setState(NodeStates.DOWN) node.setUnknown()
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
app.nm.remove(node)
def _notifyNodeInformation(self, conn):
nm = self.app.nm
node_list = [nm.getByUUID(conn.getUUID()).asTuple()] # for id_timestamp
node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list))
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn, tid):
""" """
......
#
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging
from neo.lib.protocol import uuid_str, NodeTypes, Packets
from neo.lib.protocol import NotReadyError
from neo.lib.exception import ElectionFailure
from neo.lib.handler import EventHandler
from . import MasterHandler
class BaseElectionHandler(EventHandler):
def _notifyNodeInformation(self, conn):
pass
def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested'
def announcePrimary(self, conn):
app = self.app
if app.primary:
# I am also the primary... So restart the election.
raise ElectionFailure, 'another primary arises'
try:
address = app.master_address_dict[conn]
assert conn.isServer()
except KeyError:
address = conn.getAddress()
assert conn.isClient()
app.primary = False
app.primary_master_node = node = app.nm.getByAddress(address)
app.negotiating_master_node_set.clear()
logging.info('%s is the primary', node)
def elect(self, conn, peer_address):
app = self.app
if app.server < peer_address:
app.primary = False
if conn is not None:
app.master_address_dict[conn] = peer_address
app.negotiating_master_node_set.discard(peer_address)
class ClientElectionHandler(BaseElectionHandler):
def notifyNodeInformation(self, conn, timestamp, node_list):
# XXX: For the moment, do nothing because
# we'll close this connection and reconnect.
pass
def connectionFailed(self, conn):
addr = conn.getAddress()
node = self.app.nm.getByAddress(addr)
assert node is not None, (uuid_str(self.app.uuid), addr)
# node may still be in unknown state
self.app.negotiating_master_node_set.discard(addr)
super(ClientElectionHandler, self).connectionFailed(conn)
def connectionCompleted(self, conn):
app = self.app
conn.ask(Packets.RequestIdentification(
NodeTypes.MASTER,
app.uuid,
app.server,
app.name,
None,
))
super(ClientElectionHandler, self).connectionCompleted(conn)
def connectionLost(self, conn, new_state):
# Retry connection. Either the node just died (and we will end up in
# connectionFailed) or it just got elected (and we must not ignore
# that node).
addr = conn.getAddress()
self.app.unconnected_master_node_set.add(addr)
self.app.negotiating_master_node_set.discard(addr)
def _acceptIdentification(self, node, peer_uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list):
app = self.app
# Register new master nodes.
for address, uuid in known_master_list:
if app.server == address:
# This is self.
assert node.getAddress() != primary or uuid == your_uuid, (
uuid_str(uuid), uuid_str(your_uuid))
continue
n = app.nm.getByAddress(address)
if n is None:
n = app.nm.createMaster(address=address)
if primary is not None:
# The primary master is defined.
if app.primary_master_node is not None \
and app.primary_master_node.getAddress() != primary:
# There are multiple primary master nodes. This is
# dangerous.
raise ElectionFailure, 'multiple primary master nodes'
primary_node = app.nm.getByAddress(primary)
if primary_node is None:
# I don't know such a node. Probably this information
# is old. So ignore it.
logging.warning('received an unknown primary node')
else:
# Whatever the situation is, I trust this master.
app.primary = False
app.primary_master_node = primary_node
# Stop waiting for connections than primary master's to
# complete to exit election phase ASAP.
app.negotiating_master_node_set.clear()
return
self.elect(None, node.getAddress())
class ServerElectionHandler(BaseElectionHandler, MasterHandler):
def _setupNode(self, conn, node_type, uuid, address, node):
app = self.app
if node_type != NodeTypes.MASTER:
logging.info('reject a connection from a non-master')
raise NotReadyError
if node is None is app.nm.getByAddress(address):
app.nm.createMaster(address=address)
self.elect(conn, address)
return uuid
...@@ -15,26 +15,25 @@ ...@@ -15,26 +15,25 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import PrimaryElected
from neo.lib.handler import EventHandler
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \
NotReadyError, ProtocolError, uuid_str NotReadyError, Packets, ProtocolError, uuid_str
from ..app import monotonic_time from ..app import monotonic_time
from . import MasterHandler
class IdentificationHandler(MasterHandler): class IdentificationHandler(EventHandler):
def requestIdentification(self, conn, *args, **kw): def requestIdentification(self, conn, node_type, uuid,
super(IdentificationHandler, self).requestIdentification(conn, *args, address, name, id_timestamp):
**kw)
handler = conn.getHandler()
assert not isinstance(handler, IdentificationHandler), handler
handler.connectionCompleted(conn, True)
def _setupNode(self, conn, node_type, uuid, address, node):
app = self.app app = self.app
self.checkClusterName(name)
if address == app.server:
raise ProtocolError('address conflict')
node = app.nm.getByUUID(uuid)
by_addr = address and app.nm.getByAddress(address) by_addr = address and app.nm.getByAddress(address)
while 1: while 1:
if by_addr: if by_addr:
if not by_addr.isConnected(): if not by_addr.isIdentified():
if node is by_addr: if node is by_addr:
break break
if not node or uuid < 0: if not node or uuid < 0:
...@@ -43,12 +42,15 @@ class IdentificationHandler(MasterHandler): ...@@ -43,12 +42,15 @@ class IdentificationHandler(MasterHandler):
node = by_addr node = by_addr
break break
elif node: elif node:
if node.isConnected(): if node.isIdentified():
if uuid < 0: if uuid < 0:
# The peer wants a temporary id that's already assigned. # The peer wants a temporary id that's already assigned.
# Let's give it another one. # Let's give it another one.
node = uuid = None node = uuid = None
break break
else:
if node is app._node:
node = None
else: else:
node.setAddress(address) node.setAddress(address)
break break
...@@ -77,7 +79,14 @@ class IdentificationHandler(MasterHandler): ...@@ -77,7 +79,14 @@ class IdentificationHandler(MasterHandler):
uuid is not None and node is not None) uuid is not None and node is not None)
human_readable_node_type = ' storage (%s) ' % (state, ) human_readable_node_type = ' storage (%s) ' % (state, )
elif node_type == NodeTypes.MASTER: elif node_type == NodeTypes.MASTER:
handler = app.secondary_master_handler if app.election:
if id_timestamp and \
(id_timestamp, address) < (app.election, app.server):
raise PrimaryElected(by_addr or
app.nm.createMaster(address=address))
handler = app.election_handler
else:
handler = app.secondary_handler
human_readable_node_type = ' master ' human_readable_node_type = ' master '
elif node_type == NodeTypes.ADMIN: elif node_type == NodeTypes.ADMIN:
handler = app.administration_handler handler = app.administration_handler
...@@ -94,22 +103,43 @@ class IdentificationHandler(MasterHandler): ...@@ -94,22 +103,43 @@ class IdentificationHandler(MasterHandler):
node.setUUID(uuid) node.setUUID(uuid)
node.id_timestamp = monotonic_time() node.id_timestamp = monotonic_time()
node.setState(state) node.setState(state)
node.setConnection(conn)
conn.setHandler(handler) conn.setHandler(handler)
node.setConnection(conn, not node.isIdentified())
app.broadcastNodesInformation([node], node) app.broadcastNodesInformation([node], node)
return uuid
class SecondaryIdentificationHandler(MasterHandler): conn.answer(Packets.AcceptIdentification(
NodeTypes.MASTER,
app.uuid,
app.pt.getPartitions(),
app.pt.getReplicas(),
uuid))
handler._notifyNodeInformation(conn)
handler.connectionCompleted(conn, True)
def announcePrimary(self, conn):
# If we received AnnouncePrimary on a client connection, we might have
# set this handler on server connection, and might receive
# AnnouncePrimary there too. As we cannot reach this without already
# handling a first AnnouncePrimary, we can safely ignore this one.
pass
def _setupNode(self, conn, node_type, uuid, address, node): class SecondaryIdentificationHandler(EventHandler):
# Nothing to do, storage will disconnect when it receives our answer.
# Primary will do the checks.
return uuid
def requestIdentification(self, conn, node_type, uuid,
address, name, id_timestamp):
app = self.app
self.checkClusterName(name)
if address == app.server:
raise ProtocolError('address conflict')
primary = app.primary_master.getAddress()
if primary == address:
primary = None
elif not app.primary_master.isIdentified():
if node_type == NodeTypes.MASTER:
node = app.nm.createMaster(address=address)
if id_timestamp:
conn.close()
raise PrimaryElected(node)
primary = None
# For some cases, we rely on the fact that the remote will not retry
# immediately (see SocketConnector.CONNECT_LIMIT).
known_master_list = [node.getAddress()
for node in app.nm.getMasterList()]
conn.send(Packets.NotPrimaryMaster(
primary and known_master_list.index(primary),
known_master_list))
conn.abort()
...@@ -15,83 +15,81 @@ ...@@ -15,83 +15,81 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys import sys
from ..app import monotonic_time
from . import MasterHandler from . import MasterHandler
from neo.lib.handler import EventHandler from neo.lib.exception import PrimaryElected, PrimaryFailure
from neo.lib.exception import ElectionFailure, PrimaryFailure from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.protocol import NodeStates, NodeTypes, Packets, uuid_str
from neo.lib import logging
class SecondaryMasterHandler(MasterHandler):
""" Handler used by primary to handle secondary masters"""
def connectionLost(self, conn, new_state): class SecondaryHandler(MasterHandler):
"""Handler used by primary to handle secondary masters"""
def _connectionLost(self, conn):
app = self.app app = self.app
if app.listening_conn: # if running
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
node.setDown() node.setDown()
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
def announcePrimary(self, conn):
raise ElectionFailure, 'another primary arises'
def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested'
def _notifyNodeInformation(self, conn):
node_list = [n.asTuple() for n in self.app.nm.getMasterList()]
conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list))
class PrimaryHandler(EventHandler): class ElectionHandler(MasterHandler):
""" Handler used by secondaries to handle primary master""" """Handler used by primary to handle secondary masters during election"""
def connectionLost(self, conn, new_state): def connectionCompleted(self, conn, new=None):
self.connectionFailed(conn) if new is None:
super(ElectionHandler, self).connectionCompleted(conn)
app = self.app
conn.ask(Packets.RequestIdentification(NodeTypes.MASTER,
app.uuid, app.server, app.name, app.election))
def connectionFailed(self, conn): def connectionFailed(self, conn):
self.app.primary_master_node.setDown() super(ElectionHandler, self).connectionFailed(conn)
if self.app.listening_conn: # if running self.connectionLost(conn)
def _acceptIdentification(self, node, *args):
raise PrimaryElected(node)
def _connectionLost(self, *args):
if self.app.primary: # not switching to secondary role
self.app._current_manager.try_secondary = True
def notPrimaryMaster(self, *args):
try:
super(ElectionHandler, self).notPrimaryMaster(*args)
except PrimaryElected, e:
# We keep playing the primary role when the peer does not
# know yet that we won election against the returned node.
if not e.args[0].isIdentified():
raise
# There may be new master nodes. Connect to them.
self.app._current_manager.try_secondary = True
class PrimaryHandler(ElectionHandler):
"""Handler used by secondaries to handle primary master"""
def _acceptIdentification(self, node, num_partitions, num_replicas):
assert self.app.primary_master is node, (self.app.primary_master, node)
def _connectionLost(self, conn):
node = self.app.primary_master
# node is None when switching to primary role
if node and not node.isConnected(True):
raise PrimaryFailure('primary master is dead') raise PrimaryFailure('primary master is dead')
def connectionCompleted(self, conn): def notPrimaryMaster(self, *args):
app = self.app try:
addr = conn.getAddress() super(ElectionHandler, self).notPrimaryMaster(*args)
node = app.nm.getByAddress(addr) except PrimaryElected, e:
# connection successful, set it as running if e.args[0] is not self.app.primary_master:
node.setRunning() raise
conn.ask(Packets.RequestIdentification(
NodeTypes.MASTER,
app.uuid,
app.server,
app.name,
None,
))
super(PrimaryHandler, self).connectionCompleted(conn)
def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested'
def notifyClusterInformation(self, conn, state): def notifyClusterInformation(self, conn, state):
self.app.cluster_state = state if state == ClusterStates.STOPPING:
sys.exit()
def notifyNodeInformation(self, conn, timestamp, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
super(PrimaryHandler, self).notifyNodeInformation( super(PrimaryHandler, self).notifyNodeInformation(
conn, timestamp, node_list) conn, timestamp, node_list)
for node_type, _, uuid, state, _ in node_list: for node_type, _, uuid, state, _ in node_list:
assert node_type == NodeTypes.MASTER, node_type assert node_type == NodeTypes.MASTER, node_type
if uuid == self.app.uuid and state == NodeStates.UNKNOWN: if uuid == self.app.uuid and state == NodeStates.DOWN:
sys.exit() sys.exit()
def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list):
app = self.app
if primary != app.primary_master_node.getAddress():
raise PrimaryFailure('unexpected primary uuid')
if your_uuid != app.uuid:
app.uuid = your_uuid
logging.info('My UUID: ' + uuid_str(your_uuid))
node.setUUID(uuid)
app.id_timestamp = None
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.connection import ClientConnection
from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates
from .app import monotonic_time from .app import monotonic_time
from .handlers import MasterHandler from .handlers import MasterHandler
...@@ -47,6 +48,7 @@ class RecoveryManager(MasterHandler): ...@@ -47,6 +48,7 @@ class RecoveryManager(MasterHandler):
TID, and the last Partition Table ID from storage nodes, then get TID, and the last Partition Table ID from storage nodes, then get
back the latest partition table or make a new table from scratch, back the latest partition table or make a new table from scratch,
if this is the first time. if this is the first time.
A new primary master may also arise during this phase.
""" """
logging.info('begin the recovery of the status') logging.info('begin the recovery of the status')
app = self.app app = self.app
...@@ -54,9 +56,30 @@ class RecoveryManager(MasterHandler): ...@@ -54,9 +56,30 @@ class RecoveryManager(MasterHandler):
app.changeClusterState(ClusterStates.RECOVERING) app.changeClusterState(ClusterStates.RECOVERING)
pt.clear() pt.clear()
self.try_secondary = True
# collect the last partition table available # collect the last partition table available
poll = app.em.poll poll = app.em.poll
while 1: while 1:
if self.try_secondary:
# Keep trying to connect to all other known masters,
# to make sure there is a challege between each pair
# of masters in the cluster. If we win, all connections
# opened here will be closed.
self.try_secondary = False
node_list = []
for node in app.nm.getMasterList():
if not (node is app._node or node.isConnected(True)):
# During recovery, master nodes are not put back in
# DOWN state by handlers. This is done
# entirely in this method (here and after this poll
# loop), to minimize the notification packets.
if not node.isDown():
node.setDown()
node_list.append(node)
ClientConnection(app, app.election_handler, node)
if node_list:
app.broadcastNodesInformation(node_list)
poll(1) poll(1)
if pt.filled(): if pt.filled():
# A partition table exists, we are starting an existing # A partition table exists, we are starting an existing
...@@ -100,6 +123,17 @@ class RecoveryManager(MasterHandler): ...@@ -100,6 +123,17 @@ class RecoveryManager(MasterHandler):
for node in node_list: for node in node_list:
assert node.isPending(), node assert node.isPending(), node
node.setRunning() node.setRunning()
for node in app.nm.getMasterList():
if not (node is app._node or node.isIdentified()):
if node.isConnected(True):
node.getConnection().close()
assert node.isDown(), node
elif not node.isDown():
assert self.try_secondary, node
node.setDown()
node_list.append(node)
app.broadcastNodesInformation(node_list) app.broadcastNodesInformation(node_list)
if pt.getID() is None: if pt.getID() is None:
......
...@@ -157,10 +157,10 @@ class NeoCTL(BaseApplication): ...@@ -157,10 +157,10 @@ class NeoCTL(BaseApplication):
return self.setClusterState(ClusterStates.VERIFYING) return self.setClusterState(ClusterStates.VERIFYING)
def killNode(self, node): def killNode(self, node):
return self._setNodeState(node, NodeStates.UNKNOWN) return self._setNodeState(node, NodeStates.DOWN)
def dropNode(self, node): def dropNode(self, node):
return self._setNodeState(node, NodeStates.DOWN) return self._setNodeState(node, NodeStates.UNKNOWN)
def getPrimary(self): def getPrimary(self):
""" """
......
...@@ -42,7 +42,6 @@ from neo.tests.benchmark import BenchmarkRunner ...@@ -42,7 +42,6 @@ from neo.tests.benchmark import BenchmarkRunner
# each of them have to import its TestCase classes # each of them have to import its TestCase classes
UNIT_TEST_MODULES = [ UNIT_TEST_MODULES = [
# generic parts # generic parts
'neo.tests.testBootstrap',
'neo.tests.testConnection', 'neo.tests.testConnection',
'neo.tests.testHandler', 'neo.tests.testHandler',
'neo.tests.testNodes', 'neo.tests.testNodes',
...@@ -50,7 +49,6 @@ UNIT_TEST_MODULES = [ ...@@ -50,7 +49,6 @@ UNIT_TEST_MODULES = [
'neo.tests.testPT', 'neo.tests.testPT',
# master application # master application
'neo.tests.master.testClientHandler', 'neo.tests.master.testClientHandler',
'neo.tests.master.testElectionHandler',
'neo.tests.master.testMasterApp', 'neo.tests.master.testMasterApp',
'neo.tests.master.testMasterPT', 'neo.tests.master.testMasterPT',
'neo.tests.master.testRecovery', 'neo.tests.master.testRecovery',
...@@ -61,7 +59,6 @@ UNIT_TEST_MODULES = [ ...@@ -61,7 +59,6 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp', 'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
...@@ -99,13 +96,47 @@ ZODB_TEST_MODULES = [ ...@@ -99,13 +96,47 @@ ZODB_TEST_MODULES = [
] ]
class StopOnSuccess(Exception):
pass
class NeoTestRunner(unittest.TextTestResult): class NeoTestRunner(unittest.TextTestResult):
""" Custom result class to build report with statistics per module """ """ Custom result class to build report with statistics per module """
def __init__(self, title, verbosity): _readable_tid = ()
def __init__(self, title, verbosity, stop_on_success, readable_tid):
super(NeoTestRunner, self).__init__( super(NeoTestRunner, self).__init__(
_WritelnDecorator(sys.stderr), False, verbosity) _WritelnDecorator(sys.stderr), False, verbosity)
self._title = title self._title = title
self.stop_on_success = stop_on_success
if readable_tid:
from neo.lib import util
from neo.lib.util import dump, p64, u64
from neo.master.transactions import TransactionManager
def _nextTID(orig, tm, ttid=None, divisor=None):
n = self._next_tid
self._next_tid = n + 1
n = str(n).rjust(3, '-')
if ttid:
t = u64('T%s%s-' % (n, ttid[1:4]))
m = (u64(ttid) - t) % divisor
assert m < 211, (p64(t), divisor)
t = p64(t + m)
else:
t = 'T%s----' % n
assert tm._last_tid < t, (tm._last_tid, t)
tm._last_tid = t
return t
self._readable_tid = (
Patch(self, 1, _next_tid=0),
Patch(TransactionManager, _nextTID=_nextTID),
Patch(util, 1, orig_dump=type(dump)(
dump.__code__, dump.__globals__)),
Patch(dump, __code__=(lambda s:
s if type(s) is str and s.startswith('T') else
orig_dump(s)).__code__),
)
self.modulesStats = {} self.modulesStats = {}
self.failedImports = {} self.failedImports = {}
self.run_dict = defaultdict(int) self.run_dict = defaultdict(int)
...@@ -160,17 +191,29 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -160,17 +191,29 @@ class NeoTestRunner(unittest.TextTestResult):
def startTest(self, test): def startTest(self, test):
super(NeoTestRunner, self).startTest(test) super(NeoTestRunner, self).startTest(test)
for patch in self._readable_tid:
patch.apply()
self.run_dict[test.__class__.__module__] += 1 self.run_dict[test.__class__.__module__] += 1
self.start_time = time.time() self.start_time = time.time()
def stopTest(self, test): def stopTest(self, test):
self.time_dict[test.__class__.__module__] += \ self.time_dict[test.__class__.__module__] += \
time.time() - self.start_time time.time() - self.start_time
for patch in self._readable_tid:
patch.revert()
super(NeoTestRunner, self).stopTest(test) super(NeoTestRunner, self).stopTest(test)
if self.stop_on_success is not None:
count = self.getUnexpectedCount()
if (count < self.testsRun - len(self.skipped)
if self.stop_on_success else count):
raise StopOnSuccess
def getUnexpectedCount(self):
return (len(self.errors) + len(self.failures)
+ len(self.unexpectedSuccesses))
def _buildSummary(self, add_status): def _buildSummary(self, add_status):
unexpected_count = len(self.errors) + len(self.failures) \ unexpected_count = self.getUnexpectedCount()
+ len(self.unexpectedSuccesses)
expected_count = len(self.expectedFailures) expected_count = len(self.expectedFailures)
success = self.testsRun - unexpected_count - expected_count success = self.testsRun - unexpected_count - expected_count
add_status('Directory', self.temp_directory) add_status('Directory', self.temp_directory)
...@@ -219,6 +262,8 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -219,6 +262,8 @@ class NeoTestRunner(unittest.TextTestResult):
def buildReport(self, add_status): def buildReport(self, add_status):
subject, summary = self._buildSummary(add_status) subject, summary = self._buildSummary(add_status)
if self.stop_on_success:
return subject, summary
body = StringIO() body = StringIO()
body.write(summary) body.write(summary)
for test in self.unexpectedSuccesses: for test in self.unexpectedSuccesses:
...@@ -243,6 +288,17 @@ class TestRunner(BenchmarkRunner): ...@@ -243,6 +288,17 @@ class TestRunner(BenchmarkRunner):
help='Repeat tests several times') help='Repeat tests several times')
parser.add_option('-f', '--functional', action='store_true', parser.add_option('-f', '--functional', action='store_true',
help='Functional tests') help='Functional tests')
parser.add_option('-s', '--stop-on-error', action='store_false',
dest='stop_on_success',
help='Continue as long as tests pass successfully.'
' It is usually combined with --loop, to check that tests'
' do not fail randomly.')
parser.add_option('-S', '--stop-on-success', action='store_true',
help='Opposite of --stop-on-error: stop as soon as a test'
' passes. Details about errors are not printed at exit.')
parser.add_option('-r', '--readable-tid', action='store_true',
help='Change master behaviour to generate readable TIDs for easier'
' debugging (rather than from current time).')
parser.add_option('-u', '--unit', action='store_true', parser.add_option('-u', '--unit', action='store_true',
help='Unit & threaded tests') help='Unit & threaded tests')
parser.add_option('-z', '--zodb', action='store_true', parser.add_option('-z', '--zodb', action='store_true',
...@@ -292,6 +348,8 @@ Environment Variables: ...@@ -292,6 +348,8 @@ Environment Variables:
coverage = options.coverage, coverage = options.coverage,
cov_unit = options.cov_unit, cov_unit = options.cov_unit,
only = args, only = args,
stop_on_success = options.stop_on_success,
readable_tid = options.readable_tid,
) )
def start(self): def start(self):
...@@ -300,7 +358,8 @@ Environment Variables: ...@@ -300,7 +358,8 @@ Environment Variables:
**({'max_size': None} if config.log else {})) **({'max_size': None} if config.log else {}))
only = config.only only = config.only
# run requested tests # run requested tests
runner = NeoTestRunner(config.title or 'Neo', config.verbosity) runner = NeoTestRunner(config.title or 'Neo', config.verbosity,
config.stop_on_success, config.readable_tid)
if config.cov_unit: if config.cov_unit:
from coverage import Coverage from coverage import Coverage
cov_dir = runner.temp_directory + '/coverage' cov_dir = runner.temp_directory + '/coverage'
...@@ -327,6 +386,8 @@ Environment Variables: ...@@ -327,6 +386,8 @@ Environment Variables:
except KeyboardInterrupt: except KeyboardInterrupt:
config['mail_to'] = None config['mail_to'] = None
traceback.print_exc() traceback.print_exc()
except StopOnSuccess:
pass
if config.coverage: if config.coverage:
coverage.stop() coverage.stop()
if coverage.neotestrunner: if coverage.neotestrunner:
...@@ -335,7 +396,7 @@ Environment Variables: ...@@ -335,7 +396,7 @@ Environment Variables:
if runner.dots: if runner.dots:
print print
# build report # build report
if only and not config.mail_to: if (only or config.stop_on_success) and not config.mail_to:
runner._buildSummary = lambda *args: ( runner._buildSummary = lambda *args: (
runner.__class__._buildSummary(runner, *args)[0], '') runner.__class__._buildSummary(runner, *args)[0], '')
self.build_report = str self.build_report = str
...@@ -343,6 +404,8 @@ Environment Variables: ...@@ -343,6 +404,8 @@ Environment Variables:
return runner.buildReport(self.add_status) return runner.buildReport(self.add_status)
def main(args=None): def main(args=None):
from neo.storage.database.manager import DatabaseManager
DatabaseManager.UNSAFE = True
runner = TestRunner() runner = TestRunner()
runner.run() runner.run()
return sys.exit(not runner.was_successful()) return sys.exit(not runner.was_successful())
......
...@@ -28,8 +28,7 @@ from neo.lib.util import dump ...@@ -28,8 +28,7 @@ from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker from .checker import Checker
from .database import buildDatabaseManager from .database import buildDatabaseManager
from .handlers import identification, initialization from .handlers import identification, initialization, master
from .handlers import master, hidden
from .replicator import Replicator from .replicator import Replicator
from .transactions import TransactionManager from .transactions import TransactionManager
...@@ -170,10 +169,6 @@ class Application(BaseApplication): ...@@ -170,10 +169,6 @@ class Application(BaseApplication):
if self.master_node is None: if self.master_node is None:
# look for the primary master # look for the primary master
self.connectToPrimary() self.connectToPrimary()
# check my state
node = self.nm.getByUUID(self.uuid)
if node is not None and node.isHidden():
self.wait()
self.checker = Checker(self) self.checker = Checker(self)
self.replicator = Replicator(self) self.replicator = Replicator(self)
self.tm = TransactionManager(self) self.tm = TransactionManager(self)
...@@ -274,20 +269,6 @@ class Application(BaseApplication): ...@@ -274,20 +269,6 @@ class Application(BaseApplication):
if state == ClusterStates.STOPPING_BACKUP: if state == ClusterStates.STOPPING_BACKUP:
self.replicator.stop() self.replicator.stop()
def wait(self):
# change handler
logging.info("waiting in hidden state")
_poll = self._poll
handler = hidden.HiddenHandler(self)
for conn in self.em.getConnectionList():
conn.setHandler(handler)
node = self.nm.getByUUID(self.uuid)
while True:
_poll()
if not node.isHidden():
break
def newTask(self, iterator): def newTask(self, iterator):
try: try:
iterator.next() iterator.next()
......
...@@ -55,6 +55,7 @@ class DatabaseManager(object): ...@@ -55,6 +55,7 @@ class DatabaseManager(object):
"""This class only describes an interface for database managers.""" """This class only describes an interface for database managers."""
ENGINES = () ENGINES = ()
UNSAFE = False
_deferred = 0 _deferred = 0
_duplicating = _repairing = None _duplicating = _repairing = None
......
...@@ -78,6 +78,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -78,6 +78,10 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
if self.UNSAFE:
q = self.query
q("PRAGMA synchronous = OFF")
q("PRAGMA journal_mode = MEMORY")
self._config = {} self._config = {}
def _commit(self): def _commit(self):
...@@ -108,6 +112,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -108,6 +112,10 @@ class SQLiteDatabaseManager(DatabaseManager):
raise raise
def _setup(self): def _setup(self):
# SQLite does support transactional Data Definition Language statements
# but unfortunately, the built-in Python binding automatically commits
# between such statements. This anti-feature causes this method to be
# relatively slow; unit tests enables the UNSAFE boolean flag.
self._config.clear() self._config.clear()
q = self.query q = self.query
......
...@@ -56,12 +56,9 @@ class BaseMasterHandler(BaseHandler): ...@@ -56,12 +56,9 @@ class BaseMasterHandler(BaseHandler):
if uuid == self.app.uuid: if uuid == self.app.uuid:
# This is me, do what the master tell me # This is me, do what the master tell me
logging.info("I was told I'm %s", state) logging.info("I was told I'm %s", state)
if state in (NodeStates.DOWN, NodeStates.TEMPORARILY_DOWN, if state in (NodeStates.UNKNOWN, NodeStates.DOWN):
NodeStates.BROKEN, NodeStates.UNKNOWN): erase = state == NodeStates.UNKNOWN
erase = state == NodeStates.DOWN
self.app.shutdown(erase=erase) self.app.shutdown(erase=erase)
elif state == NodeStates.HIDDEN:
raise StoppedOperation
elif node_type == NodeTypes.CLIENT and state != NodeStates.RUNNING: elif node_type == NodeTypes.CLIENT and state != NodeStates.RUNNING:
logging.info('Notified of non-running client, abort (%s)', logging.info('Notified of non-running client, abort (%s)',
uuid_str(uuid)) uuid_str(uuid))
......
#
# Copyright (C) 2006-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler
from neo.lib import logging
from neo.lib.protocol import CellStates
class HiddenHandler(BaseMasterHandler):
"""This class implements a generic part of the event handlers."""
def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that
the information is only about changes from the previous."""
app = self.app
if ptid <= app.pt.getID():
# Ignore this packet.
logging.debug('ignoring older partition changes')
return
# update partition table in memory and the database
app.pt.update(ptid, cell_list, app.nm)
app.dm.changePartitionTable(ptid, cell_list)
# Check changes for replications
for offset, uuid, state in cell_list:
if uuid == app.uuid and app.replicator is not None:
# If this is for myself, this can affect replications.
if state == CellStates.DISCARDED:
app.replicator.removePartition(offset)
elif state == CellStates.OUT_OF_DATE:
app.replicator.addPartition(offset)
def startOperation(self, conn):
self.app.operational = True
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import NodeTypes, NotReadyError, Packets from neo.lib.protocol import NodeTypes, NotReadyError, Packets
from neo.lib.protocol import ProtocolError, BrokenNodeDisallowedError from neo.lib.protocol import ProtocolError
from .storage import StorageOperationHandler from .storage import StorageOperationHandler
from .client import ClientOperationHandler, ClientReadOnlyOperationHandler from .client import ClientOperationHandler, ClientReadOnlyOperationHandler
...@@ -47,8 +47,6 @@ class IdentificationHandler(EventHandler): ...@@ -47,8 +47,6 @@ class IdentificationHandler(EventHandler):
if uuid == app.uuid: if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection") raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid, id_timestamp) node = app.nm.getByUUID(uuid, id_timestamp)
if node.isBroken():
raise BrokenNodeDisallowedError
# choose the handler according to the node type # choose the handler according to the node type
if node_type == NodeTypes.CLIENT: if node_type == NodeTypes.CLIENT:
if app.dm.getBackupTID(): if app.dm.getBackupTID():
...@@ -67,6 +65,5 @@ class IdentificationHandler(EventHandler): ...@@ -67,6 +65,5 @@ class IdentificationHandler(EventHandler):
node.setConnection(conn, app.uuid < uuid) node.setConnection(conn, app.uuid < uuid)
# accept the identification and trigger an event # accept the identification and trigger an event
conn.answer(Packets.AcceptIdentification(NodeTypes.STORAGE, uuid and conn.answer(Packets.AcceptIdentification(NodeTypes.STORAGE, uuid and
app.uuid, app.pt.getPartitions(), app.pt.getReplicas(), uuid, app.uuid, app.pt.getPartitions(), app.pt.getReplicas(), uuid))
app.master_node.getAddress(), ()))
handler.connectionCompleted(conn) handler.connectionCompleted(conn)
...@@ -18,8 +18,7 @@ import weakref ...@@ -18,8 +18,7 @@ import weakref
from functools import wraps from functools import wraps
from neo.lib.connection import ConnectionClosed from neo.lib.connection import ConnectionClosed
from neo.lib.handler import DelayEvent, EventHandler from neo.lib.handler import DelayEvent, EventHandler
from neo.lib.protocol import Errors, NodeStates, Packets, ProtocolError, \ from neo.lib.protocol import Errors, Packets, ProtocolError, ZERO_HASH
ZERO_HASH
def checkConnectionIsReplicatorConnection(func): def checkConnectionIsReplicatorConnection(func):
def wrapper(self, conn, *args, **kw): def wrapper(self, conn, *args, **kw):
...@@ -53,7 +52,7 @@ class StorageOperationHandler(EventHandler): ...@@ -53,7 +52,7 @@ class StorageOperationHandler(EventHandler):
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
else: else:
node = app.nm.getByAddress(conn.getAddress()) node = app.nm.getByAddress(conn.getAddress())
node.setState(NodeStates.DOWN) node.setUnknown()
replicator = app.replicator replicator = app.replicator
if replicator.current_node is node: if replicator.current_node is node:
replicator.abort() replicator.abort()
......
...@@ -28,7 +28,10 @@ import weakref ...@@ -28,7 +28,10 @@ import weakref
import MySQLdb import MySQLdb
import transaction import transaction
from cStringIO import StringIO
from cPickle import Unpickler
from functools import wraps from functools import wraps
from inspect import isclass
from .mock import Mock from .mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
...@@ -39,6 +42,7 @@ from unittest.case import _ExpectedFailure, _UnexpectedSuccess ...@@ -39,6 +42,7 @@ from unittest.case import _ExpectedFailure, _UnexpectedSuccess
try: try:
from transaction.interfaces import IDataManager from transaction.interfaces import IDataManager
from ZODB.utils import newTid from ZODB.utils import newTid
from ZODB.ConflictResolution import PersistentReferenceFactory
except ImportError: except ImportError:
pass pass
...@@ -309,10 +313,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -309,10 +313,6 @@ class NeoUnitTestBase(NeoTestBase):
""" Check if the ProtocolError exception was raised """ """ Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs) self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception was raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkAborted(self, conn): def checkAborted(self, conn):
""" Ensure the connection was aborted """ """ Ensure the connection was aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1) self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
...@@ -330,16 +330,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -330,16 +330,6 @@ class NeoUnitTestBase(NeoTestBase):
self._checkNoPacketSend(conn, 'answer') self._checkNoPacketSend(conn, 'answer')
self._checkNoPacketSend(conn, 'ask') self._checkNoPacketSend(conn, 'ask')
def checkUUIDSet(self, conn, uuid=None, check_intermediate=True):
""" ensure UUID was set on the connection """
calls = conn.mockGetNamedCalls('setUUID')
found_uuid = calls.pop().getParam(0)
if check_intermediate:
for call in calls:
self.assertEqual(found_uuid, call.getParam(0))
if uuid is not None:
self.assertEqual(found_uuid, uuid)
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurate checks are required # in tests if more accurate checks are required
...@@ -477,9 +467,12 @@ class Patch(object): ...@@ -477,9 +467,12 @@ class Patch(object):
self._patch = patch self._patch = patch
try: try:
orig = patched.__dict__[name] orig = patched.__dict__[name]
self._revert = lambda: setattr(patched, name, orig)
except KeyError: except KeyError:
if new or isclass(patched):
self._revert = lambda: delattr(patched, name) self._revert = lambda: delattr(patched, name)
return
orig = getattr(patched, name)
self._revert = lambda: setattr(patched, name, orig)
def apply(self): def apply(self):
assert not self.applied assert not self.applied
...@@ -502,5 +495,11 @@ class Patch(object): ...@@ -502,5 +495,11 @@ class Patch(object):
self.__del__() self.__del__()
def unpickle_state(data):
unpickler = Unpickler(StringIO(data))
unpickler.persistent_load = PersistentReferenceFactory().persistent_load
unpickler.load() # skip the class tuple
return unpickler.load()
__builtin__.pdb = lambda depth=0: \ __builtin__.pdb = lambda depth=0: \
debug.getPdb().set_trace(sys._getframe(depth+1)) debug.getPdb().set_trace(sys._getframe(depth+1))
...@@ -21,7 +21,6 @@ from .. import NeoUnitTestBase, buildUrlFromString ...@@ -21,7 +21,6 @@ from .. import NeoUnitTestBase, buildUrlFromString
from neo.client.app import Application from neo.client.app import Application
from neo.client.cache import test as testCache from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.lib.protocol import NodeTypes, UUID_NAMESPACES
class ClientApplicationTests(NeoUnitTestBase): class ClientApplicationTests(NeoUnitTestBase):
...@@ -97,63 +96,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -97,63 +96,6 @@ class ClientApplicationTests(NeoUnitTestBase):
# no packet sent # no packet sent
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
def test_connectToPrimaryNode(self):
# here we have three master nodes :
# the connection to the first will fail
# the second will have changed
# the third will not be ready
# after the third, the partition table will be operational
# (as if it was connected to the primary master node)
# will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases
# askLastTransaction
def _ask8(_):
pass
# Sixth packet : askPartitionTable succeeded
def _ask7(_):
app.pt = Mock({'operational': True})
# fifth packet : request node identification succeeded
def _ask6(conn):
app.master_conn = conn
app.uuid = 1 + (UUID_NAMESPACES[NodeTypes.CLIENT] << 24)
app.trying_master_node = app.primary_master_node = Mock({
'getAddress': ('127.0.0.1', 10011),
'__str__': 'Fake master node',
})
# third iteration : node not ready
def _ask4(_):
app.trying_master_node = None
# second iteration : master node changed
def _ask3(_):
app.primary_master_node = Mock({
'getAddress': ('127.0.0.1', 10010),
'__str__': 'Fake master node',
})
# first iteration : connection failed
def _ask2(_):
app.trying_master_node = None
# do nothing for the first call
# Case of an unknown primary_uuid (XXX: handler should probably raise,
# it's not normal for a node to inform of a primary uuid without
# telling us what its address is.)
def _ask1(_):
pass
ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, _ask8]
def _ask_base(conn, _, handler=None):
ask_func_list.pop(0)(conn)
app.nm.getByAddress(conn.getAddress())._connection = None
app._ask = _ask_base
# fake environment
app.em.close()
app.em = Mock({'getConnectionList': []})
app.pt = Mock({ 'operational': False})
app.start = lambda: None
app.master_conn = app._connectToPrimaryNode()
self.assertFalse(ask_func_list)
self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -617,13 +617,8 @@ class NEOCluster(object): ...@@ -617,13 +617,8 @@ class NEOCluster(object):
self.expectStorageState(process.getUUID(), NodeStates.PENDING, self.expectStorageState(process.getUUID(), NodeStates.PENDING,
*args, **kw) *args, **kw)
def expectUnknown(self, process, *args, **kw): def expectDown(self, process, *args, **kw):
self.expectStorageState(process.getUUID(), NodeStates.UNKNOWN, self.expectStorageState(process.getUUID(), NodeStates.DOWN, *args, **kw)
*args, **kw)
def expectUnavailable(self, process, *args, **kw):
self.expectStorageState(process.getUUID(),
NodeStates.TEMPORARILY_DOWN, *args, **kw)
def expectPrimary(self, uuid=None, *args, **kw): def expectPrimary(self, uuid=None, *args, **kw):
def callback(last_try): def callback(last_try):
...@@ -686,8 +681,7 @@ class NEOCluster(object): ...@@ -686,8 +681,7 @@ class NEOCluster(object):
return current_try, current_try return current_try, current_try
self.expectCondition(callback, *args, **kw) self.expectCondition(callback, *args, **kw)
def expectStorageNotKnown(self, process, *args, **kw): def expectStorageUnknown(self, process, *args, **kw):
# /!\ Not Known != Unknown
process_uuid = process.getUUID() process_uuid = process.getUUID()
def expected_storage_not_known(last_try): def expected_storage_not_known(last_try):
for storage in self.getStorageList(): for storage in self.getStorageList():
......
...@@ -48,7 +48,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -48,7 +48,7 @@ class ClusterTests(NEOFunctionalTest):
neo.stop() neo.stop()
neo.run(except_storages=(s2, )) neo.run(except_storages=(s2, ))
neo.expectPending(s1) neo.expectPending(s1)
neo.expectUnknown(s2) neo.expectDown(s2)
neo.expectClusterRecovering() neo.expectClusterRecovering()
# Starting missing storage allows cluster to exit Recovery without # Starting missing storage allows cluster to exit Recovery without
# neoctl action. # neoctl action.
...@@ -61,11 +61,11 @@ class ClusterTests(NEOFunctionalTest): ...@@ -61,11 +61,11 @@ class ClusterTests(NEOFunctionalTest):
neo.stop() neo.stop()
neo.run(except_storages=(s2, )) neo.run(except_storages=(s2, ))
neo.expectPending(s1) neo.expectPending(s1)
neo.expectUnknown(s2) neo.expectDown(s2)
neo.expectClusterRecovering() neo.expectClusterRecovering()
neo.startCluster() neo.startCluster()
neo.expectRunning(s1) neo.expectRunning(s1)
neo.expectUnknown(s2) neo.expectDown(s2)
neo.expectClusterRunning() neo.expectClusterRunning()
def testClusterBreaks(self): def testClusterBreaks(self):
...@@ -149,20 +149,20 @@ class ClusterTests(NEOFunctionalTest): ...@@ -149,20 +149,20 @@ class ClusterTests(NEOFunctionalTest):
) )
storages = self.neo.getStorageProcessList() storages = self.neo.getStorageProcessList()
self.neo.run(except_storages=storages) self.neo.run(except_storages=storages)
self.neo.expectStorageNotKnown(storages[0]) self.neo.expectStorageUnknown(storages[0])
self.neo.expectStorageNotKnown(storages[1]) self.neo.expectStorageUnknown(storages[1])
storages[0].start() storages[0].start()
self.neo.expectPending(storages[0]) self.neo.expectPending(storages[0])
self.neo.expectStorageNotKnown(storages[1]) self.neo.expectStorageUnknown(storages[1])
storages[1].start() storages[1].start()
self.neo.expectPending(storages[0]) self.neo.expectPending(storages[0])
self.neo.expectPending(storages[1]) self.neo.expectPending(storages[1])
storages[0].stop() storages[0].stop()
self.neo.expectUnavailable(storages[0]) self.neo.expectDown(storages[0])
self.neo.expectPending(storages[1]) self.neo.expectPending(storages[1])
storages[1].stop() storages[1].stop()
self.neo.expectUnavailable(storages[0]) self.neo.expectDown(storages[0])
self.neo.expectUnavailable(storages[1]) self.neo.expectDown(storages[1])
def test_suite(): def test_suite():
return unittest.makeSuite(ClusterTests) return unittest.makeSuite(ClusterTests)
......
...@@ -59,7 +59,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -59,7 +59,7 @@ class MasterTests(NEOFunctionalTest):
self.assertEqual(len(killed_uuid_list), 1) self.assertEqual(len(killed_uuid_list), 1)
uuid = killed_uuid_list[0] uuid = killed_uuid_list[0]
# Check the state of the primary we just killed # Check the state of the primary we just killed
self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN)) self.neo.expectMasterState(uuid, (None, NodeStates.DOWN))
# BUG: The following check expects neoctl to reconnect before # BUG: The following check expects neoctl to reconnect before
# the election finishes. # the election finishes.
self.assertEqual(self.neo.getPrimary(), None) self.assertEqual(self.neo.getPrimary(), None)
...@@ -77,12 +77,13 @@ class MasterTests(NEOFunctionalTest): ...@@ -77,12 +77,13 @@ class MasterTests(NEOFunctionalTest):
killed_uuid_list = self.neo.killSecondaryMaster() killed_uuid_list = self.neo.killSecondaryMaster()
# Test sanity checks. # Test sanity checks.
self.assertEqual(len(killed_uuid_list), 1) self.assertEqual(len(killed_uuid_list), 1)
self.neo.expectMasterState(killed_uuid_list[0], None) self.neo.expectMasterState(killed_uuid_list[0],
self.assertEqual(len(self.neo.getMasterList()), 2) NodeStates.DOWN)
self.assertEqual(len(self.neo.getMasterList()), MASTER_NODE_COUNT)
uuid, = self.neo.killPrimary() uuid, = self.neo.killPrimary()
# Check the state of the primary we just killed # Check the state of the primary we just killed
self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN)) self.neo.expectMasterState(uuid, NodeStates.DOWN)
# Check that a primary master arose. # Check that a primary master arose.
self.neo.expectPrimary(timeout=10) self.neo.expectPrimary(timeout=10)
# Check that the uuid really changed. # Check that the uuid really changed.
......
...@@ -168,7 +168,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -168,7 +168,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.neoctl.killNode(started[0].getUUID()) self.neo.neoctl.killNode(started[0].getUUID())
# Cluster still operational. All cells of first storage should be # Cluster still operational. All cells of first storage should be
# outdated. # outdated.
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectOudatedCells(2) self.neo.expectOudatedCells(2)
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
...@@ -177,7 +177,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -177,7 +177,7 @@ class StorageTests(NEOFunctionalTest):
started[1].stop() started[1].stop()
# Cluster not operational anymore. Only cells of second storage that # Cluster not operational anymore. Only cells of second storage that
# were shared with the third one should become outdated. # were shared with the third one should become outdated.
self.neo.expectUnavailable(started[1]) self.neo.expectDown(started[1])
self.neo.expectClusterRecovering() self.neo.expectClusterRecovering()
self.neo.expectOudatedCells(3) self.neo.expectOudatedCells(3)
...@@ -198,7 +198,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -198,7 +198,7 @@ class StorageTests(NEOFunctionalTest):
# stop it, the cluster must switch to verification # stop it, the cluster must switch to verification
started[0].stop() started[0].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectClusterRecovering() self.neo.expectClusterRecovering()
# client must have been disconnected # client must have been disconnected
self.assertEqual(len(self.neo.getClientlist()), 0) self.assertEqual(len(self.neo.getClientlist()), 0)
...@@ -224,7 +224,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -224,7 +224,7 @@ class StorageTests(NEOFunctionalTest):
# stop one storage, cluster must remains running # stop one storage, cluster must remains running
started[0].stop() started[0].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectRunning(started[1]) self.neo.expectRunning(started[1])
self.neo.expectRunning(started[2]) self.neo.expectRunning(started[2])
self.neo.expectOudatedCells(number=10) self.neo.expectOudatedCells(number=10)
...@@ -232,17 +232,17 @@ class StorageTests(NEOFunctionalTest): ...@@ -232,17 +232,17 @@ class StorageTests(NEOFunctionalTest):
# stop a second storage, cluster is still running # stop a second storage, cluster is still running
started[1].stop() started[1].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectUnavailable(started[1]) self.neo.expectDown(started[1])
self.neo.expectRunning(started[2]) self.neo.expectRunning(started[2])
self.neo.expectOudatedCells(number=20) self.neo.expectOudatedCells(number=20)
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
# stop the last, cluster died # stop the last, cluster died
started[2].stop() started[2].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectUnavailable(started[1]) self.neo.expectDown(started[1])
self.neo.expectUnavailable(started[2]) self.neo.expectDown(started[2])
self.neo.expectOudatedCells(number=20) self.neo.expectOudatedCells(number=20)
self.neo.expectClusterRecovering() self.neo.expectClusterRecovering()
...@@ -312,7 +312,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -312,7 +312,7 @@ class StorageTests(NEOFunctionalTest):
# kill one storage, it should be set as unavailable # kill one storage, it should be set as unavailable
started[0].stop() started[0].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectRunning(started[1]) self.neo.expectRunning(started[1])
# and the partition table must not change # and the partition table must not change
self.neo.expectAssignedCells(started[0], 10) self.neo.expectAssignedCells(started[0], 10)
...@@ -320,7 +320,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -320,7 +320,7 @@ class StorageTests(NEOFunctionalTest):
# ask neoctl to drop it # ask neoctl to drop it
self.neo.neoctl.dropNode(started[0].getUUID()) self.neo.neoctl.dropNode(started[0].getUUID())
self.neo.expectStorageNotKnown(started[0]) self.neo.expectStorageUnknown(started[0])
self.neo.expectAssignedCells(started[0], 0) self.neo.expectAssignedCells(started[0], 0)
self.neo.expectAssignedCells(started[1], 10) self.neo.expectAssignedCells(started[1], 10)
self.assertRaises(RuntimeError, self.neo.neoctl.dropNode, self.assertRaises(RuntimeError, self.neo.neoctl.dropNode,
...@@ -335,7 +335,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -335,7 +335,7 @@ class StorageTests(NEOFunctionalTest):
(started, stopped) = self.__setup(storage_number=2, replicas=1, (started, stopped) = self.__setup(storage_number=2, replicas=1,
pending_number=1, partitions=10) pending_number=1, partitions=10)
self.neo.expectRunning(started[0]) self.neo.expectRunning(started[0])
self.neo.expectStorageNotKnown(stopped[0]) self.neo.expectStorageUnknown(stopped[0])
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
# populate the cluster with some data # populate the cluster with some data
...@@ -362,7 +362,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -362,7 +362,7 @@ class StorageTests(NEOFunctionalTest):
# kill the first storage # kill the first storage
started[0].stop() started[0].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectOudatedCells(number=10) self.neo.expectOudatedCells(number=10)
self.neo.expectAssignedCells(started[0], 10) self.neo.expectAssignedCells(started[0], 10)
self.neo.expectAssignedCells(stopped[0], 10) self.neo.expectAssignedCells(stopped[0], 10)
...@@ -371,7 +371,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -371,7 +371,7 @@ class StorageTests(NEOFunctionalTest):
# drop it from partition table # drop it from partition table
self.neo.neoctl.dropNode(started[0].getUUID()) self.neo.neoctl.dropNode(started[0].getUUID())
self.neo.expectStorageNotKnown(started[0]) self.neo.expectStorageUnknown(started[0])
self.neo.expectRunning(stopped[0]) self.neo.expectRunning(stopped[0])
self.neo.expectAssignedCells(started[0], 0) self.neo.expectAssignedCells(started[0], 0)
self.neo.expectAssignedCells(stopped[0], 10) self.neo.expectAssignedCells(stopped[0], 10)
...@@ -395,12 +395,12 @@ class StorageTests(NEOFunctionalTest): ...@@ -395,12 +395,12 @@ class StorageTests(NEOFunctionalTest):
# drop the first then the second storage # drop the first then the second storage
started[0].stop() started[0].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectRunning(started[1]) self.neo.expectRunning(started[1])
self.neo.expectOudatedCells(number=10) self.neo.expectOudatedCells(number=10)
started[1].stop() started[1].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectDown(started[0])
self.neo.expectUnavailable(started[1]) self.neo.expectDown(started[1])
self.neo.expectOudatedCells(number=10) self.neo.expectOudatedCells(number=10)
self.neo.expectClusterRecovering() self.neo.expectClusterRecovering()
# XXX: need to sync with storages first # XXX: need to sync with storages first
...@@ -409,7 +409,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -409,7 +409,7 @@ class StorageTests(NEOFunctionalTest):
# restart the cluster with the first storage killed # restart the cluster with the first storage killed
self.neo.run(except_storages=[started[1]]) self.neo.run(except_storages=[started[1]])
self.neo.expectPending(started[0]) self.neo.expectPending(started[0])
self.neo.expectUnknown(started[1]) self.neo.expectDown(started[1])
self.neo.expectClusterRecovering() self.neo.expectClusterRecovering()
# Cluster doesn't know there are outdated cells # Cluster doesn't know there are outdated cells
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
......
#
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from neo.lib import protocol
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.app import Application
from neo.master.handlers.election import ClientElectionHandler, \
ServerElectionHandler
from neo.lib.exception import ElectionFailure
from neo.lib.connection import ClientConnection
class MasterClientElectionTestBase(NeoUnitTestBase):
def setUp(self):
super(MasterClientElectionTestBase, self).setUp()
self._master_port = 3001
def identifyToMasterNode(self):
node = self.app.nm.createMaster(uuid=self.getMasterUUID())
node.setAddress((self.local_ip, self._master_port))
self._master_port += 1
conn = self.getFakeConnection(
uuid=node.getUUID(),
address=node.getAddress(),
)
node.setConnection(conn)
return (node, conn)
def checkAcceptIdentification(self, conn):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification)
class MasterClientElectionTests(MasterClientElectionTestBase):
def setUp(self):
super(MasterClientElectionTests, self).setUp()
# create an application object
config = self.getMasterConfiguration(master_number=1)
self.app = Application(config)
self.app.em.close()
self.app.pt.clear()
self.app.em = Mock()
self.app.uuid = self.getMasterUUID()
self.app.server = (self.local_ip, 10000)
self.app.name = 'NEOCLUSTER'
self.election = ClientElectionHandler(self.app)
self.app.unconnected_master_node_set = set()
self.app.negotiating_master_node_set = set()
def _checkUnconnected(self, node):
addr = node.getAddress()
self.assertFalse(addr in self.app.negotiating_master_node_set)
def test_connectionFailed(self):
node, conn = self.identifyToMasterNode()
self.assertTrue(node.isUnknown())
self._checkUnconnected(node)
self.election.connectionFailed(conn)
self._checkUnconnected(node)
self.assertTrue(node.isUnknown())
def test_connectionCompleted(self):
node, conn = self.identifyToMasterNode()
self.assertTrue(node.isUnknown())
self._checkUnconnected(node)
self.election.connectionCompleted(conn)
self._checkUnconnected(node)
self.assertTrue(node.isUnknown())
self.checkAskPacket(conn, Packets.RequestIdentification)
def _setNegociating(self, node):
self._checkUnconnected(node)
addr = node.getAddress()
self.app.negotiating_master_node_set.add(addr)
def test_connectionClosed(self):
node, conn = self.identifyToMasterNode()
self._setNegociating(node)
self.election.connectionClosed(conn)
self.assertTrue(node.isUnknown())
addr = node.getAddress()
self.assertFalse(addr in self.app.negotiating_master_node_set)
def test_acceptIdentification1(self):
""" A non-master node accept identification """
node, conn = self.identifyToMasterNode()
args = (node.getUUID(), 0, 10, self.app.uuid, None,
self._getMasterList())
self.election.acceptIdentification(conn,
NodeTypes.CLIENT, *args)
self.assertFalse(node in self.app.negotiating_master_node_set)
self.checkClosed(conn)
def test_acceptIdentificationDoesNotKnowPrimary(self):
master1, master1_conn = self.identifyToMasterNode()
master1_uuid = master1.getUUID()
self.election.acceptIdentification(
master1_conn,
NodeTypes.MASTER,
master1_uuid,
1,
0,
self.app.uuid,
None,
[(master1.getAddress(), master1_uuid)],
)
self.assertEqual(self.app.primary_master_node, None)
def test_acceptIdentificationKnowsPrimary(self):
master1, master1_conn = self.identifyToMasterNode()
master1_uuid = master1.getUUID()
primary1 = master1.getAddress()
self.election.acceptIdentification(
master1_conn,
NodeTypes.MASTER,
master1_uuid,
1,
0,
self.app.uuid,
primary1,
[(master1.getAddress(), master1_uuid)],
)
self.assertNotEqual(self.app.primary_master_node, None)
def test_acceptIdentificationMultiplePrimaries(self):
master1, master1_conn = self.identifyToMasterNode()
master2, master2_conn = self.identifyToMasterNode()
master3, _ = self.identifyToMasterNode()
master1_uuid = master1.getUUID()
master2_uuid = master2.getUUID()
master3_uuid = master3.getUUID()
primary1 = master1.getAddress()
primary3 = master3.getAddress()
master1_address = master1.getAddress()
master2_address = master2.getAddress()
master3_address = master3.getAddress()
self.election.acceptIdentification(
master1_conn,
NodeTypes.MASTER,
master1_uuid,
1,
0,
self.app.uuid,
primary1,
[(master1_address, master1_uuid)],
)
self.assertRaises(ElectionFailure, self.election.acceptIdentification,
master2_conn,
NodeTypes.MASTER,
master2_uuid,
1,
0,
self.app.uuid,
primary3,
[
(master1_address, master1_uuid),
(master2_address, master2_uuid),
(master3_address, master3_uuid),
],
)
def test_acceptIdentification3(self):
""" Identification accepted """
node, conn = self.identifyToMasterNode()
args = (node.getUUID(), 0, 10, self.app.uuid, None,
self._getMasterList())
self.election.acceptIdentification(conn, NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID())
self.assertEqual(self.app.primary is False,
self.app.server < node.getAddress())
self.assertFalse(node in self.app.negotiating_master_node_set)
def _getMasterList(self, with_node=None):
master_list = self.app.nm.getMasterList()
return [(x.getAddress(), x.getUUID()) for x in master_list]
class MasterServerElectionTests(MasterClientElectionTestBase):
def setUp(self):
super(MasterServerElectionTests, self).setUp()
# create an application object
config = self.getMasterConfiguration(master_number=1)
self.app = Application(config)
self.app.em.close()
self.app.pt.clear()
self.app.name = 'NEOCLUSTER'
self.app.em = Mock()
self.election = ServerElectionHandler(self.app)
self.app.unconnected_master_node_set = set()
self.app.negotiating_master_node_set = set()
for node in self.app.nm.getMasterList():
node.setState(NodeStates.RUNNING)
# define some variable to simulate client and storage node
self.client_address = (self.local_ip, 1000)
self.storage_address = (self.local_ip, 2000)
self.master_address = (self.local_ip, 3000)
def test_requestIdentification1(self):
""" A non-master node request identification """
node, conn = self.identifyToMasterNode()
args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.NotReadyError,
self.election.requestIdentification,
conn, NodeTypes.CLIENT, *args)
def test_requestIdentification3(self):
""" A broken master node request identification """
node, conn = self.identifyToMasterNode()
node.setBroken()
args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.BrokenNodeDisallowedError,
self.election.requestIdentification,
conn, NodeTypes.MASTER, *args)
def test_requestIdentification4(self):
""" No conflict """
node, conn = self.identifyToMasterNode()
args = node.getUUID(), node.getAddress(), self.app.name, None
self.election.requestIdentification(conn,
NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID())
(node_type, uuid, partitions, replicas, new_uuid, primary_uuid,
master_list) = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid)
def __getClient(self):
uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid, address=self.client_address)
self.app.nm.createClient(uuid=uuid, address=self.client_address)
return conn
def testRequestIdentification1(self):
""" Check with a non-master node, must be refused """
conn = self.__getClient()
self.checkNotReadyErrorRaised(
self.election.requestIdentification,
conn,
NodeTypes.CLIENT,
conn.getUUID(),
conn.getAddress(),
self.app.name,
None,
)
def _requestIdentification(self):
conn = self.getFakeConnection()
peer_uuid = self.getMasterUUID()
address = (self.local_ip, 2001)
self.election.requestIdentification(
conn,
NodeTypes.MASTER,
peer_uuid,
address,
self.app.name,
None,
)
node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node_type, NodeTypes.MASTER)
self.assertEqual(uuid, self.app.uuid)
self.assertEqual(partitions, self.app.pt.getPartitions())
self.assertEqual(replicas, self.app.pt.getReplicas())
self.assertTrue(address in [x[0] for x in master_list])
self.assertTrue(self.app.server in [x[0] for x in master_list])
self.assertEqual(peer_uuid, _peer_uuid)
return primary
def testRequestIdentificationDoesNotKnowPrimary(self):
self.app.primary = False
self.app.primary_master_node = None
self.assertEqual(self._requestIdentification(), None)
def testRequestIdentificationKnowsPrimary(self):
self.app.primary = False
primary = (self.local_ip, 3000)
self.app.primary_master_node = Mock({
'getAddress': primary,
})
self.assertEqual(self._requestIdentification(), primary)
def testRequestIdentificationIsPrimary(self):
self.app.primary = True
primary = self.app.server
self.app.primary_master_node = Mock({
'getAddress': primary,
})
self.assertEqual(self._requestIdentification(), primary)
def test_reelectPrimary(self):
node, conn = self.identifyToMasterNode()
self.assertRaises(ElectionFailure, self.election.reelectPrimary, conn)
if __name__ == '__main__':
unittest.main()
...@@ -63,29 +63,24 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -63,29 +63,24 @@ class MasterPartitionTableTests(NeoUnitTestBase):
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19004) server4 = ("127.0.0.4", 19004)
sn4 = self.createStorage(server4, uuid4) sn4 = self.createStorage(server4, uuid4)
uuid5 = self.getStorageUUID()
server5 = ("127.0.0.5", 19005)
sn5 = self.createStorage(server5, uuid5)
# create partition table # create partition table
num_partitions = 5 num_partitions = 4
num_replicas = 3 num_replicas = 3
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
pt._setCell(0, sn1, CellStates.OUT_OF_DATE) pt._setCell(0, sn1, CellStates.OUT_OF_DATE)
sn1.setState(NodeStates.RUNNING) sn1.setState(NodeStates.RUNNING)
pt._setCell(1, sn2, CellStates.UP_TO_DATE) pt._setCell(1, sn2, CellStates.UP_TO_DATE)
sn2.setState(NodeStates.TEMPORARILY_DOWN) sn2.setState(NodeStates.DOWN)
pt._setCell(2, sn3, CellStates.UP_TO_DATE) pt._setCell(2, sn3, CellStates.UP_TO_DATE)
sn3.setState(NodeStates.DOWN) sn3.setState(NodeStates.UNKNOWN)
pt._setCell(3, sn4, CellStates.UP_TO_DATE) pt._setCell(3, sn4, CellStates.UP_TO_DATE)
sn4.setState(NodeStates.BROKEN) sn4.setState(NodeStates.RUNNING)
pt._setCell(4, sn5, CellStates.UP_TO_DATE)
sn5.setState(NodeStates.RUNNING)
# outdate nodes # outdate nodes
cells_outdated = pt.outdate() cells_outdated = pt.outdate()
self.assertEqual(len(cells_outdated), 3) self.assertEqual(len(cells_outdated), 2)
for offset, uuid, state in cells_outdated: for offset, uuid, state in cells_outdated:
self.assertTrue(offset in (1, 2, 3)) self.assertIn(offset, (1, 2))
self.assertTrue(uuid in (uuid2, uuid3, uuid4)) self.assertIn(uuid, (uuid2, uuid3))
self.assertEqual(state, CellStates.OUT_OF_DATE) self.assertEqual(state, CellStates.OUT_OF_DATE)
# check each cell # check each cell
# part 1, already outdated # part 1, already outdated
...@@ -103,15 +98,10 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -103,15 +98,10 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.assertEqual(len(cells), 1) self.assertEqual(len(cells), 1)
cell = cells[0] cell = cells[0]
self.assertEqual(cell.getState(), CellStates.OUT_OF_DATE) self.assertEqual(cell.getState(), CellStates.OUT_OF_DATE)
# part 4, already outdated # part 4, remains running
cells = pt.getCellList(3) cells = pt.getCellList(3)
self.assertEqual(len(cells), 1) self.assertEqual(len(cells), 1)
cell = cells[0] cell = cells[0]
self.assertEqual(cell.getState(), CellStates.OUT_OF_DATE)
# part 5, remains running
cells = pt.getCellList(4)
self.assertEqual(len(cells), 1)
cell = cells[0]
self.assertEqual(cell.getState(), CellStates.UP_TO_DATE) self.assertEqual(cell.getState(), CellStates.UP_TO_DATE)
def test_15_dropNodeList(self): def test_15_dropNodeList(self):
...@@ -156,7 +146,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -156,7 +146,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = self.createStorage(server2, uuid2) sn2 = self.createStorage(server2, uuid2)
sn2.setState(NodeStates.TEMPORARILY_DOWN) sn2.setState(NodeStates.DOWN)
# add node without uuid # add node without uuid
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = self.createStorage(server3, None, NodeStates.RUNNING) sn3 = self.createStorage(server3, None, NodeStates.RUNNING)
......
...@@ -94,7 +94,7 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -94,7 +94,7 @@ class MasterRecoveryTests(NeoUnitTestBase):
conn = self.getFakeConnection(uuid, self.storage_port) conn = self.getFakeConnection(uuid, self.storage_port)
offset = 1000000 offset = 1000000
self.assertFalse(self.app.pt.hasOffset(offset)) self.assertFalse(self.app.pt.hasOffset(offset))
cell_list = [(offset, ((uuid, NodeStates.DOWN,),),)] cell_list = [(offset, ((uuid, NodeStates.UNKNOWN,),),)]
node.setPending() node.setPending()
self.checkProtocolErrorRaised(recovery.answerPartitionTable, conn, self.checkProtocolErrorRaised(recovery.answerPartitionTable, conn,
2, cell_list) 2, cell_list)
......
#
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, BrokenNodeDisallowedError
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.identification import IdentificationHandler
class StorageIdentificationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.app.name = 'NEO'
self.app.operational = True
self.app.pt = PartitionTable(4, 1)
self.identification = IdentificationHandler(self.app)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageIdentificationHandlerTests, self)._tearDown(success)
def test_requestIdentification3(self):
""" broken nodes must be rejected """
uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid)
node = self.app.nm.createClient(uuid=uuid)
node.setBroken()
self.assertRaises(BrokenNodeDisallowedError,
self.identification.requestIdentification,
conn,
NodeTypes.CLIENT,
uuid,
None,
self.app.name,
None,
)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from . import NeoUnitTestBase
from neo.storage.app import Application
from neo.lib.bootstrap import BootstrapManager
from neo.lib.protocol import NodeTypes, Packets
class BootstrapManagerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration()
self.app = Application(config)
self.bootstrap = BootstrapManager(self.app, NodeTypes.STORAGE)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.num_partitions = 1009
self.num_replicas = 2
def _tearDown(self, success):
self.app.close()
del self.app
super(BootstrapManagerTests, self)._tearDown(success)
# Tests
def testConnectionCompleted(self):
address = ("127.0.0.1", self.master_port)
conn = self.getFakeConnection(address=address)
self.bootstrap.current = self.app.nm.createMaster(address=address)
self.bootstrap.connectionCompleted(conn)
self.checkAskPacket(conn, Packets.RequestIdentification)
def testHandleNotReady(self):
# the primary is not ready
address = ("127.0.0.1", self.master_port)
conn = self.getFakeConnection(address=address)
self.bootstrap.current = self.app.nm.createMaster(address=address)
self.bootstrap.notReady(conn, '')
self.checkClosed(conn)
self.checkNoPacketSent(conn)
if __name__ == "__main__":
unittest.main()
...@@ -21,7 +21,7 @@ from neo.lib import connection, logging ...@@ -21,7 +21,7 @@ from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \ from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, CRITICAL_TIMEOUT MTClientConnection, CRITICAL_TIMEOUT
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets from neo.lib.protocol import ENCODED_VERSION, Packets
from . import NeoUnitTestBase, Patch from . import NeoUnitTestBase, Patch
...@@ -73,6 +73,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -73,6 +73,7 @@ class ConnectionTests(NeoUnitTestBase):
# don't accept any other packet without specifying a queue. # don't accept any other packet without specifying a queue.
self.handler = EventHandler(self.app) self.handler = EventHandler(self.app)
conn = self._makeClientConnection() conn = self._makeClientConnection()
conn.read_buf.append(ENCODED_VERSION)
use_case_list = ( use_case_list = (
# (a) For a single packet sent at T, # (a) For a single packet sent at T,
......
...@@ -19,7 +19,7 @@ from .mock import Mock ...@@ -19,7 +19,7 @@ from .mock import Mock
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import PacketMalformedError, UnexpectedPacketError, \ from neo.lib.protocol import PacketMalformedError, UnexpectedPacketError, \
BrokenNodeDisallowedError, NotReadyError, ProtocolError NotReadyError, ProtocolError
class HandlerTests(NeoUnitTestBase): class HandlerTests(NeoUnitTestBase):
...@@ -60,14 +60,6 @@ class HandlerTests(NeoUnitTestBase): ...@@ -60,14 +60,6 @@ class HandlerTests(NeoUnitTestBase):
self.setFakeMethod(fake) self.setFakeMethod(fake)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkClosed(conn) self.checkClosed(conn)
# raise BrokenNodeDisallowedError
conn.mockCalledMethods = {}
def fake(c):
raise BrokenNodeDisallowedError
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn)
self.checkAborted(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c): def fake(c):
......
...@@ -35,7 +35,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -35,7 +35,7 @@ class NodesTests(NeoUnitTestBase):
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
uuid = self.getNewUUID(None) uuid = self.getNewUUID(None)
node = Node(self.nm, address=address, uuid=uuid) node = Node(self.nm, address=address, uuid=uuid)
self.assertEqual(node.getState(), NodeStates.UNKNOWN) self.assertEqual(node.getState(), NodeStates.DOWN)
self.assertEqual(node.getAddress(), address) self.assertEqual(node.getAddress(), address)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertTrue(time() - 1 < node.getLastStateChange() < time()) self.assertTrue(time() - 1 < node.getLastStateChange() < time())
...@@ -43,7 +43,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -43,7 +43,7 @@ class NodesTests(NeoUnitTestBase):
def testState(self): def testState(self):
""" Check if the last changed time is updated when state is changed """ """ Check if the last changed time is updated when state is changed """
node = Node(self.nm) node = Node(self.nm)
self.assertEqual(node.getState(), NodeStates.UNKNOWN) self.assertEqual(node.getState(), NodeStates.DOWN)
self.assertTrue(time() - 1 < node.getLastStateChange() < time()) self.assertTrue(time() - 1 < node.getLastStateChange() < time())
previous_time = node.getLastStateChange() previous_time = node.getLastStateChange()
node.setState(NodeStates.RUNNING) node.setState(NodeStates.RUNNING)
...@@ -156,15 +156,15 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -156,15 +156,15 @@ class NodeManagerTests(NeoUnitTestBase):
old_uuid = self.storage.getUUID() old_uuid = self.storage.getUUID()
new_uuid = self.getStorageUUID() new_uuid = self.getStorageUUID()
node_list = ( node_list = (
(NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN, None), (NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.UNKNOWN, None),
(NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING, None), (NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING, None),
(NodeTypes.STORAGE, self.storage.getAddress(), new_uuid, (NodeTypes.STORAGE, self.storage.getAddress(), new_uuid,
NodeStates.RUNNING, None), NodeStates.RUNNING, None),
(NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(), (NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(),
NodeStates.UNKNOWN, None), NodeStates.DOWN, None),
) )
app = Mock() app = Mock()
app.pt = Mock() app.pt = Mock({'dropNode': True})
# update manager content # update manager content
manager.update(app, time(), node_list) manager.update(app, time(), node_list)
# - the client gets down # - the client gets down
...@@ -180,9 +180,9 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -180,9 +180,9 @@ class NodeManagerTests(NeoUnitTestBase):
new_storage = storage_list[0] new_storage = storage_list[0]
self.assertNotEqual(new_storage.getUUID(), old_uuid) self.assertNotEqual(new_storage.getUUID(), old_uuid)
self.assertEqual(new_storage.getState(), NodeStates.RUNNING) self.assertEqual(new_storage.getState(), NodeStates.RUNNING)
# admin is still here but in UNKNOWN state # admin is still here but in DOWN state
self.checkNodes([self.master, self.admin, new_storage]) self.checkNodes([self.master, self.admin, new_storage])
self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN) self.assertEqual(self.admin.getState(), NodeStates.DOWN)
class MasterDBTests(NeoUnitTestBase): class MasterDBTests(NeoUnitTestBase):
...@@ -195,7 +195,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -195,7 +195,7 @@ class MasterDBTests(NeoUnitTestBase):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
directory = join(temp_dir, 'read_only') directory = join(temp_dir, 'read_only')
db_file = join(directory, 'not_created') db_file = join(directory, 'not_created')
mkdir(directory, 0400) mkdir(directory, 0500)
try: try:
self.assertRaises(IOError, MasterDB, db_file) self.assertRaises(IOError, MasterDB, db_file)
finally: finally:
...@@ -212,17 +212,17 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -212,17 +212,17 @@ class MasterDBTests(NeoUnitTestBase):
try: try:
db = MasterDB(db_file) db = MasterDB(db_file)
self.assertTrue(exists(db_file), db_file) self.assertTrue(exists(db_file), db_file)
chmod(db_file, 0400) chmod(directory, 0500)
address = ('example.com', 1024) address = ('example.com', 1024)
# Must not raise # Must not raise
db.add(address) db.addremove(None, address)
# Value is stored # Value is stored
self.assertTrue(address in db, [x for x in db]) self.assertIn(address, db)
# But not visible to a new db instance (write access restored so # But not visible to a new db instance (write access restored so
# it can be created) # it can be created)
chmod(db_file, 0600) chmod(directory, 0700)
db2 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertFalse(address in db2, [x for x in db2]) self.assertNotIn(address, db2)
finally: finally:
shutil.rmtree(directory) shutil.rmtree(directory)
...@@ -235,18 +235,21 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -235,18 +235,21 @@ class MasterDBTests(NeoUnitTestBase):
db = MasterDB(db_file) db = MasterDB(db_file)
self.assertTrue(exists(db_file), db_file) self.assertTrue(exists(db_file), db_file)
address = ('example.com', 1024) address = ('example.com', 1024)
db.add(address) db.addremove(None, address)
address2 = ('example.org', 1024) address2 = ('example.org', 1024)
db.add(address2) db.addremove(None, address2)
# Values are visible to a new db instance # Values are visible to a new db instance
db2 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertTrue(address in db2, [x for x in db2]) self.assertIn(address, db2)
self.assertTrue(address2 in db2, [x for x in db2]) self.assertIn(address2, db2)
db.discard(address) db.addremove(address, None)
# Create yet another instance (file is not supposed to be shared) # Create yet another instance (file is not supposed to be shared)
db3 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertFalse(address in db3, [x for x in db3]) self.assertNotIn(address, db2)
self.assertTrue(address2 in db3, [x for x in db3]) self.assertIn(address2, db2)
db.remove(address2)
# and again, to test remove()
self.assertNotIn(address2, MasterDB(db_file))
finally: finally:
shutil.rmtree(directory) shutil.rmtree(directory)
......
...@@ -34,7 +34,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -34,7 +34,7 @@ class PartitionTableTests(NeoUnitTestBase):
# check getter # check getter
self.assertEqual(cell.getNode(), sn) self.assertEqual(cell.getNode(), sn)
self.assertEqual(cell.getState(), CellStates.OUT_OF_DATE) self.assertEqual(cell.getState(), CellStates.OUT_OF_DATE)
self.assertEqual(cell.getNodeState(), NodeStates.UNKNOWN) self.assertEqual(cell.getNodeState(), NodeStates.DOWN)
self.assertEqual(cell.getUUID(), uuid) self.assertEqual(cell.getUUID(), uuid)
self.assertEqual(cell.getAddress(), server) self.assertEqual(cell.getAddress(), server)
# check state setter # check state setter
...@@ -104,18 +104,12 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -104,18 +104,12 @@ class PartitionTableTests(NeoUnitTestBase):
else: else:
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
# now add broken and down state, must not be taken into account # now add down state, must not be taken into account
pt._setCell(0, sn1, CellStates.DISCARDED) pt._setCell(0, sn1, CellStates.DISCARDED)
for x in xrange(num_partitions): for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
self.assertEqual(pt.count_dict[sn1], 0) self.assertEqual(pt.count_dict[sn1], 0)
sn1.setState(NodeStates.BROKEN) sn1.setState(NodeStates.UNKNOWN)
self.assertRaises(PartitionTableException, pt._setCell,
0, sn1, CellStates.UP_TO_DATE)
for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0)
self.assertEqual(pt.count_dict[sn1], 0)
sn1.setState(NodeStates.DOWN)
self.assertRaises(PartitionTableException, pt._setCell, self.assertRaises(PartitionTableException, pt._setCell,
0, sn1, CellStates.UP_TO_DATE) 0, sn1, CellStates.UP_TO_DATE)
for x in xrange(num_partitions): for x in xrange(num_partitions):
...@@ -331,7 +325,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -331,7 +325,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
sn1 = createStorage() sn1 = createStorage()
sn1.setState(NodeStates.TEMPORARILY_DOWN) sn1.setState(NodeStates.DOWN)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt._setCell(x, sn1, CellStates.FEEDING) pt._setCell(x, sn1, CellStates.FEEDING)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
......
...@@ -35,7 +35,7 @@ from neo.lib.connection import BaseConnection, \ ...@@ -35,7 +35,7 @@ from neo.lib.connection import BaseConnection, \
from neo.lib.connector import SocketConnector, ConnectorException from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER
...@@ -339,10 +339,10 @@ class ServerNode(Node): ...@@ -339,10 +339,10 @@ class ServerNode(Node):
if not address: if not address:
address = self.newAddress() address = self.newAddress()
if cluster is None: if cluster is None:
master_nodes = kw.get('master_nodes', ()) master_nodes = ()
name = kw.get('name', 'test') name = kw.get('name', 'test')
else: else:
master_nodes = kw.get('master_nodes', cluster.master_nodes) master_nodes = cluster.master_nodes
name = kw.get('name', cluster.name) name = kw.get('name', cluster.name)
port = address[1] port = address[1]
if address is not BIND: if address is not BIND:
...@@ -354,7 +354,7 @@ class ServerNode(Node): ...@@ -354,7 +354,7 @@ class ServerNode(Node):
self.daemon = True self.daemon = True
self.node_name = '%s_%u' % (self.node_type, port) self.node_name = '%s_%u' % (self.node_type, port)
kw.update(getCluster=name, getBind=address, kw.update(getCluster=name, getBind=address,
getMasters=master_nodes and parseMasterList(master_nodes, address)) getMasters=master_nodes and parseMasterList(master_nodes))
# -> app.__init__() ; Mock serves as config # -> app.__init__() ; Mock serves as config
super(ServerNode, self).__init__(Mock(kw)) super(ServerNode, self).__init__(Mock(kw))
...@@ -782,27 +782,39 @@ class NEOCluster(object): ...@@ -782,27 +782,39 @@ class NEOCluster(object):
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
self.stop(None) self.stop(None)
def start(self, storage_list=None, fast_startup=False): def start(self, storage_list=None, master_list=None, recovering=False):
self.started = True self.started = True
self._patch() self._patch()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL) self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
for node_type in 'master', 'admin': for node in self.master_list if master_list is None else master_list:
for node in getattr(self, node_type + '_list'): node.start()
for node in self.admin_list:
node.start() node.start()
Serialized.tic() Serialized.tic()
if fast_startup:
self.startCluster()
if storage_list is None: if storage_list is None:
storage_list = self.storage_list storage_list = self.storage_list
for node in storage_list: for node in storage_list:
node.start() node.start()
Serialized.tic() Serialized.tic()
if not fast_startup: if recovering:
expected_state = ClusterStates.RECOVERING
else:
self.startCluster() self.startCluster()
Serialized.tic() Serialized.tic()
expected_state = ClusterStates.RUNNING, ClusterStates.BACKINGUP
self.checkStarted(expected_state, storage_list)
def checkStarted(self, expected_state, storage_list=None):
if isinstance(expected_state, Enum.Item):
expected_state = expected_state,
state = self.neoctl.getClusterState() state = self.neoctl.getClusterState()
assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state assert state in expected_state, state
self.enableStorageList(storage_list) expected_state = (NodeStates.PENDING
if state == ClusterStates.RECOVERING
else NodeStates.RUNNING)
for node in self.storage_list if storage_list is None else storage_list:
state = self.getNodeState(node)
assert state == expected_state, (node, state)
def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw): def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw):
if self.started: if self.started:
......
...@@ -34,9 +34,9 @@ from neo.lib.connection import ConnectionClosed, \ ...@@ -34,9 +34,9 @@ from neo.lib.connection import ConnectionClosed, \
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packet, uuid_str, ZERO_OID, ZERO_TID Packets, Packet, uuid_str, ZERO_OID, ZERO_TID)
from .. import expectedFailure, Patch, TransactionalResource from .. import expectedFailure, unpickle_state, Patch, TransactionalResource
from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \ from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \
RandomConflictDict, ThreadId, with_cluster RandomConflictDict, ThreadId, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
...@@ -552,7 +552,8 @@ class Test(NEOThreadedTest): ...@@ -552,7 +552,8 @@ class Test(NEOThreadedTest):
# restart it with one storage only # restart it with one storage only
if 1: if 1:
cluster.start(storage_list=(s1,)) cluster.start(storage_list=(s1,))
self.assertEqual(NodeStates.UNKNOWN, cluster.getNodeState(s2)) self.assertEqual(NodeStates.DOWN,
cluster.getNodeState(s2))
@with_cluster(storage_count=2, partitions=2, replicas=1) @with_cluster(storage_count=2, partitions=2, replicas=1)
def testRestartStoragesWithReplicas(self, cluster): def testRestartStoragesWithReplicas(self, cluster):
...@@ -838,12 +839,6 @@ class Test(NEOThreadedTest): ...@@ -838,12 +839,6 @@ class Test(NEOThreadedTest):
@with_cluster(master_count=3, partitions=10, replicas=1, storage_count=3) @with_cluster(master_count=3, partitions=10, replicas=1, storage_count=3)
def testShutdown(self, cluster): def testShutdown(self, cluster):
# NOTE vvv # NOTE vvv
# BUG: Due to bugs in election, master nodes sometimes crash, or they
# declare themselves primary too quickly, but issues seem to be
# only reproducible with SSL enabled.
self._testShutdown(cluster)
def _testShutdown(self, cluster):
def before_finish(_): def before_finish(_):
# tell admin to shutdown the cluster # tell admin to shutdown the cluster
cluster.neoctl.setClusterState(ClusterStates.STOPPING) cluster.neoctl.setClusterState(ClusterStates.STOPPING)
...@@ -1226,13 +1221,10 @@ class Test(NEOThreadedTest): ...@@ -1226,13 +1221,10 @@ class Test(NEOThreadedTest):
@with_cluster(start_cluster=0, storage_count=3, autostart=3) @with_cluster(start_cluster=0, storage_count=3, autostart=3)
def testAutostart(self, cluster): def testAutostart(self, cluster):
def startCluster(orig): cluster.start(cluster.storage_list[:2], recovering=True)
getClusterState = cluster.neoctl.getClusterState
self.assertEqual(ClusterStates.RECOVERING, getClusterState())
cluster.storage_list[2].start() cluster.storage_list[2].start()
with Patch(cluster, startCluster=startCluster): self.tic()
self.assertEqual(ClusterStates.RUNNING, getClusterState()) cluster.checkStarted(ClusterStates.RUNNING)
cluster.start(cluster.storage_list[:2])
@with_cluster(storage_count=2, partitions=2) @with_cluster(storage_count=2, partitions=2)
def testAbortVotedTransaction(self, cluster): def testAbortVotedTransaction(self, cluster):
...@@ -1490,11 +1482,11 @@ class Test(NEOThreadedTest): ...@@ -1490,11 +1482,11 @@ class Test(NEOThreadedTest):
reports a conflict after that this conflict was fully resolved with reports a conflict after that this conflict was fully resolved with
another node. another node.
""" """
def answerStoreObject(orig, conn, conflict, oid, serial): def answerStoreObject(orig, conn, conflict, oid):
if not conflict: if not conflict:
p.revert() p.revert()
ll() ll()
orig(conn, conflict, oid, serial) orig(conn, conflict, oid)
if 1: if 1:
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
...@@ -1984,6 +1976,35 @@ class Test(NEOThreadedTest): ...@@ -1984,6 +1976,35 @@ class Test(NEOThreadedTest):
@with_cluster(replicas=1, partitions=4) @with_cluster(replicas=1, partitions=4)
def testNotifyReplicated(self, cluster): def testNotifyReplicated(self, cluster):
"""
Check replication while several concurrent transactions leads to
conflict resolutions and deadlock avoidances, and in particular the
handling of write-locks when the storage node is about to notify the
master that partitions are replicated.
Transactions are committed in the following order:
- t2
- t4, conflict on 'd'
- t1, deadlock on 'a'
- t3, deadlock on 'b', and 2 conflicts on 'a'
Special care is also taken for the change done by t3 on 'a', to check
that the client resolves conflicts with correct oldSerial:
1. The initial store (a=8) is first delayed by t2.
2. It is then kept aside by the deadlock.
3. On s1, deadlock avoidance happens after t1 stores a=7 and the store
is delayed again. However, it's the contrary on s0, and a conflict
is reported to the client.
4. Second store (a=12) based on t2.
5. t1 finishes and s1 reports the conflict for first store (with t1).
At that point, the base serial of this store is meaningless:
the client only has data for last store (based on t2), and it's its
base serial that must be used. t3 write 15 (and not 19 !).
6. Conflicts for the second store are with t2 and they're ignored
because they're already resolved.
Note that this test method lacks code to enforce some events to happen
in the expected order. Sometimes, the above scenario is not reproduced
entirely, but it's so rare that there's no point in making the code
further complicated.
"""
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
s1.stop() s1.stop()
cluster.join((s1,)) cluster.join((s1,))
...@@ -2029,14 +2050,33 @@ class Test(NEOThreadedTest): ...@@ -2029,14 +2050,33 @@ class Test(NEOThreadedTest):
yield 1 yield 1
self.tic() self.tic()
self.assertPartitionTable(cluster, 'UO|UU|UU|UU') self.assertPartitionTable(cluster, 'UO|UU|UU|UU')
def t4_vote(*args, **kw): def t4_d(*args, **kw):
self.tic() self.tic()
self.assertPartitionTable(cluster, 'UU|UU|UU|UU') self.assertPartitionTable(cluster, 'UU|UU|UU|UU')
yield 0 yield 2
# Delay the conflict for the second store of 'a' by t3.
delay_conflict = {s0.uuid: [1], s1.uuid: [1,0]}
def delayConflict(conn, packet):
app = conn.getHandler().app
if (isinstance(packet, Packets.AnswerStoreObject)
and packet.decode()[0]):
conn, = cluster.client.getConnectionList(app)
kw = conn._handlers._pending[0][0][packet._id][3]
return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop()
def writeA(orig, txn_context, oid, serial, data):
if u64(oid) == 1:
value = unpickle_state(data)['value']
if value > 12:
f.remove(delayConflict)
elif value == 12:
f.add(delayConflict)
return orig(txn_context, oid, serial, data)
###
with ConnectionFilter() as f, \ with ConnectionFilter() as f, \
Patch(cluster.client, _store=writeA), \
self.thread_switcher(threads, self.thread_switcher(threads,
(1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0, (1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0,
t1_rebase, 2, t3_b, 3, t4_vote), t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin', 2, 1, 1, ('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin', 2, 1, 1,
3, 3, 4, 4, 3, 1, 'RebaseTransaction', 'RebaseTransaction', 3, 3, 4, 4, 3, 1, 'RebaseTransaction', 'RebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2
...@@ -2128,8 +2168,13 @@ class Test(NEOThreadedTest): ...@@ -2128,8 +2168,13 @@ class Test(NEOThreadedTest):
self.assertEqual([6, 9, 6], [r[x].value for x in 'abc']) self.assertEqual([6, 9, 6], [r[x].value for x in 'abc'])
self.assertEqual([2, 2], map(end.pop(1).count, self.assertEqual([2, 2], map(end.pop(1).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['RebaseTransaction', 'AnswerRebaseTransaction']))
self.assertEqual(end, {0: ['AnswerRebaseTransaction', # Rarely, there's an extra deadlock for t1:
'StoreTransaction', 'VoteTransaction']}) # 0: ['AnswerRebaseTransaction', 'RebaseTransaction',
# 'RebaseTransaction', 'AnswerRebaseTransaction',
# 'AnswerRebaseTransaction', 2, 3, 1,
# 'StoreTransaction', 'VoteTransaction']
self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction')
self.assertFalse(end)
@with_cluster() @with_cluster()
def testDelayedStoreOrdering(self, cluster): def testDelayedStoreOrdering(self, cluster):
...@@ -2221,6 +2266,75 @@ class Test(NEOThreadedTest): ...@@ -2221,6 +2266,75 @@ class Test(NEOThreadedTest):
def testConflictAfterDeadlockWithSlowReplica2(self): def testConflictAfterDeadlockWithSlowReplica2(self):
self.testConflictAfterDeadlockWithSlowReplica1(True) self.testConflictAfterDeadlockWithSlowReplica1(True)
@with_cluster(start_cluster=0, master_count=3)
def testElection(self, cluster):
m0, m1, m2 = cluster.master_list
cluster.start(master_list=(m0,), recovering=True)
getClusterState = cluster.neoctl.getClusterState
m0.em.removeReader(m0.listening_conn)
m1.start()
self.tic()
m2.start()
self.tic()
self.assertTrue(m0.primary)
self.assertTrue(m1.primary)
self.assertFalse(m2.primary)
m0.em.addReader(m0.listening_conn)
with ConnectionFilter() as f:
f.delayAcceptIdentification()
self.tic()
self.tic()
self.assertTrue(m0.primary)
self.assertFalse(m1.primary)
self.assertFalse(m2.primary)
self.assertEqual(getClusterState(), ClusterStates.RECOVERING)
cluster.startCluster()
def stop(node):
node.stop()
cluster.join((node,))
node.resetNode()
stop(m1)
self.tic()
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertTrue(m0.primary)
self.assertFalse(m2.primary)
stop(m0)
self.tic()
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertTrue(m2.primary)
# Check for proper update of node ids on first NotifyNodeInformation.
stop(m2)
m0.start()
def update(orig, app, timestamp, node_list):
orig(app, timestamp, sorted(node_list, reverse=1))
with Patch(cluster.storage.nm, update=update):
with ConnectionFilter() as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.STORAGE)
self.tic()
m2.start()
self.tic()
self.tic()
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertTrue(m0.primary)
self.assertFalse(m2.primary)
@with_cluster(start_cluster=0, master_count=2)
def testIdentifyUnknownMaster(self, cluster):
m0, m1 = cluster.master_list
cluster.master_nodes = ()
m0.resetNode()
cluster.start(master_list=(m0,))
m1.start()
self.tic()
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
self.assertTrue(m0.primary)
self.assertTrue(m0.is_alive())
self.assertFalse(m1.primary)
self.assertTrue(m1.is_alive())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from neo.lib.connection import ClientConnection, ListeningConnection
from neo.lib.protocol import Packets from neo.lib.protocol import Packets
from .. import SSL from .. import Patch, SSL
from . import NEOCluster, with_cluster, test, testReplication from . import NEOCluster, test, testReplication
class SSLMixin: class SSLMixin:
...@@ -36,14 +37,6 @@ class SSLTests(SSLMixin, test.Test): ...@@ -36,14 +37,6 @@ class SSLTests(SSLMixin, test.Test):
testDeadlockAvoidance = None # XXX why this fails? testDeadlockAvoidance = None # XXX why this fails?
testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails? testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails?
if 1:
testShutdownWithSeveralMasterNodes = unittest.skip("fails randomly")(
test.Test.testShutdown.__func__)
@with_cluster(partitions=10, replicas=1, storage_count=3)
def testShutdown(self, cluster):
self._testShutdown(cluster)
def testAbortConnection(self, after_handshake=1): def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn: with self.getLoopbackConnection() as conn:
conn.ask(Packets.Ping()) conn.ask(Packets.Ping())
...@@ -65,6 +58,18 @@ class SSLTests(SSLMixin, test.Test): ...@@ -65,6 +58,18 @@ class SSLTests(SSLMixin, test.Test):
def testAbortConnectionBeforeHandshake(self): def testAbortConnectionBeforeHandshake(self):
self.testAbortConnection(0) self.testAbortConnection(0)
def testSSLVsNoSSL(self):
def __init__(orig, self, app, *args, **kw):
with Patch(app, ssl=None):
orig(self, app, *args, **kw)
for cls in (ListeningConnection, # SSL connecting to non-SSL
ClientConnection, # non-SSL connecting to SSL
):
with Patch(cls, __init__=__init__), \
self.getLoopbackConnection() as conn:
while not conn.isClosed():
conn.em.poll(1)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests): class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
# do not repeat slowest tests with SSL # do not repeat slowest tests with SSL
testBackupNodeLost = testBackupNormalCase = None # TODO recheck testBackupNodeLost = testBackupNormalCase = None # TODO recheck
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment