Commit dd785bb2 authored by Martín Ferrari's avatar Martín Ferrari

Better handling of the FDs; some basic tests

parent 08105d5b
......@@ -79,9 +79,13 @@ class Server(object):
if hasattr(fd, "readline"):
self._fd = fd
else:
# Since openfd insists on closing the fd on destruction, I need to
# dup()
if hasattr(fd, "fileno"):
fd = fd.fileno()
self._fd = os.fdopen(fd, "r+", 1)
nfd = os.dup(fd.fileno())
else:
nfd = os.dup(fd)
self._fd = os.fdopen(nfd, "r+", 1)
def reply(self, code, text):
"Send back a reply to the client; handle multiline messages"
......@@ -196,6 +200,9 @@ class Server(object):
try:
if args[i][0] == '=':
args[i] = base64.b64decode(args[i][1:])
if len(args[i]) == 0:
self.reply(500, "Invalid parameter: empty.")
return None
except TypeError:
self.reply(500, "Invalid parameter: not base-64 encoded.")
return None
......@@ -390,17 +397,20 @@ class Slave(object):
If fd and pid are specified, the slave process is not created; fd is
used as a control socket and pid is assumed to be the pid of the slave
process."""
if fd and pid:
# If fd is passed do not fork or anything
if hasattr(fd, "readline"):
pass # fd ok
# If fd is passed do not fork or anything
if not (fd and pid):
fd, pid = _start_child(debug)
# XXX: In some cases we do not call dup(); maybe this should be
# consistent?
if not hasattr(fd, "readline"):
# Since openfd insists on closing the fd on destruction, I need to
# dup()
if hasattr(fd, "fileno"):
nfd = os.dup(fd.fileno())
else:
if hasattr(fd, "fileno"):
fd = fd.fileno()
fd = os.fdopen(fd, "r+", 1)
else:
f, pid = _start_child(debug)
fd = os.fdopen(f.fileno(), "r+", 1)
nfd = os.dup(fd)
fd = os.fdopen(nfd, "r+", 1)
self._pid = pid
self._fd = fd
......
#!/usr/bin/env python
# vim:ts=4:sw=4:et:ai:sts=4
import netns.protocol
import os, socket, sys, unittest
class TestServer(unittest.TestCase):
def test_server_startup(self):
# Test the creation of the server object with different ways of passing
# the file descriptor; and check the banner.
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
(s2, s3) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
pid = os.fork()
if not pid:
s1.close()
srv = netns.protocol.Server(s0)
srv.run()
s3.close()
srv = netns.protocol.Server(s2.fileno())
srv.run()
os._exit(0)
s0.close()
s = os.fdopen(s1.fileno(), "r+", 1)
self.assertEquals(s.readline()[0:4], "220 ")
s.close()
s2.close()
s = os.fdopen(s3.fileno(), "r+", 1)
self.assertEquals(s.readline()[0:4], "220 ")
s.close()
pid, ret = os.waitpid(pid, 0)
self.assertEquals(ret, 0)
def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = netns.protocol.Server(s0)
s1 = s1.makefile("r+", 1)
def check_error(self, cmd, code = 500):
s1.write("%s\n" % cmd)
self.assertEquals(srv.readcmd(), None)
self.assertEquals(s1.readline()[0:4], "%d " % code)
def check_ok(self, cmd, func, args):
s1.write("%s\n" % cmd)
ccmd = " ".join(cmd.upper().split()[0:2])
if func == None:
self.assertEquals(srv.readcmd()[1:3], (ccmd, args))
else:
self.assertEquals(srv.readcmd(), (func, ccmd, args))
check_ok(self, "quit", srv.do_QUIT, [])
check_ok(self, " quit ", srv.do_QUIT, [])
# protocol error
check_error(self, "quit 1")
# Not allowed in normal mode
check_error(self, "proc sin")
check_error(self, "proc sout")
check_error(self, "proc serr")
check_error(self, "proc cwd")
check_error(self, "proc env")
check_error(self, "proc abrt")
check_error(self, "proc run")
# not implemented
#check_ok(self, "if list", srv.do_IF_LIST, [])
#check_ok(self, "if list 1", srv.do_IF_LIST, [1])
check_error(self, "if list")
check_error(self, "proc poll") # missing arg
check_error(self, "proc poll 1 2") # too many args
check_error(self, "proc poll a") # invalid type
check_ok(self, "proc crte 0 0 /bin/sh", srv.do_PROC_CRTE,
[0, 0, '/bin/sh'])
# Commands that would fail, but the parsing is correct
check_ok(self, "proc poll 0", None, [0])
check_ok(self, "proc wait 0", None, [0])
check_ok(self, "proc kill 0", None, [0])
check_error(self, "proc crte 0 0 =") # empty b64
check_error(self, "proc crte 0 0 =a") # invalid b64
# simulate proc mode
srv.commands = netns.protocol._proc_commands
check_error(self, "proc crte 0 0 foo")
check_error(self, "proc poll 0")
check_error(self, "proc wait 0")
check_error(self, "proc kill 0")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment