Commit 5ea9a3ef authored by Martín Ferrari's avatar Martín Ferrari

Node class: Keep track of processes, make unshare optional.

parent 8b0eca1e
...@@ -12,16 +12,17 @@ class Node(object): ...@@ -12,16 +12,17 @@ class Node(object):
s = sorted(Node._nodes.items(), key = lambda x: x[0]) s = sorted(Node._nodes.items(), key = lambda x: x[0])
return [ x[1] for x in s ] return [ x[1] for x in s ]
def __init__(self, debug = False): def __init__(self, debug = False, nonetns = False):
"""Create a new node in the emulation. Implemented as a separate """Create a new node in the emulation. Implemented as a separate
process in a new network name space. Requires root privileges to run. process in a new network name space. Requires root privileges to run.
If keepns is true, the network name space is not created and can be run If keepns is true, the network name space is not created and can be run
as a normal user, for testing. If debug is true, details of the as a normal user, for testing. If debug is true, details of the
communication protocol are printed on stderr.""" communication protocol are printed on stderr."""
fd, pid = _start_child(debug) fd, pid = _start_child(debug, nonetns)
self._pid = pid self._pid = pid
self._slave = netns.protocol.Client(fd, debug) self._slave = netns.protocol.Client(fd, debug)
self._processes = weakref.WeakValueDictionary()
Node._nodes[Node._nextnode] = self Node._nodes[Node._nextnode] = self
Node._nextnode += 1 Node._nextnode += 1
...@@ -29,10 +30,15 @@ class Node(object): ...@@ -29,10 +30,15 @@ class Node(object):
self.shutdown() self.shutdown()
def shutdown(self): def shutdown(self):
self._slave.shutdown() for p in self._processes:
p.destroy()
del self._processes
del self._pid del self._pid
self._slave.shutdown()
del self._slave del self._slave
def _add_subprocess(self, subprocess):
self._processes[subprocess.pid] = subprocess
@property @property
def pid(self): def pid(self):
return self._pid return self._pid
...@@ -43,7 +49,7 @@ class Node(object): ...@@ -43,7 +49,7 @@ class Node(object):
def add_default_route(self, nexthop, interface = None): def add_default_route(self, nexthop, interface = None):
return self.add_route('0.0.0.0', 0, nexthop, interface) return self.add_route('0.0.0.0', 0, nexthop, interface)
def start_process(self, args): def start_process(self, args):
return Process() return netns.subprocess.start_process(node, args)
def run_process(self, args): def run_process(self, args):
return ("", "") return ("", "")
def get_routes(self): def get_routes(self):
...@@ -60,15 +66,10 @@ class Interface(object): ...@@ -60,15 +66,10 @@ class Interface(object):
def add_v6_address(self, address, prefix_len): def add_v6_address(self, address, prefix_len):
pass pass
class Process(object):
def __init__(self):
self.pid = os.getpid()
self.valid = True
# Handle the creation of the child; parent gets (fd, pid), child creates and # Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns. # runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run. # Requires CAP_SYS_ADMIN privileges to run.
def _start_child(debug = False): def _start_child(debug, nonetns):
# Create socket pair to communicate # Create socket pair to communicate
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop # Spawn a child that will run in a loop
...@@ -81,6 +82,7 @@ def _start_child(debug = False): ...@@ -81,6 +82,7 @@ def _start_child(debug = False):
try: try:
s0.close() s0.close()
srv = netns.protocol.Server(s1, debug) srv = netns.protocol.Server(s1, debug)
if not nonetns:
unshare.unshare(unshare.CLONE_NEWNET) unshare.unshare(unshare.CLONE_NEWNET)
srv.run() srv.run()
except BaseException, e: except BaseException, e:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment