Commit 033ce168 authored by Tom Niget's avatar Tom Niget

Add type hints

parent 7f17df26
......@@ -140,7 +140,7 @@ def main():
if not r:
break
out += r
if srv.poll() != None or clt.poll() != None:
if srv.poll() is not None or clt.poll() is not None:
break
if srv.poll():
......
......@@ -15,6 +15,6 @@ setup(
license = 'GPLv2',
platforms = 'Linux',
packages = ['nemu'],
install_requires = ['unshare', 'six'],
install_requires = ['unshare', 'six', 'attrs'],
package_dir = {'': 'src'}
)
......@@ -41,7 +41,7 @@ class _Config(object):
except KeyError:
pass # User not found.
def _set_run_as(self, user):
def _set_run_as(self, user: str | int):
"""Setter for `run_as'."""
if str(user).isdigit():
uid = int(user)
......@@ -61,7 +61,7 @@ class _Config(object):
self._run_as = run_as
return run_as
def _get_run_as(self):
def _get_run_as(self) -> str:
"""Setter for `run_as'."""
return self._run_as
......
......@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]:
os.set_inheritable(b, True)
return a, b
def socket(*args, **kwargs) -> pysocket.socket:
s = pysocket.socket(*args, **kwargs)
s.set_inheritable(True)
return s
def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]:
a, b = pysocket.socketpair(*args, **kwargs)
a.set_inheritable(True)
b.set_inheritable(True)
return a, b
def fromfd(*args, **kwargs) -> pysocket.socket:
s = pysocket.fromfd(*args, **kwargs)
s.set_inheritable(True)
return s
def fdopen(*args, **kwargs) -> pysocket.socket:
s = os.fdopen(*args, **kwargs)
s.set_inheritable(True)
return s
\ No newline at end of file
......@@ -25,7 +25,7 @@ import subprocess
import sys
import syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
from typing import TypeVar, Callable
from typing import TypeVar, Callable, Optional
__all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"]
......@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
__all__ += ["error", "warning", "notice", "info", "debug"]
def find_bin(name, extra_path=None):
def find_bin(name: str, extra_path: Optional[list[str]] = None) -> Optional[str]:
"""Try hard to find the location of needed programs."""
search = []
if "PATH" in os.environ:
......@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None):
return None
def find_bin_or_die(name, extra_path=None):
def find_bin_or_die(name: str, extra_path: Optional[list[str]] = None) -> str:
"""Try hard to find the location of needed programs; raise on failure."""
res = find_bin(name, extra_path)
if not res:
......@@ -156,7 +156,7 @@ _log_syslog_opts = ()
_log_pid = os.getpid()
def set_log_level(level):
def set_log_level(level: int):
"Sets the log level for console messages, does not affect syslog logging."
global _log_level
assert level > LOG_ERR and level <= LOG_DEBUG
......@@ -191,7 +191,7 @@ def _init_log():
info("Syslog logging started")
def logger(priority, message):
def logger(priority: int, message: str):
"Print a log message in syslog, console or both."
if _log_use_syslog:
if os.getpid() != _log_pid:
......
......@@ -19,33 +19,36 @@
import os
import weakref
from typing import TypedDict
import nemu.iproute
from nemu.environ import *
__all__ = ['NodeInterface', 'P2PInterface', 'ImportedInterface',
'ImportedNodeInterface', 'Switch']
'ImportedNodeInterface', 'Switch']
class Interface(object):
"""Just a base class for the *Interface classes: assign names and handle
destruction."""
_nextid = 0
@staticmethod
def _gen_next_id():
def _gen_next_id() -> int:
n = Interface._nextid
Interface._nextid += 1
return n
@staticmethod
def _gen_if_name():
def _gen_if_name() -> str:
n = Interface._gen_next_id()
# Max 15 chars
return "NETNSif-%.4x%.3x" % (os.getpid() & 0xffff, n)
def __init__(self, index):
def __init__(self, index: int):
self._idx = index
debug("%s(0x%x).__init__(), index = %d" % (self.__class__.__name__,
id(self), index))
id(self), index))
def __del__(self):
debug("%s(0x%x).__del__()" % (self.__class__.__name__, id(self)))
......@@ -55,7 +58,7 @@ class Interface(object):
raise NotImplementedError
@property
def index(self):
def index(self) -> int:
"""Interface index as seen by the kernel."""
return self._idx
......@@ -65,20 +68,35 @@ class Interface(object):
control interfaces can be put into a Switch, for example."""
return None
class Ipv4Dict(TypedDict):
address: str
prefix_len: int
broadcast: str
family: str
class Ipv6Dict(TypedDict):
address: str
prefix_len: int
family: str
class NSInterface(Interface):
"""Add user-facing methods for interfaces that go into a netns."""
def __init__(self, node, index):
def __init__(self, node: "nemu.Node", index):
super(NSInterface, self).__init__(index)
self._slave = node._slave
# Disable auto-configuration
# you wish: need to take into account the nonetns mode; plus not
# touching some pre-existing ifaces
#node.system([SYSCTL_PATH, '-w', 'net.ipv6.conf.%s.autoconf=0' %
#self.name])
# node.system([SYSCTL_PATH, '-w', 'net.ipv6.conf.%s.autoconf=0' %
# self.name])
node._add_interface(self)
# some black magic to automatically get/set interface attributes
def __getattr__(self, name):
def __getattr__(self, name: str):
# If name starts with _, it must be a normal attr
if name[0] == '_':
return super(Interface, self).__getattribute__(name)
......@@ -92,57 +110,59 @@ class NSInterface(Interface):
iface = slave.get_if_data(self.index)
return getattr(iface, name)
def __setattr__(self, name, value):
if name[0] == '_': # forbid anything that doesn't start with a _
def __setattr__(self, name: str, value):
if name[0] == '_': # forbid anything that doesn't start with a _
super(Interface, self).__setattr__(name, value)
return
iface = nemu.iproute.interface(index = self.index)
iface = nemu.iproute.interface(index=self.index)
setattr(iface, name, value)
return self._slave.set_if(iface)
def add_v4_address(self, address, prefix_len, broadcast = None):
def add_v4_address(self, address: str, prefix_len: int, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
self._slave.add_addr(self.index, addr)
def add_v6_address(self, address, prefix_len):
def add_v6_address(self, address: str, prefix_len: int):
addr = nemu.iproute.ipv6address(address, prefix_len)
self._slave.add_addr(self.index, addr)
def del_v4_address(self, address, prefix_len, broadcast = None):
def del_v4_address(self, address: str, prefix_len: int, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
self._slave.del_addr(self.index, addr)
def del_v6_address(self, address, prefix_len):
def del_v6_address(self, address: str, prefix_len: int):
addr = nemu.iproute.ipv6address(address, prefix_len)
self._slave.del_addr(self.index, addr)
def get_addresses(self):
def get_addresses(self) -> list[Ipv4Dict | Ipv6Dict]:
addresses = self._slave.get_addr_data(self.index)
ret = []
for a in addresses:
if hasattr(a, 'broadcast'):
ret.append(dict(
address = a.address,
prefix_len = a.prefix_len,
broadcast = a.broadcast,
family = 'inet'))
address=a.address,
prefix_len=a.prefix_len,
broadcast=a.broadcast,
family='inet'))
else:
ret.append(dict(
address = a.address,
prefix_len = a.prefix_len,
family = 'inet6'))
address=a.address,
prefix_len=a.prefix_len,
family='inet6'))
return ret
class NodeInterface(NSInterface):
"""Class to create and handle a virtual interface inside a name space, it
can be connected to a Switch object with emulation of link
characteristics."""
def __init__(self, node):
def __init__(self, node: "nemu.Node"):
"""Create a new interface. `node' is the name space in which this
interface should be put."""
self._slave = None
if1 = nemu.iproute.interface(name = self._gen_if_name())
if2 = nemu.iproute.interface(name = self._gen_if_name())
if1 = nemu.iproute.interface(name=self._gen_if_name())
if2 = nemu.iproute.interface(name=self._gen_if_name())
ctl, ns = nemu.iproute.create_if_pair(if1, if2)
try:
nemu.iproute.change_netns(ns, node.pid)
......@@ -165,18 +185,20 @@ class NodeInterface(NSInterface):
self._slave.del_if(self.index)
self._slave = None
class P2PInterface(NSInterface):
"""Class to create and handle point-to-point interfaces between name
spaces, without using Switch objects. Those do not allow any kind of
traffic shaping.
As two interfaces need to be created, instead of using the class
constructor, use the P2PInterface.create_pair() static method."""
@staticmethod
def create_pair(node1, node2):
def create_pair(node1: "nemu.Node", node2: "nemu.Node"):
"""Create and return a pair of connected P2PInterface objects,
assigned to name spaces represented by `node1' and `node2'."""
if1 = nemu.iproute.interface(name = P2PInterface._gen_if_name())
if2 = nemu.iproute.interface(name = P2PInterface._gen_if_name())
if1 = nemu.iproute.interface(name=P2PInterface._gen_if_name())
if2 = nemu.iproute.interface(name=P2PInterface._gen_if_name())
pair = nemu.iproute.create_if_pair(if1, if2)
try:
nemu.iproute.change_netns(pair[0], node1.pid)
......@@ -206,6 +228,7 @@ class P2PInterface(NSInterface):
self._slave.del_if(self.index)
self._slave = None
class ImportedNodeInterface(NSInterface):
"""Class to handle already existing interfaces inside a name space:
real devices, tun devices, etc.
......@@ -213,7 +236,8 @@ class ImportedNodeInterface(NSInterface):
to be moved inside the name space.
On destruction, the interface will be restored to the original name space
and will try to restore the original state."""
def __init__(self, node, iface, migrate = True):
def __init__(self, node: "nemu.Node", iface, migrate=True):
self._slave = None
self._migrate = migrate
if self._migrate:
......@@ -230,7 +254,7 @@ class ImportedNodeInterface(NSInterface):
super(ImportedNodeInterface, self).__init__(node, iface.index)
def destroy(self): # override: restore as much as possible
def destroy(self): # override: restore as much as possible
if not self._slave:
return
debug("ImportedNodeInterface(0x%x).destroy()" % id(self))
......@@ -244,17 +268,19 @@ class ImportedNodeInterface(NSInterface):
nemu.iproute.set_if(self._original_state)
self._slave = None
class TapNodeInterface(NSInterface):
"""Class to create a tap interface inside a name space, it
can be connected to a Switch object with emulation of link
characteristics."""
def __init__(self, node, use_pi = False):
def __init__(self, node, use_pi=False):
"""Create a new tap interface. 'node' is the name space in which this
interface should be put."""
self._fd = None
self._slave = None
iface = nemu.iproute.interface(name = self._gen_if_name())
iface, self._fd = nemu.iproute.create_tap(iface, use_pi = use_pi)
iface = nemu.iproute.interface(name=self._gen_if_name())
iface, self._fd = nemu.iproute.create_tap(iface, use_pi=use_pi)
nemu.iproute.change_netns(iface.name, node.pid)
super(TapNodeInterface, self).__init__(node, iface.index)
......@@ -271,18 +297,20 @@ class TapNodeInterface(NSInterface):
except:
pass
class TunNodeInterface(NSInterface):
"""Class to create a tun interface inside a name space, it
can be connected to a Switch object with emulation of link
characteristics."""
def __init__(self, node, use_pi = False):
def __init__(self, node, use_pi=False):
"""Create a new tap interface. 'node' is the name space in which this
interface should be put."""
self._fd = None
self._slave = None
iface = nemu.iproute.interface(name = self._gen_if_name())
iface, self._fd = nemu.iproute.create_tap(iface, use_pi = use_pi,
tun = True)
iface = nemu.iproute.interface(name=self._gen_if_name())
iface, self._fd = nemu.iproute.create_tap(iface, use_pi=use_pi,
tun=True)
nemu.iproute.change_netns(iface.name, node.pid)
super(TunNodeInterface, self).__init__(node, iface.index)
......@@ -299,9 +327,11 @@ class TunNodeInterface(NSInterface):
except:
pass
class ExternalInterface(Interface):
"""Add user-facing methods for interfaces that run in the main
namespace."""
@property
def control(self):
# This is *the* control interface
......@@ -313,14 +343,14 @@ class ExternalInterface(Interface):
return getattr(iface, name)
def __setattr__(self, name, value):
if name[0] == '_': # forbid anything that doesn't start with a _
if name[0] == '_': # forbid anything that doesn't start with a _
super(ExternalInterface, self).__setattr__(name, value)
return
iface = nemu.iproute.interface(index = self.index)
iface = nemu.iproute.interface(index=self.index)
setattr(iface, name, value)
return nemu.iproute.set_if(iface)
def add_v4_address(self, address, prefix_len, broadcast = None):
def add_v4_address(self, address, prefix_len, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
nemu.iproute.add_addr(self.index, addr)
......@@ -328,7 +358,7 @@ class ExternalInterface(Interface):
addr = nemu.iproute.ipv6address(address, prefix_len)
nemu.iproute.add_addr(self.index, addr)
def del_v4_address(self, address, prefix_len, broadcast = None):
def del_v4_address(self, address, prefix_len, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
nemu.iproute.del_addr(self.index, addr)
......@@ -342,23 +372,26 @@ class ExternalInterface(Interface):
for a in addresses:
if hasattr(a, 'broadcast'):
ret.append(dict(
address = a.address,
prefix_len = a.prefix_len,
broadcast = a.broadcast,
family = 'inet'))
address=a.address,
prefix_len=a.prefix_len,
broadcast=a.broadcast,
family='inet'))
else:
ret.append(dict(
address = a.address,
prefix_len = a.prefix_len,
family = 'inet6'))
address=a.address,
prefix_len=a.prefix_len,
family='inet6'))
return ret
class SlaveInterface(ExternalInterface):
"""Class to handle the main-name-space-facing half of NodeInterface.
Does nothing, just avoids any destroy code."""
def destroy(self):
pass
class ImportedInterface(ExternalInterface):
"""Class to handle already existing interfaces. Analogous to
ImportedNodeInterface, this class only differs in that the interface is
......@@ -366,6 +399,7 @@ class ImportedInterface(ExternalInterface):
connected to Switch objects and not assigned to a name space. On
destruction, the code will try to restore the interface to the state it
was in before being imported into nemu."""
def __init__(self, iface):
self._original_state = None
iface = nemu.iproute.get_if(iface)
......@@ -373,12 +407,13 @@ class ImportedInterface(ExternalInterface):
super(ImportedInterface, self).__init__(iface.index)
# FIXME: register somewhere for destruction!
def destroy(self): # override: restore as much as possible
def destroy(self): # override: restore as much as possible
if self._original_state:
debug("ImportedInterface(0x%x).destroy()" % id(self))
nemu.iproute.set_if(self._original_state)
self._original_state = None
# Switch is just another interface type
class Switch(ExternalInterface):
......@@ -412,7 +447,7 @@ class Switch(ExternalInterface):
return getattr(iface, name)
def __setattr__(self, name, value):
if name[0] == '_': # forbid anything that doesn't start with a _
if name[0] == '_': # forbid anything that doesn't start with a _
super(Switch, self).__setattr__(name, value)
return
# Set ports
......@@ -421,7 +456,7 @@ class Switch(ExternalInterface):
if self._check_port(i.index):
setattr(i, name, value)
# Set bridge
iface = nemu.iproute.bridge(index = self.index)
iface = nemu.iproute.bridge(index=self.index)
setattr(iface, name, value)
nemu.iproute.set_bridge(iface)
......@@ -460,7 +495,7 @@ class Switch(ExternalInterface):
return True
# else
warning("Switch(0x%x): Port (index = %d) went away." % (id(self),
port_index))
port_index))
del self._ports[port_index]
return False
......@@ -472,12 +507,12 @@ class Switch(ExternalInterface):
self._apply_parameters({}, iface.control)
del self._ports[iface.control.index]
def set_parameters(self, bandwidth = None,
delay = None, delay_jitter = None,
delay_correlation = None, delay_distribution = None,
loss = None, loss_correlation = None,
dup = None, dup_correlation = None,
corrupt = None, corrupt_correlation = None):
def set_parameters(self, bandwidth=None,
delay=None, delay_jitter=None,
delay_correlation=None, delay_distribution=None,
loss=None, loss_correlation=None,
dup=None, dup_correlation=None,
corrupt=None, corrupt_correlation=None):
"""Set the parameters that control the link characteristics. For the
description of each, refer to netem documentation:
http://www.linuxfoundation.org/collaborate/workgroups/networking/netem
......@@ -492,13 +527,13 @@ class Switch(ExternalInterface):
`dup_correlation', `corrupt', and `corrupt_correlation' take a
percentage value in the form of a number between 0 and 1. (50% is
passed as 0.5)."""
parameters = dict(bandwidth = bandwidth,
delay = delay, delay_jitter = delay_jitter,
delay_correlation = delay_correlation,
delay_distribution = delay_distribution,
loss = loss, loss_correlation = loss_correlation,
dup = dup, dup_correlation = dup_correlation,
corrupt = corrupt, corrupt_correlation = corrupt_correlation)
parameters = dict(bandwidth=bandwidth,
delay=delay, delay_jitter=delay_jitter,
delay_correlation=delay_correlation,
delay_distribution=delay_distribution,
loss=loss, loss_correlation=loss_correlation,
dup=dup, dup_correlation=dup_correlation,
corrupt=corrupt, corrupt_correlation=corrupt_correlation)
try:
self._apply_parameters(parameters)
except:
......@@ -506,7 +541,6 @@ class Switch(ExternalInterface):
raise
self._parameters = parameters
def _apply_parameters(self, parameters, port = None):
def _apply_parameters(self, parameters, port=None):
for i in [port] if port else list(self._ports.values()):
nemu.iproute.set_tc(i.index, **parameters)
......@@ -24,7 +24,10 @@ import re
import socket
import struct
import sys
from typing import TypeVar, Callable, Literal
from attr import evolve
from attrs import define, setters, field
import six
from nemu.environ import *
......@@ -46,18 +49,21 @@ def _any_to_bool(any):
return any != ""
return bool(any)
def _positive(val):
v = int(val)
if v <= 0:
raise ValueError("Invalid value: %d" % v)
return v
def _non_empty_str(val):
if val == "":
return None
else:
return str(val)
def _fix_lladdr(addr):
foo = addr.lower()
if ":" in addr:
......@@ -77,106 +83,100 @@ def _fix_lladdr(addr):
# Glue
return ":".join(m.groups())
def _make_getter(attr, conv = lambda x: x):
def _make_getter(attr, conv=lambda x: x):
def getter(self):
return conv(getattr(self, attr))
return getter
def _make_setter(attr, conv = lambda x: x):
def _make_setter(attr, conv=lambda x: x):
def setter(self, value):
if value == None:
if value is None:
setattr(self, attr, None)
else:
setattr(self, attr, conv(value))
return setter
T = TypeVar("T")
U = TypeVar("U")
def _if_any(conv: Callable[[T], U]):
def c(val: T) -> U:
if val is None:
return None
else:
return conv(val)
return c
# classes for internal use
class interface(object):
@define(repr=False)
class interface:
"""Class for internal use. It is mostly a data container used to easily
pass information around; with some convenience methods."""
# information for other parts of the code
changeable_attributes = ["name", "mtu", "lladdr", "broadcast", "up",
"multicast", "arp"]
# Index should be read-only
index = property(_make_getter("_index"))
up = property(_make_getter("_up"), _make_setter("_up", _any_to_bool))
mtu = property(_make_getter("_mtu"), _make_setter("_mtu", _positive))
lladdr = property(_make_getter("_lladdr"),
_make_setter("_lladdr", _fix_lladdr))
arp = property(_make_getter("_arp"), _make_setter("_arp", _any_to_bool))
multicast = property(_make_getter("_mc"), _make_setter("_mc", _any_to_bool))
def __init__(self, index = None, name = None, up = None, mtu = None,
lladdr = None, broadcast = None, multicast = None, arp = None):
self._index = _positive(index) if index is not None else None
self.name = name
self.up = up
self.mtu = mtu
self.lladdr = lladdr
self.broadcast = broadcast
self.multicast = multicast
self.arp = arp
"multicast", "arp"]
index: int = field(default=None, converter=_if_any(_positive), on_setattr=setters.frozen)
name: str = field(default=None)
up: bool = field(default=None, converter=_if_any(_any_to_bool))
mtu: int = field(default=None, converter=_if_any(_positive))
lladdr: str = field(default=None, converter=_if_any(_fix_lladdr))
broadcast: str = field(default=None)
multicast: bool = field(default=None, converter=_if_any(_any_to_bool))
arp: bool = field(default=None, converter=_if_any(_any_to_bool))
def __repr__(self):
s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
s += "broadcast = %s, multicast = %s, arp = %s)"
return s % (self.__module__, self.__class__.__name__,
self.index.__repr__(), self.name.__repr__(),
self.up.__repr__(), self.mtu.__repr__(),
self.lladdr.__repr__(), self.broadcast.__repr__(),
self.multicast.__repr__(), self.arp.__repr__())
self.index.__repr__(), self.name.__repr__(),
self.up.__repr__(), self.mtu.__repr__(),
self.lladdr.__repr__(), self.broadcast.__repr__(),
self.multicast.__repr__(), self.arp.__repr__())
def __sub__(self, o):
"""Compare attributes and return a new object with just the attributes
that differ set (with the value they have in the first operand). The
index remains equal to the first operand."""
name = None if self.name == o.name else self.name
up = None if self.up == o.up else self.up
mtu = None if self.mtu == o.mtu else self.mtu
lladdr = None if self.lladdr == o.lladdr else self.lladdr
broadcast = None if self.broadcast == o.broadcast else self.broadcast
multicast = None if self.multicast == o.multicast else self.multicast
arp = None if self.arp == o.arp else self.arp
return self.__class__(self.index, name, up, mtu, lladdr, broadcast,
multicast, arp)
name = None if self.name == o.name else self.name
up = None if self.up == o.up else self.up
mtu = None if self.mtu == o.mtu else self.mtu
lladdr = None if self.lladdr == o.lladdr else self.lladdr
broadcast = None if self.broadcast == o.broadcast else self.broadcast
multicast = None if self.multicast == o.multicast else self.multicast
arp = None if self.arp == o.arp else self.arp
return interface(self.index, name, up, mtu, lladdr, broadcast,
multicast, arp)
def copy(self):
return copy.copy(self)
@define(repr=False)
class bridge(interface):
changeable_attributes = interface.changeable_attributes + ["stp",
"forward_delay", "hello_time", "ageing_time", "max_age"]
# Index should be read-only
stp = property(_make_getter("_stp"), _make_setter("_stp", _any_to_bool))
forward_delay = property(_make_getter("_forward_delay"),
_make_setter("_forward_delay", float))
hello_time = property(_make_getter("_hello_time"),
_make_setter("_hello_time", float))
ageing_time = property(_make_getter("_ageing_time"),
_make_setter("_ageing_time", float))
max_age = property(_make_getter("_max_age"),
_make_setter("_max_age", float))
"forward_delay", "hello_time", "ageing_time", "max_age"]
stp: bool = field(default=None, converter=_if_any(_any_to_bool))
forward_delay: float = field(default=None, converter=_if_any(float))
hello_time: float = field(default=None, converter=_if_any(float))
ageing_time: float = field(default=None, converter=_if_any(float))
max_age: float = field(default=None, converter=_if_any(float))
@classmethod
def upgrade(cls, iface, *kargs, **kwargs):
"""Upgrade a interface to a bridge."""
return cls(iface.index, iface.name, iface.up, iface.mtu, iface.lladdr,
iface.broadcast, iface.multicast, iface.arp, *kargs, **kwargs)
def __init__(self, index = None, name = None, up = None, mtu = None,
lladdr = None, broadcast = None, multicast = None, arp = None,
stp = None, forward_delay = None, hello_time = None,
ageing_time = None, max_age = None):
super(bridge, self).__init__(index, name, up, mtu, lladdr, broadcast,
multicast, arp)
self.stp = stp
self.forward_delay = forward_delay
self.hello_time = hello_time
self.ageing_time = ageing_time
self.max_age = max_age
iface.broadcast, iface.multicast, iface.arp, *kargs, **kwargs)
def __repr__(self):
s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
......@@ -184,33 +184,40 @@ class bridge(interface):
s += "forward_delay = %s, hello_time = %s, ageing_time = %s, "
s += "max_age = %s)"
return s % (self.__module__, self.__class__.__name__,
self.index.__repr__(), self.name.__repr__(),
self.up.__repr__(), self.mtu.__repr__(),
self.lladdr.__repr__(), self.broadcast.__repr__(),
self.multicast.__repr__(), self.arp.__repr__(),
self.stp.__repr__(), self.forward_delay.__repr__(),
self.hello_time.__repr__(), self.ageing_time.__repr__(),
self.max_age.__repr__())
self.index.__repr__(), self.name.__repr__(),
self.up.__repr__(), self.mtu.__repr__(),
self.lladdr.__repr__(), self.broadcast.__repr__(),
self.multicast.__repr__(), self.arp.__repr__(),
self.stp.__repr__(), self.forward_delay.__repr__(),
self.hello_time.__repr__(), self.ageing_time.__repr__(),
self.max_age.__repr__())
def __sub__(self, o):
r = super(bridge, self).__sub__(o)
r = bridge.upgrade(super().__sub__(o))
if type(o) == interface:
return r
r.stp = None if self.stp == o.stp else self.stp
r.hello_time = None if self.hello_time == o.hello_time else \
self.hello_time
r.stp = None if self.stp == o.stp else self.stp
r.hello_time = None if self.hello_time == o.hello_time else \
self.hello_time
r.forward_delay = None if self.forward_delay == o.forward_delay else \
self.forward_delay
r.ageing_time = None if self.ageing_time == o.ageing_time else \
self.ageing_time
r.max_age = None if self.max_age == o.max_age else self.max_age
self.forward_delay
r.ageing_time = None if self.ageing_time == o.ageing_time else \
self.ageing_time
r.max_age = None if self.max_age == o.max_age else self.max_age
return r
class address(object):
"""Class for internal use. It is mostly a data container used to easily
pass information around; with some convenience methods. __eq__ and
__hash__ are defined just to be able to easily find duplicated
addresses."""
def __init__(self, address: str, prefix_len: int, family: socket.AddressFamily):
self.address = address
self.prefix_len = int(prefix_len)
self.family = family
# broadcast is not taken into account for differentiating addresses
def __eq__(self, o):
if not isinstance(o, address):
......@@ -220,52 +227,50 @@ class address(object):
def __hash__(self):
h = (self.address.__hash__() ^ self.prefix_len.__hash__() ^
self.family.__hash__())
self.family.__hash__())
return h
class ipv4address(address):
def __init__(self, address, prefix_len, broadcast):
self.address = address
self.prefix_len = int(prefix_len)
def __init__(self, address: str, prefix_len: int, broadcast):
super().__init__(address, prefix_len, socket.AF_INET)
self.broadcast = broadcast
self.family = socket.AF_INET
def __repr__(self):
s = "%s.%s(address = %s, prefix_len = %d, broadcast = %s)"
return s % (self.__module__, self.__class__.__name__,
self.address.__repr__(), self.prefix_len,
self.broadcast.__repr__())
self.address.__repr__(), self.prefix_len,
self.broadcast.__repr__())
class ipv6address(address):
def __init__(self, address, prefix_len):
self.address = address
self.prefix_len = int(prefix_len)
self.family = socket.AF_INET6
def __init__(self, address: str, prefix_len: int):
super().__init__(address, prefix_len, socket.AF_INET6)
def __repr__(self):
s = "%s.%s(address = %s, prefix_len = %d)"
return s % (self.__module__, self.__class__.__name__,
self.address.__repr__(), self.prefix_len)
self.address.__repr__(), self.prefix_len)
class route(object):
tipes = ["unicast", "local", "broadcast", "multicast", "throw",
"unreachable", "prohibit", "blackhole", "nat"]
"unreachable", "prohibit", "blackhole", "nat"]
tipe = property(_make_getter("_tipe", tipes.__getitem__),
_make_setter("_tipe", tipes.index))
_make_setter("_tipe", tipes.index))
prefix = property(_make_getter("_prefix"),
_make_setter("_prefix", _non_empty_str))
_make_setter("_prefix", _non_empty_str))
prefix_len = property(_make_getter("_plen"),
lambda s, v: setattr(s, "_plen", int(v or 0)))
lambda s, v: setattr(s, "_plen", int(v or 0)))
nexthop = property(_make_getter("_nexthop"),
_make_setter("_nexthop", _non_empty_str))
_make_setter("_nexthop", _non_empty_str))
interface = property(_make_getter("_interface"),
_make_setter("_interface", _positive))
_make_setter("_interface", _positive))
metric = property(_make_getter("_metric"),
lambda s, v: setattr(s, "_metric", int(v or 0)))
lambda s, v: setattr(s, "_metric", int(v or 0)))
def __init__(self, tipe = "unicast", prefix = None, prefix_len = 0,
nexthop = None, interface = None, metric = 0):
def __init__(self, tipe="unicast", prefix=None, prefix_len=0,
nexthop=None, interface=None, metric=0):
self.tipe = tipe
self.prefix = prefix
self.prefix_len = prefix_len
......@@ -278,9 +283,9 @@ class route(object):
s = "%s.%s(tipe = %s, prefix = %s, prefix_len = %s, nexthop = %s, "
s += "interface = %s, metric = %s)"
return s % (self.__module__, self.__class__.__name__,
self.tipe.__repr__(), self.prefix.__repr__(),
self.prefix_len.__repr__(), self.nexthop.__repr__(),
self.interface.__repr__(), self.metric.__repr__())
self.tipe.__repr__(), self.prefix.__repr__(),
self.prefix_len.__repr__(), self.nexthop.__repr__(),
self.interface.__repr__(), self.metric.__repr__())
def __eq__(self, o):
if not isinstance(o, route):
......@@ -289,20 +294,22 @@ class route(object):
self.prefix_len == o.prefix_len and self.nexthop == o.nexthop
and self.interface == o.interface and self.metric == o.metric)
# helpers
def _get_if_name(iface):
def _get_if_name(iface: interface | int | str):
if isinstance(iface, interface):
if iface.name != None:
if iface.name is not None:
return iface.name
if isinstance(iface, str):
return iface
return get_if(iface).name
# XXX: ideally this should be replaced by netlink communication
# Interface handling
# FIXME: try to lower the amount of calls to retrieve data!!
def get_if_data():
def get_if_data() -> tuple[dict[int, interface], dict[str, interface]]:
"""Gets current interface information. Returns a tuple (byidx, bynam) in
which each element is a dictionary with the same data, but using different
keys: interface indexes and interface names.
......@@ -323,21 +330,22 @@ def get_if_data():
r'brd ([0-9a-f:]+))?', line)
flags = match.group(3).split(",")
i = interface(
index = match.group(1),
name = match.group(2),
up = "UP" in flags,
mtu = match.group(4),
lladdr = match.group(5),
arp = not ("NOARP" in flags),
broadcast = match.group(6),
multicast = "MULTICAST" in flags)
index=match.group(1),
name=match.group(2),
up="UP" in flags,
mtu=match.group(4),
lladdr=match.group(5),
arp=not ("NOARP" in flags),
broadcast=match.group(6),
multicast="MULTICAST" in flags)
byidx[idx] = bynam[i.name] = i
return byidx, bynam
def get_if(iface):
def get_if(iface: interface | int | str) -> interface:
ifdata = get_if_data()
if isinstance(iface, interface):
if iface.index != None:
if iface.index is not None:
return ifdata[0][iface.index]
else:
return ifdata[1][iface.name]
......@@ -345,7 +353,8 @@ def get_if(iface):
return ifdata[0][iface]
return ifdata[1][iface]
def create_if_pair(if1, if2):
def create_if_pair(if1: interface, if2: interface) -> tuple[interface, interface]:
assert if1.name and if2.name
cmd = [[], []]
......@@ -375,22 +384,24 @@ def create_if_pair(if1, if2):
interfaces = get_if_data()[1]
return interfaces[if1.name], interfaces[if2.name]
def del_if(iface):
ifname = _get_if_name(iface)
execute([IP_PATH, "link", "del", ifname])
def set_if(iface, recover = True):
def do_cmds(cmds, orig_iface):
def set_if(iface: interface, recover=True):
def do_cmds(cmds: list[list[str]], orig_iface: interface):
for c in cmds:
try:
execute(c)
except:
if recover:
set_if(orig_iface, recover = False) # rollback
set_if(orig_iface, recover=False) # rollback
raise
orig_iface = get_if(iface)
diff = iface - orig_iface # Only set what's needed
diff = iface - orig_iface # Only set what's needed
# Name goes first
if diff.name:
......@@ -410,26 +421,28 @@ def set_if(iface, recover = True):
# iface needs to be down
cmds.append(_ils + ["down"])
cmds.append(_ils + ["address", diff.lladdr])
if orig_iface.up and diff.up == None:
if orig_iface.up and diff.up is None:
# restore if it was up and it's not going to be set later
cmds.append(_ils + ["up"])
if diff.mtu:
cmds.append(_ils + ["mtu", str(diff.mtu)])
if diff.broadcast:
cmds.append(_ils + ["broadcast", diff.broadcast])
if diff.multicast != None:
if diff.multicast is not None:
cmds.append(_ils + ["multicast", "on" if diff.multicast else "off"])
if diff.arp != None:
if diff.arp is not None:
cmds.append(_ils + ["arp", "on" if diff.arp else "off"])
if diff.up != None:
if diff.up is not None:
cmds.append(_ils + ["up" if diff.up else "down"])
do_cmds(cmds, orig_iface)
def change_netns(iface, netns):
ifname = _get_if_name(iface)
execute([IP_PATH, "link", "set", "dev", ifname, "netns", str(netns)])
# Address handling
def get_addr_data():
......@@ -459,16 +472,16 @@ def get_addr_data():
match = re.search(r'^\s*inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
if match:
bynam[current].append(ipv4address(
address = match.group(1),
prefix_len = match.group(2),
broadcast = match.group(3)))
address=match.group(1),
prefix_len=match.group(2),
broadcast=match.group(3)))
continue
match = re.search(r'^\s*inet6 ([0-9a-f:]+)/(\d+)', line)
if match:
bynam[current].append(ipv6address(
address = match.group(1),
prefix_len = match.group(2)))
address=match.group(1),
prefix_len=match.group(2)))
continue
# Extra info, ignored.
......@@ -476,26 +489,29 @@ def get_addr_data():
return byidx, bynam
def add_addr(iface, address):
ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname]
assert address not in addresses
cmd = [IP_PATH, "addr", "add", "dev", ifname, "local",
"%s/%d" % (address.address, int(address.prefix_len))]
"%s/%d" % (address.address, int(address.prefix_len))]
if hasattr(address, "broadcast"):
cmd += ["broadcast", address.broadcast if address.broadcast else "+"]
execute(cmd)
def del_addr(iface, address):
ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname]
assert address in addresses
cmd = [IP_PATH, "addr", "del", "dev", ifname, "local",
"%s/%d" % (address.address, int(address.prefix_len))]
"%s/%d" % (address.address, int(address.prefix_len))]
execute(cmd)
# Bridge handling
def _sysfs_read_br(brname):
def readval(fname):
......@@ -509,12 +525,13 @@ def _sysfs_read_br(brname):
except:
return None
return dict(
stp = readval(p + "stp_state"),
forward_delay = float(readval(p + "forward_delay")) / 100,
hello_time = float(readval(p + "hello_time")) / 100,
ageing_time = float(readval(p + "ageing_time")) / 100,
max_age = float(readval(p + "max_age")) / 100,
ports = os.listdir(p2))
stp=readval(p + "stp_state"),
forward_delay=float(readval(p + "forward_delay")) / 100,
hello_time=float(readval(p + "hello_time")) / 100,
ageing_time=float(readval(p + "ageing_time")) / 100,
max_age=float(readval(p + "max_age")) / 100,
ports=os.listdir(p2))
def get_bridge_data():
# brctl stinks too much; it is better to directly use sysfs, it is
......@@ -525,24 +542,26 @@ def get_bridge_data():
ifdata = get_if_data()
for iface in ifdata[0].values():
brdata = _sysfs_read_br(iface.name)
if brdata == None:
if brdata is None:
continue
ports[iface.index] = [ifdata[1][x].index for x in brdata["ports"]]
del brdata["ports"]
bynam[iface.name] = byidx[iface.index] = \
bridge.upgrade(iface, **brdata)
bridge.upgrade(iface, **brdata)
return byidx, bynam, ports
def get_bridge(br):
iface = get_if(br)
brdata = _sysfs_read_br(iface.name)
#ports = [ifdata[1][x].index for x in brdata["ports"]]
# ports = [ifdata[1][x].index for x in brdata["ports"]]
del brdata["ports"]
return bridge.upgrade(iface, **brdata)
def create_bridge(br):
if isinstance(br, str):
br = interface(name = br)
br = interface(name=br)
assert br.name
execute([BRCTL_PATH, "addbr", br.name])
try:
......@@ -556,58 +575,64 @@ def create_bridge(br):
six.reraise(t, v, bt)
return get_if_data()[1][br.name]
def del_bridge(br):
brname = _get_if_name(br)
execute([BRCTL_PATH, "delbr", brname])
def set_bridge(br, recover = True):
def set_bridge(br, recover=True):
def saveval(fname, val):
f = open(fname, "w")
f.write(str(val))
f.close()
def do_cmds(basename, cmds, orig_br):
for n, v in cmds:
try:
saveval(basename + n, v)
except:
if recover:
set_bridge(orig_br, recover = False) # rollback
set_if(orig_br, recover = False) # rollback
set_bridge(orig_br, recover=False) # rollback
set_if(orig_br, recover=False) # rollback
raise
orig_br = get_bridge(br)
diff = br - orig_br # Only set what's needed
diff = br - orig_br # Only set what's needed
cmds = []
if diff.stp != None:
if diff.stp is not None:
cmds.append(("stp_state", int(diff.stp)))
if diff.forward_delay != None:
if diff.forward_delay is not None:
cmds.append(("forward_delay", int(diff.forward_delay)))
if diff.hello_time != None:
if diff.hello_time is not None:
cmds.append(("hello_time", int(diff.hello_time)))
if diff.ageing_time != None:
if diff.ageing_time is not None:
cmds.append(("ageing_time", int(diff.ageing_time)))
if diff.max_age != None:
if diff.max_age is not None:
cmds.append(("max_age", int(diff.max_age)))
set_if(diff)
name = diff.name if diff.name != None else orig_br.name
name = diff.name if diff.name is not None else orig_br.name
do_cmds("/sys/class/net/%s/bridge/" % name, cmds, orig_br)
def add_bridge_port(br, iface):
ifname = _get_if_name(iface)
brname = _get_if_name(br)
execute([BRCTL_PATH, "addif", brname, ifname])
def del_bridge_port(br, iface):
ifname = _get_if_name(iface)
brname = _get_if_name(br)
execute([BRCTL_PATH, "delif", brname, ifname])
# Routing
def get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
ipdata += backticks([IP_PATH, "-o", "-f", "inet6", "route", "list"])
ifdata = get_if_data()[1]
......@@ -616,8 +641,8 @@ def get_all_route_data():
if line == "":
continue
match = re.match(r'(?:(unicast|local|broadcast|multicast|throw|' +
r'unreachable|prohibit|blackhole|nat) )?' +
r'(\S+)(?: via (\S+))? dev (\S+).*(?: metric (\d+))?', line)
r'unreachable|prohibit|blackhole|nat) )?' +
r'(\S+)(?: via (\S+))? dev (\S+).*(?: metric (\d+))?', line)
if not match:
raise RuntimeError("Invalid output from `ip route': `%s'" % line)
tipe = match.group(1) or "unicast"
......@@ -633,26 +658,30 @@ def get_all_route_data():
prefix = match.group(1)
prefix_len = int(match.group(2) or 32)
ret.append(route(tipe, prefix, prefix_len, nexthop, interface.index,
metric))
metric))
return ret
def get_route_data():
def get_route_data() -> list[route]:
# filter out non-unicast routes
return [x for x in get_all_route_data() if x.tipe == "unicast"]
def add_route(route):
def add_route(route: route):
# Cannot really test this
#if route in get_all_route_data():
# if route in get_all_route_data():
# raise ValueError("Route already exists")
_add_del_route("add", route)
def del_route(route):
def del_route(route: route):
# Cannot really test this
#if route not in get_all_route_data():
# if route not in get_all_route_data():
# raise ValueError("Route does not exist")
_add_del_route("del", route)
def _add_del_route(action, route):
def _add_del_route(action: Literal["add", "del"], route: route):
cmd = [IP_PATH, "route", action]
if route.tipe != "unicast":
cmd += [route.tipe]
......@@ -666,6 +695,7 @@ def _add_del_route(action, route):
cmd += ["dev", _get_if_name(route.interface)]
execute(cmd)
# TC stuff
def get_tc_tree():
......@@ -676,13 +706,13 @@ def get_tc_tree():
if line == "":
continue
match = re.match(r'qdisc (\S+) ([0-9a-f]+):[0-9a-f]* dev (\S+) ' +
r'(?:parent ([0-9a-f]*):[0-9a-f]*|root)\s*(.*)', line)
r'(?:parent ([0-9a-f]*):[0-9a-f]*|root)\s*(.*)', line)
if not match:
raise RuntimeError("Invalid output from `tc qdisc': `%s'" % line)
qdisc = match.group(1)
handle = match.group(2)
iface = match.group(3)
parent = match.group(4) # or None
parent = match.group(4) # or None
extra = match.group(5)
if parent == "":
# XXX: Still not sure what is this, shows in newer kernels for wlan
......@@ -706,15 +736,19 @@ def get_tc_tree():
for h in data[data_node[0]]:
node["children"].append(gen_tree(data, h))
return node
tree[iface] = gen_tree(data[iface], data[iface][None][0])
return tree
_multipliers = {"M": 1000000, "K": 1000}
_dividers = {"m": 1000, "u": 1000000}
def _parse_netem_delay(line):
ret = {}
match = re.search(r'delay ([\d.]+)([mu]?)s(?: +([\d.]+)([mu]?)s)?' +
r'(?: *([\d.]+)%)?(?: *distribution (\S+))?', line)
r'(?: *([\d.]+)%)?(?: *distribution (\S+))?', line)
if not match:
return ret
......@@ -737,6 +771,7 @@ def _parse_netem_delay(line):
return ret
def _parse_netem_loss(line):
ret = {}
match = re.search(r'loss ([\d.]+)%(?: *([\d.]+)%)?', line)
......@@ -748,6 +783,7 @@ def _parse_netem_loss(line):
ret["loss_correlation"] = float(match.group(2)) / 100
return ret
def _parse_netem_dup(line):
ret = {}
match = re.search(r'duplicate ([\d.]+)%(?: *([\d.]+)%)?', line)
......@@ -759,6 +795,7 @@ def _parse_netem_dup(line):
ret["dup_correlation"] = float(match.group(2)) / 100
return ret
def _parse_netem_corrupt(line):
ret = {}
match = re.search(r'corrupt ([\d.]+)%(?: *([\d.]+)%)?', line)
......@@ -770,6 +807,7 @@ def _parse_netem_corrupt(line):
ret["corrupt_correlation"] = float(match.group(2)) / 100
return ret
def get_tc_data():
tree = get_tc_tree()
ifdata = get_if_data()
......@@ -802,7 +840,7 @@ def get_tc_data():
continue
tbf = node["extra"], node["handle"]
netem = node["children"][0]["extra"], \
node["children"][0]["handle"]
node["children"][0]["handle"]
if tbf:
ret[i]["qdiscs"]["tbf"] = tbf[1]
......@@ -823,22 +861,24 @@ def get_tc_data():
ret[i].update(_parse_netem_corrupt(netem[0]))
return ret, ifdata[0], ifdata[1]
def clear_tc(iface):
iface = get_if(iface)
tcdata = get_tc_data()[0]
if tcdata[iface.index] == None:
if tcdata[iface.index] is None:
return
# Any other case, we clean
execute([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
delay_correlation = None, delay_distribution = None,
loss = None, loss_correlation = None,
dup = None, dup_correlation = None,
corrupt = None, corrupt_correlation = None):
def set_tc(iface, bandwidth=None, delay=None, delay_jitter=None,
delay_correlation=None, delay_distribution=None,
loss=None, loss_correlation=None,
dup=None, dup_correlation=None,
corrupt=None, corrupt_correlation=None):
use_netem = bool(delay or delay_jitter or delay_correlation or
delay_distribution or loss or loss_correlation or dup or
dup_correlation or corrupt or corrupt_correlation)
delay_distribution or loss or loss_correlation or dup or
dup_correlation or corrupt or corrupt_correlation)
iface = get_if(iface)
tcdata, ifdata = get_tc_data()[0:2]
......@@ -846,7 +886,7 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
if tcdata[iface.index] == 'foreign':
# Avoid the overhead of calling tc+ip again
commands.append([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
tcdata[iface.index] = {'qdiscs': []}
tcdata[iface.index] = {'qdiscs': []}
has_netem = 'netem' in tcdata[iface.index]['qdiscs']
has_tbf = 'tbf' in tcdata[iface.index]['qdiscs']
......@@ -862,20 +902,20 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
# Too much work to do better :)
if has_netem or has_tbf:
commands.append([TC_PATH, "qdisc", "del", "dev", iface.name,
"root"])
"root"])
cmd = "add"
if bandwidth:
rate = "%dbit" % int(bandwidth)
mtu = ifdata[iface.index].mtu
burst = max(mtu, int(bandwidth) // HZ)
limit = burst * 2 # FIXME?
limit = burst * 2 # FIXME?
handle = "1:"
if cmd == "change":
handle = "%d:" % int(tcdata[iface.index]["qdiscs"]["tbf"])
command = [TC_PATH, "qdisc", cmd, "dev", iface.name, "root", "handle",
handle, "tbf", "rate", rate, "limit", str(limit), "burst",
str(burst)]
handle, "tbf", "rate", rate, "limit", str(limit), "burst",
str(burst)]
commands.append(command)
if use_netem:
......@@ -920,16 +960,17 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
for c in commands:
execute(c)
def create_tap(iface, use_pi = False, tun = False):
def create_tap(iface, use_pi=False, tun=False):
"""Creates a tap/tun device and returns the associated file descriptor"""
if isinstance(iface, str):
iface = interface(name = iface)
iface = interface(name=iface)
assert iface.name
IFF_TUN = 0x0001
IFF_TAP = 0x0002
IFF_NO_PI = 0x1000
TUNSETIFF = 0x400454ca
IFF_TUN = 0x0001
IFF_TAP = 0x0002
IFF_NO_PI = 0x1000
TUNSETIFF = 0x400454ca
if tun:
mode = IFF_TUN
else:
......@@ -952,4 +993,3 @@ def create_tap(iface, use_pi = False, tun = False):
raise
interfaces = get_if_data()[1]
return interfaces[iface.name], fd
......@@ -21,10 +21,13 @@ import os
import socket
import sys
import traceback
from typing import MutableMapping
import unshare
import weakref
import nemu.interface
import nemu.iproute
import nemu.protocol
import nemu.subprocess_
from nemu import compat
......@@ -33,10 +36,11 @@ from nemu.environ import *
__all__ = ['Node', 'get_nodes', 'import_if']
class Node(object):
_nodes = weakref.WeakValueDictionary()
_nodes: MutableMapping[int, "Node"] = weakref.WeakValueDictionary()
_nextnode = 0
_processes: MutableMapping[int, nemu.subprocess_.Subprocess]
@staticmethod
def get_nodes():
def get_nodes() -> list["Node"]:
s = sorted(list(Node._nodes.items()), key = lambda x: x[0])
return [x[1] for x in s]
......@@ -98,7 +102,7 @@ class Node(object):
return self._pid
# Subprocesses
def _add_subprocess(self, subprocess):
def _add_subprocess(self, subprocess: nemu.subprocess_.Subprocess):
self._processes[subprocess.pid] = subprocess
def Subprocess(self, *kargs, **kwargs):
......@@ -188,13 +192,13 @@ class Node(object):
r = self.route(*args, **kwargs)
return self._slave.del_route(r)
def get_routes(self):
def get_routes(self) -> list[route]:
return self._slave.get_route_data()
# Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run.
def _start_child(nonetns) -> (socket.socket, int):
def _start_child(nonetns: bool) -> (socket.socket, int):
# Create socket pair to communicate
(s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop
......
......@@ -21,7 +21,7 @@ import struct
from io import IOBase
def __check_socket(sock: socket.socket | IOBase):
def __check_socket(sock: socket.socket | IOBase) -> socket.socket:
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed")
......@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase):
return sock
def __check_fd(fd):
def __check_fd(fd) -> int:
try:
fd = fd.fileno()
except AttributeError:
......@@ -44,7 +44,7 @@ def __check_fd(fd):
return fd
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]:
size = struct.calcsize("@i")
msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size))
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
......@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
return fd, msg.decode("utf-8")
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE"):
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE") -> int:
return __check_socket(sock).sendmsg(
[message],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))])
\ No newline at end of file
......@@ -29,6 +29,7 @@ import tempfile
import time
import traceback
from pickle import loads, dumps
from typing import Literal
import nemu.iproute
import nemu.subprocess_
......@@ -278,7 +279,7 @@ class Server(object):
self.reply(220, "Hello.");
while not self._closed:
cmd = self.readcmd()
if cmd == None:
if cmd is None:
continue
try:
cmd[0](cmd[1], *cmd[2])
......@@ -422,7 +423,7 @@ class Server(object):
else:
ret = nemu.subprocess_.wait(pid)
if ret != None:
if ret is not None:
self._children.remove(pid)
if pid in self._xauthfiles:
try:
......@@ -449,7 +450,7 @@ class Server(object):
self.reply(200, "Process signalled.")
def do_IF_LIST(self, cmdname, ifnr=None):
if ifnr == None:
if ifnr is None:
ifdata = nemu.iproute.get_if_data()[0]
else:
ifdata = nemu.iproute.get_if(ifnr)
......@@ -479,7 +480,7 @@ class Server(object):
def do_ADDR_LIST(self, cmdname, ifnr=None):
addrdata = nemu.iproute.get_addr_data()[0]
if ifnr != None:
if ifnr is not None:
addrdata = addrdata[ifnr]
self.reply(200, ["# Address data follows.",
_b64(dumps(addrdata, protocol=2))])
......@@ -652,7 +653,7 @@ class Client(object):
stdin/stdout/stderr can only be None or a open file descriptor.
See nemu.subprocess_.spawn for details."""
if executable == None:
if executable is None:
executable = argv[0]
params = ["PROC", "CRTE", _b64(executable)]
for i in argv:
......@@ -663,28 +664,28 @@ class Client(object):
# After this, if we get an error, we have to abort the PROC
try:
if user != None:
if user is not None:
self._send_cmd("PROC", "USER", _b64(user))
self._read_and_check_reply()
if cwd != None:
if cwd is not None:
self._send_cmd("PROC", "CWD", _b64(cwd))
self._read_and_check_reply()
if env != None:
if env is not None:
params = []
for k, v in env.items():
params.extend([_b64(k), _b64(v)])
self._send_cmd("PROC", "ENV", *params)
self._read_and_check_reply()
if stdin != None:
if stdin is not None:
os.set_inheritable(stdin, True)
self._send_fd("SIN", stdin)
if stdout != None:
if stdout is not None:
os.set_inheritable(stdout, True)
self._send_fd("SOUT", stdout)
if stderr != None:
if stderr is not None:
os.set_inheritable(stderr, True)
self._send_fd("SERR", stderr)
except:
......@@ -739,7 +740,7 @@ class Client(object):
cmd = ["IF", "SET", interface.index]
for k in interface.changeable_attributes:
v = getattr(interface, k)
if v != None:
if v is not None:
cmd += [k, str(v)]
self._send_cmd(*cmd)
......@@ -761,7 +762,7 @@ class Client(object):
data = self._read_and_check_reply()
return loads(_db64(data.partition("\n")[2]))
def add_addr(self, ifnr, address):
def add_addr(self, ifnr: int, address: nemu.iproute.address):
if hasattr(address, "broadcast") and address.broadcast:
self._send_cmd("ADDR", "ADD", ifnr, address.address,
address.prefix_len, address.broadcast)
......@@ -770,7 +771,7 @@ class Client(object):
address.prefix_len)
self._read_and_check_reply()
def del_addr(self, ifnr, address):
def del_addr(self, ifnr: int, address: nemu.iproute.address):
self._send_cmd("ADDR", "DEL", ifnr, address.address, address.prefix_len)
self._read_and_check_reply()
......@@ -785,14 +786,14 @@ class Client(object):
def del_route(self, route):
self._add_del_route("DEL", route)
def _add_del_route(self, action, route):
def _add_del_route(self, action: Literal["ADD", "DEL"], route: nemu.iproute.route):
args = ["ROUT", action, _b64(route.tipe), _b64(route.prefix),
route.prefix_len or 0, _b64(route.nexthop),
route.interface or 0, route.metric or 0]
self._send_cmd(*args)
self._read_and_check_reply()
def set_x11(self, protoname, hexkey):
def set_x11(self, protoname: str, hexkey: str) -> socket.socket:
# Returns a socket ready to accept() connections
self._send_cmd("X11", "SET", protoname, hexkey)
self._read_and_check_reply()
......@@ -823,7 +824,7 @@ class Client(object):
def _b64_OLD(text: str | bytes) -> str:
if text == None:
if text is None:
# easier this way
text = ''
if type(text) is str:
......@@ -833,11 +834,12 @@ def _b64_OLD(text: str | bytes) -> str:
else:
btext = text
if len(btext) == 0 or any(x for x in btext if x <= ord(" ") or
x > ord("z") or x == ord("=")):
x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(btext).decode("ascii")
else:
return text
def _b64(text) -> str:
if text is None:
# easier this way
......@@ -848,7 +850,7 @@ def _b64(text) -> str:
else:
enc = str(text).encode("utf-8")
if len(enc) == 0 or any(x for x in enc if x <= ord(" ") or
x > ord("z") or x == ord("=")):
x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(enc).decode("ascii")
else:
return enc.decode("utf-8")
......
......@@ -27,7 +27,10 @@ import signal
import sys
import time
import traceback
import typing
if typing.TYPE_CHECKING:
from nemu import Node
from nemu import compat
from nemu.environ import eintr_wrapper
......@@ -46,7 +49,7 @@ class Subprocess(object):
# FIXME
default_user = None
def __init__(self, node, argv: str | list[str], executable=None,
def __init__(self, node: "Node", argv: str | list[str], executable=None,
stdin=None, stdout=None, stderr=None,
shell=False, cwd=None, env=None, user=None):
self._slave = node._slave
......@@ -78,7 +81,7 @@ class Subprocess(object):
Exceptions occurred while trying to set up the environment or executing
the program are propagated to the parent."""
if user == None:
if user is None:
user = Subprocess.default_user
if isinstance(argv, str):
......@@ -106,20 +109,20 @@ class Subprocess(object):
def poll(self):
"""Checks status of program, returns exitcode or None if still running.
See Popen.poll."""
if self._returncode == None:
if self._returncode is None:
self._returncode = self._slave.poll(self._pid)
return self.returncode
def wait(self):
"""Waits for program to complete and returns the exitcode.
See Popen.wait"""
if self._returncode == None:
if self._returncode is None:
self._returncode = self._slave.wait(self._pid)
return self.returncode
def signal(self, sig=signal.SIGTERM):
"""Sends a signal to the process."""
if self._returncode == None:
if self._returncode is None:
self._slave.signal(self._pid, sig)
@property
......@@ -128,7 +131,7 @@ class Subprocess(object):
communicate, wait, or poll), returns the signal that killed the
program, if negative; otherwise, it is the exit code of the program.
"""
if self._returncode == None:
if self._returncode is None:
return None
if os.WIFSIGNALED(self._returncode):
return -os.WTERMSIG(self._returncode)
......@@ -140,12 +143,12 @@ class Subprocess(object):
self.destroy()
def destroy(self):
if self._returncode != None or self._pid == None:
if self._returncode is not None or self._pid is None:
return
self.signal()
now = time.time()
while time.time() - now < KILL_WAIT:
if self.poll() != None:
if self.poll() is not None:
return
time.sleep(0.1)
sys.stderr.write("WARNING: killing forcefully process %d.\n" %
......@@ -179,7 +182,7 @@ class Popen(Subprocess):
fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr}
# if PIPE: all should be closed at the end
for k, v in fdmap.items():
if v == None:
if v is None:
continue
if v == PIPE:
r, w = compat.pipe()
......@@ -206,27 +209,29 @@ class Popen(Subprocess):
# Close pipes, they have been dup()ed to the child
for k, v in fdmap.items():
if getattr(self, k) != None:
if getattr(self, k) is not None:
eintr_wrapper(os.close, v)
def communicate(self, input: bytes = None) -> tuple[bytes, bytes]:
def communicate(self, input: bytes | str = None) -> tuple[bytes, bytes]:
"""See Popen.communicate."""
# FIXME: almost verbatim from stdlib version, need to be removed or
# something
if type(input) is str:
input = input.encode("utf-8")
wset = []
rset = []
err = None
out = None
if self.stdin != None:
if self.stdin is not None:
self.stdin.flush()
if input:
wset.append(self.stdin)
else:
self.stdin.close()
if self.stdout != None:
if self.stdout is not None:
rset.append(self.stdout)
out = []
if self.stderr != None:
if self.stderr is not None:
rset.append(self.stderr)
err = []
......@@ -253,9 +258,9 @@ class Popen(Subprocess):
else:
err.append(d)
if out != None:
if out is not None:
out = b''.join(out)
if err != None:
if err is not None:
err = b''.join(err)
self.wait()
return (out, err)
......@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
is not supported here. Also, the original descriptors are not closed.
"""
userfd = [stdin, stdout, stderr]
filtered_userfd = [x for x in userfd if x != None and x >= 0]
filtered_userfd = [x for x in userfd if x is not None and x >= 0]
for i in range(3):
if userfd[i] != None and not isinstance(userfd[i], int):
if userfd[i] is not None and not isinstance(userfd[i], int):
userfd[i] = userfd[i].fileno() # pragma: no cover
# Verify there is no clash
assert not (set([0, 1, 2]) & set(filtered_userfd))
if user != None:
if user is not None:
user, uid, gid = get_user(user)
home = pwd.getpwuid(uid)[5]
groups = [x[2] for x in grp.getgrall() if user in x[3]]
......@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
try:
# Set up stdio piping
for i in range(3):
if userfd[i] != None and userfd[i] >= 0:
if userfd[i] is not None and userfd[i] >= 0:
os.dup2(userfd[i], i)
if userfd[i] != i and userfd[i] not in userfd[0:i]:
eintr_wrapper(os.close, userfd[i]) # only in child!
......@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
# (it is necessary to kill the forked subprocesses)
os.setpgrp()
if user != None:
if user is not None:
# Change user
os.setgid(gid)
os.setgroups(groups)
os.setuid(uid)
if cwd != None:
if cwd is not None:
os.chdir(cwd)
if not argv:
argv = [executable]
if '/' in executable: # Should not search in PATH
if env != None:
if env is not None:
os.execve(executable, argv, env)
else:
os.execv(executable, argv)
else: # use PATH
if env != None:
if env is not None:
os.execvpe(executable, argv, env)
else:
os.execvp(executable, argv)
......
......@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase):
os.write(if1.fd, s)
if not s:
break
if subproc.poll() != None:
if subproc.poll() is not None:
break
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
......
......@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase):
def check_ok(self, cmd, func, args):
s1.write("%s\n" % cmd)
ccmd = " ".join(cmd.upper().split()[0:2])
if func == None:
if func is None:
self.assertEqual(srv.readcmd()[1:3], (ccmd, args))
else:
self.assertEqual(srv.readcmd(), (func, ccmd, args))
......
......@@ -15,7 +15,7 @@ def process_ipcmd(str: str):
match = re.search(r'^(\d+): ([^@\s]+)(?:@\S+)?: <(\S+)> mtu (\d+) '
r'qdisc (\S+)',
line)
if match != None:
if match is not None:
cur = match.group(2)
out[cur] = {
'idx': int(match.group(1)),
......@@ -27,14 +27,14 @@ def process_ipcmd(str: str):
out[cur]['up'] = 'UP' in out[cur]['flags']
continue
# Assume cur is defined
assert cur != None
assert cur is not None
match = re.search(r'^\s+link/\S*(?: ([0-9a-f:]+))?(?: |$)', line)
if match != None:
if match is not None:
out[cur]['lladdr'] = match.group(1)
continue
match = re.search(r'^\s+inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
if match != None:
if match is not None:
out[cur]['addr'].append({
'address': match.group(1),
'prefix_len': int(match.group(2)),
......@@ -43,7 +43,7 @@ def process_ipcmd(str: str):
continue
match = re.search(r'^\s+inet6 ([0-9a-f:]+)/(\d+)(?: |$)', line)
if match != None:
if match is not None:
out[cur]['addr'].append({
'address': match.group(1),
'prefix_len': int(match.group(2)),
......@@ -51,7 +51,7 @@ def process_ipcmd(str: str):
continue
match = re.search(r'^\s{4}', line)
assert match != None
assert match is not None
return out
def get_devs():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment