Commit 8c4c5df6 authored by Yoshinori Okuji's avatar Yoshinori Okuji

This is a significant rewrite of the client node, using multithreading-aware connections and epoll.

git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@156 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent a89322a4
...@@ -19,6 +19,16 @@ TODO ...@@ -19,6 +19,16 @@ TODO
- Replication. - Replication.
Requirements
- Python 2.4 or later
- ctypes http://python.net/crew/theller/ctypes/
- MySQLdb http://sourceforge.net/projects/mysql-python
- Zope 2.8 or later
Installation Installation
1. In zope: 1. In zope:
...@@ -50,4 +60,4 @@ Installation ...@@ -50,4 +60,4 @@ Installation
from neo.client.NEOStorare import NEOStorage from neo.client.NEOStorare import NEOStorage
s = NEOStorage(master_addr="127.0.0.1", master_port=10010, name="main") s = NEOStorage(master_addr="127.0.0.1", master_port=10010, name="main")
... ...
\ No newline at end of file
...@@ -28,22 +28,19 @@ class NEOStorage(BaseStorage.BaseStorage, ...@@ -28,22 +28,19 @@ class NEOStorage(BaseStorage.BaseStorage,
l = Lock() l = Lock()
self._txn_lock_acquire = l.acquire self._txn_lock_acquire = l.acquire
self._txn_lock_release = l.release self._txn_lock_release = l.release
# Create two queue for message between thread and dispatcher # Create a queue for message between thread and dispatcher
# - message queue is for message that has to be send to other node
# through the dispatcher
# - request queue is for message receive from other node which have to # - request queue is for message receive from other node which have to
# be processed # be processed
message_queue = Queue()
request_queue = Queue() request_queue = Queue()
# Create the event manager # Create the event manager
em = EventManager() em = EventManager()
# Create dispatcher thread # Create dispatcher thread
dispatcher = Dispatcher(em, message_queue, request_queue) dispatcher = Dispatcher(em, request_queue)
dispatcher.setDaemon(True) dispatcher.setDaemon(True)
# Import here to prevent recursive import # Import here to prevent recursive import
from neo.client.app import Application from neo.client.app import Application
self.app = Application(master_nodes, name, em, dispatcher, self.app = Application(master_nodes, name, em, dispatcher,
message_queue, request_queue) request_queue)
# Connect to primary master node # Connect to primary master node
dispatcher.connectToPrimaryMasterNode(self.app) dispatcher.connectToPrimaryMasterNode(self.app)
# Start dispatcher # Start dispatcher
......
This diff is collapsed.
...@@ -2,7 +2,7 @@ from threading import Thread ...@@ -2,7 +2,7 @@ from threading import Thread
from Queue import Empty, Queue from Queue import Empty, Queue
from neo.protocol import PING, Packet, CLIENT_NODE_TYPE, FINISH_TRANSACTION from neo.protocol import PING, Packet, CLIENT_NODE_TYPE, FINISH_TRANSACTION
from neo.connection import ClientConnection from neo.connection import MTClientConnection
from neo.node import MasterNode from neo.node import MasterNode
from time import time from time import time
...@@ -11,9 +11,8 @@ import logging ...@@ -11,9 +11,8 @@ import logging
class Dispatcher(Thread): class Dispatcher(Thread):
"""Dispatcher class use to redirect request to thread.""" """Dispatcher class use to redirect request to thread."""
def __init__(self, em, message_queue, request_queue, **kw): def __init__(self, em, request_queue, **kw):
Thread.__init__(self, **kw) Thread.__init__(self, **kw)
self._message_queue = message_queue
self._request_queue = request_queue self._request_queue = request_queue
self.em = em self.em = em
# Queue of received packet that have to be processed # Queue of received packet that have to be processed
...@@ -29,51 +28,42 @@ class Dispatcher(Thread): ...@@ -29,51 +28,42 @@ class Dispatcher(Thread):
# First check if we receive any new message from other node # First check if we receive any new message from other node
m = None m = None
try: try:
self.em.poll(0.02) self.em.poll()
except KeyError: except KeyError:
# This happen when there is no connection # This happen when there is no connection
logging.error('Dispatcher, run, poll returned a KeyError') logging.error('Dispatcher, run, poll returned a KeyError')
while 1: while 1:
try: try:
conn, packet = self.message.get_nowait() conn, packet = self.message.get_nowait()
except Empty: except Empty:
break break
# Send message to waiting thread # Send message to waiting thread
key = "%s-%s" %(conn.getUUID(),packet.getId()) key = (conn.getUUID(), packet.getId())
#logging.info('dispatcher got packet %s' %(key,)) #logging.info('dispatcher got packet %s' %(key,))
if self.message_table.has_key(key): if key in self.message_table:
tmp_q = self.message_table.pop(key) queue = self.message_table.pop(key)
tmp_q.put((conn, packet), True) queue.put((conn, packet))
else: else:
#conn, packet = self.message #conn, packet = self.message
method_type = packet.getType() method_type = packet.getType()
if method_type == PING: if method_type == PING:
# must answer with no delay # must answer with no delay
conn.addPacket(Packet().pong(packet.getId())) conn.lock()
try:
conn.addPacket(Packet().pong(packet.getId()))
finally:
conn.unlock()
else: else:
# put message in request queue # put message in request queue
self._request_queue.put((conn, packet), True) self._request_queue.put((conn, packet))
# Then check if a client ask me to send a message def register(self, conn, msg_id, queue):
try: """Register an expectation for a reply. Thanks to GIL, it is
m = self._message_queue.get_nowait() safe not to use a lock here."""
if m is not None: key = (conn.getUUID(), msg_id)
tmp_q, msg_id, conn, p = m self.message_table[key] = queue
conn.addPacket(p)
if tmp_q is not None:
# We expect an answer
key = "%s-%s" %(conn.getUUID(), msg_id)
self.message_table[key] = tmp_q
# XXX this is a hack. Probably queued tasks themselves
# should specify the timeout values.
if p.getType() == FINISH_TRANSACTION:
# Finish Transaction may take a lot of time when
# many objects are committed at a time.
conn.expectMessage(msg_id, additional_timeout = 300)
else:
conn.expectMessage(msg_id)
except Empty:
continue
def connectToPrimaryMasterNode(self, app): def connectToPrimaryMasterNode(self, app):
"""Connect to a primary master node. """Connect to a primary master node.
...@@ -87,7 +77,7 @@ class Dispatcher(Thread): ...@@ -87,7 +77,7 @@ class Dispatcher(Thread):
master_index = 0 master_index = 0
conn = None conn = None
# Make application execute remaining message if any # Make application execute remaining message if any
app._waitMessage(block=0) app._waitMessage()
handler = ClientEventHandler(app, app.dispatcher) handler = ClientEventHandler(app, app.dispatcher)
while 1: while 1:
if app.pt is not None and app.pt.operational(): if app.pt is not None and app.pt.operational():
...@@ -101,18 +91,24 @@ class Dispatcher(Thread): ...@@ -101,18 +91,24 @@ class Dispatcher(Thread):
else: else:
addr, port = app.primary_master_node.getServer() addr, port = app.primary_master_node.getServer()
# Request Node Identification # Request Node Identification
conn = ClientConnection(app.em, handler, (addr, port)) conn = MTClientConnection(app.em, handler, (addr, port))
if app.nm.getNodeByServer((addr, port)) is None: if app.nm.getNodeByServer((addr, port)) is None:
n = MasterNode(server = (addr, port)) n = MasterNode(server = (addr, port))
app.nm.add(n) app.nm.add(n)
msg_id = conn.getNextId()
p = Packet() conn.lock()
p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, app.uuid, try:
'0.0.0.0', 0, app.name) msg_id = conn.getNextId()
# Send message p = Packet()
conn.addPacket(p) p.requestNodeIdentification(msg_id, CLIENT_NODE_TYPE, app.uuid,
conn.expectMessage(msg_id) '0.0.0.0', 0, app.name)
app.local_var.tmp_q = Queue(1)
# Send message
conn.addPacket(p)
conn.expectMessage(msg_id)
finally:
conn.unlock()
# Wait for answer # Wait for answer
while 1: while 1:
try: try:
...@@ -124,15 +120,19 @@ class Dispatcher(Thread): ...@@ -124,15 +120,19 @@ class Dispatcher(Thread):
break break
# Check if we got a reply # Check if we got a reply
try: try:
conn, packet = self.message.get_nowait() conn, packet = self.message.get_nowait()
method_type = packet.getType() method_type = packet.getType()
if method_type == PING: conn.lock()
# Must answer with no delay try:
conn.addPacket(Packet().pong(packet.getId())) if method_type == PING:
break # Must answer with no delay
else: conn.addPacket(Packet().pong(packet.getId()))
# Process message by handler break
conn.handler.dispatch(conn, packet) else:
# Process message by handler
conn.handler.dispatch(conn, packet)
finally:
conn.unlock()
except Empty: except Empty:
pass pass
...@@ -156,15 +156,6 @@ class Dispatcher(Thread): ...@@ -156,15 +156,6 @@ class Dispatcher(Thread):
# Connected to primary master node # Connected to primary master node
break break
# If nothing, check if we have new message to send
try:
m = self._message_queue.get_nowait()
if m is not None:
tmp_q, msg_id, conn, p = m
conn.addPacket(p)
except Empty:
continue
logging.info("connected to primary master node %s %d" %app.primary_master_node.getServer()) logging.info("connected to primary master node %s %d" %app.primary_master_node.getServer())
app.master_conn = conn app.master_conn = conn
self.connecting_to_master_node = 0 self.connecting_to_master_node = 0
import logging import logging
from neo.handler import EventHandler from neo.handler import EventHandler
from neo.connection import ClientConnection from neo.connection import MTClientConnection
from neo.protocol import Packet, \ from neo.protocol import Packet, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
INVALID_UUID, RUNNING_STATE, TEMPORARILY_DOWN_STATE, BROKEN_STATE INVALID_UUID, RUNNING_STATE, TEMPORARILY_DOWN_STATE, BROKEN_STATE
...@@ -74,7 +74,7 @@ class ClientEventHandler(EventHandler): ...@@ -74,7 +74,7 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True) app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection # Remove from pool connection
app.cm.removeConnection(node) app.cp.removeConnection(node)
EventHandler.connectionClosed(self, conn) EventHandler.connectionClosed(self, conn)
def timeoutExpired(self, conn): def timeoutExpired(self, conn):
...@@ -98,7 +98,7 @@ class ClientEventHandler(EventHandler): ...@@ -98,7 +98,7 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True) app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection # Remove from pool connection
app.cm.removeConnection(node) app.cp.removeConnection(node)
EventHandler.timeoutExpired(self, conn) EventHandler.timeoutExpired(self, conn)
def peerBroken(self, conn): def peerBroken(self, conn):
...@@ -122,12 +122,12 @@ class ClientEventHandler(EventHandler): ...@@ -122,12 +122,12 @@ class ClientEventHandler(EventHandler):
p.notifyNodeInformation(msg_id, node_list) p.notifyNodeInformation(msg_id, node_list)
app.queue.put((None, msg_id, conn, p), True) app.queue.put((None, msg_id, conn, p), True)
# Remove from pool connection # Remove from pool connection
app.cm.removeConnection(node) app.cp.removeConnection(node)
EventHandler.peerBroken(self, conn) EventHandler.peerBroken(self, conn)
def handleNotReady(self, conn, packet, message): def handleNotReady(self, conn, packet, message):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.node_not_ready = 1 app.node_not_ready = 1
else: else:
...@@ -136,7 +136,7 @@ class ClientEventHandler(EventHandler): ...@@ -136,7 +136,7 @@ class ClientEventHandler(EventHandler):
def handleAcceptNodeIdentification(self, conn, packet, node_type, def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, uuid, ip_address, port,
num_partitions, num_replicas): num_partitions, num_replicas):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
node = app.nm.getNodeByServer(conn.getAddress()) node = app.nm.getNodeByServer(conn.getAddress())
# It can be eiter a master node or a storage node # It can be eiter a master node or a storage node
...@@ -162,12 +162,18 @@ class ClientEventHandler(EventHandler): ...@@ -162,12 +162,18 @@ class ClientEventHandler(EventHandler):
app.pt = PartitionTable(num_partitions, num_replicas) app.pt = PartitionTable(num_partitions, num_replicas)
app.num_partitions = num_partitions app.num_partitions = num_partitions
app.num_replicas = num_replicas app.num_replicas = num_replicas
# Ask a primary master. # Ask a primary master.
msg_id = conn.getNextId() conn.lock()
p = Packet() try:
p.askPrimaryMaster(msg_id) msg_id = conn.getNextId()
# send message to dispatcher p = Packet()
app.queue.put((app.local_var.tmp_q, msg_id, conn, p), True) p.askPrimaryMaster(msg_id)
conn.addPacket(p)
conn.expectMessage(msg_id)
app.dispatcher.register(conn, msg_id, app.getQueue())
finally:
conn.unlock()
elif node_type == STORAGE_NODE_TYPE: elif node_type == STORAGE_NODE_TYPE:
app.storage_node = node app.storage_node = node
else: else:
...@@ -176,7 +182,7 @@ class ClientEventHandler(EventHandler): ...@@ -176,7 +182,7 @@ class ClientEventHandler(EventHandler):
# Master node handler # Master node handler
def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, known_master_list): def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, known_master_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None: if uuid is None:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -220,7 +226,7 @@ class ClientEventHandler(EventHandler): ...@@ -220,7 +226,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleSendPartitionTable(self, conn, packet, ptid, row_list): def handleSendPartitionTable(self, conn, packet, ptid, row_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None: if uuid is None:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -249,7 +255,7 @@ class ClientEventHandler(EventHandler): ...@@ -249,7 +255,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleNotifyNodeInformation(self, conn, packet, node_list): def handleNotifyNodeInformation(self, conn, packet, node_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None: if uuid is None:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -309,7 +315,7 @@ class ClientEventHandler(EventHandler): ...@@ -309,7 +315,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list): def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
nm = app.nm nm = app.nm
pt = app.pt pt = app.pt
...@@ -344,14 +350,14 @@ class ClientEventHandler(EventHandler): ...@@ -344,14 +350,14 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerNewTID(self, conn, packet, tid): def handleAnswerNewTID(self, conn, packet, tid):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.tid = tid app.tid = tid
else: else:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleNotifyTransactionFinished(self, conn, packet, tid): def handleNotifyTransactionFinished(self, conn, packet, tid):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
if tid != app.tid: if tid != app.tid:
app.txn_finished = -1 app.txn_finished = -1
...@@ -361,7 +367,7 @@ class ClientEventHandler(EventHandler): ...@@ -361,7 +367,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleInvalidateObjects(self, conn, packet, oid_list): def handleInvalidateObjects(self, conn, packet, oid_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
...@@ -379,7 +385,7 @@ class ClientEventHandler(EventHandler): ...@@ -379,7 +385,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerNewOIDs(self, conn, packet, oid_list): def handleAnswerNewOIDs(self, conn, packet, oid_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.new_oid_list = oid_list app.new_oid_list = oid_list
app.new_oid_list.reverse() app.new_oid_list.reverse()
...@@ -387,7 +393,7 @@ class ClientEventHandler(EventHandler): ...@@ -387,7 +393,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleStopOperation(self, conn, packet): def handleStopOperation(self, conn, packet):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
logging.critical("master node ask to stop operation") logging.critical("master node ask to stop operation")
else: else:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
...@@ -396,7 +402,7 @@ class ClientEventHandler(EventHandler): ...@@ -396,7 +402,7 @@ class ClientEventHandler(EventHandler):
# Storage node handler # Storage node handler
def handleAnswerObject(self, conn, packet, oid, start_serial, end_serial, compression, def handleAnswerObject(self, conn, packet, oid, start_serial, end_serial, compression,
checksum, data): checksum, data):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.local_var.asked_object = (oid, start_serial, end_serial, compression, app.local_var.asked_object = (oid, start_serial, end_serial, compression,
checksum, data) checksum, data)
...@@ -404,7 +410,7 @@ class ClientEventHandler(EventHandler): ...@@ -404,7 +410,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerStoreObject(self, conn, packet, conflicting, oid, serial): def handleAnswerStoreObject(self, conn, packet, conflicting, oid, serial):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
if conflicting: if conflicting:
app.txn_object_stored = -1, serial app.txn_object_stored = -1, serial
...@@ -414,14 +420,14 @@ class ClientEventHandler(EventHandler): ...@@ -414,14 +420,14 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerStoreTransaction(self, conn, packet, tid): def handleAnswerStoreTransaction(self, conn, packet, tid):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
app.txn_voted = 1 app.txn_voted = 1
else: else:
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerTransactionInformation(self, conn, packet, tid, user, desc, oid_list): def handleAnswerTransactionInformation(self, conn, packet, tid, user, desc, oid_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# transaction information are returned as a dict # transaction information are returned as a dict
info = {} info = {}
...@@ -435,7 +441,7 @@ class ClientEventHandler(EventHandler): ...@@ -435,7 +441,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerObjectHistory(self, conn, packet, oid, history_list): def handleAnswerObjectHistory(self, conn, packet, oid, history_list):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# history_list is a list of tuple (serial, size) # history_list is a list of tuple (serial, size)
self.history = oid, history_list self.history = oid, history_list
...@@ -443,7 +449,7 @@ class ClientEventHandler(EventHandler): ...@@ -443,7 +449,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleOidNotFound(self, conn, packet, message): def handleOidNotFound(self, conn, packet, message):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# This can happen either when : # This can happen either when :
# - loading an object # - loading an object
...@@ -454,7 +460,7 @@ class ClientEventHandler(EventHandler): ...@@ -454,7 +460,7 @@ class ClientEventHandler(EventHandler):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleTidNotFound(self, conn, packet, message): def handleTidNotFound(self, conn, packet, message):
if isinstance(conn, ClientConnection): if isinstance(conn, MTClientConnection):
app = self.app app = self.app
# This can happen when requiring txn informations # This can happen when requiring txn informations
app.local_var.txn_info = -1 app.local_var.txn_info = -1
......
import socket import socket
import errno import errno
import logging import logging
from threading import RLock
from neo.protocol import Packet, ProtocolError from neo.protocol import Packet, ProtocolError
from neo.event import IdleEvent from neo.event import IdleEvent
...@@ -46,6 +47,12 @@ class BaseConnection(object): ...@@ -46,6 +47,12 @@ class BaseConnection(object):
def getUUID(self): def getUUID(self):
return None return None
def acquire(self, block = 1):
return 1
def release(self):
pass
class ListeningConnection(BaseConnection): class ListeningConnection(BaseConnection):
"""A listen connection.""" """A listen connection."""
def __init__(self, event_manager, handler, addr = None, **kw): def __init__(self, event_manager, handler, addr = None, **kw):
...@@ -305,3 +312,31 @@ class ClientConnection(Connection): ...@@ -305,3 +312,31 @@ class ClientConnection(Connection):
class ServerConnection(Connection): class ServerConnection(Connection):
"""A connection from a remote node to this node.""" """A connection from a remote node to this node."""
pass pass
class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection."""
def __init__(self, *args, **kwargs):
super(MTClientConnection, self).__init__(*args, **kwargs)
lock = RLock()
self.acquire = lock.acquire
self.release = lock.release
def lock(self, blocking = 1):
return self.acquire(blocking = blocking)
def unlock(self):
self.release()
class MTServerConnection(ServerConnection):
"""A Multithread-safe version of ServerConnection."""
def __init__(self, *args, **kwargs):
super(MTClientConnection, self).__init__(*args, **kwargs)
lock = RLock()
self.acquire = lock.acquire
self.release = lock.release
def lock(self, blocking = 1):
return self.acquire(blocking = blocking)
def unlock(self):
self.release()
r"""This is an epoll(4) interface available in Linux 2.6. This requires
ctypes <http://python.net/crew/theller/ctypes/>."""
from ctypes import cdll, Union, Structure, \
c_void_p, c_int, byref
try:
from ctypes import c_uint32, c_uint64
except ImportError:
from ctypes import c_uint, c_ulonglong
c_uint32 = c_uint
c_uint64 = c_ulonglong
from os import close
from errno import EINTR
libc = cdll.LoadLibrary('libc.so.6')
epoll_create = libc.epoll_create
epoll_wait = libc.epoll_wait
epoll_ctl = libc.epoll_ctl
errno = c_int.in_dll(libc, 'errno')
EPOLLIN = 0x001
EPOLLPRI = 0x002
EPOLLOUT = 0x004
EPOLLRDNORM = 0x040
EPOLLRDBAND = 0x080
EPOLLWRNORM = 0x100
EPOLLWRBAND = 0x200
EPOLLMSG = 0x400
EPOLLERR = 0x008
EPOLLHUP = 0x010
EPOLLONESHOT = (1 << 30)
EPOLLET = (1 << 31)
EPOLL_CTL_ADD = 1
EPOLL_CTL_DEL = 2
EPOLL_CTL_MOD = 3
class epoll_data(Union):
_fields_ = [("ptr", c_void_p),
("fd", c_int),
("u32", c_uint32),
("u64", c_uint64)]
class epoll_event(Structure):
_fields_ = [("events", c_uint32),
("data", epoll_data)]
class Epoll(object):
efd = -1
def __init__(self):
self.efd = epoll_create(10)
if self.efd == -1:
raise OSError(errno.value, 'epoll_create failed')
self.maxevents = 1024 # XXX arbitrary
epoll_event_array = epoll_event * self.maxevents
self.events = epoll_event_array()
def poll(self, timeout = 1):
timeout *= 1000
timeout = int(timeout)
while 1:
n = epoll_wait(self.efd, byref(self.events), self.maxevents,
timeout)
if n == -1:
e = errno.value
if e == EINTR:
continue
else:
raise OSError(e, 'epoll_wait failed')
else:
readable_fd_list = []
writable_fd_list = []
for i in xrange(n):
ev = self.events[i]
if ev.events & (EPOLLIN | EPOLLERR | EPOLLHUP):
readable_fd_list.append(int(ev.data.fd))
elif ev.events & (EPOLLOUT | EPOLLERR | EPOLLHUP):
writable_fd_list.append(int(ev.data.fd))
return readable_fd_list, writable_fd_list
def register(self, fd):
ev = epoll_event()
ev.data.fd = fd
ret = epoll_ctl(self.efd, EPOLL_CTL_ADD, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def modify(self, fd, readable, writable):
ev = epoll_event()
ev.data.fd = fd
events = 0
if readable:
events |= EPOLLIN
if writable:
events |= EPOLLOUT
ev.events = events
ret = epoll_ctl(self.efd, EPOLL_CTL_MOD, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def unregister(self, fd):
ev = epoll_event()
ret = epoll_ctl(self.efd, EPOLL_CTL_DEL, fd, byref(ev))
if ret == -1:
raise OSError(errno.value, 'epoll_ctl failed')
def __del__(self):
if self.efd >= 0:
close(self.efd)
...@@ -3,6 +3,7 @@ from select import select ...@@ -3,6 +3,7 @@ from select import select
from time import time from time import time
from neo.protocol import Packet from neo.protocol import Packet
from neo.epoll import Epoll
class IdleEvent(object): class IdleEvent(object):
"""This class represents an event called when a connection is waiting for """This class represents an event called when a connection is waiting for
...@@ -28,26 +29,35 @@ class IdleEvent(object): ...@@ -28,26 +29,35 @@ class IdleEvent(object):
def __call__(self, t): def __call__(self, t):
conn = self._conn conn = self._conn
if t > self._critical_time: if t > self._critical_time:
logging.info('timeout with %s:%d', *(conn.getAddress())) conn.lock()
conn.getHandler().timeoutExpired(conn) try:
conn.close() logging.info('timeout with %s:%d', *(conn.getAddress()))
return True conn.getHandler().timeoutExpired(conn)
conn.close()
return True
finally:
conn.unlock()
elif t > self._time: elif t > self._time:
if self._additional_timeout > 5: conn.lock()
self._additional_timeout -= 5 try:
conn.expectMessage(self._id, 5, self._additional_timeout) if self._additional_timeout > 5:
# Start a keep-alive packet. self._additional_timeout -= 5
logging.info('sending a ping to %s:%d', *(conn.getAddress())) conn.expectMessage(self._id, 5, self._additional_timeout)
msg_id = conn.getNextId() # Start a keep-alive packet.
conn.addPacket(Packet().ping(msg_id)) logging.info('sending a ping to %s:%d',
conn.expectMessage(msg_id, 5, 0) *(conn.getAddress()))
else: msg_id = conn.getNextId()
conn.expectMessage(self._id, self._additional_timeout, 0) conn.addPacket(Packet().ping(msg_id))
return True conn.expectMessage(msg_id, 5, 0)
else:
conn.expectMessage(self._id, self._additional_timeout, 0)
return True
finally:
conn.unlock()
return False return False
class EventManager(object): class SelectEventManager(object):
"""This class manages connections and events.""" """This class manages connections and events based on select(2)."""
def __init__(self): def __init__(self):
self.connection_dict = {} self.connection_dict = {}
...@@ -71,13 +81,21 @@ class EventManager(object): ...@@ -71,13 +81,21 @@ class EventManager(object):
timeout) timeout)
for s in rlist: for s in rlist:
conn = self.connection_dict[s] conn = self.connection_dict[s]
conn.readable() conn.lock()
try:
conn.readable()
finally:
conn.unlock()
for s in wlist: for s in wlist:
# This can fail, if a connection is closed in readable(). # This can fail, if a connection is closed in readable().
try: try:
conn = self.connection_dict[s] conn = self.connection_dict[s]
conn.writable() conn.lock()
try:
conn.writable()
finally:
conn.unlock()
except KeyError: except KeyError:
pass pass
...@@ -120,3 +138,102 @@ class EventManager(object): ...@@ -120,3 +138,102 @@ class EventManager(object):
def removeWriter(self, conn): def removeWriter(self, conn):
self.writer_set.discard(conn.getSocket()) self.writer_set.discard(conn.getSocket())
class EpollEventManager(object):
"""This class manages connections and events based on epoll(5)."""
def __init__(self):
self.connection_dict = {}
self.reader_set = set([])
self.writer_set = set([])
self.event_list = []
self.prev_time = time()
self.epoll = Epoll()
def getConnectionList(self):
return self.connection_dict.values()
def register(self, conn):
fd = conn.getSocket().fileno()
self.connection_dict[fd] = conn
self.epoll.register(fd)
def unregister(self, conn):
fd = conn.getSocket().fileno()
self.epoll.unregister(fd)
del self.connection_dict[fd]
def poll(self, timeout = 1):
rlist, wlist = self.epoll.poll(timeout)
for fd in rlist:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.readable()
finally:
conn.unlock()
for fd in wlist:
# This can fail, if a connection is closed in readable().
try:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.writable()
finally:
conn.unlock()
except KeyError:
pass
# Check idle events. Do not check them out too often, because this
# is somehow heavy.
event_list = self.event_list
if event_list:
t = time()
if t - self.prev_time >= 1:
self.prev_time = t
event_list.sort(key = lambda event: event.getTime())
while event_list:
event = event_list[0]
if event(t):
try:
event_list.remove(event)
except ValueError:
pass
else:
break
def addIdleEvent(self, event):
self.event_list.append(event)
def removeIdleEvent(self, event):
try:
self.event_list.remove(event)
except ValueError:
pass
def addReader(self, conn):
fd = conn.getSocket().fileno()
if fd not in self.reader_set:
self.reader_set.add(fd)
self.epoll.modify(fd, 1, fd in self.writer_set)
def removeReader(self, conn):
fd = conn.getSocket().fileno()
if fd in self.reader_set:
self.reader_set.remove(fd)
self.epoll.modify(fd, 0, fd in self.writer_set)
def addWriter(self, conn):
fd = conn.getSocket().fileno()
if fd not in self.writer_set:
self.writer_set.add(fd)
self.epoll.modify(fd, fd in self.reader_set, 1)
def removeWriter(self, conn):
fd = conn.getSocket().fileno()
if fd in self.writer_set:
self.writer_set.remove(fd)
self.epoll.modify(fd, fd in self.reader_set, 0)
# Default to EpollEventManager.
EventManager = EpollEventManager
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment