Commit 8c4c5df6 authored by Yoshinori Okuji's avatar Yoshinori Okuji

This is a significant rewrite of the client node, using multithreading-aware connections and epoll.

git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@156 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent a89322a4
......@@ -19,6 +19,16 @@ TODO
- Replication.
Requirements
- Python 2.4 or later
- ctypes http://python.net/crew/theller/ctypes/
- MySQLdb http://sourceforge.net/projects/mysql-python
- Zope 2.8 or later
Installation
1. In zope:
......
......@@ -28,22 +28,19 @@ class NEOStorage(BaseStorage.BaseStorage,
l = Lock()
self._txn_lock_acquire = l.acquire
self._txn_lock_release = l.release
# Create two queue for message between thread and dispatcher
# - message queue is for message that has to be send to other node
# through the dispatcher
# Create a queue for message between thread and dispatcher
# - request queue is for message receive from other node which have to
# be processed
message_queue = Queue()
request_queue = Queue()
# Create the event manager
em = EventManager()
# Create dispatcher thread
dispatcher = Dispatcher(em, message_queue, request_queue)
dispatcher = Dispatcher(em, request_queue)
dispatcher.setDaemon(True)
# Import here to prevent recursive import
from neo.client.app import Application
self.app = Application(master_nodes, name, em, dispatcher,
message_queue, request_queue)
request_queue)
# Connect to primary master node
dispatcher.connectToPrimaryMasterNode(self.app)
# Start dispatcher
......
import logging
import os
from time import time
from threading import Lock, local
from cPickle import dumps, loads
from zlib import compress, decompress
......@@ -9,7 +8,7 @@ from random import shuffle
from neo.client.mq import MQ
from neo.node import NodeManager, MasterNode, StorageNode
from neo.connection import ListeningConnection, ClientConnection
from neo.connection import MTClientConnection
from neo.protocol import Packet, INVALID_UUID, INVALID_TID, \
STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
TEMPORARILY_DOWN_STATE, \
......@@ -23,17 +22,17 @@ from neo.util import makeChecksum, dump
from ZODB.POSException import UndoError, StorageTransactionError, ConflictError
from ZODB.utils import p64, u64, oid_repr
class ConnectionManager(object):
"""This class manage a pool of connection to storage node."""
class ConnectionPool(object):
"""This class manages a pool of connections to storage nodes."""
def __init__(self, storage, pool_size=25):
self.storage = storage
def __init__(self, app, pool_size = 25):
self.app = app
self.pool_size = 0
self.max_pool_size = pool_size
self.connection_dict = {}
# define a lock in order to create one connection to
# a storage node at a time to avoid multiple connection
# to the same node
# Define a lock in order to create one connection to
# a storage node at a time to avoid multiple connections
# to the same node.
l = Lock()
self.connection_lock_acquire = l.acquire
self.connection_lock_release = l.release
......@@ -43,33 +42,45 @@ class ConnectionManager(object):
addr = node.getNode().getServer()
if addr is None:
return None
handler = ClientEventHandler(self.storage, self.storage.dispatcher)
conn = ClientConnection(self.storage.em, handler, addr)
handler = ClientEventHandler(self.app, self.app.dispatcher)
conn = MTClientConnection(self.app.em, handler, addr)
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, self.storage.uuid, addr[0],
addr[1], self.storage.name)
self.storage.local_var.tmp_q = Queue(1)
self.storage.queue.put((self.storage.local_var.tmp_q, msg_id, conn, p), True)
self.storage.local_var.storage_node = None
self.storage._waitMessage()
if self.storage.storage_node == -1:
p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE,
self.app.uuid, addr[0],
addr[1], self.app.name)
conn.addPacket(p)
conn.expectMessage(msg_id)
self.app.dispatcher.register(conn, msg_id, self.app.getQueue())
self.app.local_var.storage_node = None
finally:
conn.unlock()
self.app._waitMessage(conn, msg_id)
if self.app.storage_node == -1:
# Connection failed, notify primary master node
logging.error('Connection to storage node %s failed' %(addr,))
conn = self.storage.master_conn
conn = self.app.master_conn
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
node_list = [(STORAGE_NODE_TYPE, addr[0], addr[1], node.getUUID(),
TEMPORARILY_DOWN_STATE),]
node_list = [(STORAGE_NODE_TYPE, addr[0], addr[1],
node.getUUID(), TEMPORARILY_DOWN_STATE),]
p.notifyNodeInformation(msg_id, node_list)
self.storage.queue.put((None, msg_id, conn, p), True)
conn.addPacket(p)
finally:
conn.unlock()
return None
logging.info('connected to storage node %s' %(addr,))
return conn
def _dropConnection(self,):
def _dropConnection(self):
"""Drop a connection."""
pass
raise NotImplementedError
def _createNodeConnection(self, node):
"""Create a connection to a given storage node."""
......@@ -81,14 +92,14 @@ class ConnectionManager(object):
return self.connection_dict[node.getUUID()]
if self.pool_size > self.max_pool_size:
# must drop some unused connections
self.dropConnection()
self._dropConnection()
conn = self._initNodeConnection(node)
if conn is None:
return None
# add node to node manager
if self.storage.nm.getNodeByServer(node.getServer()) is None:
if self.app.nm.getNodeByServer(node.getServer()) is None:
n = StorageNode(node.getServer())
self.storage.nm.add(n)
self.app.nm.add(n)
self.connection_dict[node.getUUID()] = conn
return conn
finally:
......@@ -113,8 +124,7 @@ class ConnectionManager(object):
class Application(ThreadingMixIn, object):
"""The client node application."""
def __init__(self, master_nodes, name, em, dispatcher, message_queue,
request_queue, **kw):
def __init__(self, master_nodes, name, em, dispatcher, request_queue, **kw):
logging.basicConfig(level = logging.DEBUG)
logging.debug('master node address are %s' %(master_nodes,))
# Internal Attributes common to all thread
......@@ -122,9 +132,8 @@ class Application(ThreadingMixIn, object):
self.em = em
self.dispatcher = dispatcher
self.nm = NodeManager()
self.cm = ConnectionManager(self)
self.cp = ConnectionPool(self)
self.pt = None
self.queue = message_queue
self.request_queue = request_queue
self.primary_master_node = None
self.master_node_list = master_nodes.split(' ')
......@@ -167,32 +176,35 @@ class Application(ThreadingMixIn, object):
break
self.uuid = uuid
def _waitMessage(self,block=1):
"""Wait for a message returned by dispatcher in queues."""
# First check if there are global messages and execute them
global_message = None
def getQueue(self):
return self.local_var.__dict__.setdefault('queue', Queue(5))
def _waitMessage(self, target_conn = None, msg_id = None):
"""Wait for a message returned by the dispatcher in queues."""
global_queue = self.request_queue
local_queue = self.getQueue()
while 1:
try:
global_message = self.request_queue.get_nowait()
conn, packet = global_queue.get_nowait()
conn.handler.dispatch(conn, packet)
except Empty:
if msg_id is None:
try:
conn, packet = local_queue.get_nowait()
except Empty:
break
if global_message is not None:
global_message[0].handler.dispatch(global_message[0], global_message[1])
# Next get messages we are waiting for
if not hasattr(self.local_var, 'tmp_q'):
return
message = None
if block:
message = self.local_var.tmp_q.get(True, None)
else:
# we don't want to block until we got a message
conn, packet = local_queue.get()
conn.lock()
try:
message = self.local_var.tmp_q.get_nowait()
except Empty:
pass
if message is not None:
message[0].handler.dispatch(message[0], message[1])
conn.handler.dispatch(conn, packet)
finally:
conn.unlock()
if target_conn is conn and msg_id == packet.getId():
break
def registerDB(self, db, limit):
self._db = db
......@@ -207,12 +219,18 @@ class Application(ThreadingMixIn, object):
# from asking too many time new oid one by one
# from master node
conn = self.master_conn
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askNewOIDs(msg_id, 25)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
self._waitMessage()
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
finally:
conn.unlock()
self._waitMessage(conn, msg_id)
if len(self.new_oid_list) <= 0:
raise NEOStorageError('new_oid failed')
return self.new_oid_list.pop()
......@@ -229,7 +247,7 @@ class Application(ThreadingMixIn, object):
finally:
self._cache_lock_release()
# history return serial, so use it
hist = self.history(oid, length=1, object_only=1)
hist = self.history(oid, length = 1, object_only = 1)
if len(hist) == 0:
raise NEOStorageNotFoundError()
if hist[0] != oid:
......@@ -237,7 +255,7 @@ class Application(ThreadingMixIn, object):
return hist[1][0][0]
def _load(self, oid, serial=INVALID_TID, tid=INVALID_TID, cache=0):
def _load(self, oid, serial = INVALID_TID, tid = INVALID_TID, cache = 0):
"""Internal method which manage load ,loadSerial and loadBefore."""
partition_id = u64(oid) % self.num_partitions
# Only used up to date node for retrieving object
......@@ -253,21 +271,27 @@ class Application(ThreadingMixIn, object):
for cell in cell_list:
logging.debug('trying to load %s from %s',
dump(oid), dump(cell.getUUID()))
conn = self.cm.getConnForNode(cell)
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askObject(msg_id, oid, serial, tid)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.asked_object = 0
finally:
conn.unlock()
# Wait for answer
self.local_var.asked_object = 0
# asked object retured value are :
# -1 : oid not found
# other : data
self._waitMessage()
self._waitMessage(conn, msg_id)
if self.local_var.asked_object == -1:
# OID not found
# XXX either try with another node, either raise error here
......@@ -275,16 +299,17 @@ class Application(ThreadingMixIn, object):
continue
# Check data
noid, start_serial, end_serial, compression, checksum, data = self.local_var.asked_object
noid, start_serial, end_serial, compression, checksum, data \
= self.local_var.asked_object
if noid != oid:
# Oops, try with next node
logging.error('got wrong oid %s instead of %s from node %s' \
%(noid, oid, storage_node.getServer()))
% (noid, oid, cell.getServer()))
continue
elif checksum != makeChecksum(data):
# Check checksum.
logging.error('wrong checksum from node %s for oid %s' \
%(storage_node.getServer(), oid))
% (cell.getServer(), oid))
continue
else:
# Everything looks alright.
......@@ -352,13 +377,19 @@ class Application(ThreadingMixIn, object):
if tid is None:
self.tid = None
conn = self.master_conn
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askNewTID(msg_id)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
finally:
conn.unlock()
# Wait for answer
self._waitMessage()
self._waitMessage(conn, msg_id)
if self.tid is None:
raise NEOStorageError('tpc_begin failed')
else:
......@@ -375,27 +406,34 @@ class Application(ThreadingMixIn, object):
dump(oid), dump(serial))
# Find which storage node to use
partition_id = u64(oid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, False)
if len(storage_node_list) == 0:
cell_list = self.pt.getCellList(partition_id, False)
if len(cell_list) == 0:
# FIXME must wait for cluster to be ready
raise NEOStorageError
# Store data on each node
ddata = dumps(data)
compressed_data = compress(ddata)
checksum = makeChecksum(compressed_data)
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askStoreObject(msg_id, oid, serial, 1, checksum, compressed_data, self.tid)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
p.askStoreObject(msg_id, oid, serial, 1,
checksum, compressed_data, self.tid)
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.txn_object_stored = 0
finally:
conn.unlock()
# Check we don't get any conflict
self.txn_object_stored = 0
self._waitMessage()
self._waitMessage(conn, msg_id)
if self.txn_object_stored[0] == -1:
if self.txn_data_dict.has_key(oid):
# One storage already accept the object, is it normal ??
......@@ -423,22 +461,29 @@ class Application(ThreadingMixIn, object):
oid_list = self.txn_data_dict.keys()
# Store data on each node
partition_id = u64(self.tid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askStoreTransaction(msg_id, self.tid, user, desc, ext, oid_list)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
p.askStoreTransaction(msg_id, self.tid, user, desc, ext,
oid_list)
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.txn_voted == 0
self._waitMessage()
finally:
conn.unlock()
self._waitMessage(conn, msg_id)
if self.txn_voted != 1:
raise NEOStorageError('tpc_vote failed')
def _clear_txn(self):
"""Clear some transaction parameters."""
self.tid = None
......@@ -447,61 +492,80 @@ class Application(ThreadingMixIn, object):
self.txn_voted = 0
self.txn_finished = 0
def tpc_abort(self, transaction):
"""Abort current transaction."""
if transaction is not self.txn:
return
# Abort txn in node where objects were stored
aborted_node = {}
aborted_node_set = set()
for oid in self.txn_data_dict.iterkeys():
partition_id = u64(oid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list:
if not aborted_node.has_key(storage_node):
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
for cell in cell_list:
if cell.getNode() not in aborted_node_set:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.abortTransaction(msg_id, self.tid)
self.queue.put((None, msg_id, conn, p), True)
aborted_node[storage_node] = 1
conn.addPacket(p)
finally:
conn.unlock()
aborted_node_set.add(cell.getNode())
# Abort in nodes where transaction was stored
partition_id = u64(self.tid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list:
if not aborted_node.has_key(storage_node):
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
for cell in cell_list:
if cell.getNode() not in aborted_node_set:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.abortTransaction(msg_id, self.tid)
self.queue.put((None, msg_id, conn, p), True)
conn.addPacket(p)
finally:
conn.unlock()
self._clear_txn()
aborted_node_set.add(cell.getNode())
self._clear_txn()
def tpc_finish(self, transaction, f=None):
"""Finish current transaction."""
if self.txn is not transaction:
return
# Call function given by ZODB
if f is not None:
f(self.tid)
# Call finish on master
oid_list = self.txn_data_dict.keys()
conn = self.master_conn
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.finishTransaction(msg_id, oid_list, self.tid)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
conn.addPacket(p)
conn.expectMessage(msg_id, additional_timeout = 300)
self.dispatcher.register(conn, msg_id, self.getQueue())
finally:
conn.unlock()
# Wait for answer
self._waitMessage()
self._waitMessage(conn, msg_id)
if self.txn_finished != 1:
raise NEOStorageError('tpc_finish failed')
......@@ -522,25 +586,33 @@ class Application(ThreadingMixIn, object):
if transaction_id is not self.txn:
raise StorageTransactionError(self, transaction_id)
# First get transaction information from master node
# First get transaction information from a storage node.
partition_id = u64(transaction_id) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
shuffle(cell_list)
for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askTransactionInformation(msg_id, transaction_id)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
# Wait for answer
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.txn_info = 0
self._waitMessage()
finally:
conn.unlock()
# Wait for answer
self._waitMessage(conn, msg_id)
if self.local_var.txn_info == -1:
# Tid not found, try with next node
continue
elif isinstance(self.local_var.txn_info, {}):
elif isinstance(self.local_var.txn_info, dict):
break
else:
raise NEOStorageError('undo failed')
......@@ -589,21 +661,26 @@ class Application(ThreadingMixIn, object):
# See FileStorage.py for explanation
last = first - last
# First get list of transaction from all storage node
# First get a list of transactions from all storage nodes.
storage_node_list = [x for x in self.pt.getNodeList() if x.getState() \
in (UP_TO_DATE_STATE, FEEDING_STATE)]
self.local_var.node_tids = {}
self.local_var.tmp_q = Queue(len(storage_node_list))
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
conn = self.cp.getConnForNode(storage_node)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askTIDs(msg_id, first, last)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
conn.addPacket(p)
finally:
conn.unlock()
# Wait for answer from all storages
# Wait for answers from all storages.
# FIXME this is a busy loop.
while True:
self._waitMessage()
if len(self.local_var.node_tids) == len(storage_node_list):
......@@ -620,23 +697,31 @@ class Application(ThreadingMixIn, object):
undo_info = []
for tid in ordered_tids:
partition_id = u64(tid) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
shuffle(cell_list)
for cell in cell_list:
conn = self.cp.getConnForNode(storage_node)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askTransactionInformation(msg_id, tid)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
# Wait for answer
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.txn_info = 0
self._waitMessage()
finally:
conn.unlock()
# Wait for answer
self._waitMessage(conn, msg_id)
if self.local_var.txn_info == -1:
# TID not found, go on with next node
continue
elif isinstance(self.local_var.txn_info, {}):
elif isinstance(self.local_var.txn_info, dict):
break
# Filter result if needed
......@@ -648,7 +733,7 @@ class Application(ThreadingMixIn, object):
# Append to returned list
self.local_var.txn_info.pop("oids")
undo_info.append(self.local_var.txn_info)
if len(undo_info) >= last-first:
if len(undo_info) >= last - first:
break
return undo_info
......@@ -657,26 +742,35 @@ class Application(ThreadingMixIn, object):
def history(self, oid, version, length=1, filter=None, object_only=0):
# Get history informations for object first
partition_id = u64(oid) % self.num_partitions
storage_node_list = [x for x in self.pt.getCellList(partition_id, True) \
if x.getState() == UP_TO_DATE_STATE]
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
shuffle(cell_list)
for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askObjectHistory(msg_id, oid, length)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.history = None
self._waitMessage()
finally:
conn.unlock()
self._waitMessage(conn, msg_id)
if self.local_var.history == -1:
# Not found, go on with next node
continue
if self.local_var.history[0] != oid:
# Got history for wrong oid
continue
if not isinstance(self.local_var.history, {}):
if not isinstance(self.local_var.history, dict):
raise NEOStorageError('history failed')
if object_only:
# Use by getSerial
......@@ -686,23 +780,32 @@ class Application(ThreadingMixIn, object):
history_list = []
for serial, size in self.local_var.hisory[1]:
partition_id = u64(serial) % self.num_partitions
storage_node_list = self.pt.getCellList(partition_id, True)
for storage_node in storage_node_list:
conn = self.cm.getConnForNode(storage_node)
cell_list = self.pt.getCellList(partition_id, True)
shuffle(cell_list)
for cell in cell_list:
conn = self.cp.getConnForNode(cell)
if conn is None:
continue
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askTransactionInformation(msg_id, serial)
self.local_var.tmp_q = Queue(1)
self.queue.put((self.local_var.tmp_q, msg_id, conn, p), True)
# Wait for answer
conn.addPacket(p)
conn.expectMessage(msg_id)
self.dispatcher.register(conn, msg_id, self.getQueue())
self.local_var.txn_info = None
self._waitMessage()
finally:
conn.unlock()
# Wait for answer
self._waitMessage(conn, msg_id)
if self.local_var.txn_info == -1:
# TID not found
continue
if isinstance(self.local_var.txn_info, {}):
if isinstance(self.local_var.txn_info, dict):
break
# create history dict
......
......@@ -2,7 +2,7 @@ from threading import Thread
from Queue import Empty, Queue
from neo.protocol import PING, Packet, CLIENT_NODE_TYPE, FINISH_TRANSACTION
from neo.connection import ClientConnection
from neo.connection import MTClientConnection
from neo.node import MasterNode
from time import time
......@@ -11,9 +11,8 @@ import logging
class Dispatcher(Thread):
"""Dispatcher class use to redirect request to thread."""
def __init__(self, em, message_queue, request_queue, **kw):
def __init__(self, em, request_queue, **kw):
Thread.__init__(self, **kw)
self._message_queue = message_queue
self._request_queue = request_queue
self.em = em
# Queue of received packet that have to be processed
......@@ -29,51 +28,42 @@ class Dispatcher(Thread):
# First check if we receive any new message from other node
m = None
try:
self.em.poll(0.02)
self.em.poll()
except KeyError:
# This happen when there is no connection
logging.error('Dispatcher, run, poll returned a KeyError')
while 1:
try:
conn, packet = self.message.get_nowait()
except Empty:
break
# Send message to waiting thread
key = "%s-%s" %(conn.getUUID(),packet.getId())
key = (conn.getUUID(), packet.getId())
#logging.info('dispatcher got packet %s' %(key,))
if self.message_table.has_key(key):
tmp_q = self.message_table.pop(key)
tmp_q.put((conn, packet), True)
if key in self.message_table:
queue = self.message_table.pop(key)
queue.put((conn, packet))
else:
#conn, packet = self.message
method_type = packet.getType()
if method_type == PING:
# must answer with no delay
conn.lock()
try:
conn.addPacket(Packet().pong(packet.getId()))
finally:
conn.unlock()
else:
# put message in request queue
self._request_queue.put((conn, packet), True)
self._request_queue.put((conn, packet))
# Then check if a client ask me to send a message
try:
m = self._message_queue.get_nowait()
if m is not None:
tmp_q, msg_id, conn, p = m
conn.addPacket(p)
if tmp_q is not None:
# We expect an answer
key = "%s-%s" %(conn.getUUID(), msg_id)
self.message_table[key] = tmp_q
# XXX this is a hack. Probably queued tasks themselves
# should specify the timeout values.
if p.getType() == FINISH_TRANSACTION:
# Finish Transaction may take a lot of time when
# many objects are committed at a time.
conn.expectMessage(msg_id, additional_timeout = 300)
else:
conn.expectMessage(msg_id)
except Empty:
continue
def register(self, conn, msg_id, queue):
"""Register an expectation for a reply. Thanks to GIL, it is
safe not to use a lock here."""
key = (conn.getUUID(), msg_id)
self.message_table[key] = queue
def connectToPrimaryMasterNode(self, app):
"""Connect to a primary master node.
......@@ -87,7 +77,7 @@ class Dispatcher(Thread):
master_index = 0
conn = None
# Make application execute remaining message if any
app._waitMessage(block=0)
app._waitMessage()
handler = ClientEventHandler(app, app.dispatcher)
while 1:
if app.pt is not None and app.pt.operational():
......@@ -101,18 +91,24 @@ class Dispatcher(Thread):
else:
addr, port = app.primary_master_node.getServer()
# Request Node Identification
conn = ClientConnection(app.em, handler, (addr, port))
conn = MTClientConnection(app.em, handler, (addr, port))
if app.nm.getNodeByServer((addr, port)) is None:
n = MasterNode(server = (addr, port))
app.nm.add(n)
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, app.uuid,
'0.0.0.0', 0, app.name)
# Send message
conn.addPacket(p)
conn.expectMessage(msg_id)
app.local_var.tmp_q = Queue(1)
finally:
conn.unlock()
# Wait for answer
while 1:
try:
......@@ -126,6 +122,8 @@ class Dispatcher(Thread):
try:
conn, packet = self.message.get_nowait()
method_type = packet.getType()
conn.lock()
try:
if method_type == PING:
# Must answer with no delay
conn.addPacket(Packet().pong(packet.getId()))
......@@ -133,6 +131,8 @@ class Dispatcher(Thread):
else:
# Process message by handler
conn.handler.dispatch(conn, packet)
finally:
conn.unlock()
except Empty:
pass
......@@ -156,15 +156,6 @@ class Dispatcher(Thread):
# Connected to primary master node
break
# If nothing, check if we have new message to send
try:
m = self._message_queue.get_nowait()
if m is not None:
tmp_q, msg_id, conn, p = m
conn.addPacket(p)
except Empty:
continue
logging.info("connected to primary master node %s %d" %app.primary_master_node.getServer())
app.master_conn = conn
self.connecting_to_master_node = 0
import logging
from neo.handler import EventHandler
from neo.connection import ClientConnection
from neo.connection import MTClientConnection
from neo.protocol import Packet, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
INVALID_UUID, RUNNING_STATE, TEMPORARILY_DOWN_STATE, BROKEN_STATE
......@@ -74,7 +74,7 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection
app.cm.removeConnection(node)
app.cp.removeConnection(node)
EventHandler.connectionClosed(self, conn)
def timeoutExpired(self, conn):
......@@ -98,7 +98,7 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection
app.cm.removeConnection(node)
app.cp.removeConnection(node)
EventHandler.timeoutExpired(self, conn)
def peerBroken(self, conn):
......@@ -122,12 +122,12 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection
app.cm.removeConnection(node)
app.cp.removeConnection(node)
EventHandler.peerBroken(self, conn)
def handleNotReady(self, conn, packet, message):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
app.node_not_ready = 1
else:
......@@ -136,7 +136,7 @@ class ClientEventHandler(EventHandler):
def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port,
num_partitions, num_replicas):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
node = app.nm.getNodeByServer(conn.getAddress())
# It can be eiter a master node or a storage node
......@@ -162,12 +162,18 @@ class ClientEventHandler(EventHandler):
app.pt = PartitionTable(num_partitions, num_replicas)
app.num_partitions = num_partitions
app.num_replicas = num_replicas
# Ask a primary master.
conn.lock()
try:
msg_id = conn.getNextId()
p = Packet()
p.askPrimaryMaster(msg_id)
# send message to dispatcher
app.queue.put((app.local_var.tmp_q, msg_id, conn, p), True)
conn.addPacket(p)
conn.expectMessage(msg_id)
app.dispatcher.register(conn, msg_id, app.getQueue())
finally:
conn.unlock()
elif node_type == STORAGE_NODE_TYPE:
app.storage_node = node
else:
......@@ -176,7 +182,7 @@ class ClientEventHandler(EventHandler):
# Master node handler
def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, known_master_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
......@@ -220,7 +226,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleSendPartitionTable(self, conn, packet, ptid, row_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
......@@ -249,7 +255,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleNotifyNodeInformation(self, conn, packet, node_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
......@@ -309,7 +315,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
nm = app.nm
pt = app.pt
......@@ -344,14 +350,14 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleAnswerNewTID(self, conn, packet, tid):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
app.tid = tid
else:
self.handleUnexpectedPacket(conn, packet)
def handleNotifyTransactionFinished(self, conn, packet, tid):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
if tid != app.tid:
app.txn_finished = -1
......@@ -361,7 +367,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleInvalidateObjects(self, conn, packet, oid_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
app._cache_lock_acquire()
try:
......@@ -379,7 +385,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleAnswerNewOIDs(self, conn, packet, oid_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
app.new_oid_list = oid_list
app.new_oid_list.reverse()
......@@ -387,7 +393,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleStopOperation(self, conn, packet):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
logging.critical("master node ask to stop operation")
else:
self.handleUnexpectedPacket(conn, packet)
......@@ -396,7 +402,7 @@ class ClientEventHandler(EventHandler):
# Storage node handler
def handleAnswerObject(self, conn, packet, oid, start_serial, end_serial, compression,
checksum, data):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
app.local_var.asked_object = (oid, start_serial, end_serial, compression,
checksum, data)
......@@ -404,7 +410,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleAnswerStoreObject(self, conn, packet, conflicting, oid, serial):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
if conflicting:
app.txn_object_stored = -1, serial
......@@ -414,14 +420,14 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleAnswerStoreTransaction(self, conn, packet, tid):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
app.txn_voted = 1
else:
self.handleUnexpectedPacket(conn, packet)
def handleAnswerTransactionInformation(self, conn, packet, tid, user, desc, oid_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
# transaction information are returned as a dict
info = {}
......@@ -435,7 +441,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleAnswerObjectHistory(self, conn, packet, oid, history_list):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
# history_list is a list of tuple (serial, size)
self.history = oid, history_list
......@@ -443,7 +449,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleOidNotFound(self, conn, packet, message):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
# This can happen either when :
# - loading an object
......@@ -454,7 +460,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet)
def handleTidNotFound(self, conn, packet, message):
if isinstance(conn, ClientConnection):
if isinstance(conn, MTClientConnection):
app = self.app
# This can happen when requiring txn informations
app.local_var.txn_info = -1
......
import socket
import errno
import logging
from threading import RLock
from neo.protocol import Packet, ProtocolError
from neo.event import IdleEvent
......@@ -46,6 +47,12 @@ class BaseConnection(object):
def getUUID(self):
return None
def acquire(self, block = 1):
return 1
def release(self):
pass
class ListeningConnection(BaseConnection):
"""A listen connection."""
def __init__(self, event_manager, handler, addr = None, **kw):
......@@ -305,3 +312,31 @@ class ClientConnection(Connection):
class ServerConnection(Connection):
"""A connection from a remote node to this node."""
pass
class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection."""
def __init__(self, *args, **kwargs):
super(MTClientConnection, self).__init__(*args, **kwargs)
lock = RLock()
self.acquire = lock.acquire
self.release = lock.release
def lock(self, blocking = 1):
return self.acquire(blocking = blocking)
def unlock(self):
self.release()
class MTServerConnection(ServerConnection):
"""A Multithread-safe version of ServerConnection."""
def __init__(self, *args, **kwargs):
super(MTClientConnection, self).__init__(*args, **kwargs)
lock = RLock()
self.acquire = lock.acquire
self.release = lock.release
def lock(self, blocking = 1):
return self.acquire(blocking = blocking)
def unlock(self):
self.release()
r"""This is an epoll(4) interface available in Linux 2.6. This requires
ctypes <http://python.net/crew/theller/ctypes/>."""
from ctypes import cdll, Union, Structure, \
c_void_p, c_int, byref
try:
from ctypes import c_uint32, c_uint64
except ImportError:
from ctypes import c_uint, c_ulonglong
c_uint32 = c_uint
c_uint64 = c_ulonglong
from os import close
from errno import EINTR
libc = cdll.LoadLibrary('libc.so.6')
epoll_create = libc.epoll_create
epoll_wait = libc.epoll_wait
epoll_ctl = libc.epoll_ctl
errno = c_int.in_dll(libc, 'errno')
EPOLLIN = 0x001
EPOLLPRI = 0x002
EPOLLOUT = 0x004
EPOLLRDNORM = 0x040
EPOLLRDBAND = 0x080
EPOLLWRNORM = 0x100
EPOLLWRBAND = 0x200
EPOLLMSG = 0x400
EPOLLERR = 0x008
EPOLLHUP = 0x010
EPOLLONESHOT = (1 << 30)
EPOLLET = (1 << 31)
EPOLL_CTL_ADD = 1
EPOLL_CTL_DEL = 2
EPOLL_CTL_MOD = 3
class epoll_data(Union):
_fields_ = [("ptr", c_void_p),
("fd", c_int),
("u32", c_uint32),
("u64", c_uint64)]
class epoll_event(Structure):
_fields_ = [("events", c_uint32),
("data", epoll_data)]
class Epoll(object):
efd = -1
def __init__(self):
self.efd = epoll_create(10)
if self.efd == -1:
raise OSError(errno.value, 'epoll_create failed')
self.maxevents = 1024 # XXX arbitrary
epoll_event_array = epoll_event * self.maxevents
self.events = epoll_event_array()
def poll(self, timeout = 1):
timeout *= 1000
timeout = int(timeout)
while 1:
n = epoll_wait(self.efd, byref(self.events), self.maxevents,
timeout)
if n == -1:
e = errno.value
if e == EINTR:
continue
else:
raise OSError(e, 'epoll_wait failed')
else:
readable_fd_list = []
writable_fd_list = []
for i in xrange(n):
ev = self.events[i]
if ev.events & (EPOLLIN | EPOLLERR | EPOLLHUP):
readable_fd_list.append(int(ev.data.fd))
elif ev.events & (EPOLLOUT | EPOLLERR | EPOLLHUP):
writable_fd_list.append(int(ev.data.fd))
return readable_fd_list, writable_fd_list
def register(self, fd):
ev = epoll_event()
ev.data.fd = fd
ret = epoll_ctl(self.efd, EPOLL_CTL_ADD, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def modify(self, fd, readable, writable):
ev = epoll_event()
ev.data.fd = fd
events = 0
if readable:
events |= EPOLLIN
if writable:
events |= EPOLLOUT
ev.events = events
ret = epoll_ctl(self.efd, EPOLL_CTL_MOD, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def unregister(self, fd):
ev = epoll_event()
ret = epoll_ctl(self.efd, EPOLL_CTL_DEL, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def __del__(self):
if self.efd >= 0:
close(self.efd)
......@@ -3,6 +3,7 @@ from select import select
from time import time
from neo.protocol import Packet
from neo.epoll import Epoll
class IdleEvent(object):
"""This class represents an event called when a connection is waiting for
......@@ -28,26 +29,35 @@ class IdleEvent(object):
def __call__(self, t):
conn = self._conn
if t > self._critical_time:
conn.lock()
try:
logging.info('timeout with %s:%d', *(conn.getAddress()))
conn.getHandler().timeoutExpired(conn)
conn.close()
return True
finally:
conn.unlock()
elif t > self._time:
conn.lock()
try:
if self._additional_timeout > 5:
self._additional_timeout -= 5
conn.expectMessage(self._id, 5, self._additional_timeout)
# Start a keep-alive packet.
logging.info('sending a ping to %s:%d', *(conn.getAddress()))
logging.info('sending a ping to %s:%d',
*(conn.getAddress()))
msg_id = conn.getNextId()
conn.addPacket(Packet().ping(msg_id))
conn.expectMessage(msg_id, 5, 0)
else:
conn.expectMessage(self._id, self._additional_timeout, 0)
return True
finally:
conn.unlock()
return False
class EventManager(object):
"""This class manages connections and events."""
class SelectEventManager(object):
"""This class manages connections and events based on select(2)."""
def __init__(self):
self.connection_dict = {}
......@@ -71,13 +81,21 @@ class EventManager(object):
timeout)
for s in rlist:
conn = self.connection_dict[s]
conn.lock()
try:
conn.readable()
finally:
conn.unlock()
for s in wlist:
# This can fail, if a connection is closed in readable().
try:
conn = self.connection_dict[s]
conn.lock()
try:
conn.writable()
finally:
conn.unlock()
except KeyError:
pass
......@@ -120,3 +138,102 @@ class EventManager(object):
def removeWriter(self, conn):
self.writer_set.discard(conn.getSocket())
class EpollEventManager(object):
"""This class manages connections and events based on epoll(5)."""
def __init__(self):
self.connection_dict = {}
self.reader_set = set([])
self.writer_set = set([])
self.event_list = []
self.prev_time = time()
self.epoll = Epoll()
def getConnectionList(self):
return self.connection_dict.values()
def register(self, conn):
fd = conn.getSocket().fileno()
self.connection_dict[fd] = conn
self.epoll.register(fd)
def unregister(self, conn):
fd = conn.getSocket().fileno()
self.epoll.unregister(fd)
del self.connection_dict[fd]
def poll(self, timeout = 1):
rlist, wlist = self.epoll.poll(timeout)
for fd in rlist:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.readable()
finally:
conn.unlock()
for fd in wlist:
# This can fail, if a connection is closed in readable().
try:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.writable()
finally:
conn.unlock()
except KeyError:
pass
# Check idle events. Do not check them out too often, because this
# is somehow heavy.
event_list = self.event_list
if event_list:
t = time()
if t - self.prev_time >= 1:
self.prev_time = t
event_list.sort(key = lambda event: event.getTime())
while event_list:
event = event_list[0]
if event(t):
try:
event_list.remove(event)
except ValueError:
pass
else:
break
def addIdleEvent(self, event):
self.event_list.append(event)
def removeIdleEvent(self, event):
try:
self.event_list.remove(event)
except ValueError:
pass
def addReader(self, conn):
fd = conn.getSocket().fileno()
if fd not in self.reader_set:
self.reader_set.add(fd)
self.epoll.modify(fd, 1, fd in self.writer_set)
def removeReader(self, conn):
fd = conn.getSocket().fileno()
if fd in self.reader_set:
self.reader_set.remove(fd)
self.epoll.modify(fd, 0, fd in self.writer_set)
def addWriter(self, conn):
fd = conn.getSocket().fileno()
if fd not in self.writer_set:
self.writer_set.add(fd)
self.epoll.modify(fd, fd in self.reader_set, 1)
def removeWriter(self, conn):
fd = conn.getSocket().fileno()
if fd in self.writer_set:
self.writer_set.remove(fd)
self.epoll.modify(fd, fd in self.reader_set, 0)
# Default to EpollEventManager.
EventManager = EpollEventManager
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment