Commit 1a81ac39 authored by Martín Ferrari's avatar Martín Ferrari

- Separate addresses from interfaces: less trouble

- Make Server and Client take two unidirectional fd's, so stdio can be used.
- Implement IF LIST.
parent 35df66e0
...@@ -19,8 +19,7 @@ class interface(object): ...@@ -19,8 +19,7 @@ class interface(object):
multicast = "MULTICAST" in flags) multicast = "MULTICAST" in flags)
def __init__(self, index = None, name = None, up = None, mtu = None, def __init__(self, index = None, name = None, up = None, mtu = None,
lladdr = None, broadcast = None, multicast = None, arp = None, lladdr = None, broadcast = None, multicast = None, arp = None):
addresses = None):
self.index = int(index) if index else None self.index = int(index) if index else None
self.name = name self.name = name
self.up = up self.up = up
...@@ -29,30 +28,15 @@ class interface(object): ...@@ -29,30 +28,15 @@ class interface(object):
self.broadcast = broadcast self.broadcast = broadcast
self.multicast = multicast self.multicast = multicast
self.arp = arp self.arp = arp
self.addresses = addresses if addresses else []
def _set_addresses(self, value):
if value == None:
self._addresses = None
return
assert len(value) == len(set(value))
self._addresses = list(value)
def _get_addresses(self):
if self._addresses != None:
return list(self._addresses) # Copy, to make this inmutable
addresses = property(_get_addresses, _set_addresses)
def __repr__(self): def __repr__(self):
s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, " s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
s += "broadcast = %s, multicast = %s, arp = %s, addresses = %s)" s += "broadcast = %s, multicast = %s, arp = %s)"
return s % (self.__module__, self.__class__.__name__, return s % (self.__module__, self.__class__.__name__,
self.index.__repr__(), self.name.__repr__(), self.index.__repr__(), self.name.__repr__(),
self.up.__repr__(), self.mtu.__repr__(), self.up.__repr__(), self.mtu.__repr__(),
self.lladdr.__repr__(), self.broadcast.__repr__(), self.lladdr.__repr__(), self.broadcast.__repr__(),
self.multicast.__repr__(), self.arp.__repr__(), self.multicast.__repr__(), self.arp.__repr__())
self.addresses.__repr__())
def __sub__(self, o): def __sub__(self, o):
"""Compare attributes and return a new object with just the attributes """Compare attributes and return a new object with just the attributes
...@@ -65,18 +49,10 @@ class interface(object): ...@@ -65,18 +49,10 @@ class interface(object):
broadcast = None if self.broadcast == o.broadcast else self.broadcast broadcast = None if self.broadcast == o.broadcast else self.broadcast
multicast = None if self.multicast == o.multicast else self.multicast multicast = None if self.multicast == o.multicast else self.multicast
arp = None if self.arp == o.arp else self.arp arp = None if self.arp == o.arp else self.arp
addresses = None if self.addresses == o.addresses else self.addresses
return self.__class__(self.index, name, up, mtu, lladdr, broadcast, return self.__class__(self.index, name, up, mtu, lladdr, broadcast,
multicast, arp, addresses) multicast, arp)
class address(object): class address(object):
@property
def address(self): return self._address
@property
def prefix_len(self): return self._prefix_len
@property
def family(self): return self._family
@classmethod @classmethod
def parse_ip(cls, line): def parse_ip(cls, line):
match = re.search(r'^inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line) match = re.search(r'^inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
...@@ -93,16 +69,27 @@ class address(object): ...@@ -93,16 +69,27 @@ class address(object):
prefix_len = match.group(2)) prefix_len = match.group(2))
raise RuntimeError("Problems parsing ip command output") raise RuntimeError("Problems parsing ip command output")
def __eq__(self, o):
if not isinstance(o, address):
return False
return (self.family == o.family and self.address == o.address and
self.prefix_len == o.prefix_len and
self.broadcast == o.broadcast)
def __hash__(self):
h = (self.address.__hash__() ^ self.prefix_len.__hash__() ^
self.family.__hash__())
if hasattr(self, 'broadcast'):
h ^= self.broadcast.__hash__()
return h
class ipv4address(address): class ipv4address(address):
def __init__(self, address, prefix_len, broadcast): def __init__(self, address, prefix_len, broadcast):
self._address = address self.address = address
self._prefix_len = int(prefix_len) self.prefix_len = int(prefix_len)
self._broadcast = broadcast self.broadcast = broadcast
self._family = socket.AF_INET self.family = socket.AF_INET
@property
def broadcast(self): return self._broadcast
def __repr__(self): def __repr__(self):
s = "%s.%s(address = %s, prefix_len = %d, broadcast = %s)" s = "%s.%s(address = %s, prefix_len = %d, broadcast = %s)"
...@@ -110,70 +97,64 @@ class ipv4address(address): ...@@ -110,70 +97,64 @@ class ipv4address(address):
self.address.__repr__(), self.prefix_len, self.address.__repr__(), self.prefix_len,
self.broadcast.__repr__()) self.broadcast.__repr__())
def __eq__(self, o):
if not isinstance(o, address):
return False
return (self.address == o.address and
self.prefix_len == o.prefix_len and
self.broadcast == o.broadcast)
def __hash__(self):
return (self._address.__hash__() ^ self._prefix_len.__hash__() ^
self._family.__hash__()) ^ self._broadcast.__hash__()
class ipv6address(address): class ipv6address(address):
def __init__(self, address, prefix_len): def __init__(self, address, prefix_len):
self._address = address self.address = address
self._prefix_len = int(prefix_len) self.prefix_len = int(prefix_len)
self._family = socket.AF_INET6 self.family = socket.AF_INET6
def __repr__(self): def __repr__(self):
s = "%s.%s(address = %s, prefix_len = %d)" s = "%s.%s(address = %s, prefix_len = %d)"
return s % (self.__module__, self.__class__.__name__, return s % (self.__module__, self.__class__.__name__,
self.address.__repr__(), self.prefix_len) self.address.__repr__(), self.prefix_len)
def __eq__(self, o):
if not isinstance(o, address):
return False
return (self.address == o.address and self.prefix_len == o.prefix_len)
def __hash__(self):
return (self._address.__hash__() ^ self._prefix_len.__hash__() ^
self._family.__hash__())
# XXX: ideally this should be replaced by netlink communication # XXX: ideally this should be replaced by netlink communication
def get_if_data(): def get_if_data():
"""Gets current interface and addresses information. Returns a tuple """Gets current interface information. Returns a tuple (byidx, bynam) in
(byidx, bynam) in which each element is a dictionary with the same data, which each element is a dictionary with the same data, but using different
but using different keys: interface indexes and interface names. keys: interface indexes and interface names.
In each dictionary, values are interface objects. In each dictionary, values are interface objects.
""" """
ipcmd = subprocess.Popen(["ip", "-o", "addr", "list"], ipcmd = subprocess.Popen(["ip", "-o", "link", "list"],
stdout = subprocess.PIPE) stdout = subprocess.PIPE)
ipdata = ipcmd.communicate()[0] ipdata = ipcmd.communicate()[0]
assert ipcmd.wait() == 0 assert ipcmd.wait() == 0
curidx = None
byidx = {} byidx = {}
bynam = {} bynam = {}
for line in ipdata.split("\n"): for line in ipdata.split("\n"):
if line == "": if line == "":
continue continue
match = re.search(r'^(\d+):\s+(.*)', line) match = re.search(r'^(\d+):\s+(.*)', line)
if curidx != int(match.group(1)): idx = int(match.group(1))
curidx = int(match.group(1)) i = interface.parse_ip(line)
i = interface.parse_ip(line) byidx[idx] = bynam[i.name] = i
byidx[curidx] = bynam[i.name] = i return byidx, bynam
continue
# Assume curidx is defined def get_addr_data():
assert curidx != None ipcmd = subprocess.Popen(["ip", "-o", "addr", "list"],
stdout = subprocess.PIPE)
ipdata = ipcmd.communicate()[0]
assert ipcmd.wait() == 0
match = re.search(("^%s: %s" % (curidx, byidx[curidx].name)) + byidx = {}
r'\s+(.*)$', line) bynam = {}
line = match.group(1) for line in ipdata.split("\n"):
byidx[curidx].addresses += [address.parse_ip(line)] if line == "":
continue
match = re.search(r'^(\d+):\s+(\S+?)(:?)\s+(.*)', line)
if not match:
raise RuntimeError("Invalid `ip' command output")
idx = int(match.group(1))
name = match.group(2)
if match.group(3):
continue # link info
if name not in bynam:
assert idx not in byidx
bynam[name] = byidx[idx] = []
bynam[name].append(address.parse_ip(match.group(4)))
return byidx, bynam return byidx, bynam
def create_if_pair(if1, if2): def create_if_pair(if1, if2):
......
...@@ -21,7 +21,7 @@ class Node(object): ...@@ -21,7 +21,7 @@ class Node(object):
communication protocol are printed on stderr.""" communication protocol are printed on stderr."""
fd, pid = _start_child(debug, nonetns) fd, pid = _start_child(debug, nonetns)
self._pid = pid self._pid = pid
self._slave = netns.protocol.Client(fd, debug) self._slave = netns.protocol.Client(fd, fd, debug)
self._processes = weakref.WeakValueDictionary() self._processes = weakref.WeakValueDictionary()
Node._nodes[Node._nextnode] = self Node._nodes[Node._nextnode] = self
Node._nextnode += 1 Node._nextnode += 1
...@@ -77,7 +77,7 @@ def _start_child(debug, nonetns): ...@@ -77,7 +77,7 @@ def _start_child(debug, nonetns):
# FIXME: clean up signal handers, atexit functions, etc. # FIXME: clean up signal handers, atexit functions, etc.
try: try:
s0.close() s0.close()
srv = netns.protocol.Server(s1, debug) srv = netns.protocol.Server(s1, s1, debug)
if not nonetns: if not nonetns:
unshare.unshare(unshare.CLONE_NEWNET) unshare.unshare(unshare.CLONE_NEWNET)
srv.run() srv.run()
......
...@@ -7,7 +7,7 @@ try: ...@@ -7,7 +7,7 @@ try:
except ImportError: except ImportError:
from yaml import Loader, Dumper from yaml import Loader, Dumper
import base64, os, passfd, re, signal, socket, sys, traceback, unshare, yaml import base64, os, passfd, re, signal, socket, sys, traceback, unshare, yaml
import netns.subprocess_ import netns.subprocess_, netns.iproute
# ============================================================================ # ============================================================================
# Server-side protocol implementation # Server-side protocol implementation
...@@ -65,7 +65,7 @@ class Server(object): ...@@ -65,7 +65,7 @@ class Server(object):
"""Class that implements the communication protocol and dispatches calls to """Class that implements the communication protocol and dispatches calls to
the required functions. Also works as the main loop for the slave the required functions. Also works as the main loop for the slave
process.""" process."""
def __init__(self, fd, debug = False): def __init__(self, rfd, wfd, debug = False):
# Dictionary of valid commands # Dictionary of valid commands
self.commands = _proto_commands self.commands = _proto_commands
# Flag to stop the server # Flag to stop the server
...@@ -77,16 +77,8 @@ class Server(object): ...@@ -77,16 +77,8 @@ class Server(object):
# Buffer and flag for PROC mode # Buffer and flag for PROC mode
self._proc = None self._proc = None
if hasattr(fd, "readline"): self._rfd = _get_file(rfd, "r")
self._fd = fd self._wfd = _get_file(wfd, "w")
else:
# Since openfd insists on closing the fd on destruction, I need to
# dup()
if hasattr(fd, "fileno"):
nfd = os.dup(fd.fileno())
else:
nfd = os.dup(fd)
self._fd = os.fdopen(nfd, "r+", 1)
def reply(self, code, text): def reply(self, code, text):
"Send back a reply to the client; handle multiline messages" "Send back a reply to the client; handle multiline messages"
...@@ -98,19 +90,19 @@ class Server(object): ...@@ -98,19 +90,19 @@ class Server(object):
clean.extend(i.splitlines()) clean.extend(i.splitlines())
for i in range(len(clean) - 1): for i in range(len(clean) - 1):
s = str(code) + "-" + clean[i] + "\n" s = str(code) + "-" + clean[i] + "\n"
self._fd.write(s) self._wfd.write(s)
if self.debug: if self.debug:
sys.stderr.write("<ans> %s" % s) sys.stderr.write("<ans> %s" % s)
s = str(code) + " " + clean[-1] + "\n" s = str(code) + " " + clean[-1] + "\n"
self._fd.write(s) self._wfd.write(s)
if self.debug: if self.debug:
sys.stderr.write("<ans> %s" % s) sys.stderr.write("<ans> %s" % s)
return return
def readline(self): def readline(self):
"Read a line from the socket and detect connection break-up." "Read a line from the socket and detect connection break-up."
line = self._fd.readline() line = self._rfd.readline()
if not line: if not line:
self.closed = True self.closed = True
return None return None
...@@ -124,7 +116,7 @@ class Server(object): ...@@ -124,7 +116,7 @@ class Server(object):
res = "" res = ""
while True: while True:
line = self._fd.readline() line = self._rfd.readline()
if not line: if not line:
self.closed = True self.closed = True
return None return None
...@@ -228,7 +220,8 @@ class Server(object): ...@@ -228,7 +220,8 @@ class Server(object):
continue continue
cmd[0](cmd[1], *cmd[2]) cmd[0](cmd[1], *cmd[2])
try: try:
self._fd.close() self._rfd.close()
self._wfd.close()
except: except:
pass pass
# FIXME: cleanup # FIXME: cleanup
...@@ -279,7 +272,7 @@ class Server(object): ...@@ -279,7 +272,7 @@ class Server(object):
cmdname) cmdname)
try: try:
fd, payload = passfd.recvfd(self._fd, len(cmdname) + 1) fd, payload = passfd.recvfd(self._rfd, len(cmdname) + 1)
except (IOError, BaseException), e: # FIXME except (IOError, BaseException), e: # FIXME
self.reply(500, "Error receiving FD: %s" % str(e)) self.reply(500, "Error receiving FD: %s" % str(e))
return return
...@@ -297,7 +290,7 @@ class Server(object): ...@@ -297,7 +290,7 @@ class Server(object):
def do_PROC_RUN(self, cmdname): def do_PROC_RUN(self, cmdname):
try: try:
# self._proc['close_fds'] = True # forced self._proc['close_fds'] = True # forced
chld = netns.subprocess_.spawn(**self._proc) chld = netns.subprocess_.spawn(**self._proc)
except: except:
(t, v, tb) = sys.exc_info() (t, v, tb) = sys.exc_info()
...@@ -353,7 +346,13 @@ class Server(object): ...@@ -353,7 +346,13 @@ class Server(object):
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
self.reply(200, "Process signalled.") self.reply(200, "Process signalled.")
# def do_IF_LIST(self, cmdname, ifnr = None): def do_IF_LIST(self, cmdname, ifnr = None):
ifdata = netns.iproute.get_if_data()[0]
if ifnr != None:
ifdata = ifdata[ifnr]
self.reply(200, ["# Interface data follows."] +
yaml.dump(ifdata).split("\n"))
# def do_IF_SET(self, cmdname, ifnr, key, val): # def do_IF_SET(self, cmdname, ifnr, key, val):
# def do_IF_RTRN(self, cmdname, ifnr, netns): # def do_IF_RTRN(self, cmdname, ifnr, netns):
# def do_ADDR_LIST(self, cmdname, ifnr = None): # def do_ADDR_LIST(self, cmdname, ifnr = None):
...@@ -370,32 +369,22 @@ class Server(object): ...@@ -370,32 +369,22 @@ class Server(object):
class Client(object): class Client(object):
"""Client-side implementation of the communication protocol. Acts as a RPC """Client-side implementation of the communication protocol. Acts as a RPC
service.""" service."""
def __init__(self, fd, debug = False): def __init__(self, rfd, wfd, debug = False):
# XXX: In some cases we do not call dup(); maybe this should be self._rfd = _get_file(rfd, "r")
# consistent? self._wfd = _get_file(wfd, "w")
if not hasattr(fd, "readline"):
# Since openfd insists on closing the fd on destruction, I need to
# dup()
if hasattr(fd, "fileno"):
nfd = os.dup(fd.fileno())
else:
nfd = os.dup(fd)
fd = os.fdopen(nfd, "r+", 1)
self._fd = fd
# Wait for slave to send banner # Wait for slave to send banner
self._read_and_check_reply() self._read_and_check_reply()
def _send_cmd(self, *args): def _send_cmd(self, *args):
s = " ".join(map(str, args)) + "\n" s = " ".join(map(str, args)) + "\n"
self._fd.write(s) self._wfd.write(s)
def _read_reply(self): def _read_reply(self):
"""Reads a (possibly multi-line) response from the server. Returns a """Reads a (possibly multi-line) response from the server. Returns a
tuple containing (code, text)""" tuple containing (code, text)"""
text = [] text = []
while True: while True:
line = self._fd.readline().rstrip() line = self._rfd.readline().rstrip()
if not line: if not line:
raise RuntimeError("Protocol error, empty line received") raise RuntimeError("Protocol error, empty line received")
...@@ -428,10 +417,10 @@ class Client(object): ...@@ -428,10 +417,10 @@ class Client(object):
self._send_cmd("PROC", name) self._send_cmd("PROC", name)
self._read_and_check_reply(3) self._read_and_check_reply(3)
try: try:
passfd.sendfd(self._fd, fd, "PROC " + name) passfd.sendfd(self._wfd, fd, "PROC " + name)
except: except:
# need to fill the buffer on the other side, nevertheless # need to fill the buffer on the other side, nevertheless
self._fd.write("=" * (len(name) + 5)) self._wfd.write("=" * (len(name) + 5) + "\n")
# And also read the expected error # And also read the expected error
self._read_and_check_reply(5) self._read_and_check_reply(5)
raise raise
...@@ -526,3 +515,14 @@ def _b64(text): ...@@ -526,3 +515,14 @@ def _b64(text):
else: else:
return text return text
def _get_file(fd, mode):
# XXX: In some cases we do not call dup(); maybe this should be consistent?
if hasattr(fd, "readline"):
return fd
# Since openfd insists on closing the fd on destruction, I need to dup()
if hasattr(fd, "fileno"):
nfd = os.dup(fd.fileno())
else:
nfd = os.dup(fd)
return os.fdopen(nfd, mode, 1)
...@@ -13,11 +13,11 @@ class TestServer(unittest.TestCase): ...@@ -13,11 +13,11 @@ class TestServer(unittest.TestCase):
pid = os.fork() pid = os.fork()
if not pid: if not pid:
s1.close() s1.close()
srv = netns.protocol.Server(s0) srv = netns.protocol.Server(s0, s0)
srv.run() srv.run()
s3.close() s3.close()
srv = netns.protocol.Server(s2.fileno()) srv = netns.protocol.Server(s2.fileno(), s2.fileno())
srv.run() srv.run()
os._exit(0) os._exit(0)
...@@ -39,10 +39,10 @@ class TestServer(unittest.TestCase): ...@@ -39,10 +39,10 @@ class TestServer(unittest.TestCase):
pid = os.fork() pid = os.fork()
if not pid: if not pid:
s1.close() s1.close()
srv = netns.protocol.Server(s0) srv = netns.protocol.Server(s0, s0)
srv.run() srv.run()
os._exit(0) os._exit(0)
cli = netns.protocol.Client(s1) cli = netns.protocol.Client(s1, s1)
s0.close() s0.close()
# make PROC SIN fail # make PROC SIN fail
...@@ -57,7 +57,7 @@ class TestServer(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestServer(unittest.TestCase):
def test_basic_stuff(self): def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = netns.protocol.Server(s0) srv = netns.protocol.Server(s0, s0)
s1 = s1.makefile("r+", 1) s1 = s1.makefile("r+", 1)
def check_error(self, cmd, code = 500): def check_error(self, cmd, code = 500):
...@@ -87,10 +87,8 @@ class TestServer(unittest.TestCase): ...@@ -87,10 +87,8 @@ class TestServer(unittest.TestCase):
check_error(self, "proc abrt") check_error(self, "proc abrt")
check_error(self, "proc run") check_error(self, "proc run")
# not implemented check_ok(self, "if list", srv.do_IF_LIST, [])
#check_ok(self, "if list", srv.do_IF_LIST, []) check_ok(self, "if list 1", srv.do_IF_LIST, [1])
#check_ok(self, "if list 1", srv.do_IF_LIST, [1])
check_error(self, "if list")
check_error(self, "proc poll") # missing arg check_error(self, "proc poll") # missing arg
check_error(self, "proc poll 1 2") # too many args check_error(self, "proc poll 1 2") # too many args
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment