From 467b9d97ce6fcb22048622d04214203c56cc2614 Mon Sep 17 00:00:00 2001 From: Jean-Paul Smets <jp@nexedi.com> Date: Thu, 1 Apr 2004 17:14:15 +0000 Subject: [PATCH] *** empty log message *** git-svn-id: https://svn.erp5.org/repos/public/erp5/trunk@641 20353a03-c40f-0410-a6d1-a30d3c3de9de --- product/CMFActivity/sbalance.py | 370 ++++++++++++++++++++++++++++++++ 1 file changed, 370 insertions(+) create mode 100755 product/CMFActivity/sbalance.py diff --git a/product/CMFActivity/sbalance.py b/product/CMFActivity/sbalance.py new file mode 100755 index 0000000000..6a0154063b --- /dev/null +++ b/product/CMFActivity/sbalance.py @@ -0,0 +1,370 @@ +#! /usr/bin/env python + +############################################################################## +# +# Yoshinori OKUJI <yo@nexedi.com> +# +# Copyright (C) 2004 Nexedi SARL +# +# This program is Free Software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. ?See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA ?02111-1307, USA. +# +############################################################################## + +SBALANCE_VERSION = '4.0' + +import sys +import getopt +import socket +import os +import threading +import time +from select import select +import re + +if not hasattr(socket, 'setdefaulttimeout'): + raise RuntimeError, 'Your Python interpreter is too old. Please upgrade it.' + +class ClientInfo: pass + +class Balancer: + def __init__(self, port, server_list, bind = '', connect_timeout = 5, select_timeout = None, + debug = 0, foreground = 0, packet_dump = 0): + """ + Initialize the basic status. + """ + self.port = port + self.server_list = server_list + self.bind = bind + self.connect_timeout = connect_timeout + self.select_timeout = select_timeout + self.debug = debug + self.foreground = foreground + self.packet_dump = packet_dump + + # Make a socket to listen. + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.bind((self.bind, self.port)) + self.socket.listen(5) + + # Make shared information and a lock for it. + self.lock = threading.Lock() + self.next_server = 0 + self.sticked_server_dict = {} + self.disabled_server_dict = {} + + # Daemonize itself. + if not self.foreground: + self.daemonize() + + def daemonize(self): + """ + Make myself a daemon. + """ + pid = os.fork() + if pid > 0: + sys.exit() + os.chdir('/') + os.setsid() + os.umask(0) + pid = os.fork() + if pid > 0: + sys.exit() + f = open('/dev/null', 'w+') + os.dup2(f.fileno(), sys.stdin.fileno()) + os.dup2(f.fileno(), sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stderr.fileno()) + + def run(self): + try: + # Make a thread for expiration of old sticky entries. + if self.debug: + print "Starting an expiring daemon thread" + t = threading.Thread(target=Balancer.expire, args=(self,)) + t.setDaemon(1) + t.start() + + if self.debug: + print "Beginning the mail loop to accept clients" + while 1: + conn, addr = self.socket.accept() + if self.debug: + print "New connection from %s" % str(addr) + t = threading.Thread(target=Balancer.handleClient, args=(self, conn, addr)) + t.start() + finally: + self.socket.close() + + def expire(self): + while 1: + time.sleep(60) + try: + self.lock.acquire() + cur_time = time.clock() + count_dict = {} + expired_server_list = [] + for key,value in self.sticked_server_dict.items(): + if cur_time > value.atime + 60 * 10: + expired_server_list.append(key) + else: + if value.addr in count_dict: + count_dict[value.addr] += 1 + else: + count_dict[value.addr] = 1 + for key in expired_server_list: + if self.debug: + print "Expiring %s" % str(key) + del self.sticked_server_dict[key] # Expire this entry. + # Find the max and the min. + max = -1 + min = len(self.sticked_server_dict) + 1 + for addr,count in count_dict.items(): + if count > max: + max = count + max_addr = addr + if count < min: + min = count + min_addr = addr + # If the max is significantly greater than the min, move some clients. + if max > min + 1: + num = max - min + for key,value in self.sticked_server_dict.items(): + if value.addr == max_addr: + if self.debug: + print "Moving %s from %s to %s" % (str(key), str(max_addr), str(min_addr)) + value.addr = min_addr + num -= 1 + if num <= 0: + break + # Enable old entries in disabled servers. + enabled_server_list = [] + for addr,ctime in self.disabled_server_dict.items(): + if cur_time > ctime + 60 * 3: + enabled_server_list.append(addr) + for addr in enabled_server_list: + if self.debug: + print 'Enabling %s again' % addr + del self.disabled_server_dict[addr] + finally: + self.lock.release() + + def getSignature(self, s): + """ + Try to find out a signature. Depend on Zope and CookieCrumbler in CMFCore. + """ + if s[:3] == 'GET' or s[:3] == 'PUT' or s[:4] == 'PUSH': + # This looks like a HTTP request. + header_end = s.find('\r\n\r\n') + if header_end < 0: + return None + s = s[:header_end] + m = re.search('\r\nAuthorization:\s*(.+)', s, re.IGNORECASE) + if m: + return s[m.start(1):m.end(1)] + m = re.search('\r\nCookie:.*__ac=\"(.+)\"', s, re.IGNORECASE) + if m: + return s[m.start(1):m.end(1)] + return None + + def handleClient(self, conn, addr): + """ + Choose a server and do a proxy job. + """ + server_conn = None + try: + # Make a new socket. + server_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_conn.settimeout(self.connect_timeout) + + # Need to read the first some bytes to get a signature. + size = 4096 + buf = "" + while size > 0: + iwtd, owtd, ewtd = select([conn], [], [], 0.1) + if len(iwtd) == 0: break + data = conn.recv(size) + size -= len(data) + buf += data + signature = self.getSignature(buf) + + # Choose a server. + try: + self.lock.acquire() + addr = None + if signature is not None and signature in self.sticked_server_dict: + addr = self.sticked_server_dict[signature].addr + start_index = self.server_list.index(addr) + else: + addr = self.server_list[self.next_server] + start_index = self.next_server + self.next_server = start_index + 1 + if self.next_server == len(self.server_list): + self.next_server = 0 + finally: + self.lock.release() + + index = start_index + while 1: + # Check if this server is enabled. + enabled = 1 + try: + self.lock.acquire() + if addr in self.disabled_server_dict: + enabled = 0 + finally: + self.lock.release() + + if enabled: + try: + host, port = addr.split(':') + port = int(port) + server_conn.connect((host, port)) + break + except: + # Something wrong happened with this server. + try: + self.lock.acquire() + if self.debug: + print 'Disabling %s' % addr + cur_time = time.clock() + self.disabled_server_dict[addr] = cur_time + finally: + self.lock.release() + + # Need to find the next server. + index += 1 + try: + self.lock.acquire() + if index >= len(self.server_list): + index = 0 + addr = self.server_list[index] + finally: + self.lock.release() + + if index == start_index: + # No way. + if self.debug: + print 'No available server found.' + return + + # Register this client if possible. + if signature: + try: + self.lock.acquire() + if self.debug: + print 'Registering %s with %s' % (signature, addr) + cur_time = time.clock() + if signature in self.sticked_server_dict: + info = self.sticked_server_dict[signature] + info.atime = cur_time + info.addr = addr + else: + info = ClientInfo() + info.atime = cur_time + info.addr = addr + self.sticked_server_dict[signature] = info + finally: + self.lock.release() + + # Now is the time to play. + server_conn.settimeout(None) + server_conn.sendall(buf) + while 1: + iwtd, owtd, ewtd = select([server_conn, conn], [], [], self.select_timeout) + if len(iwtd) == 0: + return + if server_conn in iwtd: + buf = server_conn.recv(4096) + if len(buf) == 0: return + conn.sendall(buf) + if conn in iwtd: + buf = conn.recv(4096) + if len(buf) == 0: return + server_conn.sendall(buf) + finally: + conn.close() + if server_conn is not None: server_conn.close() + +def main(): + kwd = {} + try: + opts, args = getopt.getopt(sys.argv[1:], "hvb:t:T:dfps", + ["help", "version", "bind=", "connect-timeout=", "select-timeout=", "debug", "foreground", "packet-dump", "sticky"]) + except getopt.GetoptError, msg: + print msg + print "Try ``sbalance --help'' for more information." + sys.exit(2) + for o, a in opts: + if o in ("-v", "--version"): + print "sbalance version %s" % SBALANCE_VERSION + sys.exit() + elif o in ("-h", "--help"): + print '''Usage: sbalace [OPTION...] PORT HOST:[PORT]... +Balance TCP/IP loads with distributed servers. + + -h, --help display this message and exit + -v, --version print version information and exit + -b, --bind=HOST accept connections only to a host instead of any + -t, --connect-timeout=SEC specify the timeout for connect + -T, --select-timeout=SEC specify the timeout for select + -d, --debug output debugging information + -f, --foreground run sbalance in foreground + -p, --packet-dump dump packet contents + -s, --sticky for backward compatibility + +PORT is the port number to listen to. You can specify any number of +pairs of a host and a port. + +Report bugs to <yo@nexedi.com>.''' + sys.exit() + elif o in ("-b", "--bind"): + kwd['bind'] = a + elif o in ("-t", "--connect-timeout"): + kwd['connect_timeout'] = int(a) + elif o in ("-T", "--select-timeout"): + kwd['select_timeout'] = int(a) + elif o in ("-d", "--debug"): + kwd['debug'] = 1 + elif o in ("-f", "--foreground"): + kwd['foreground'] = 1 + elif o in ("-p", "--packet-dump"): + kwd['packet_dump'] = 1 + elif o in ("-s", "--stickey"): + pass + + if len(args) < 2: + print "Too few arguments." + print "Try ``sbalance --help'' for more information." + sys.exit(2) + + port = int(args[0]) + server_list = [] + for server in args[1:]: + if server == '%' or server == '!': continue # For compatibility. + i = server.find(':') + if i < 0: + addr = server + ':' + str(port) + else: + addr = server + server_list.append(addr) + if len(server_list) < 1: + print "No server is specified." + print "Try ``sbalance --help'' for more information." + sys.exit(2) + + b = Balancer(port, server_list, **kwd) + b.run() + +if __name__ == "__main__": + main() -- 2.30.9