#! /usr/bin/env python ############################################################################## # # Yoshinori OKUJI <yo@nexedi.com> # # Copyright (C) 2004 Nexedi SARL # # This program is Free Software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. ?See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA ?02111-1307, USA. # ############################################################################## SBALANCE_VERSION = '4.0' import sys import getopt import socket import os import threading import time from select import select import re if not hasattr(socket, 'setdefaulttimeout'): raise RuntimeError, 'Your Python interpreter is too old. Please upgrade it.' class ClientInfo: pass class Balancer: def __init__(self, port, server_list, bind = '', connect_timeout = 5, select_timeout = None, debug = 0, foreground = 0, packet_dump = 0): """ Initialize the basic status. """ self.port = port self.server_list = server_list self.bind = bind self.connect_timeout = connect_timeout self.select_timeout = select_timeout self.debug = debug self.foreground = foreground self.packet_dump = packet_dump # Make a socket to listen. self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.bind((self.bind, self.port)) self.socket.listen(5) # Make shared information and a lock for it. self.lock = threading.Lock() self.next_server = 0 self.sticked_server_dict = {} self.disabled_server_dict = {} # Daemonize itself. if not self.foreground: self.daemonize() def daemonize(self): """ Make myself a daemon. """ pid = os.fork() if pid > 0: sys.exit() os.chdir('/') os.setsid() os.umask(0) pid = os.fork() if pid > 0: sys.exit() f = open('/dev/null', 'w+') os.dup2(f.fileno(), sys.stdin.fileno()) os.dup2(f.fileno(), sys.stdout.fileno()) os.dup2(f.fileno(), sys.stderr.fileno()) def run(self): try: # Make a thread for expiration of old sticky entries. if self.debug: print "Starting an expiring daemon thread" t = threading.Thread(target=Balancer.expire, args=(self,)) t.setDaemon(1) t.start() if self.debug: print "Beginning the main loop to accept clients" while 1: conn, addr = self.socket.accept() if self.debug: print "New connection from %s" % str(addr) t = threading.Thread(target=Balancer.handleClient, args=(self, conn, addr)) t.start() finally: self.socket.close() def expire(self): while 1: time.sleep(60) try: self.lock.acquire() cur_time = time.time() count_dict = {} for addr in self.server_list: if addr not in self.disabled_server_dict: count_dict[addr] = 0 expired_server_list = [] for key,value in self.sticked_server_dict.items(): if self.debug: print 'cur_time = %f, value.atime = %f' % (cur_time, value.atime) if cur_time > value.atime + 60 * 10: expired_server_list.append(key) else: if value.addr in count_dict: count_dict[value.addr] += 1 for key in expired_server_list: if self.debug: print "Expiring %s" % str(key) del self.sticked_server_dict[key] # Expire this entry. # Find the max and the min. if self.debug: print 'count_dict = %s, sticked_server_dict = %s, disabled_server_dict = %s' % (str(count_dict), str(self.sticked_server_dict), str(self.disabled_server_dict)) max = -1 min = len(self.sticked_server_dict) + 1 for addr,count in count_dict.items(): if count > max: max = count max_addr = addr if count < min: min = count min_addr = addr # If the max is significantly greater than the min, move some clients. if max > min + 1: num = max - min - 1 for key,value in self.sticked_server_dict.items(): if value.addr == max_addr: if self.debug: print "Moving %s from %s to %s" % (str(key), str(max_addr), str(min_addr)) value.addr = min_addr num -= 1 if num <= 0: break # Enable old entries in disabled servers. enabled_server_list = [] for addr,ctime in self.disabled_server_dict.items(): if cur_time > ctime + 60 * 3: enabled_server_list.append(addr) for addr in enabled_server_list: if self.debug: print 'Enabling %s again' % addr del self.disabled_server_dict[addr] finally: self.lock.release() def getSignature(self, s): """ Try to find out a signature. Depend on Zope and CookieCrumbler in CMFCore. """ if s[:3] == 'GET' or s[:3] == 'PUT' or s[:4] == 'PUSH': # This looks like a HTTP request. header_end = s.find('\r\n\r\n') if header_end < 0: return None s = s[:header_end] m = re.search('\r\nAuthorization:\s*(.+)', s, re.IGNORECASE) if m: return s[m.start(1):m.end(1)] m = re.search('\r\nCookie:.*__ac=\"(.+)\"', s, re.IGNORECASE) if m: return s[m.start(1):m.end(1)] return None def handleClient(self, conn, addr): """ Choose a server and do a proxy job. """ server_conn = None try: # Make a new socket. server_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_conn.settimeout(self.connect_timeout) # Need to read the first some bytes to get a signature. size = 4096 buf = "" while size > 0: iwtd, owtd, ewtd = select([conn], [], [], 0.1) if len(iwtd) == 0: break data = conn.recv(size) size -= len(data) buf += data signature = self.getSignature(buf) # Choose a server. try: self.lock.acquire() addr = None if signature is not None and signature in self.sticked_server_dict: addr = self.sticked_server_dict[signature].addr start_index = self.server_list.index(addr) else: addr = self.server_list[self.next_server] start_index = self.next_server self.next_server = start_index + 1 if self.next_server == len(self.server_list): self.next_server = 0 finally: self.lock.release() index = start_index while 1: # Check if this server is enabled. enabled = 1 try: self.lock.acquire() if addr in self.disabled_server_dict: enabled = 0 finally: self.lock.release() if enabled: try: host, port = addr.split(':') port = int(port) server_conn.connect((host, port)) break except: # Something wrong happened with this server. try: self.lock.acquire() if self.debug: print 'Disabling %s' % addr cur_time = time.time() self.disabled_server_dict[addr] = cur_time finally: self.lock.release() # Need to find the next server. index += 1 try: self.lock.acquire() if index >= len(self.server_list): index = 0 addr = self.server_list[index] finally: self.lock.release() if index == start_index: # No way. if self.debug: print 'No available server found.' return # Register this client if possible. if signature: try: self.lock.acquire() if self.debug: print 'Registering %s with %s' % (signature, addr) cur_time = time.time() if signature in self.sticked_server_dict: info = self.sticked_server_dict[signature] info.atime = cur_time info.addr = addr else: info = ClientInfo() info.atime = cur_time info.addr = addr self.sticked_server_dict[signature] = info finally: self.lock.release() # Now is the time to play. server_conn.settimeout(None) server_conn.sendall(buf) while 1: iwtd, owtd, ewtd = select([server_conn, conn], [], [], self.select_timeout) if len(iwtd) == 0: return if server_conn in iwtd: buf = server_conn.recv(4096) if len(buf) == 0: return conn.sendall(buf) if conn in iwtd: buf = conn.recv(4096) if len(buf) == 0: return server_conn.sendall(buf) finally: conn.close() if server_conn is not None: server_conn.close() def main(): kwd = {} try: opts, args = getopt.getopt(sys.argv[1:], "hvb:t:T:dfps", ["help", "version", "bind=", "connect-timeout=", "select-timeout=", "debug", "foreground", "packet-dump", "sticky"]) except getopt.GetoptError, msg: print msg print "Try ``sbalance --help'' for more information." sys.exit(2) for o, a in opts: if o in ("-v", "--version"): print "sbalance version %s" % SBALANCE_VERSION sys.exit() elif o in ("-h", "--help"): print '''Usage: sbalace [OPTION...] PORT HOST:[PORT]... Balance TCP/IP loads with distributed servers. -h, --help display this message and exit -v, --version print version information and exit -b, --bind=HOST accept connections only to a host instead of any -t, --connect-timeout=SEC specify the timeout for connect -T, --select-timeout=SEC specify the timeout for select -d, --debug output debugging information -f, --foreground run sbalance in foreground -p, --packet-dump dump packet contents -s, --sticky for backward compatibility PORT is the port number to listen to. You can specify any number of pairs of a host and a port. Report bugs to <yo@nexedi.com>.''' sys.exit() elif o in ("-b", "--bind"): kwd['bind'] = a elif o in ("-t", "--connect-timeout"): kwd['connect_timeout'] = int(a) elif o in ("-T", "--select-timeout"): kwd['select_timeout'] = int(a) elif o in ("-d", "--debug"): kwd['debug'] = 1 elif o in ("-f", "--foreground"): kwd['foreground'] = 1 elif o in ("-p", "--packet-dump"): kwd['packet_dump'] = 1 elif o in ("-s", "--stickey"): pass if len(args) < 2: print "Too few arguments." print "Try ``sbalance --help'' for more information." sys.exit(2) port = int(args[0]) server_list = [] for server in args[1:]: if server == '%' or server == '!': continue # For compatibility. i = server.find(':') if i < 0: addr = server + ':' + str(port) else: addr = server server_list.append(addr) if len(server_list) < 1: print "No server is specified." print "Try ``sbalance --help'' for more information." sys.exit(2) b = Balancer(port, server_list, **kwd) b.run() if __name__ == "__main__": main()