Commit 8c4c5df6 authored by Yoshinori Okuji's avatar Yoshinori Okuji

This is a significant rewrite of the client node, using multithreading-aware connections and epoll.

git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@156 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent a89322a4
...@@ -19,6 +19,16 @@ TODO ...@@ -19,6 +19,16 @@ TODO
- Replication. - Replication.
Requirements
- Python 2.4 or later
- ctypes http://python.net/crew/theller/ctypes/
- MySQLdb http://sourceforge.net/projects/mysql-python
- Zope 2.8 or later
Installation Installation
1. In zope: 1. In zope:
......
...@@ -28,22 +28,19 @@ class NEOStorage(BaseStorage.BaseStorage, ...@@ -28,22 +28,19 @@ class NEOStorage(BaseStorage.BaseStorage,
l = Lock() l = Lock()
self._txn_lock_acquire = l.acquire self._txn_lock_acquire = l.acquire
self._txn_lock_release = l.release self._txn_lock_release = l.release
# Create two queue for message between thread and dispatcher # Create a queue for message between thread and dispatcher
# - message queue is for message that has to be send to other node
# through the dispatcher
# - request queue is for message receive from other node which have to # - request queue is for message receive from other node which have to
# be processed # be processed
message_queue = Queue()
request_queue = Queue() request_queue = Queue()
# Create the event manager # Create the event manager
em = EventManager() em = EventManager()
# Create dispatcher thread # Create dispatcher thread
dispatcher = Dispatcher(em, message_queue, request_queue) dispatcher = Dispatcher(em, request_queue)
dispatcher.setDaemon(True) dispatcher.setDaemon(True)
# Import here to prevent recursive import # Import here to prevent recursive import
from neo.client.app import Application from neo.client.app import Application
self.app = Application(master_nodes, name, em, dispatcher, self.app = Application(master_nodes, name, em, dispatcher,
message_queue, request_queue) request_queue)
# Connect to primary master node # Connect to primary master node
dispatcher.connectToPrimaryMasterNode(self.app) dispatcher.connectToPrimaryMasterNode(self.app)
# Start dispatcher # Start dispatcher
......
import logging import logging
import os import os
from time import time
from threading import Lock, local from threading import Lock, local
from cPickle import dumps, loads from cPickle import dumps, loads
from zlib import compress, decompress from zlib import compress, decompress
...@@ -9,7 +8,7 @@ from random import shuffle ...@@ -9,7 +8,7 @@ from random import shuffle
from neo.client.mq import MQ from neo.client.mq import MQ
from neo.node import NodeManager, MasterNode, StorageNode from neo.node import NodeManager, MasterNode, StorageNode
from neo.connection import ListeningConnection, ClientConnection from neo.connection import MTClientConnection
from neo.protocol import Packet, INVALID_UUID, INVALID_TID, \ from neo.protocol import Packet, INVALID_UUID, INVALID_TID, \
STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
TEMPORARILY_DOWN_STATE, \ TEMPORARILY_DOWN_STATE, \
...@@ -23,17 +22,17 @@ from neo.util import makeChecksum, dump ...@@ -23,17 +22,17 @@ from neo.util import makeChecksum, dump
from ZODB.POSException import UndoError, StorageTransactionError, ConflictError from ZODB.POSException import UndoError, StorageTransactionError, ConflictError
from ZODB.utils import p64, u64, oid_repr from ZODB.utils import p64, u64, oid_repr
class ConnectionManager(object): class ConnectionPool(object):
"""This class manage a pool of connection to storage node.""" """This class manages a pool of connections to storage nodes."""
def __init__(self, storage, pool_size=25): def __init__(self, app, pool_size = 25):
self.storage = storage self.app = app
self.pool_size = 0 self.pool_size = 0
self.max_pool_size = pool_size self.max_pool_size = pool_size
self.connection_dict = {} self.connection_dict = {}
# define a lock in order to create one connection to # Define a lock in order to create one connection to
# a storage node at a time to avoid multiple connection # a storage node at a time to avoid multiple connections
# to the same node # to the same node.
l = Lock() l = Lock()
self.connection_lock_acquire = l.acquire self.connection_lock_acquire = l.acquire
self.connection_lock_release = l.release self.connection_lock_release = l.release
...@@ -43,33 +42,45 @@ class ConnectionManager(object): ...@@ -43,33 +42,45 @@ class ConnectionManager(object):
addr = node.getNode().getServer() addr = node.getNode().getServer()
if addr is None: if addr is None:
return None return None
handler = ClientEventHandler(self.storage, self.storage.dispatcher) handler = ClientEventHandler(self.app, self.app.dispatcher)
conn = ClientConnection(self.storage.em, handler, addr) conn = MTClientConnection(self.app.em, handler, addr)
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, self.storage.uuid, addr[0], p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE,
addr[1], self.storage.name) self.app.uuid, addr[0],
self.storage.local_var.tmp_q = Queue(1) addr[1], self.app.name)
self.storage.queue.put((self.storage.local_var.tmp_q, msg_id, conn, p), True) conn.addPacket(p)
self.storage.local_var.storage_node = None conn.expectMessage(msg_id)
self.storage._waitMessage() self.app.dispatcher.register(conn, msg_id, self.app.getQueue())
if self.storage.storage_node == -1: self.app.local_var.storage_node = None
finally:
conn.unlock()
self.app._waitMessage(conn, msg_id)
if self.app.storage_node == -1:
# Connection failed, notify primary master node # Connection failed, notify primary master node
logging.error('Connection to storage node %s failed' %(addr,)) logging.error('Connection to storage node %s failed' %(addr,))
conn = self.storage.master_conn conn = self.app.master_conn
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
node_list = [(STORAGE_NODE_TYPE, addr[0], addr[1], node.getUUID(), node_list = [(STORAGE_NODE_TYPE, addr[0], addr[1],
TEMPORARILY_DOWN_STATE),] node.getUUID(), TEMPORARILY_DOWN_STATE),]
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
self.storage.queue.put((None, msg_id, conn, p), True) conn.addPacket(p)
finally:
conn.unlock()
return None return None
logging.info('connected to storage node %s' %(addr,)) logging.info('connected to storage node %s' %(addr,))
return conn return conn
def _dropConnection(self,): def _dropConnection(self):
"""Drop a connection.""" """Drop a connection."""
pass raise NotImplementedError
def _createNodeConnection(self, node): def _createNodeConnection(self, node):
"""Create a connection to a given storage node.""" """Create a connection to a given storage node."""
...@@ -81,14 +92,14 @@ class ConnectionManager(object): ...@@ -81,14 +92,14 @@ class ConnectionManager(object):
return self.connection_dict[node.getUUID()] return self.connection_dict[node.getUUID()]
if self.pool_size > self.max_pool_size: if self.pool_size > self.max_pool_size:
# must drop some unused connections # must drop some unused connections
self.dropConnection() self._dropConnection()
conn = self._initNodeConnection(node) conn = self._initNodeConnection(node)
if conn is None: if conn is None:
return None return None
# add node to node manager # add node to node manager
if self.storage.nm.getNodeByServer(node.getServer()) is None: if self.app.nm.getNodeByServer(node.getServer()) is None:
n = StorageNode(node.getServer()) n = StorageNode(node.getServer())
self.storage.nm.add(n) self.app.nm.add(n)
self.connection_dict[node.getUUID()] = conn self.connection_dict[node.getUUID()] = conn
return conn return conn
finally: finally:
...@@ -113,8 +124,7 @@ class ConnectionManager(object): ...@@ -113,8 +124,7 @@ class ConnectionManager(object):
class Application(ThreadingMixIn, object): class Application(ThreadingMixIn, object):
"""The client node application.""" """The client node application."""
def __init__(self, master_nodes, name, em, dispatcher, message_queue, def __init__(self, master_nodes, name, em, dispatcher, request_queue, **kw):
request_queue, **kw):
logging.basicConfig(level = logging.DEBUG) logging.basicConfig(level = logging.DEBUG)
logging.debug('master node address are %s' %(master_nodes,)) logging.debug('master node address are %s' %(master_nodes,))
# Internal Attributes common to all thread # Internal Attributes common to all thread
...@@ -122,9 +132,8 @@ class Application(ThreadingMixIn, object): ...@@ -122,9 +132,8 @@ class Application(ThreadingMixIn, object):
self.em = em self.em = em
self.dispatcher = dispatcher self.dispatcher = dispatcher
self.nm = NodeManager() self.nm = NodeManager()
self.cm = ConnectionManager(self) self.cp = ConnectionPool(self)
self.pt = None self.pt = None
self.queue = message_queue
self.request_queue = request_queue self.request_queue = request_queue
self.primary_master_node = None self.primary_master_node = None
self.master_node_list = master_nodes.split(' ') self.master_node_list = master_nodes.split(' ')
...@@ -167,32 +176,35 @@ class Application(ThreadingMixIn, object): ...@@ -167,32 +176,35 @@ class Application(ThreadingMixIn, object):
break break
self.uuid = uuid self.uuid = uuid
def _waitMessage(self,block=1): def getQueue(self):
"""Wait for a message returned by dispatcher in queues.""" return self.local_var.__dict__.setdefault('queue', Queue(5))
# First check if there are global messages and execute them
global_message = None def _waitMessage(self, target_conn = None, msg_id = None):
"""Wait for a message returned by the dispatcher in queues."""
global_queue = self.request_queue
local_queue = self.getQueue()
while 1: while 1:
try: try:
global_message = self.request_queue.get_nowait() conn, packet = global_queue.get_nowait()
conn.handler.dispatch(conn, packet)
except Empty:
if msg_id is None:
try:
conn, packet = local_queue.get_nowait()
except Empty: except Empty:
break break
if global_message is not None:
global_message[0].handler.dispatch(global_message[0], global_message[1])
# Next get messages we are waiting for
if not hasattr(self.local_var, 'tmp_q'):
return
message = None
if block:
message = self.local_var.tmp_q.get(True, None)
else: else:
# we don't want to block until we got a message conn, packet = local_queue.get()
conn.lock()
try: try:
message = self.local_var.tmp_q.get_nowait() conn.handler.dispatch(conn, packet)
except Empty: finally:
pass conn.unlock()
if message is not None:
message[0].handler.dispatch(message[0], message[1])
if target_conn is conn and msg_id == packet.getId():
break
def registerDB(self, db, limit): def registerDB(self, db, limit):
self._db = db self._db = db
...@@ -207,12 +219,18 @@ class Application(ThreadingMixIn, object): ...@@ -207,12 +219,18 @@ class Application(ThreadingMixIn, object):
# from asking too many time new oid one by one # from asking too many time new oid one by one
# from master node # from master node
conn = self.master_conn conn = self.master_conn
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askNewOIDs(msg_id, 25) p.askNewOIDs(msg_id, 25)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
self._waitMessage() self.dispatcher.register(conn, msg_id, self.getQueue())
finally:
conn.unlock()
self._waitMessage(conn, msg_id)
if len(self.new_oid_list) <= 0: if len(self.new_oid_list) <= 0:
raise NEOStorageError('new_oid failed') raise NEOStorageError('new_oid failed')
return self.new_oid_list.pop() return self.new_oid_list.pop()
...@@ -229,7 +247,7 @@ class Application(ThreadingMixIn, object): ...@@ -229,7 +247,7 @@ class Application(ThreadingMixIn, object):
finally: finally:
self._cache_lock_release() self._cache_lock_release()
# history return serial, so use it # history return serial, so use it
hist = self.history(oid, length=1, object_only=1) hist = self.history(oid, length = 1, object_only = 1)
if len(hist) == 0: if len(hist) == 0:
raise NEOStorageNotFoundError() raise NEOStorageNotFoundError()
if hist[0] != oid: if hist[0] != oid:
...@@ -237,7 +255,7 @@ class Application(ThreadingMixIn, object): ...@@ -237,7 +255,7 @@ class Application(ThreadingMixIn, object):
return hist[1][0][0] return hist[1][0][0]
def _load(self, oid, serial=INVALID_TID, tid=INVALID_TID, cache=0): def _load(self, oid, serial = INVALID_TID, tid = INVALID_TID, cache = 0):
"""Internal method which manage load ,loadSerial and loadBefore.""" """Internal method which manage load ,loadSerial and loadBefore."""
partition_id = u64(oid) % self.num_partitions partition_id = u64(oid) % self.num_partitions
# Only used up to date node for retrieving object # Only used up to date node for retrieving object
...@@ -253,21 +271,27 @@ class Application(ThreadingMixIn, object): ...@@ -253,21 +271,27 @@ class Application(ThreadingMixIn, object):
for cell in cell_list: for cell in cell_list:
logging.debug('trying to load %s from %s', logging.debug('trying to load %s from %s',
dump(oid), dump(cell.getUUID())) dump(oid), dump(cell.getUUID()))
conn = self.cm.getConnForNode(cell) conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askObject(msg_id, oid, serial, tid) p.askObject(msg_id, oid, serial, tid)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.asked_object = 0
finally:
conn.unlock()
# Wait for answer # Wait for answer
self.local_var.asked_object = 0
# asked object retured value are : # asked object retured value are :
# -1 : oid not found # -1 : oid not found
# other : data # other : data
self._waitMessage() self._waitMessage(conn, msg_id)
if self.local_var.asked_object == -1: if self.local_var.asked_object == -1:
# OID not found # OID not found
# XXX either try with another node, either raise error here # XXX either try with another node, either raise error here
...@@ -275,16 +299,17 @@ class Application(ThreadingMixIn, object): ...@@ -275,16 +299,17 @@ class Application(ThreadingMixIn, object):
continue continue
# Check data # Check data
noid, start_serial, end_serial, compression, checksum, data = self.local_var.asked_object noid, start_serial, end_serial, compression, checksum, data \
= self.local_var.asked_object
if noid != oid: if noid != oid:
# Oops, try with next node # Oops, try with next node
logging.error('got wrong oid %s instead of %s from node %s' \ logging.error('got wrong oid %s instead of %s from node %s' \
%(noid, oid, storage_node.getServer())) % (noid, oid, cell.getServer()))
continue continue
elif checksum != makeChecksum(data): elif checksum != makeChecksum(data):
# Check checksum. # Check checksum.
logging.error('wrong checksum from node %s for oid %s' \ logging.error('wrong checksum from node %s for oid %s' \
%(storage_node.getServer(), oid)) % (cell.getServer(), oid))
continue continue
else: else:
# Everything looks alright. # Everything looks alright.
...@@ -352,13 +377,19 @@ class Application(ThreadingMixIn, object): ...@@ -352,13 +377,19 @@ class Application(ThreadingMixIn, object):
if tid is None: if tid is None:
self.tid = None self.tid = None
conn = self.master_conn conn = self.master_conn
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askNewTID(msg_id) p.askNewTID(msg_id)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
finally:
conn.unlock()
# Wait for answer # Wait for answer
self._waitMessage() self._waitMessage(conn, msg_id)
if self.tid is None: if self.tid is None:
raise NEOStorageError('tpc_begin failed') raise NEOStorageError('tpc_begin failed')
else: else:
...@@ -375,27 +406,34 @@ class Application(ThreadingMixIn, object): ...@@ -375,27 +406,34 @@ class Application(ThreadingMixIn, object):
dump(oid), dump(serial)) dump(oid), dump(serial))
# Find which storage node to use # Find which storage node to use
partition_id = u64(oid) % self.num_partitions partition_id = u64(oid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, False) cell_list = self.pt.getCellList(partition_id, False)
if len(storage_node_list) == 0: if len(cell_list) == 0:
# FIXME must wait for cluster to be ready # FIXME must wait for cluster to be ready
raise NEOStorageError raise NEOStorageError
# Store data on each node # Store data on each node
ddata = dumps(data) ddata = dumps(data)
compressed_data = compress(ddata) compressed_data = compress(ddata)
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
for storage_node in storage_node_list: for cell in cell_list:
conn = self.cm.getConnForNode(storage_node) conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askStoreObject(msg_id, oid, serial, 1, checksum, compressed_data, self.tid) p.askStoreObject(msg_id, oid, serial, 1,
self.local_var.tmp_q = Queue(1) checksum, compressed_data, self.tid)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.txn_object_stored = 0
finally:
conn.unlock()
# Check we don't get any conflict # Check we don't get any conflict
self.txn_object_stored = 0 self._waitMessage(conn, msg_id)
self._waitMessage()
if self.txn_object_stored[0] == -1: if self.txn_object_stored[0] == -1:
if self.txn_data_dict.has_key(oid): if self.txn_data_dict.has_key(oid):
# One storage already accept the object, is it normal ?? # One storage already accept the object, is it normal ??
...@@ -423,22 +461,29 @@ class Application(ThreadingMixIn, object): ...@@ -423,22 +461,29 @@ class Application(ThreadingMixIn, object):
oid_list = self.txn_data_dict.keys() oid_list = self.txn_data_dict.keys()
# Store data on each node # Store data on each node
partition_id = u64(self.tid) % self.num_partitions partition_id = u64(self.tid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True) cell_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list: for cell in cell_list:
conn = self.cm.getConnForNode(storage_node) conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askStoreTransaction(msg_id, self.tid, user, desc, ext, oid_list) p.askStoreTransaction(msg_id, self.tid, user, desc, ext,
self.local_var.tmp_q = Queue(1) oid_list)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.txn_voted == 0 self.txn_voted == 0
self._waitMessage() finally:
conn.unlock()
self._waitMessage(conn, msg_id)
if self.txn_voted != 1: if self.txn_voted != 1:
raise NEOStorageError('tpc_vote failed') raise NEOStorageError('tpc_vote failed')
def _clear_txn(self): def _clear_txn(self):
"""Clear some transaction parameters.""" """Clear some transaction parameters."""
self.tid = None self.tid = None
...@@ -447,61 +492,80 @@ class Application(ThreadingMixIn, object): ...@@ -447,61 +492,80 @@ class Application(ThreadingMixIn, object):
self.txn_voted = 0 self.txn_voted = 0
self.txn_finished = 0 self.txn_finished = 0
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Abort current transaction.""" """Abort current transaction."""
if transaction is not self.txn: if transaction is not self.txn:
return return
# Abort txn in node where objects were stored # Abort txn in node where objects were stored
aborted_node = {} aborted_node_set = set()
for oid in self.txn_data_dict.iterkeys(): for oid in self.txn_data_dict.iterkeys():
partition_id = u64(oid) % self.num_partitions partition_id = u64(oid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True) cell_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list: for cell in cell_list:
if not aborted_node.has_key(storage_node): if cell.getNode() not in aborted_node_set:
conn = self.cm.getConnForNode(storage_node) conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.abortTransaction(msg_id, self.tid) p.abortTransaction(msg_id, self.tid)
self.queue.put((None, msg_id, conn, p), True) conn.addPacket(p)
aborted_node[storage_node] = 1 finally:
conn.unlock()
aborted_node_set.add(cell.getNode())
# Abort in nodes where transaction was stored # Abort in nodes where transaction was stored
partition_id = u64(self.tid) % self.num_partitions partition_id = u64(self.tid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True) cell_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list: for cell in cell_list:
if not aborted_node.has_key(storage_node): if cell.getNode() not in aborted_node_set:
conn = self.cm.getConnForNode(storage_node) conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.abortTransaction(msg_id, self.tid) p.abortTransaction(msg_id, self.tid)
self.queue.put((None, msg_id, conn, p), True) conn.addPacket(p)
finally:
conn.unlock()
self._clear_txn() aborted_node_set.add(cell.getNode())
self._clear_txn()
def tpc_finish(self, transaction, f=None): def tpc_finish(self, transaction, f=None):
"""Finish current transaction.""" """Finish current transaction."""
if self.txn is not transaction: if self.txn is not transaction:
return return
# Call function given by ZODB # Call function given by ZODB
if f is not None: if f is not None:
f(self.tid) f(self.tid)
# Call finish on master # Call finish on master
oid_list = self.txn_data_dict.keys() oid_list = self.txn_data_dict.keys()
conn = self.master_conn conn = self.master_conn
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.finishTransaction(msg_id, oid_list, self.tid) p.finishTransaction(msg_id, oid_list, self.tid)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id, additional_timeout = 300)
self.dispatcher.register(conn, msg_id, self.getQueue())
finally:
conn.unlock()
# Wait for answer # Wait for answer
self._waitMessage() self._waitMessage(conn, msg_id)
if self.txn_finished != 1: if self.txn_finished != 1:
raise NEOStorageError('tpc_finish failed') raise NEOStorageError('tpc_finish failed')
...@@ -522,25 +586,33 @@ class Application(ThreadingMixIn, object): ...@@ -522,25 +586,33 @@ class Application(ThreadingMixIn, object):
if transaction_id is not self.txn: if transaction_id is not self.txn:
raise StorageTransactionError(self, transaction_id) raise StorageTransactionError(self, transaction_id)
# First get transaction information from master node # First get transaction information from a storage node.
partition_id = u64(transaction_id) % self.num_partitions partition_id = u64(transaction_id) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True) cell_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list: shuffle(cell_list)
conn = self.cm.getConnForNode(storage_node) for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askTransactionInformation(msg_id, transaction_id) p.askTransactionInformation(msg_id, transaction_id)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
# Wait for answer self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.txn_info = 0 self.local_var.txn_info = 0
self._waitMessage() finally:
conn.unlock()
# Wait for answer
self._waitMessage(conn, msg_id)
if self.local_var.txn_info == -1: if self.local_var.txn_info == -1:
# Tid not found, try with next node # Tid not found, try with next node
continue continue
elif isinstance(self.local_var.txn_info, {}): elif isinstance(self.local_var.txn_info, dict):
break break
else: else:
raise NEOStorageError('undo failed') raise NEOStorageError('undo failed')
...@@ -589,21 +661,26 @@ class Application(ThreadingMixIn, object): ...@@ -589,21 +661,26 @@ class Application(ThreadingMixIn, object):
# See FileStorage.py for explanation # See FileStorage.py for explanation
last = first - last last = first - last
# First get list of transaction from all storage node # First get a list of transactions from all storage nodes.
storage_node_list = [x for x in self.pt.getNodeList() if x.getState() \ storage_node_list = [x for x in self.pt.getNodeList() if x.getState() \
in (UP_TO_DATE_STATE, FEEDING_STATE)] in (UP_TO_DATE_STATE, FEEDING_STATE)]
self.local_var.node_tids = {} self.local_var.node_tids = {}
self.local_var.tmp_q = Queue(len(storage_node_list))
for storage_node in storage_node_list: for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node) conn = self.cp.getConnForNode(storage_node)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askTIDs(msg_id, first, last) p.askTIDs(msg_id, first, last)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.addPacket(p)
finally:
conn.unlock()
# Wait for answer from all storages # Wait for answers from all storages.
# FIXME this is a busy loop.
while True: while True:
self._waitMessage() self._waitMessage()
if len(self.local_var.node_tids) == len(storage_node_list): if len(self.local_var.node_tids) == len(storage_node_list):
...@@ -620,23 +697,31 @@ class Application(ThreadingMixIn, object): ...@@ -620,23 +697,31 @@ class Application(ThreadingMixIn, object):
undo_info = [] undo_info = []
for tid in ordered_tids: for tid in ordered_tids:
partition_id = u64(tid) % self.num_partitions partition_id = u64(tid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True) cell_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list: shuffle(cell_list)
conn = self.cm.getConnForNode(storage_node) for cell in cell_list:
conn = self.cp.getConnForNode(storage_node)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askTransactionInformation(msg_id, tid) p.askTransactionInformation(msg_id, tid)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
# Wait for answer self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.txn_info = 0 self.local_var.txn_info = 0
self._waitMessage() finally:
conn.unlock()
# Wait for answer
self._waitMessage(conn, msg_id)
if self.local_var.txn_info == -1: if self.local_var.txn_info == -1:
# TID not found, go on with next node # TID not found, go on with next node
continue continue
elif isinstance(self.local_var.txn_info, {}): elif isinstance(self.local_var.txn_info, dict):
break break
# Filter result if needed # Filter result if needed
...@@ -648,7 +733,7 @@ class Application(ThreadingMixIn, object): ...@@ -648,7 +733,7 @@ class Application(ThreadingMixIn, object):
# Append to returned list # Append to returned list
self.local_var.txn_info.pop("oids") self.local_var.txn_info.pop("oids")
undo_info.append(self.local_var.txn_info) undo_info.append(self.local_var.txn_info)
if len(undo_info) >= last-first: if len(undo_info) >= last - first:
break break
return undo_info return undo_info
...@@ -657,26 +742,35 @@ class Application(ThreadingMixIn, object): ...@@ -657,26 +742,35 @@ class Application(ThreadingMixIn, object):
def history(self, oid, version, length=1, filter=None, object_only=0): def history(self, oid, version, length=1, filter=None, object_only=0):
# Get history informations for object first # Get history informations for object first
partition_id = u64(oid) % self.num_partitions partition_id = u64(oid) % self.num_partitions
storage_node_list = [x for x in self.pt.getCellList(partition_id, True) \ cell_list = self.pt.getCellList(partition_id, True)
if x.getState() == UP_TO_DATE_STATE] shuffle(cell_list)
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node) for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askObjectHistory(msg_id, oid, length) p.askObjectHistory(msg_id, oid, length)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.history = None self.local_var.history = None
self._waitMessage() finally:
conn.unlock()
self._waitMessage(conn, msg_id)
if self.local_var.history == -1: if self.local_var.history == -1:
# Not found, go on with next node # Not found, go on with next node
continue continue
if self.local_var.history[0] != oid: if self.local_var.history[0] != oid:
# Got history for wrong oid # Got history for wrong oid
continue continue
if not isinstance(self.local_var.history, {}):
if not isinstance(self.local_var.history, dict):
raise NEOStorageError('history failed') raise NEOStorageError('history failed')
if object_only: if object_only:
# Use by getSerial # Use by getSerial
...@@ -686,23 +780,32 @@ class Application(ThreadingMixIn, object): ...@@ -686,23 +780,32 @@ class Application(ThreadingMixIn, object):
history_list = [] history_list = []
for serial, size in self.local_var.hisory[1]: for serial, size in self.local_var.hisory[1]:
partition_id = u64(serial) % self.num_partitions partition_id = u64(serial) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True) cell_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list: shuffle(cell_list)
conn = self.cm.getConnForNode(storage_node)
for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None: if conn is None:
continue continue
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askTransactionInformation(msg_id, serial) p.askTransactionInformation(msg_id, serial)
self.local_var.tmp_q = Queue(1) conn.addPacket(p)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
# Wait for answer self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.txn_info = None self.local_var.txn_info = None
self._waitMessage() finally:
conn.unlock()
# Wait for answer
self._waitMessage(conn, msg_id)
if self.local_var.txn_info == -1: if self.local_var.txn_info == -1:
# TID not found # TID not found
continue continue
if isinstance(self.local_var.txn_info, {}): if isinstance(self.local_var.txn_info, dict):
break break
# create history dict # create history dict
......
...@@ -2,7 +2,7 @@ from threading import Thread ...@@ -2,7 +2,7 @@ from threading import Thread
from Queue import Empty, Queue from Queue import Empty, Queue
from neo.protocol import PING, Packet, CLIENT_NODE_TYPE, FINISH_TRANSACTION from neo.protocol import PING, Packet, CLIENT_NODE_TYPE, FINISH_TRANSACTION
from neo.connection import ClientConnection from neo.connection import MTClientConnection
from neo.node import MasterNode from neo.node import MasterNode
from time import time from time import time
...@@ -11,9 +11,8 @@ import logging ...@@ -11,9 +11,8 @@ import logging
class Dispatcher(Thread): class Dispatcher(Thread):
"""Dispatcher class use to redirect request to thread.""" """Dispatcher class use to redirect request to thread."""
def __init__(self, em, message_queue, request_queue, **kw): def __init__(self, em, request_queue, **kw):
Thread.__init__(self, **kw) Thread.__init__(self, **kw)
self._message_queue = message_queue
self._request_queue = request_queue self._request_queue = request_queue
self.em = em self.em = em
# Queue of received packet that have to be processed # Queue of received packet that have to be processed
...@@ -29,51 +28,42 @@ class Dispatcher(Thread): ...@@ -29,51 +28,42 @@ class Dispatcher(Thread):
# First check if we receive any new message from other node # First check if we receive any new message from other node
m = None m = None
try: try:
self.em.poll(0.02) self.em.poll()
except KeyError: except KeyError:
# This happen when there is no connection # This happen when there is no connection
logging.error('Dispatcher, run, poll returned a KeyError') logging.error('Dispatcher, run, poll returned a KeyError')
while 1: while 1:
try: try:
conn, packet = self.message.get_nowait() conn, packet = self.message.get_nowait()
except Empty: except Empty:
break break
# Send message to waiting thread # Send message to waiting thread
key = "%s-%s" %(conn.getUUID(),packet.getId()) key = (conn.getUUID(), packet.getId())
#logging.info('dispatcher got packet %s' %(key,)) #logging.info('dispatcher got packet %s' %(key,))
if self.message_table.has_key(key): if key in self.message_table:
tmp_q = self.message_table.pop(key) queue = self.message_table.pop(key)
tmp_q.put((conn, packet), True) queue.put((conn, packet))
else: else:
#conn, packet = self.message #conn, packet = self.message
method_type = packet.getType() method_type = packet.getType()
if method_type == PING: if method_type == PING:
# must answer with no delay # must answer with no delay
conn.lock()
try:
conn.addPacket(Packet().pong(packet.getId())) conn.addPacket(Packet().pong(packet.getId()))
finally:
conn.unlock()
else: else:
# put message in request queue # put message in request queue
self._request_queue.put((conn, packet), True) self._request_queue.put((conn, packet))
# Then check if a client ask me to send a message def register(self, conn, msg_id, queue):
try: """Register an expectation for a reply. Thanks to GIL, it is
m = self._message_queue.get_nowait() safe not to use a lock here."""
if m is not None: key = (conn.getUUID(), msg_id)
tmp_q, msg_id, conn, p = m self.message_table[key] = queue
conn.addPacket(p)
if tmp_q is not None:
# We expect an answer
key = "%s-%s" %(conn.getUUID(), msg_id)
self.message_table[key] = tmp_q
# XXX this is a hack. Probably queued tasks themselves
# should specify the timeout values.
if p.getType() == FINISH_TRANSACTION:
# Finish Transaction may take a lot of time when
# many objects are committed at a time.
conn.expectMessage(msg_id, additional_timeout = 300)
else:
conn.expectMessage(msg_id)
except Empty:
continue
def connectToPrimaryMasterNode(self, app): def connectToPrimaryMasterNode(self, app):
"""Connect to a primary master node. """Connect to a primary master node.
...@@ -87,7 +77,7 @@ class Dispatcher(Thread): ...@@ -87,7 +77,7 @@ class Dispatcher(Thread):
master_index = 0 master_index = 0
conn = None conn = None
# Make application execute remaining message if any # Make application execute remaining message if any
app._waitMessage(block=0) app._waitMessage()
handler = ClientEventHandler(app, app.dispatcher) handler = ClientEventHandler(app, app.dispatcher)
while 1: while 1:
if app.pt is not None and app.pt.operational(): if app.pt is not None and app.pt.operational():
...@@ -101,18 +91,24 @@ class Dispatcher(Thread): ...@@ -101,18 +91,24 @@ class Dispatcher(Thread):
else: else:
addr, port = app.primary_master_node.getServer() addr, port = app.primary_master_node.getServer()
# Request Node Identification # Request Node Identification
conn = ClientConnection(app.em, handler, (addr, port)) conn = MTClientConnection(app.em, handler, (addr, port))
if app.nm.getNodeByServer((addr, port)) is None: if app.nm.getNodeByServer((addr, port)) is None:
n = MasterNode(server = (addr, port)) n = MasterNode(server = (addr, port))
app.nm.add(n) app.nm.add(n)
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, app.uuid, p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, app.uuid,
'0.0.0.0', 0, app.name) '0.0.0.0', 0, app.name)
# Send message # Send message
conn.addPacket(p) conn.addPacket(p)
conn.expectMessage(msg_id) conn.expectMessage(msg_id)
app.local_var.tmp_q = Queue(1) finally:
conn.unlock()
# Wait for answer # Wait for answer
while 1: while 1:
try: try:
...@@ -126,6 +122,8 @@ class Dispatcher(Thread): ...@@ -126,6 +122,8 @@ class Dispatcher(Thread):
try: try:
conn, packet = self.message.get_nowait() conn, packet = self.message.get_nowait()
method_type = packet.getType() method_type = packet.getType()
conn.lock()
try:
if method_type == PING: if method_type == PING:
# Must answer with no delay # Must answer with no delay
conn.addPacket(Packet().pong(packet.getId())) conn.addPacket(Packet().pong(packet.getId()))
...@@ -133,6 +131,8 @@ class Dispatcher(Thread): ...@@ -133,6 +131,8 @@ class Dispatcher(Thread):
else: else:
# Process message by handler # Process message by handler
conn.handler.dispatch(conn, packet) conn.handler.dispatch(conn, packet)
finally:
conn.unlock()
except Empty: except Empty:
pass pass
...@@ -156,15 +156,6 @@ class Dispatcher(Thread): ...@@ -156,15 +156,6 @@ class Dispatcher(Thread):
# Connected to primary master node # Connected to primary master node
break break
# If nothing, check if we have new message to send
try:
m = self._message_queue.get_nowait()
if m is not None:
tmp_q, msg_id, conn, p = m
conn.addPacket(p)
except Empty:
continue
logging.info("connected to primary master node %s %d" %app.primary_master_node.getServer()) logging.info("connected to primary master node %s %d" %app.primary_master_node.getServer())
app.master_conn = conn app.master_conn = conn
self.connecting_to_master_node = 0 self.connecting_to_master_node = 0
import logging import logging
from neo.handler import EventHandler from neo.handler import EventHandler
from neo.connection import ClientConnection from neo.connection import MTClientConnection
from neo.protocol import Packet, \ from neo.protocol import Packet, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
INVALID_UUID, RUNNING_STATE, TEMPORARILY_DOWN_STATE, BROKEN_STATE INVALID_UUID, RUNNING_STATE, TEMPORARILY_DOWN_STATE, BROKEN_STATE
...@@ -74,7 +74,7 @@ class ClientEventHandler(EventHandler): ...@@ -74,7 +74,7 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True) app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection # Remove from pool connection
app.cm.removeConnection(node) app.cp.removeConnection(node)
EventHandler.connectionClosed(self, conn) EventHandler.connectionClosed(self, conn)
def timeoutExpired(self, conn): def timeoutExpired(self, conn):
...@@ -98,7 +98,7 @@ class ClientEventHandler(EventHandler): ...@@ -98,7 +98,7 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True) app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection # Remove from pool connection
app.cm.removeConnection(node) app.cp.removeConnection(node)
EventHandler.timeoutExpired(self, conn) EventHandler.timeoutExpired(self, conn)
def peerBroken(self, conn): def peerBroken(self, conn):
...@@ -122,12 +122,12 @@ class ClientEventHandler(EventHandler): ...@@ -122,12 +122,12 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True) app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection # Remove from pool connection
app.cm.removeConnection(node) app.cp.removeConnection(node)
EventHandler.peerBroken(self, conn) EventHandler.peerBroken(self, conn)
def handleNotReady(self, conn, packet, message): def handleNotReady(self, conn, packet, message):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.node_not_ready = 1 app.node_not_ready = 1
else: else:
...@@ -136,7 +136,7 @@ class ClientEventHandler(EventHandler): ...@@ -136,7 +136,7 @@ class ClientEventHandler(EventHandler):
def handleAcceptNodeIdentification(self, conn, packet, node_type, def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, uuid, ip_address, port,
num_partitions, num_replicas): num_partitions, num_replicas):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
node = app.nm.getNodeByServer(conn.getAddress()) node = app.nm.getNodeByServer(conn.getAddress())
# It can be eiter a master node or a storage node # It can be eiter a master node or a storage node
...@@ -162,12 +162,18 @@ class ClientEventHandler(EventHandler): ...@@ -162,12 +162,18 @@ class ClientEventHandler(EventHandler):
app.pt = PartitionTable(num_partitions, num_replicas) app.pt = PartitionTable(num_partitions, num_replicas)
app.num_partitions = num_partitions app.num_partitions = num_partitions
app.num_replicas = num_replicas app.num_replicas = num_replicas
# Ask a primary master. # Ask a primary master.
conn.lock()
try:
msg_id = conn.getNextId() msg_id = conn.getNextId()
p = Packet() p = Packet()
p.askPrimaryMaster(msg_id) p.askPrimaryMaster(msg_id)
# send message to dispatcher conn.addPacket(p)
app.queue.put((app.local_var.tmp_q, msg_id, conn, p), True) conn.expectMessage(msg_id)
app.dispatcher.register(conn, msg_id, app.getQueue())
finally:
conn.unlock()
elif node_type == STORAGE_NODE_TYPE: elif node_type == STORAGE_NODE_TYPE:
app.storage_node = node app.storage_node = node
else: else:
...@@ -176,7 +182,7 @@ class ClientEventHandler(EventHandler): ...@@ -176,7 +182,7 @@ class ClientEventHandler(EventHandler):
# Master node handler # Master node handler
def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, known_master_list): def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, known_master_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None: if uuid is None:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -220,7 +226,7 @@ class ClientEventHandler(EventHandler): ...@@ -220,7 +226,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleSendPartitionTable(self, conn, packet, ptid, row_list): def handleSendPartitionTable(self, conn, packet, ptid, row_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None: if uuid is None:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -249,7 +255,7 @@ class ClientEventHandler(EventHandler): ...@@ -249,7 +255,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleNotifyNodeInformation(self, conn, packet, node_list): def handleNotifyNodeInformation(self, conn, packet, node_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None: if uuid is None:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -309,7 +315,7 @@ class ClientEventHandler(EventHandler): ...@@ -309,7 +315,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list): def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
nm = app.nm nm = app.nm
pt = app.pt pt = app.pt
...@@ -344,14 +350,14 @@ class ClientEventHandler(EventHandler): ...@@ -344,14 +350,14 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerNewTID(self, conn, packet, tid): def handleAnswerNewTID(self, conn, packet, tid):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.tid = tid app.tid = tid
else: else:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleNotifyTransactionFinished(self, conn, packet, tid): def handleNotifyTransactionFinished(self, conn, packet, tid):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
if tid != app.tid: if tid != app.tid:
app.txn_finished = -1 app.txn_finished = -1
...@@ -361,7 +367,7 @@ class ClientEventHandler(EventHandler): ...@@ -361,7 +367,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleInvalidateObjects(self, conn, packet, oid_list): def handleInvalidateObjects(self, conn, packet, oid_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
...@@ -379,7 +385,7 @@ class ClientEventHandler(EventHandler): ...@@ -379,7 +385,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerNewOIDs(self, conn, packet, oid_list): def handleAnswerNewOIDs(self, conn, packet, oid_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.new_oid_list = oid_list app.new_oid_list = oid_list
app.new_oid_list.reverse() app.new_oid_list.reverse()
...@@ -387,7 +393,7 @@ class ClientEventHandler(EventHandler): ...@@ -387,7 +393,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleStopOperation(self, conn, packet): def handleStopOperation(self, conn, packet):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
logging.critical("master node ask to stop operation") logging.critical("master node ask to stop operation")
else: else:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -396,7 +402,7 @@ class ClientEventHandler(EventHandler): ...@@ -396,7 +402,7 @@ class ClientEventHandler(EventHandler):
# Storage node handler # Storage node handler
def handleAnswerObject(self, conn, packet, oid, start_serial, end_serial, compression, def handleAnswerObject(self, conn, packet, oid, start_serial, end_serial, compression,
checksum, data): checksum, data):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.local_var.asked_object = (oid, start_serial, end_serial, compression, app.local_var.asked_object = (oid, start_serial, end_serial, compression,
checksum, data) checksum, data)
...@@ -404,7 +410,7 @@ class ClientEventHandler(EventHandler): ...@@ -404,7 +410,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerStoreObject(self, conn, packet, conflicting, oid, serial): def handleAnswerStoreObject(self, conn, packet, conflicting, oid, serial):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
if conflicting: if conflicting:
app.txn_object_stored = -1, serial app.txn_object_stored = -1, serial
...@@ -414,14 +420,14 @@ class ClientEventHandler(EventHandler): ...@@ -414,14 +420,14 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerStoreTransaction(self, conn, packet, tid): def handleAnswerStoreTransaction(self, conn, packet, tid):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.txn_voted = 1 app.txn_voted = 1
else: else:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerTransactionInformation(self, conn, packet, tid, user, desc, oid_list): def handleAnswerTransactionInformation(self, conn, packet, tid, user, desc, oid_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# transaction information are returned as a dict # transaction information are returned as a dict
info = {} info = {}
...@@ -435,7 +441,7 @@ class ClientEventHandler(EventHandler): ...@@ -435,7 +441,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerObjectHistory(self, conn, packet, oid, history_list): def handleAnswerObjectHistory(self, conn, packet, oid, history_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# history_list is a list of tuple (serial, size) # history_list is a list of tuple (serial, size)
self.history = oid, history_list self.history = oid, history_list
...@@ -443,7 +449,7 @@ class ClientEventHandler(EventHandler): ...@@ -443,7 +449,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleOidNotFound(self, conn, packet, message): def handleOidNotFound(self, conn, packet, message):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# This can happen either when : # This can happen either when :
# - loading an object # - loading an object
...@@ -454,7 +460,7 @@ class ClientEventHandler(EventHandler): ...@@ -454,7 +460,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleTidNotFound(self, conn, packet, message): def handleTidNotFound(self, conn, packet, message):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# This can happen when requiring txn informations # This can happen when requiring txn informations
app.local_var.txn_info = -1 app.local_var.txn_info = -1
......
import socket import socket
import errno import errno
import logging import logging
from threading import RLock
from neo.protocol import Packet, ProtocolError from neo.protocol import Packet, ProtocolError
from neo.event import IdleEvent from neo.event import IdleEvent
...@@ -46,6 +47,12 @@ class BaseConnection(object): ...@@ -46,6 +47,12 @@ class BaseConnection(object):
def getUUID(self): def getUUID(self):
return None return None
def acquire(self, block = 1):
return 1
def release(self):
pass
class ListeningConnection(BaseConnection): class ListeningConnection(BaseConnection):
"""A listen connection.""" """A listen connection."""
def __init__(self, event_manager, handler, addr = None, **kw): def __init__(self, event_manager, handler, addr = None, **kw):
...@@ -305,3 +312,31 @@ class ClientConnection(Connection): ...@@ -305,3 +312,31 @@ class ClientConnection(Connection):
class ServerConnection(Connection): class ServerConnection(Connection):
"""A connection from a remote node to this node.""" """A connection from a remote node to this node."""
pass pass
class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection."""
def __init__(self, *args, **kwargs):
super(MTClientConnection, self).__init__(*args, **kwargs)
lock = RLock()
self.acquire = lock.acquire
self.release = lock.release
def lock(self, blocking = 1):
return self.acquire(blocking = blocking)
def unlock(self):
self.release()
class MTServerConnection(ServerConnection):
"""A Multithread-safe version of ServerConnection."""
def __init__(self, *args, **kwargs):
super(MTClientConnection, self).__init__(*args, **kwargs)
lock = RLock()
self.acquire = lock.acquire
self.release = lock.release
def lock(self, blocking = 1):
return self.acquire(blocking = blocking)
def unlock(self):
self.release()
r"""This is an epoll(4) interface available in Linux 2.6. This requires
ctypes <http://python.net/crew/theller/ctypes/>."""
from ctypes import cdll, Union, Structure, \
c_void_p, c_int, byref
try:
from ctypes import c_uint32, c_uint64
except ImportError:
from ctypes import c_uint, c_ulonglong
c_uint32 = c_uint
c_uint64 = c_ulonglong
from os import close
from errno import EINTR
libc = cdll.LoadLibrary('libc.so.6')
epoll_create = libc.epoll_create
epoll_wait = libc.epoll_wait
epoll_ctl = libc.epoll_ctl
errno = c_int.in_dll(libc, 'errno')
EPOLLIN = 0x001
EPOLLPRI = 0x002
EPOLLOUT = 0x004
EPOLLRDNORM = 0x040
EPOLLRDBAND = 0x080
EPOLLWRNORM = 0x100
EPOLLWRBAND = 0x200
EPOLLMSG = 0x400
EPOLLERR = 0x008
EPOLLHUP = 0x010
EPOLLONESHOT = (1 << 30)
EPOLLET = (1 << 31)
EPOLL_CTL_ADD = 1
EPOLL_CTL_DEL = 2
EPOLL_CTL_MOD = 3
class epoll_data(Union):
_fields_ = [("ptr", c_void_p),
("fd", c_int),
("u32", c_uint32),
("u64", c_uint64)]
class epoll_event(Structure):
_fields_ = [("events", c_uint32),
("data", epoll_data)]
class Epoll(object):
efd = -1
def __init__(self):
self.efd = epoll_create(10)
if self.efd == -1:
raise OSError(errno.value, 'epoll_create failed')
self.maxevents = 1024 # XXX arbitrary
epoll_event_array = epoll_event * self.maxevents
self.events = epoll_event_array()
def poll(self, timeout = 1):
timeout *= 1000
timeout = int(timeout)
while 1:
n = epoll_wait(self.efd, byref(self.events), self.maxevents,
timeout)
if n == -1:
e = errno.value
if e == EINTR:
continue
else:
raise OSError(e, 'epoll_wait failed')
else:
readable_fd_list = []
writable_fd_list = []
for i in xrange(n):
ev = self.events[i]
if ev.events & (EPOLLIN | EPOLLERR | EPOLLHUP):
readable_fd_list.append(int(ev.data.fd))
elif ev.events & (EPOLLOUT | EPOLLERR | EPOLLHUP):
writable_fd_list.append(int(ev.data.fd))
return readable_fd_list, writable_fd_list
def register(self, fd):
ev = epoll_event()
ev.data.fd = fd
ret = epoll_ctl(self.efd, EPOLL_CTL_ADD, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def modify(self, fd, readable, writable):
ev = epoll_event()
ev.data.fd = fd
events = 0
if readable:
events |= EPOLLIN
if writable:
events |= EPOLLOUT
ev.events = events
ret = epoll_ctl(self.efd, EPOLL_CTL_MOD, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def unregister(self, fd):
ev = epoll_event()
ret = epoll_ctl(self.efd, EPOLL_CTL_DEL, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def __del__(self):
if self.efd >= 0:
close(self.efd)
...@@ -3,6 +3,7 @@ from select import select ...@@ -3,6 +3,7 @@ from select import select
from time import time from time import time
from neo.protocol import Packet from neo.protocol import Packet
from neo.epoll import Epoll
class IdleEvent(object): class IdleEvent(object):
"""This class represents an event called when a connection is waiting for """This class represents an event called when a connection is waiting for
...@@ -28,26 +29,35 @@ class IdleEvent(object): ...@@ -28,26 +29,35 @@ class IdleEvent(object):
def __call__(self, t): def __call__(self, t):
conn = self._conn conn = self._conn
if t > self._critical_time: if t > self._critical_time:
conn.lock()
try:
logging.info('timeout with %s:%d', *(conn.getAddress())) logging.info('timeout with %s:%d', *(conn.getAddress()))
conn.getHandler().timeoutExpired(conn) conn.getHandler().timeoutExpired(conn)
conn.close() conn.close()
return True return True
finally:
conn.unlock()
elif t > self._time: elif t > self._time:
conn.lock()
try:
if self._additional_timeout > 5: if self._additional_timeout > 5:
self._additional_timeout -= 5 self._additional_timeout -= 5
conn.expectMessage(self._id, 5, self._additional_timeout) conn.expectMessage(self._id, 5, self._additional_timeout)
# Start a keep-alive packet. # Start a keep-alive packet.
logging.info('sending a ping to %s:%d', *(conn.getAddress())) logging.info('sending a ping to %s:%d',
*(conn.getAddress()))
msg_id = conn.getNextId() msg_id = conn.getNextId()
conn.addPacket(Packet().ping(msg_id)) conn.addPacket(Packet().ping(msg_id))
conn.expectMessage(msg_id, 5, 0) conn.expectMessage(msg_id, 5, 0)
else: else:
conn.expectMessage(self._id, self._additional_timeout, 0) conn.expectMessage(self._id, self._additional_timeout, 0)
return True return True
finally:
conn.unlock()
return False return False
class EventManager(object): class SelectEventManager(object):
"""This class manages connections and events.""" """This class manages connections and events based on select(2)."""
def __init__(self): def __init__(self):
self.connection_dict = {} self.connection_dict = {}
...@@ -71,13 +81,21 @@ class EventManager(object): ...@@ -71,13 +81,21 @@ class EventManager(object):
timeout) timeout)
for s in rlist: for s in rlist:
conn = self.connection_dict[s] conn = self.connection_dict[s]
conn.lock()
try:
conn.readable() conn.readable()
finally:
conn.unlock()
for s in wlist: for s in wlist:
# This can fail, if a connection is closed in readable(). # This can fail, if a connection is closed in readable().
try: try:
conn = self.connection_dict[s] conn = self.connection_dict[s]
conn.lock()
try:
conn.writable() conn.writable()
finally:
conn.unlock()
except KeyError: except KeyError:
pass pass
...@@ -120,3 +138,102 @@ class EventManager(object): ...@@ -120,3 +138,102 @@ class EventManager(object):
def removeWriter(self, conn): def removeWriter(self, conn):
self.writer_set.discard(conn.getSocket()) self.writer_set.discard(conn.getSocket())
class EpollEventManager(object):
"""This class manages connections and events based on epoll(5)."""
def __init__(self):
self.connection_dict = {}
self.reader_set = set([])
self.writer_set = set([])
self.event_list = []
self.prev_time = time()
self.epoll = Epoll()
def getConnectionList(self):
return self.connection_dict.values()
def register(self, conn):
fd = conn.getSocket().fileno()
self.connection_dict[fd] = conn
self.epoll.register(fd)
def unregister(self, conn):
fd = conn.getSocket().fileno()
self.epoll.unregister(fd)
del self.connection_dict[fd]
def poll(self, timeout = 1):
rlist, wlist = self.epoll.poll(timeout)
for fd in rlist:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.readable()
finally:
conn.unlock()
for fd in wlist:
# This can fail, if a connection is closed in readable().
try:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.writable()
finally:
conn.unlock()
except KeyError:
pass
# Check idle events. Do not check them out too often, because this
# is somehow heavy.
event_list = self.event_list
if event_list:
t = time()
if t - self.prev_time >= 1:
self.prev_time = t
event_list.sort(key = lambda event: event.getTime())
while event_list:
event = event_list[0]
if event(t):
try:
event_list.remove(event)
except ValueError:
pass
else:
break
def addIdleEvent(self, event):
self.event_list.append(event)
def removeIdleEvent(self, event):
try:
self.event_list.remove(event)
except ValueError:
pass
def addReader(self, conn):
fd = conn.getSocket().fileno()
if fd not in self.reader_set:
self.reader_set.add(fd)
self.epoll.modify(fd, 1, fd in self.writer_set)
def removeReader(self, conn):
fd = conn.getSocket().fileno()
if fd in self.reader_set:
self.reader_set.remove(fd)
self.epoll.modify(fd, 0, fd in self.writer_set)
def addWriter(self, conn):
fd = conn.getSocket().fileno()
if fd not in self.writer_set:
self.writer_set.add(fd)
self.epoll.modify(fd, fd in self.reader_set, 1)
def removeWriter(self, conn):
fd = conn.getSocket().fileno()
if fd in self.writer_set:
self.writer_set.remove(fd)
self.epoll.modify(fd, fd in self.reader_set, 0)
# Default to EpollEventManager.
EventManager = EpollEventManager
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment