Commit 033ce168 authored by Tom Niget's avatar Tom Niget

Add type hints

parent 7f17df26
...@@ -140,7 +140,7 @@ def main(): ...@@ -140,7 +140,7 @@ def main():
if not r: if not r:
break break
out += r out += r
if srv.poll() != None or clt.poll() != None: if srv.poll() is not None or clt.poll() is not None:
break break
if srv.poll(): if srv.poll():
......
...@@ -15,6 +15,6 @@ setup( ...@@ -15,6 +15,6 @@ setup(
license = 'GPLv2', license = 'GPLv2',
platforms = 'Linux', platforms = 'Linux',
packages = ['nemu'], packages = ['nemu'],
install_requires = ['unshare', 'six'], install_requires = ['unshare', 'six', 'attrs'],
package_dir = {'': 'src'} package_dir = {'': 'src'}
) )
...@@ -41,7 +41,7 @@ class _Config(object): ...@@ -41,7 +41,7 @@ class _Config(object):
except KeyError: except KeyError:
pass # User not found. pass # User not found.
def _set_run_as(self, user): def _set_run_as(self, user: str | int):
"""Setter for `run_as'.""" """Setter for `run_as'."""
if str(user).isdigit(): if str(user).isdigit():
uid = int(user) uid = int(user)
...@@ -61,7 +61,7 @@ class _Config(object): ...@@ -61,7 +61,7 @@ class _Config(object):
self._run_as = run_as self._run_as = run_as
return run_as return run_as
def _get_run_as(self): def _get_run_as(self) -> str:
"""Setter for `run_as'.""" """Setter for `run_as'."""
return self._run_as return self._run_as
......
...@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]: ...@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]:
os.set_inheritable(b, True) os.set_inheritable(b, True)
return a, b return a, b
def socket(*args, **kwargs) -> pysocket.socket: def socket(*args, **kwargs) -> pysocket.socket:
s = pysocket.socket(*args, **kwargs) s = pysocket.socket(*args, **kwargs)
s.set_inheritable(True) s.set_inheritable(True)
return s return s
def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]: def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]:
a, b = pysocket.socketpair(*args, **kwargs) a, b = pysocket.socketpair(*args, **kwargs)
a.set_inheritable(True) a.set_inheritable(True)
b.set_inheritable(True) b.set_inheritable(True)
return a, b return a, b
def fromfd(*args, **kwargs) -> pysocket.socket: def fromfd(*args, **kwargs) -> pysocket.socket:
s = pysocket.fromfd(*args, **kwargs) s = pysocket.fromfd(*args, **kwargs)
s.set_inheritable(True) s.set_inheritable(True)
return s return s
def fdopen(*args, **kwargs) -> pysocket.socket:
s = os.fdopen(*args, **kwargs)
s.set_inheritable(True)
return s
\ No newline at end of file
...@@ -25,7 +25,7 @@ import subprocess ...@@ -25,7 +25,7 @@ import subprocess
import sys import sys
import syslog import syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
from typing import TypeVar, Callable from typing import TypeVar, Callable, Optional
__all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"] __all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"]
...@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"] ...@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
__all__ += ["error", "warning", "notice", "info", "debug"] __all__ += ["error", "warning", "notice", "info", "debug"]
def find_bin(name, extra_path=None): def find_bin(name: str, extra_path: Optional[list[str]] = None) -> Optional[str]:
"""Try hard to find the location of needed programs.""" """Try hard to find the location of needed programs."""
search = [] search = []
if "PATH" in os.environ: if "PATH" in os.environ:
...@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None): ...@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None):
return None return None
def find_bin_or_die(name, extra_path=None): def find_bin_or_die(name: str, extra_path: Optional[list[str]] = None) -> str:
"""Try hard to find the location of needed programs; raise on failure.""" """Try hard to find the location of needed programs; raise on failure."""
res = find_bin(name, extra_path) res = find_bin(name, extra_path)
if not res: if not res:
...@@ -156,7 +156,7 @@ _log_syslog_opts = () ...@@ -156,7 +156,7 @@ _log_syslog_opts = ()
_log_pid = os.getpid() _log_pid = os.getpid()
def set_log_level(level): def set_log_level(level: int):
"Sets the log level for console messages, does not affect syslog logging." "Sets the log level for console messages, does not affect syslog logging."
global _log_level global _log_level
assert level > LOG_ERR and level <= LOG_DEBUG assert level > LOG_ERR and level <= LOG_DEBUG
...@@ -191,7 +191,7 @@ def _init_log(): ...@@ -191,7 +191,7 @@ def _init_log():
info("Syslog logging started") info("Syslog logging started")
def logger(priority, message): def logger(priority: int, message: str):
"Print a log message in syslog, console or both." "Print a log message in syslog, console or both."
if _log_use_syslog: if _log_use_syslog:
if os.getpid() != _log_pid: if os.getpid() != _log_pid:
......
...@@ -19,30 +19,33 @@ ...@@ -19,30 +19,33 @@
import os import os
import weakref import weakref
from typing import TypedDict
import nemu.iproute import nemu.iproute
from nemu.environ import * from nemu.environ import *
__all__ = ['NodeInterface', 'P2PInterface', 'ImportedInterface', __all__ = ['NodeInterface', 'P2PInterface', 'ImportedInterface',
'ImportedNodeInterface', 'Switch'] 'ImportedNodeInterface', 'Switch']
class Interface(object): class Interface(object):
"""Just a base class for the *Interface classes: assign names and handle """Just a base class for the *Interface classes: assign names and handle
destruction.""" destruction."""
_nextid = 0 _nextid = 0
@staticmethod @staticmethod
def _gen_next_id(): def _gen_next_id() -> int:
n = Interface._nextid n = Interface._nextid
Interface._nextid += 1 Interface._nextid += 1
return n return n
@staticmethod @staticmethod
def _gen_if_name(): def _gen_if_name() -> str:
n = Interface._gen_next_id() n = Interface._gen_next_id()
# Max 15 chars # Max 15 chars
return "NETNSif-%.4x%.3x" % (os.getpid() & 0xffff, n) return "NETNSif-%.4x%.3x" % (os.getpid() & 0xffff, n)
def __init__(self, index): def __init__(self, index: int):
self._idx = index self._idx = index
debug("%s(0x%x).__init__(), index = %d" % (self.__class__.__name__, debug("%s(0x%x).__init__(), index = %d" % (self.__class__.__name__,
id(self), index)) id(self), index))
...@@ -55,7 +58,7 @@ class Interface(object): ...@@ -55,7 +58,7 @@ class Interface(object):
raise NotImplementedError raise NotImplementedError
@property @property
def index(self): def index(self) -> int:
"""Interface index as seen by the kernel.""" """Interface index as seen by the kernel."""
return self._idx return self._idx
...@@ -65,20 +68,35 @@ class Interface(object): ...@@ -65,20 +68,35 @@ class Interface(object):
control interfaces can be put into a Switch, for example.""" control interfaces can be put into a Switch, for example."""
return None return None
class Ipv4Dict(TypedDict):
address: str
prefix_len: int
broadcast: str
family: str
class Ipv6Dict(TypedDict):
address: str
prefix_len: int
family: str
class NSInterface(Interface): class NSInterface(Interface):
"""Add user-facing methods for interfaces that go into a netns.""" """Add user-facing methods for interfaces that go into a netns."""
def __init__(self, node, index):
def __init__(self, node: "nemu.Node", index):
super(NSInterface, self).__init__(index) super(NSInterface, self).__init__(index)
self._slave = node._slave self._slave = node._slave
# Disable auto-configuration # Disable auto-configuration
# you wish: need to take into account the nonetns mode; plus not # you wish: need to take into account the nonetns mode; plus not
# touching some pre-existing ifaces # touching some pre-existing ifaces
#node.system([SYSCTL_PATH, '-w', 'net.ipv6.conf.%s.autoconf=0' % # node.system([SYSCTL_PATH, '-w', 'net.ipv6.conf.%s.autoconf=0' %
#self.name]) # self.name])
node._add_interface(self) node._add_interface(self)
# some black magic to automatically get/set interface attributes # some black magic to automatically get/set interface attributes
def __getattr__(self, name): def __getattr__(self, name: str):
# If name starts with _, it must be a normal attr # If name starts with _, it must be a normal attr
if name[0] == '_': if name[0] == '_':
return super(Interface, self).__getattribute__(name) return super(Interface, self).__getattribute__(name)
...@@ -92,57 +110,59 @@ class NSInterface(Interface): ...@@ -92,57 +110,59 @@ class NSInterface(Interface):
iface = slave.get_if_data(self.index) iface = slave.get_if_data(self.index)
return getattr(iface, name) return getattr(iface, name)
def __setattr__(self, name, value): def __setattr__(self, name: str, value):
if name[0] == '_': # forbid anything that doesn't start with a _ if name[0] == '_': # forbid anything that doesn't start with a _
super(Interface, self).__setattr__(name, value) super(Interface, self).__setattr__(name, value)
return return
iface = nemu.iproute.interface(index = self.index) iface = nemu.iproute.interface(index=self.index)
setattr(iface, name, value) setattr(iface, name, value)
return self._slave.set_if(iface) return self._slave.set_if(iface)
def add_v4_address(self, address, prefix_len, broadcast = None): def add_v4_address(self, address: str, prefix_len: int, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast) addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
self._slave.add_addr(self.index, addr) self._slave.add_addr(self.index, addr)
def add_v6_address(self, address, prefix_len): def add_v6_address(self, address: str, prefix_len: int):
addr = nemu.iproute.ipv6address(address, prefix_len) addr = nemu.iproute.ipv6address(address, prefix_len)
self._slave.add_addr(self.index, addr) self._slave.add_addr(self.index, addr)
def del_v4_address(self, address, prefix_len, broadcast = None): def del_v4_address(self, address: str, prefix_len: int, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast) addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
self._slave.del_addr(self.index, addr) self._slave.del_addr(self.index, addr)
def del_v6_address(self, address, prefix_len): def del_v6_address(self, address: str, prefix_len: int):
addr = nemu.iproute.ipv6address(address, prefix_len) addr = nemu.iproute.ipv6address(address, prefix_len)
self._slave.del_addr(self.index, addr) self._slave.del_addr(self.index, addr)
def get_addresses(self): def get_addresses(self) -> list[Ipv4Dict | Ipv6Dict]:
addresses = self._slave.get_addr_data(self.index) addresses = self._slave.get_addr_data(self.index)
ret = [] ret = []
for a in addresses: for a in addresses:
if hasattr(a, 'broadcast'): if hasattr(a, 'broadcast'):
ret.append(dict( ret.append(dict(
address = a.address, address=a.address,
prefix_len = a.prefix_len, prefix_len=a.prefix_len,
broadcast = a.broadcast, broadcast=a.broadcast,
family = 'inet')) family='inet'))
else: else:
ret.append(dict( ret.append(dict(
address = a.address, address=a.address,
prefix_len = a.prefix_len, prefix_len=a.prefix_len,
family = 'inet6')) family='inet6'))
return ret return ret
class NodeInterface(NSInterface): class NodeInterface(NSInterface):
"""Class to create and handle a virtual interface inside a name space, it """Class to create and handle a virtual interface inside a name space, it
can be connected to a Switch object with emulation of link can be connected to a Switch object with emulation of link
characteristics.""" characteristics."""
def __init__(self, node):
def __init__(self, node: "nemu.Node"):
"""Create a new interface. `node' is the name space in which this """Create a new interface. `node' is the name space in which this
interface should be put.""" interface should be put."""
self._slave = None self._slave = None
if1 = nemu.iproute.interface(name = self._gen_if_name()) if1 = nemu.iproute.interface(name=self._gen_if_name())
if2 = nemu.iproute.interface(name = self._gen_if_name()) if2 = nemu.iproute.interface(name=self._gen_if_name())
ctl, ns = nemu.iproute.create_if_pair(if1, if2) ctl, ns = nemu.iproute.create_if_pair(if1, if2)
try: try:
nemu.iproute.change_netns(ns, node.pid) nemu.iproute.change_netns(ns, node.pid)
...@@ -165,18 +185,20 @@ class NodeInterface(NSInterface): ...@@ -165,18 +185,20 @@ class NodeInterface(NSInterface):
self._slave.del_if(self.index) self._slave.del_if(self.index)
self._slave = None self._slave = None
class P2PInterface(NSInterface): class P2PInterface(NSInterface):
"""Class to create and handle point-to-point interfaces between name """Class to create and handle point-to-point interfaces between name
spaces, without using Switch objects. Those do not allow any kind of spaces, without using Switch objects. Those do not allow any kind of
traffic shaping. traffic shaping.
As two interfaces need to be created, instead of using the class As two interfaces need to be created, instead of using the class
constructor, use the P2PInterface.create_pair() static method.""" constructor, use the P2PInterface.create_pair() static method."""
@staticmethod @staticmethod
def create_pair(node1, node2): def create_pair(node1: "nemu.Node", node2: "nemu.Node"):
"""Create and return a pair of connected P2PInterface objects, """Create and return a pair of connected P2PInterface objects,
assigned to name spaces represented by `node1' and `node2'.""" assigned to name spaces represented by `node1' and `node2'."""
if1 = nemu.iproute.interface(name = P2PInterface._gen_if_name()) if1 = nemu.iproute.interface(name=P2PInterface._gen_if_name())
if2 = nemu.iproute.interface(name = P2PInterface._gen_if_name()) if2 = nemu.iproute.interface(name=P2PInterface._gen_if_name())
pair = nemu.iproute.create_if_pair(if1, if2) pair = nemu.iproute.create_if_pair(if1, if2)
try: try:
nemu.iproute.change_netns(pair[0], node1.pid) nemu.iproute.change_netns(pair[0], node1.pid)
...@@ -206,6 +228,7 @@ class P2PInterface(NSInterface): ...@@ -206,6 +228,7 @@ class P2PInterface(NSInterface):
self._slave.del_if(self.index) self._slave.del_if(self.index)
self._slave = None self._slave = None
class ImportedNodeInterface(NSInterface): class ImportedNodeInterface(NSInterface):
"""Class to handle already existing interfaces inside a name space: """Class to handle already existing interfaces inside a name space:
real devices, tun devices, etc. real devices, tun devices, etc.
...@@ -213,7 +236,8 @@ class ImportedNodeInterface(NSInterface): ...@@ -213,7 +236,8 @@ class ImportedNodeInterface(NSInterface):
to be moved inside the name space. to be moved inside the name space.
On destruction, the interface will be restored to the original name space On destruction, the interface will be restored to the original name space
and will try to restore the original state.""" and will try to restore the original state."""
def __init__(self, node, iface, migrate = True):
def __init__(self, node: "nemu.Node", iface, migrate=True):
self._slave = None self._slave = None
self._migrate = migrate self._migrate = migrate
if self._migrate: if self._migrate:
...@@ -244,17 +268,19 @@ class ImportedNodeInterface(NSInterface): ...@@ -244,17 +268,19 @@ class ImportedNodeInterface(NSInterface):
nemu.iproute.set_if(self._original_state) nemu.iproute.set_if(self._original_state)
self._slave = None self._slave = None
class TapNodeInterface(NSInterface): class TapNodeInterface(NSInterface):
"""Class to create a tap interface inside a name space, it """Class to create a tap interface inside a name space, it
can be connected to a Switch object with emulation of link can be connected to a Switch object with emulation of link
characteristics.""" characteristics."""
def __init__(self, node, use_pi = False):
def __init__(self, node, use_pi=False):
"""Create a new tap interface. 'node' is the name space in which this """Create a new tap interface. 'node' is the name space in which this
interface should be put.""" interface should be put."""
self._fd = None self._fd = None
self._slave = None self._slave = None
iface = nemu.iproute.interface(name = self._gen_if_name()) iface = nemu.iproute.interface(name=self._gen_if_name())
iface, self._fd = nemu.iproute.create_tap(iface, use_pi = use_pi) iface, self._fd = nemu.iproute.create_tap(iface, use_pi=use_pi)
nemu.iproute.change_netns(iface.name, node.pid) nemu.iproute.change_netns(iface.name, node.pid)
super(TapNodeInterface, self).__init__(node, iface.index) super(TapNodeInterface, self).__init__(node, iface.index)
...@@ -271,18 +297,20 @@ class TapNodeInterface(NSInterface): ...@@ -271,18 +297,20 @@ class TapNodeInterface(NSInterface):
except: except:
pass pass
class TunNodeInterface(NSInterface): class TunNodeInterface(NSInterface):
"""Class to create a tun interface inside a name space, it """Class to create a tun interface inside a name space, it
can be connected to a Switch object with emulation of link can be connected to a Switch object with emulation of link
characteristics.""" characteristics."""
def __init__(self, node, use_pi = False):
def __init__(self, node, use_pi=False):
"""Create a new tap interface. 'node' is the name space in which this """Create a new tap interface. 'node' is the name space in which this
interface should be put.""" interface should be put."""
self._fd = None self._fd = None
self._slave = None self._slave = None
iface = nemu.iproute.interface(name = self._gen_if_name()) iface = nemu.iproute.interface(name=self._gen_if_name())
iface, self._fd = nemu.iproute.create_tap(iface, use_pi = use_pi, iface, self._fd = nemu.iproute.create_tap(iface, use_pi=use_pi,
tun = True) tun=True)
nemu.iproute.change_netns(iface.name, node.pid) nemu.iproute.change_netns(iface.name, node.pid)
super(TunNodeInterface, self).__init__(node, iface.index) super(TunNodeInterface, self).__init__(node, iface.index)
...@@ -299,9 +327,11 @@ class TunNodeInterface(NSInterface): ...@@ -299,9 +327,11 @@ class TunNodeInterface(NSInterface):
except: except:
pass pass
class ExternalInterface(Interface): class ExternalInterface(Interface):
"""Add user-facing methods for interfaces that run in the main """Add user-facing methods for interfaces that run in the main
namespace.""" namespace."""
@property @property
def control(self): def control(self):
# This is *the* control interface # This is *the* control interface
...@@ -316,11 +346,11 @@ class ExternalInterface(Interface): ...@@ -316,11 +346,11 @@ class ExternalInterface(Interface):
if name[0] == '_': # forbid anything that doesn't start with a _ if name[0] == '_': # forbid anything that doesn't start with a _
super(ExternalInterface, self).__setattr__(name, value) super(ExternalInterface, self).__setattr__(name, value)
return return
iface = nemu.iproute.interface(index = self.index) iface = nemu.iproute.interface(index=self.index)
setattr(iface, name, value) setattr(iface, name, value)
return nemu.iproute.set_if(iface) return nemu.iproute.set_if(iface)
def add_v4_address(self, address, prefix_len, broadcast = None): def add_v4_address(self, address, prefix_len, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast) addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
nemu.iproute.add_addr(self.index, addr) nemu.iproute.add_addr(self.index, addr)
...@@ -328,7 +358,7 @@ class ExternalInterface(Interface): ...@@ -328,7 +358,7 @@ class ExternalInterface(Interface):
addr = nemu.iproute.ipv6address(address, prefix_len) addr = nemu.iproute.ipv6address(address, prefix_len)
nemu.iproute.add_addr(self.index, addr) nemu.iproute.add_addr(self.index, addr)
def del_v4_address(self, address, prefix_len, broadcast = None): def del_v4_address(self, address, prefix_len, broadcast=None):
addr = nemu.iproute.ipv4address(address, prefix_len, broadcast) addr = nemu.iproute.ipv4address(address, prefix_len, broadcast)
nemu.iproute.del_addr(self.index, addr) nemu.iproute.del_addr(self.index, addr)
...@@ -342,23 +372,26 @@ class ExternalInterface(Interface): ...@@ -342,23 +372,26 @@ class ExternalInterface(Interface):
for a in addresses: for a in addresses:
if hasattr(a, 'broadcast'): if hasattr(a, 'broadcast'):
ret.append(dict( ret.append(dict(
address = a.address, address=a.address,
prefix_len = a.prefix_len, prefix_len=a.prefix_len,
broadcast = a.broadcast, broadcast=a.broadcast,
family = 'inet')) family='inet'))
else: else:
ret.append(dict( ret.append(dict(
address = a.address, address=a.address,
prefix_len = a.prefix_len, prefix_len=a.prefix_len,
family = 'inet6')) family='inet6'))
return ret return ret
class SlaveInterface(ExternalInterface): class SlaveInterface(ExternalInterface):
"""Class to handle the main-name-space-facing half of NodeInterface. """Class to handle the main-name-space-facing half of NodeInterface.
Does nothing, just avoids any destroy code.""" Does nothing, just avoids any destroy code."""
def destroy(self): def destroy(self):
pass pass
class ImportedInterface(ExternalInterface): class ImportedInterface(ExternalInterface):
"""Class to handle already existing interfaces. Analogous to """Class to handle already existing interfaces. Analogous to
ImportedNodeInterface, this class only differs in that the interface is ImportedNodeInterface, this class only differs in that the interface is
...@@ -366,6 +399,7 @@ class ImportedInterface(ExternalInterface): ...@@ -366,6 +399,7 @@ class ImportedInterface(ExternalInterface):
connected to Switch objects and not assigned to a name space. On connected to Switch objects and not assigned to a name space. On
destruction, the code will try to restore the interface to the state it destruction, the code will try to restore the interface to the state it
was in before being imported into nemu.""" was in before being imported into nemu."""
def __init__(self, iface): def __init__(self, iface):
self._original_state = None self._original_state = None
iface = nemu.iproute.get_if(iface) iface = nemu.iproute.get_if(iface)
...@@ -379,6 +413,7 @@ class ImportedInterface(ExternalInterface): ...@@ -379,6 +413,7 @@ class ImportedInterface(ExternalInterface):
nemu.iproute.set_if(self._original_state) nemu.iproute.set_if(self._original_state)
self._original_state = None self._original_state = None
# Switch is just another interface type # Switch is just another interface type
class Switch(ExternalInterface): class Switch(ExternalInterface):
...@@ -421,7 +456,7 @@ class Switch(ExternalInterface): ...@@ -421,7 +456,7 @@ class Switch(ExternalInterface):
if self._check_port(i.index): if self._check_port(i.index):
setattr(i, name, value) setattr(i, name, value)
# Set bridge # Set bridge
iface = nemu.iproute.bridge(index = self.index) iface = nemu.iproute.bridge(index=self.index)
setattr(iface, name, value) setattr(iface, name, value)
nemu.iproute.set_bridge(iface) nemu.iproute.set_bridge(iface)
...@@ -472,12 +507,12 @@ class Switch(ExternalInterface): ...@@ -472,12 +507,12 @@ class Switch(ExternalInterface):
self._apply_parameters({}, iface.control) self._apply_parameters({}, iface.control)
del self._ports[iface.control.index] del self._ports[iface.control.index]
def set_parameters(self, bandwidth = None, def set_parameters(self, bandwidth=None,
delay = None, delay_jitter = None, delay=None, delay_jitter=None,
delay_correlation = None, delay_distribution = None, delay_correlation=None, delay_distribution=None,
loss = None, loss_correlation = None, loss=None, loss_correlation=None,
dup = None, dup_correlation = None, dup=None, dup_correlation=None,
corrupt = None, corrupt_correlation = None): corrupt=None, corrupt_correlation=None):
"""Set the parameters that control the link characteristics. For the """Set the parameters that control the link characteristics. For the
description of each, refer to netem documentation: description of each, refer to netem documentation:
http://www.linuxfoundation.org/collaborate/workgroups/networking/netem http://www.linuxfoundation.org/collaborate/workgroups/networking/netem
...@@ -492,13 +527,13 @@ class Switch(ExternalInterface): ...@@ -492,13 +527,13 @@ class Switch(ExternalInterface):
`dup_correlation', `corrupt', and `corrupt_correlation' take a `dup_correlation', `corrupt', and `corrupt_correlation' take a
percentage value in the form of a number between 0 and 1. (50% is percentage value in the form of a number between 0 and 1. (50% is
passed as 0.5).""" passed as 0.5)."""
parameters = dict(bandwidth = bandwidth, parameters = dict(bandwidth=bandwidth,
delay = delay, delay_jitter = delay_jitter, delay=delay, delay_jitter=delay_jitter,
delay_correlation = delay_correlation, delay_correlation=delay_correlation,
delay_distribution = delay_distribution, delay_distribution=delay_distribution,
loss = loss, loss_correlation = loss_correlation, loss=loss, loss_correlation=loss_correlation,
dup = dup, dup_correlation = dup_correlation, dup=dup, dup_correlation=dup_correlation,
corrupt = corrupt, corrupt_correlation = corrupt_correlation) corrupt=corrupt, corrupt_correlation=corrupt_correlation)
try: try:
self._apply_parameters(parameters) self._apply_parameters(parameters)
except: except:
...@@ -506,7 +541,6 @@ class Switch(ExternalInterface): ...@@ -506,7 +541,6 @@ class Switch(ExternalInterface):
raise raise
self._parameters = parameters self._parameters = parameters
def _apply_parameters(self, parameters, port = None): def _apply_parameters(self, parameters, port=None):
for i in [port] if port else list(self._ports.values()): for i in [port] if port else list(self._ports.values()):
nemu.iproute.set_tc(i.index, **parameters) nemu.iproute.set_tc(i.index, **parameters)
...@@ -24,7 +24,10 @@ import re ...@@ -24,7 +24,10 @@ import re
import socket import socket
import struct import struct
import sys import sys
from typing import TypeVar, Callable, Literal
from attr import evolve
from attrs import define, setters, field
import six import six
from nemu.environ import * from nemu.environ import *
...@@ -46,18 +49,21 @@ def _any_to_bool(any): ...@@ -46,18 +49,21 @@ def _any_to_bool(any):
return any != "" return any != ""
return bool(any) return bool(any)
def _positive(val): def _positive(val):
v = int(val) v = int(val)
if v <= 0: if v <= 0:
raise ValueError("Invalid value: %d" % v) raise ValueError("Invalid value: %d" % v)
return v return v
def _non_empty_str(val): def _non_empty_str(val):
if val == "": if val == "":
return None return None
else: else:
return str(val) return str(val)
def _fix_lladdr(addr): def _fix_lladdr(addr):
foo = addr.lower() foo = addr.lower()
if ":" in addr: if ":" in addr:
...@@ -77,21 +83,41 @@ def _fix_lladdr(addr): ...@@ -77,21 +83,41 @@ def _fix_lladdr(addr):
# Glue # Glue
return ":".join(m.groups()) return ":".join(m.groups())
def _make_getter(attr, conv = lambda x: x):
def _make_getter(attr, conv=lambda x: x):
def getter(self): def getter(self):
return conv(getattr(self, attr)) return conv(getattr(self, attr))
return getter return getter
def _make_setter(attr, conv = lambda x: x):
def _make_setter(attr, conv=lambda x: x):
def setter(self, value): def setter(self, value):
if value == None: if value is None:
setattr(self, attr, None) setattr(self, attr, None)
else: else:
setattr(self, attr, conv(value)) setattr(self, attr, conv(value))
return setter return setter
T = TypeVar("T")
U = TypeVar("U")
def _if_any(conv: Callable[[T], U]):
def c(val: T) -> U:
if val is None:
return None
else:
return conv(val)
return c
# classes for internal use # classes for internal use
class interface(object): @define(repr=False)
class interface:
"""Class for internal use. It is mostly a data container used to easily """Class for internal use. It is mostly a data container used to easily
pass information around; with some convenience methods.""" pass information around; with some convenience methods."""
...@@ -99,25 +125,14 @@ class interface(object): ...@@ -99,25 +125,14 @@ class interface(object):
changeable_attributes = ["name", "mtu", "lladdr", "broadcast", "up", changeable_attributes = ["name", "mtu", "lladdr", "broadcast", "up",
"multicast", "arp"] "multicast", "arp"]
# Index should be read-only index: int = field(default=None, converter=_if_any(_positive), on_setattr=setters.frozen)
index = property(_make_getter("_index")) name: str = field(default=None)
up = property(_make_getter("_up"), _make_setter("_up", _any_to_bool)) up: bool = field(default=None, converter=_if_any(_any_to_bool))
mtu = property(_make_getter("_mtu"), _make_setter("_mtu", _positive)) mtu: int = field(default=None, converter=_if_any(_positive))
lladdr = property(_make_getter("_lladdr"), lladdr: str = field(default=None, converter=_if_any(_fix_lladdr))
_make_setter("_lladdr", _fix_lladdr)) broadcast: str = field(default=None)
arp = property(_make_getter("_arp"), _make_setter("_arp", _any_to_bool)) multicast: bool = field(default=None, converter=_if_any(_any_to_bool))
multicast = property(_make_getter("_mc"), _make_setter("_mc", _any_to_bool)) arp: bool = field(default=None, converter=_if_any(_any_to_bool))
def __init__(self, index = None, name = None, up = None, mtu = None,
lladdr = None, broadcast = None, multicast = None, arp = None):
self._index = _positive(index) if index is not None else None
self.name = name
self.up = up
self.mtu = mtu
self.lladdr = lladdr
self.broadcast = broadcast
self.multicast = multicast
self.arp = arp
def __repr__(self): def __repr__(self):
s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, " s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
...@@ -139,26 +154,23 @@ class interface(object): ...@@ -139,26 +154,23 @@ class interface(object):
broadcast = None if self.broadcast == o.broadcast else self.broadcast broadcast = None if self.broadcast == o.broadcast else self.broadcast
multicast = None if self.multicast == o.multicast else self.multicast multicast = None if self.multicast == o.multicast else self.multicast
arp = None if self.arp == o.arp else self.arp arp = None if self.arp == o.arp else self.arp
return self.__class__(self.index, name, up, mtu, lladdr, broadcast, return interface(self.index, name, up, mtu, lladdr, broadcast,
multicast, arp) multicast, arp)
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
@define(repr=False)
class bridge(interface): class bridge(interface):
changeable_attributes = interface.changeable_attributes + ["stp", changeable_attributes = interface.changeable_attributes + ["stp",
"forward_delay", "hello_time", "ageing_time", "max_age"] "forward_delay", "hello_time", "ageing_time", "max_age"]
# Index should be read-only stp: bool = field(default=None, converter=_if_any(_any_to_bool))
stp = property(_make_getter("_stp"), _make_setter("_stp", _any_to_bool)) forward_delay: float = field(default=None, converter=_if_any(float))
forward_delay = property(_make_getter("_forward_delay"), hello_time: float = field(default=None, converter=_if_any(float))
_make_setter("_forward_delay", float)) ageing_time: float = field(default=None, converter=_if_any(float))
hello_time = property(_make_getter("_hello_time"), max_age: float = field(default=None, converter=_if_any(float))
_make_setter("_hello_time", float))
ageing_time = property(_make_getter("_ageing_time"),
_make_setter("_ageing_time", float))
max_age = property(_make_getter("_max_age"),
_make_setter("_max_age", float))
@classmethod @classmethod
def upgrade(cls, iface, *kargs, **kwargs): def upgrade(cls, iface, *kargs, **kwargs):
...@@ -166,18 +178,6 @@ class bridge(interface): ...@@ -166,18 +178,6 @@ class bridge(interface):
return cls(iface.index, iface.name, iface.up, iface.mtu, iface.lladdr, return cls(iface.index, iface.name, iface.up, iface.mtu, iface.lladdr,
iface.broadcast, iface.multicast, iface.arp, *kargs, **kwargs) iface.broadcast, iface.multicast, iface.arp, *kargs, **kwargs)
def __init__(self, index = None, name = None, up = None, mtu = None,
lladdr = None, broadcast = None, multicast = None, arp = None,
stp = None, forward_delay = None, hello_time = None,
ageing_time = None, max_age = None):
super(bridge, self).__init__(index, name, up, mtu, lladdr, broadcast,
multicast, arp)
self.stp = stp
self.forward_delay = forward_delay
self.hello_time = hello_time
self.ageing_time = ageing_time
self.max_age = max_age
def __repr__(self): def __repr__(self):
s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, " s = "%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
s += "broadcast = %s, multicast = %s, arp = %s, stp = %s, " s += "broadcast = %s, multicast = %s, arp = %s, stp = %s, "
...@@ -193,7 +193,7 @@ class bridge(interface): ...@@ -193,7 +193,7 @@ class bridge(interface):
self.max_age.__repr__()) self.max_age.__repr__())
def __sub__(self, o): def __sub__(self, o):
r = super(bridge, self).__sub__(o) r = bridge.upgrade(super().__sub__(o))
if type(o) == interface: if type(o) == interface:
return r return r
r.stp = None if self.stp == o.stp else self.stp r.stp = None if self.stp == o.stp else self.stp
...@@ -206,11 +206,18 @@ class bridge(interface): ...@@ -206,11 +206,18 @@ class bridge(interface):
r.max_age = None if self.max_age == o.max_age else self.max_age r.max_age = None if self.max_age == o.max_age else self.max_age
return r return r
class address(object): class address(object):
"""Class for internal use. It is mostly a data container used to easily """Class for internal use. It is mostly a data container used to easily
pass information around; with some convenience methods. __eq__ and pass information around; with some convenience methods. __eq__ and
__hash__ are defined just to be able to easily find duplicated __hash__ are defined just to be able to easily find duplicated
addresses.""" addresses."""
def __init__(self, address: str, prefix_len: int, family: socket.AddressFamily):
self.address = address
self.prefix_len = int(prefix_len)
self.family = family
# broadcast is not taken into account for differentiating addresses # broadcast is not taken into account for differentiating addresses
def __eq__(self, o): def __eq__(self, o):
if not isinstance(o, address): if not isinstance(o, address):
...@@ -223,12 +230,11 @@ class address(object): ...@@ -223,12 +230,11 @@ class address(object):
self.family.__hash__()) self.family.__hash__())
return h return h
class ipv4address(address): class ipv4address(address):
def __init__(self, address, prefix_len, broadcast): def __init__(self, address: str, prefix_len: int, broadcast):
self.address = address super().__init__(address, prefix_len, socket.AF_INET)
self.prefix_len = int(prefix_len)
self.broadcast = broadcast self.broadcast = broadcast
self.family = socket.AF_INET
def __repr__(self): def __repr__(self):
s = "%s.%s(address = %s, prefix_len = %d, broadcast = %s)" s = "%s.%s(address = %s, prefix_len = %d, broadcast = %s)"
...@@ -236,11 +242,10 @@ class ipv4address(address): ...@@ -236,11 +242,10 @@ class ipv4address(address):
self.address.__repr__(), self.prefix_len, self.address.__repr__(), self.prefix_len,
self.broadcast.__repr__()) self.broadcast.__repr__())
class ipv6address(address): class ipv6address(address):
def __init__(self, address, prefix_len): def __init__(self, address: str, prefix_len: int):
self.address = address super().__init__(address, prefix_len, socket.AF_INET6)
self.prefix_len = int(prefix_len)
self.family = socket.AF_INET6
def __repr__(self): def __repr__(self):
s = "%s.%s(address = %s, prefix_len = %d)" s = "%s.%s(address = %s, prefix_len = %d)"
...@@ -264,8 +269,8 @@ class route(object): ...@@ -264,8 +269,8 @@ class route(object):
metric = property(_make_getter("_metric"), metric = property(_make_getter("_metric"),
lambda s, v: setattr(s, "_metric", int(v or 0))) lambda s, v: setattr(s, "_metric", int(v or 0)))
def __init__(self, tipe = "unicast", prefix = None, prefix_len = 0, def __init__(self, tipe="unicast", prefix=None, prefix_len=0,
nexthop = None, interface = None, metric = 0): nexthop=None, interface=None, metric=0):
self.tipe = tipe self.tipe = tipe
self.prefix = prefix self.prefix = prefix
self.prefix_len = prefix_len self.prefix_len = prefix_len
...@@ -289,20 +294,22 @@ class route(object): ...@@ -289,20 +294,22 @@ class route(object):
self.prefix_len == o.prefix_len and self.nexthop == o.nexthop self.prefix_len == o.prefix_len and self.nexthop == o.nexthop
and self.interface == o.interface and self.metric == o.metric) and self.interface == o.interface and self.metric == o.metric)
# helpers # helpers
def _get_if_name(iface): def _get_if_name(iface: interface | int | str):
if isinstance(iface, interface): if isinstance(iface, interface):
if iface.name != None: if iface.name is not None:
return iface.name return iface.name
if isinstance(iface, str): if isinstance(iface, str):
return iface return iface
return get_if(iface).name return get_if(iface).name
# XXX: ideally this should be replaced by netlink communication # XXX: ideally this should be replaced by netlink communication
# Interface handling # Interface handling
# FIXME: try to lower the amount of calls to retrieve data!! # FIXME: try to lower the amount of calls to retrieve data!!
def get_if_data(): def get_if_data() -> tuple[dict[int, interface], dict[str, interface]]:
"""Gets current interface information. Returns a tuple (byidx, bynam) in """Gets current interface information. Returns a tuple (byidx, bynam) in
which each element is a dictionary with the same data, but using different which each element is a dictionary with the same data, but using different
keys: interface indexes and interface names. keys: interface indexes and interface names.
...@@ -323,21 +330,22 @@ def get_if_data(): ...@@ -323,21 +330,22 @@ def get_if_data():
r'brd ([0-9a-f:]+))?', line) r'brd ([0-9a-f:]+))?', line)
flags = match.group(3).split(",") flags = match.group(3).split(",")
i = interface( i = interface(
index = match.group(1), index=match.group(1),
name = match.group(2), name=match.group(2),
up = "UP" in flags, up="UP" in flags,
mtu = match.group(4), mtu=match.group(4),
lladdr = match.group(5), lladdr=match.group(5),
arp = not ("NOARP" in flags), arp=not ("NOARP" in flags),
broadcast = match.group(6), broadcast=match.group(6),
multicast = "MULTICAST" in flags) multicast="MULTICAST" in flags)
byidx[idx] = bynam[i.name] = i byidx[idx] = bynam[i.name] = i
return byidx, bynam return byidx, bynam
def get_if(iface):
def get_if(iface: interface | int | str) -> interface:
ifdata = get_if_data() ifdata = get_if_data()
if isinstance(iface, interface): if isinstance(iface, interface):
if iface.index != None: if iface.index is not None:
return ifdata[0][iface.index] return ifdata[0][iface.index]
else: else:
return ifdata[1][iface.name] return ifdata[1][iface.name]
...@@ -345,7 +353,8 @@ def get_if(iface): ...@@ -345,7 +353,8 @@ def get_if(iface):
return ifdata[0][iface] return ifdata[0][iface]
return ifdata[1][iface] return ifdata[1][iface]
def create_if_pair(if1, if2):
def create_if_pair(if1: interface, if2: interface) -> tuple[interface, interface]:
assert if1.name and if2.name assert if1.name and if2.name
cmd = [[], []] cmd = [[], []]
...@@ -375,18 +384,20 @@ def create_if_pair(if1, if2): ...@@ -375,18 +384,20 @@ def create_if_pair(if1, if2):
interfaces = get_if_data()[1] interfaces = get_if_data()[1]
return interfaces[if1.name], interfaces[if2.name] return interfaces[if1.name], interfaces[if2.name]
def del_if(iface): def del_if(iface):
ifname = _get_if_name(iface) ifname = _get_if_name(iface)
execute([IP_PATH, "link", "del", ifname]) execute([IP_PATH, "link", "del", ifname])
def set_if(iface, recover = True):
def do_cmds(cmds, orig_iface): def set_if(iface: interface, recover=True):
def do_cmds(cmds: list[list[str]], orig_iface: interface):
for c in cmds: for c in cmds:
try: try:
execute(c) execute(c)
except: except:
if recover: if recover:
set_if(orig_iface, recover = False) # rollback set_if(orig_iface, recover=False) # rollback
raise raise
orig_iface = get_if(iface) orig_iface = get_if(iface)
...@@ -410,26 +421,28 @@ def set_if(iface, recover = True): ...@@ -410,26 +421,28 @@ def set_if(iface, recover = True):
# iface needs to be down # iface needs to be down
cmds.append(_ils + ["down"]) cmds.append(_ils + ["down"])
cmds.append(_ils + ["address", diff.lladdr]) cmds.append(_ils + ["address", diff.lladdr])
if orig_iface.up and diff.up == None: if orig_iface.up and diff.up is None:
# restore if it was up and it's not going to be set later # restore if it was up and it's not going to be set later
cmds.append(_ils + ["up"]) cmds.append(_ils + ["up"])
if diff.mtu: if diff.mtu:
cmds.append(_ils + ["mtu", str(diff.mtu)]) cmds.append(_ils + ["mtu", str(diff.mtu)])
if diff.broadcast: if diff.broadcast:
cmds.append(_ils + ["broadcast", diff.broadcast]) cmds.append(_ils + ["broadcast", diff.broadcast])
if diff.multicast != None: if diff.multicast is not None:
cmds.append(_ils + ["multicast", "on" if diff.multicast else "off"]) cmds.append(_ils + ["multicast", "on" if diff.multicast else "off"])
if diff.arp != None: if diff.arp is not None:
cmds.append(_ils + ["arp", "on" if diff.arp else "off"]) cmds.append(_ils + ["arp", "on" if diff.arp else "off"])
if diff.up != None: if diff.up is not None:
cmds.append(_ils + ["up" if diff.up else "down"]) cmds.append(_ils + ["up" if diff.up else "down"])
do_cmds(cmds, orig_iface) do_cmds(cmds, orig_iface)
def change_netns(iface, netns): def change_netns(iface, netns):
ifname = _get_if_name(iface) ifname = _get_if_name(iface)
execute([IP_PATH, "link", "set", "dev", ifname, "netns", str(netns)]) execute([IP_PATH, "link", "set", "dev", ifname, "netns", str(netns)])
# Address handling # Address handling
def get_addr_data(): def get_addr_data():
...@@ -459,16 +472,16 @@ def get_addr_data(): ...@@ -459,16 +472,16 @@ def get_addr_data():
match = re.search(r'^\s*inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line) match = re.search(r'^\s*inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
if match: if match:
bynam[current].append(ipv4address( bynam[current].append(ipv4address(
address = match.group(1), address=match.group(1),
prefix_len = match.group(2), prefix_len=match.group(2),
broadcast = match.group(3))) broadcast=match.group(3)))
continue continue
match = re.search(r'^\s*inet6 ([0-9a-f:]+)/(\d+)', line) match = re.search(r'^\s*inet6 ([0-9a-f:]+)/(\d+)', line)
if match: if match:
bynam[current].append(ipv6address( bynam[current].append(ipv6address(
address = match.group(1), address=match.group(1),
prefix_len = match.group(2))) prefix_len=match.group(2)))
continue continue
# Extra info, ignored. # Extra info, ignored.
...@@ -476,6 +489,7 @@ def get_addr_data(): ...@@ -476,6 +489,7 @@ def get_addr_data():
return byidx, bynam return byidx, bynam
def add_addr(iface, address): def add_addr(iface, address):
ifname = _get_if_name(iface) ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname] addresses = get_addr_data()[1][ifname]
...@@ -487,6 +501,7 @@ def add_addr(iface, address): ...@@ -487,6 +501,7 @@ def add_addr(iface, address):
cmd += ["broadcast", address.broadcast if address.broadcast else "+"] cmd += ["broadcast", address.broadcast if address.broadcast else "+"]
execute(cmd) execute(cmd)
def del_addr(iface, address): def del_addr(iface, address):
ifname = _get_if_name(iface) ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname] addresses = get_addr_data()[1][ifname]
...@@ -496,6 +511,7 @@ def del_addr(iface, address): ...@@ -496,6 +511,7 @@ def del_addr(iface, address):
"%s/%d" % (address.address, int(address.prefix_len))] "%s/%d" % (address.address, int(address.prefix_len))]
execute(cmd) execute(cmd)
# Bridge handling # Bridge handling
def _sysfs_read_br(brname): def _sysfs_read_br(brname):
def readval(fname): def readval(fname):
...@@ -509,12 +525,13 @@ def _sysfs_read_br(brname): ...@@ -509,12 +525,13 @@ def _sysfs_read_br(brname):
except: except:
return None return None
return dict( return dict(
stp = readval(p + "stp_state"), stp=readval(p + "stp_state"),
forward_delay = float(readval(p + "forward_delay")) / 100, forward_delay=float(readval(p + "forward_delay")) / 100,
hello_time = float(readval(p + "hello_time")) / 100, hello_time=float(readval(p + "hello_time")) / 100,
ageing_time = float(readval(p + "ageing_time")) / 100, ageing_time=float(readval(p + "ageing_time")) / 100,
max_age = float(readval(p + "max_age")) / 100, max_age=float(readval(p + "max_age")) / 100,
ports = os.listdir(p2)) ports=os.listdir(p2))
def get_bridge_data(): def get_bridge_data():
# brctl stinks too much; it is better to directly use sysfs, it is # brctl stinks too much; it is better to directly use sysfs, it is
...@@ -525,7 +542,7 @@ def get_bridge_data(): ...@@ -525,7 +542,7 @@ def get_bridge_data():
ifdata = get_if_data() ifdata = get_if_data()
for iface in ifdata[0].values(): for iface in ifdata[0].values():
brdata = _sysfs_read_br(iface.name) brdata = _sysfs_read_br(iface.name)
if brdata == None: if brdata is None:
continue continue
ports[iface.index] = [ifdata[1][x].index for x in brdata["ports"]] ports[iface.index] = [ifdata[1][x].index for x in brdata["ports"]]
del brdata["ports"] del brdata["ports"]
...@@ -533,16 +550,18 @@ def get_bridge_data(): ...@@ -533,16 +550,18 @@ def get_bridge_data():
bridge.upgrade(iface, **brdata) bridge.upgrade(iface, **brdata)
return byidx, bynam, ports return byidx, bynam, ports
def get_bridge(br): def get_bridge(br):
iface = get_if(br) iface = get_if(br)
brdata = _sysfs_read_br(iface.name) brdata = _sysfs_read_br(iface.name)
#ports = [ifdata[1][x].index for x in brdata["ports"]] # ports = [ifdata[1][x].index for x in brdata["ports"]]
del brdata["ports"] del brdata["ports"]
return bridge.upgrade(iface, **brdata) return bridge.upgrade(iface, **brdata)
def create_bridge(br): def create_bridge(br):
if isinstance(br, str): if isinstance(br, str):
br = interface(name = br) br = interface(name=br)
assert br.name assert br.name
execute([BRCTL_PATH, "addbr", br.name]) execute([BRCTL_PATH, "addbr", br.name])
try: try:
...@@ -556,54 +575,60 @@ def create_bridge(br): ...@@ -556,54 +575,60 @@ def create_bridge(br):
six.reraise(t, v, bt) six.reraise(t, v, bt)
return get_if_data()[1][br.name] return get_if_data()[1][br.name]
def del_bridge(br): def del_bridge(br):
brname = _get_if_name(br) brname = _get_if_name(br)
execute([BRCTL_PATH, "delbr", brname]) execute([BRCTL_PATH, "delbr", brname])
def set_bridge(br, recover = True):
def set_bridge(br, recover=True):
def saveval(fname, val): def saveval(fname, val):
f = open(fname, "w") f = open(fname, "w")
f.write(str(val)) f.write(str(val))
f.close() f.close()
def do_cmds(basename, cmds, orig_br): def do_cmds(basename, cmds, orig_br):
for n, v in cmds: for n, v in cmds:
try: try:
saveval(basename + n, v) saveval(basename + n, v)
except: except:
if recover: if recover:
set_bridge(orig_br, recover = False) # rollback set_bridge(orig_br, recover=False) # rollback
set_if(orig_br, recover = False) # rollback set_if(orig_br, recover=False) # rollback
raise raise
orig_br = get_bridge(br) orig_br = get_bridge(br)
diff = br - orig_br # Only set what's needed diff = br - orig_br # Only set what's needed
cmds = [] cmds = []
if diff.stp != None: if diff.stp is not None:
cmds.append(("stp_state", int(diff.stp))) cmds.append(("stp_state", int(diff.stp)))
if diff.forward_delay != None: if diff.forward_delay is not None:
cmds.append(("forward_delay", int(diff.forward_delay))) cmds.append(("forward_delay", int(diff.forward_delay)))
if diff.hello_time != None: if diff.hello_time is not None:
cmds.append(("hello_time", int(diff.hello_time))) cmds.append(("hello_time", int(diff.hello_time)))
if diff.ageing_time != None: if diff.ageing_time is not None:
cmds.append(("ageing_time", int(diff.ageing_time))) cmds.append(("ageing_time", int(diff.ageing_time)))
if diff.max_age != None: if diff.max_age is not None:
cmds.append(("max_age", int(diff.max_age))) cmds.append(("max_age", int(diff.max_age)))
set_if(diff) set_if(diff)
name = diff.name if diff.name != None else orig_br.name name = diff.name if diff.name is not None else orig_br.name
do_cmds("/sys/class/net/%s/bridge/" % name, cmds, orig_br) do_cmds("/sys/class/net/%s/bridge/" % name, cmds, orig_br)
def add_bridge_port(br, iface): def add_bridge_port(br, iface):
ifname = _get_if_name(iface) ifname = _get_if_name(iface)
brname = _get_if_name(br) brname = _get_if_name(br)
execute([BRCTL_PATH, "addif", brname, ifname]) execute([BRCTL_PATH, "addif", brname, ifname])
def del_bridge_port(br, iface): def del_bridge_port(br, iface):
ifname = _get_if_name(iface) ifname = _get_if_name(iface)
brname = _get_if_name(br) brname = _get_if_name(br)
execute([BRCTL_PATH, "delif", brname, ifname]) execute([BRCTL_PATH, "delif", brname, ifname])
# Routing # Routing
def get_all_route_data(): def get_all_route_data():
...@@ -636,23 +661,27 @@ def get_all_route_data(): ...@@ -636,23 +661,27 @@ def get_all_route_data():
metric)) metric))
return ret return ret
def get_route_data():
def get_route_data() -> list[route]:
# filter out non-unicast routes # filter out non-unicast routes
return [x for x in get_all_route_data() if x.tipe == "unicast"] return [x for x in get_all_route_data() if x.tipe == "unicast"]
def add_route(route):
def add_route(route: route):
# Cannot really test this # Cannot really test this
#if route in get_all_route_data(): # if route in get_all_route_data():
# raise ValueError("Route already exists") # raise ValueError("Route already exists")
_add_del_route("add", route) _add_del_route("add", route)
def del_route(route):
def del_route(route: route):
# Cannot really test this # Cannot really test this
#if route not in get_all_route_data(): # if route not in get_all_route_data():
# raise ValueError("Route does not exist") # raise ValueError("Route does not exist")
_add_del_route("del", route) _add_del_route("del", route)
def _add_del_route(action, route):
def _add_del_route(action: Literal["add", "del"], route: route):
cmd = [IP_PATH, "route", action] cmd = [IP_PATH, "route", action]
if route.tipe != "unicast": if route.tipe != "unicast":
cmd += [route.tipe] cmd += [route.tipe]
...@@ -666,6 +695,7 @@ def _add_del_route(action, route): ...@@ -666,6 +695,7 @@ def _add_del_route(action, route):
cmd += ["dev", _get_if_name(route.interface)] cmd += ["dev", _get_if_name(route.interface)]
execute(cmd) execute(cmd)
# TC stuff # TC stuff
def get_tc_tree(): def get_tc_tree():
...@@ -706,11 +736,15 @@ def get_tc_tree(): ...@@ -706,11 +736,15 @@ def get_tc_tree():
for h in data[data_node[0]]: for h in data[data_node[0]]:
node["children"].append(gen_tree(data, h)) node["children"].append(gen_tree(data, h))
return node return node
tree[iface] = gen_tree(data[iface], data[iface][None][0]) tree[iface] = gen_tree(data[iface], data[iface][None][0])
return tree return tree
_multipliers = {"M": 1000000, "K": 1000} _multipliers = {"M": 1000000, "K": 1000}
_dividers = {"m": 1000, "u": 1000000} _dividers = {"m": 1000, "u": 1000000}
def _parse_netem_delay(line): def _parse_netem_delay(line):
ret = {} ret = {}
match = re.search(r'delay ([\d.]+)([mu]?)s(?: +([\d.]+)([mu]?)s)?' + match = re.search(r'delay ([\d.]+)([mu]?)s(?: +([\d.]+)([mu]?)s)?' +
...@@ -737,6 +771,7 @@ def _parse_netem_delay(line): ...@@ -737,6 +771,7 @@ def _parse_netem_delay(line):
return ret return ret
def _parse_netem_loss(line): def _parse_netem_loss(line):
ret = {} ret = {}
match = re.search(r'loss ([\d.]+)%(?: *([\d.]+)%)?', line) match = re.search(r'loss ([\d.]+)%(?: *([\d.]+)%)?', line)
...@@ -748,6 +783,7 @@ def _parse_netem_loss(line): ...@@ -748,6 +783,7 @@ def _parse_netem_loss(line):
ret["loss_correlation"] = float(match.group(2)) / 100 ret["loss_correlation"] = float(match.group(2)) / 100
return ret return ret
def _parse_netem_dup(line): def _parse_netem_dup(line):
ret = {} ret = {}
match = re.search(r'duplicate ([\d.]+)%(?: *([\d.]+)%)?', line) match = re.search(r'duplicate ([\d.]+)%(?: *([\d.]+)%)?', line)
...@@ -759,6 +795,7 @@ def _parse_netem_dup(line): ...@@ -759,6 +795,7 @@ def _parse_netem_dup(line):
ret["dup_correlation"] = float(match.group(2)) / 100 ret["dup_correlation"] = float(match.group(2)) / 100
return ret return ret
def _parse_netem_corrupt(line): def _parse_netem_corrupt(line):
ret = {} ret = {}
match = re.search(r'corrupt ([\d.]+)%(?: *([\d.]+)%)?', line) match = re.search(r'corrupt ([\d.]+)%(?: *([\d.]+)%)?', line)
...@@ -770,6 +807,7 @@ def _parse_netem_corrupt(line): ...@@ -770,6 +807,7 @@ def _parse_netem_corrupt(line):
ret["corrupt_correlation"] = float(match.group(2)) / 100 ret["corrupt_correlation"] = float(match.group(2)) / 100
return ret return ret
def get_tc_data(): def get_tc_data():
tree = get_tc_tree() tree = get_tc_tree()
ifdata = get_if_data() ifdata = get_if_data()
...@@ -823,19 +861,21 @@ def get_tc_data(): ...@@ -823,19 +861,21 @@ def get_tc_data():
ret[i].update(_parse_netem_corrupt(netem[0])) ret[i].update(_parse_netem_corrupt(netem[0]))
return ret, ifdata[0], ifdata[1] return ret, ifdata[0], ifdata[1]
def clear_tc(iface): def clear_tc(iface):
iface = get_if(iface) iface = get_if(iface)
tcdata = get_tc_data()[0] tcdata = get_tc_data()[0]
if tcdata[iface.index] == None: if tcdata[iface.index] is None:
return return
# Any other case, we clean # Any other case, we clean
execute([TC_PATH, "qdisc", "del", "dev", iface.name, "root"]) execute([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
delay_correlation = None, delay_distribution = None, def set_tc(iface, bandwidth=None, delay=None, delay_jitter=None,
loss = None, loss_correlation = None, delay_correlation=None, delay_distribution=None,
dup = None, dup_correlation = None, loss=None, loss_correlation=None,
corrupt = None, corrupt_correlation = None): dup=None, dup_correlation=None,
corrupt=None, corrupt_correlation=None):
use_netem = bool(delay or delay_jitter or delay_correlation or use_netem = bool(delay or delay_jitter or delay_correlation or
delay_distribution or loss or loss_correlation or dup or delay_distribution or loss or loss_correlation or dup or
dup_correlation or corrupt or corrupt_correlation) dup_correlation or corrupt or corrupt_correlation)
...@@ -920,10 +960,11 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None, ...@@ -920,10 +960,11 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
for c in commands: for c in commands:
execute(c) execute(c)
def create_tap(iface, use_pi = False, tun = False):
def create_tap(iface, use_pi=False, tun=False):
"""Creates a tap/tun device and returns the associated file descriptor""" """Creates a tap/tun device and returns the associated file descriptor"""
if isinstance(iface, str): if isinstance(iface, str):
iface = interface(name = iface) iface = interface(name=iface)
assert iface.name assert iface.name
IFF_TUN = 0x0001 IFF_TUN = 0x0001
...@@ -952,4 +993,3 @@ def create_tap(iface, use_pi = False, tun = False): ...@@ -952,4 +993,3 @@ def create_tap(iface, use_pi = False, tun = False):
raise raise
interfaces = get_if_data()[1] interfaces = get_if_data()[1]
return interfaces[iface.name], fd return interfaces[iface.name], fd
...@@ -21,10 +21,13 @@ import os ...@@ -21,10 +21,13 @@ import os
import socket import socket
import sys import sys
import traceback import traceback
from typing import MutableMapping
import unshare import unshare
import weakref import weakref
import nemu.interface import nemu.interface
import nemu.iproute
import nemu.protocol import nemu.protocol
import nemu.subprocess_ import nemu.subprocess_
from nemu import compat from nemu import compat
...@@ -33,10 +36,11 @@ from nemu.environ import * ...@@ -33,10 +36,11 @@ from nemu.environ import *
__all__ = ['Node', 'get_nodes', 'import_if'] __all__ = ['Node', 'get_nodes', 'import_if']
class Node(object): class Node(object):
_nodes = weakref.WeakValueDictionary() _nodes: MutableMapping[int, "Node"] = weakref.WeakValueDictionary()
_nextnode = 0 _nextnode = 0
_processes: MutableMapping[int, nemu.subprocess_.Subprocess]
@staticmethod @staticmethod
def get_nodes(): def get_nodes() -> list["Node"]:
s = sorted(list(Node._nodes.items()), key = lambda x: x[0]) s = sorted(list(Node._nodes.items()), key = lambda x: x[0])
return [x[1] for x in s] return [x[1] for x in s]
...@@ -98,7 +102,7 @@ class Node(object): ...@@ -98,7 +102,7 @@ class Node(object):
return self._pid return self._pid
# Subprocesses # Subprocesses
def _add_subprocess(self, subprocess): def _add_subprocess(self, subprocess: nemu.subprocess_.Subprocess):
self._processes[subprocess.pid] = subprocess self._processes[subprocess.pid] = subprocess
def Subprocess(self, *kargs, **kwargs): def Subprocess(self, *kargs, **kwargs):
...@@ -188,13 +192,13 @@ class Node(object): ...@@ -188,13 +192,13 @@ class Node(object):
r = self.route(*args, **kwargs) r = self.route(*args, **kwargs)
return self._slave.del_route(r) return self._slave.del_route(r)
def get_routes(self): def get_routes(self) -> list[route]:
return self._slave.get_route_data() return self._slave.get_route_data()
# Handle the creation of the child; parent gets (fd, pid), child creates and # Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns. # runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run. # Requires CAP_SYS_ADMIN privileges to run.
def _start_child(nonetns) -> (socket.socket, int): def _start_child(nonetns: bool) -> (socket.socket, int):
# Create socket pair to communicate # Create socket pair to communicate
(s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop # Spawn a child that will run in a loop
......
...@@ -21,7 +21,7 @@ import struct ...@@ -21,7 +21,7 @@ import struct
from io import IOBase from io import IOBase
def __check_socket(sock: socket.socket | IOBase): def __check_socket(sock: socket.socket | IOBase) -> socket.socket:
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX: if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed") raise ValueError("Only AF_UNIX sockets are allowed")
...@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase): ...@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase):
return sock return sock
def __check_fd(fd): def __check_fd(fd) -> int:
try: try:
fd = fd.fileno() fd = fd.fileno()
except AttributeError: except AttributeError:
...@@ -44,7 +44,7 @@ def __check_fd(fd): ...@@ -44,7 +44,7 @@ def __check_fd(fd):
return fd return fd
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096): def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]:
size = struct.calcsize("@i") size = struct.calcsize("@i")
msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size)) msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size))
cmsg_level, cmsg_type, cmsg_data = ancdata[0] cmsg_level, cmsg_type, cmsg_data = ancdata[0]
...@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096): ...@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
return fd, msg.decode("utf-8") return fd, msg.decode("utf-8")
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE"): def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE") -> int:
return __check_socket(sock).sendmsg( return __check_socket(sock).sendmsg(
[message], [message],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))]) [(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))])
\ No newline at end of file
...@@ -29,6 +29,7 @@ import tempfile ...@@ -29,6 +29,7 @@ import tempfile
import time import time
import traceback import traceback
from pickle import loads, dumps from pickle import loads, dumps
from typing import Literal
import nemu.iproute import nemu.iproute
import nemu.subprocess_ import nemu.subprocess_
...@@ -278,7 +279,7 @@ class Server(object): ...@@ -278,7 +279,7 @@ class Server(object):
self.reply(220, "Hello."); self.reply(220, "Hello.");
while not self._closed: while not self._closed:
cmd = self.readcmd() cmd = self.readcmd()
if cmd == None: if cmd is None:
continue continue
try: try:
cmd[0](cmd[1], *cmd[2]) cmd[0](cmd[1], *cmd[2])
...@@ -422,7 +423,7 @@ class Server(object): ...@@ -422,7 +423,7 @@ class Server(object):
else: else:
ret = nemu.subprocess_.wait(pid) ret = nemu.subprocess_.wait(pid)
if ret != None: if ret is not None:
self._children.remove(pid) self._children.remove(pid)
if pid in self._xauthfiles: if pid in self._xauthfiles:
try: try:
...@@ -449,7 +450,7 @@ class Server(object): ...@@ -449,7 +450,7 @@ class Server(object):
self.reply(200, "Process signalled.") self.reply(200, "Process signalled.")
def do_IF_LIST(self, cmdname, ifnr=None): def do_IF_LIST(self, cmdname, ifnr=None):
if ifnr == None: if ifnr is None:
ifdata = nemu.iproute.get_if_data()[0] ifdata = nemu.iproute.get_if_data()[0]
else: else:
ifdata = nemu.iproute.get_if(ifnr) ifdata = nemu.iproute.get_if(ifnr)
...@@ -479,7 +480,7 @@ class Server(object): ...@@ -479,7 +480,7 @@ class Server(object):
def do_ADDR_LIST(self, cmdname, ifnr=None): def do_ADDR_LIST(self, cmdname, ifnr=None):
addrdata = nemu.iproute.get_addr_data()[0] addrdata = nemu.iproute.get_addr_data()[0]
if ifnr != None: if ifnr is not None:
addrdata = addrdata[ifnr] addrdata = addrdata[ifnr]
self.reply(200, ["# Address data follows.", self.reply(200, ["# Address data follows.",
_b64(dumps(addrdata, protocol=2))]) _b64(dumps(addrdata, protocol=2))])
...@@ -652,7 +653,7 @@ class Client(object): ...@@ -652,7 +653,7 @@ class Client(object):
stdin/stdout/stderr can only be None or a open file descriptor. stdin/stdout/stderr can only be None or a open file descriptor.
See nemu.subprocess_.spawn for details.""" See nemu.subprocess_.spawn for details."""
if executable == None: if executable is None:
executable = argv[0] executable = argv[0]
params = ["PROC", "CRTE", _b64(executable)] params = ["PROC", "CRTE", _b64(executable)]
for i in argv: for i in argv:
...@@ -663,28 +664,28 @@ class Client(object): ...@@ -663,28 +664,28 @@ class Client(object):
# After this, if we get an error, we have to abort the PROC # After this, if we get an error, we have to abort the PROC
try: try:
if user != None: if user is not None:
self._send_cmd("PROC", "USER", _b64(user)) self._send_cmd("PROC", "USER", _b64(user))
self._read_and_check_reply() self._read_and_check_reply()
if cwd != None: if cwd is not None:
self._send_cmd("PROC", "CWD", _b64(cwd)) self._send_cmd("PROC", "CWD", _b64(cwd))
self._read_and_check_reply() self._read_and_check_reply()
if env != None: if env is not None:
params = [] params = []
for k, v in env.items(): for k, v in env.items():
params.extend([_b64(k), _b64(v)]) params.extend([_b64(k), _b64(v)])
self._send_cmd("PROC", "ENV", *params) self._send_cmd("PROC", "ENV", *params)
self._read_and_check_reply() self._read_and_check_reply()
if stdin != None: if stdin is not None:
os.set_inheritable(stdin, True) os.set_inheritable(stdin, True)
self._send_fd("SIN", stdin) self._send_fd("SIN", stdin)
if stdout != None: if stdout is not None:
os.set_inheritable(stdout, True) os.set_inheritable(stdout, True)
self._send_fd("SOUT", stdout) self._send_fd("SOUT", stdout)
if stderr != None: if stderr is not None:
os.set_inheritable(stderr, True) os.set_inheritable(stderr, True)
self._send_fd("SERR", stderr) self._send_fd("SERR", stderr)
except: except:
...@@ -739,7 +740,7 @@ class Client(object): ...@@ -739,7 +740,7 @@ class Client(object):
cmd = ["IF", "SET", interface.index] cmd = ["IF", "SET", interface.index]
for k in interface.changeable_attributes: for k in interface.changeable_attributes:
v = getattr(interface, k) v = getattr(interface, k)
if v != None: if v is not None:
cmd += [k, str(v)] cmd += [k, str(v)]
self._send_cmd(*cmd) self._send_cmd(*cmd)
...@@ -761,7 +762,7 @@ class Client(object): ...@@ -761,7 +762,7 @@ class Client(object):
data = self._read_and_check_reply() data = self._read_and_check_reply()
return loads(_db64(data.partition("\n")[2])) return loads(_db64(data.partition("\n")[2]))
def add_addr(self, ifnr, address): def add_addr(self, ifnr: int, address: nemu.iproute.address):
if hasattr(address, "broadcast") and address.broadcast: if hasattr(address, "broadcast") and address.broadcast:
self._send_cmd("ADDR", "ADD", ifnr, address.address, self._send_cmd("ADDR", "ADD", ifnr, address.address,
address.prefix_len, address.broadcast) address.prefix_len, address.broadcast)
...@@ -770,7 +771,7 @@ class Client(object): ...@@ -770,7 +771,7 @@ class Client(object):
address.prefix_len) address.prefix_len)
self._read_and_check_reply() self._read_and_check_reply()
def del_addr(self, ifnr, address): def del_addr(self, ifnr: int, address: nemu.iproute.address):
self._send_cmd("ADDR", "DEL", ifnr, address.address, address.prefix_len) self._send_cmd("ADDR", "DEL", ifnr, address.address, address.prefix_len)
self._read_and_check_reply() self._read_and_check_reply()
...@@ -785,14 +786,14 @@ class Client(object): ...@@ -785,14 +786,14 @@ class Client(object):
def del_route(self, route): def del_route(self, route):
self._add_del_route("DEL", route) self._add_del_route("DEL", route)
def _add_del_route(self, action, route): def _add_del_route(self, action: Literal["ADD", "DEL"], route: nemu.iproute.route):
args = ["ROUT", action, _b64(route.tipe), _b64(route.prefix), args = ["ROUT", action, _b64(route.tipe), _b64(route.prefix),
route.prefix_len or 0, _b64(route.nexthop), route.prefix_len or 0, _b64(route.nexthop),
route.interface or 0, route.metric or 0] route.interface or 0, route.metric or 0]
self._send_cmd(*args) self._send_cmd(*args)
self._read_and_check_reply() self._read_and_check_reply()
def set_x11(self, protoname, hexkey): def set_x11(self, protoname: str, hexkey: str) -> socket.socket:
# Returns a socket ready to accept() connections # Returns a socket ready to accept() connections
self._send_cmd("X11", "SET", protoname, hexkey) self._send_cmd("X11", "SET", protoname, hexkey)
self._read_and_check_reply() self._read_and_check_reply()
...@@ -823,7 +824,7 @@ class Client(object): ...@@ -823,7 +824,7 @@ class Client(object):
def _b64_OLD(text: str | bytes) -> str: def _b64_OLD(text: str | bytes) -> str:
if text == None: if text is None:
# easier this way # easier this way
text = '' text = ''
if type(text) is str: if type(text) is str:
...@@ -838,6 +839,7 @@ def _b64_OLD(text: str | bytes) -> str: ...@@ -838,6 +839,7 @@ def _b64_OLD(text: str | bytes) -> str:
else: else:
return text return text
def _b64(text) -> str: def _b64(text) -> str:
if text is None: if text is None:
# easier this way # easier this way
......
...@@ -27,7 +27,10 @@ import signal ...@@ -27,7 +27,10 @@ import signal
import sys import sys
import time import time
import traceback import traceback
import typing
if typing.TYPE_CHECKING:
from nemu import Node
from nemu import compat from nemu import compat
from nemu.environ import eintr_wrapper from nemu.environ import eintr_wrapper
...@@ -46,7 +49,7 @@ class Subprocess(object): ...@@ -46,7 +49,7 @@ class Subprocess(object):
# FIXME # FIXME
default_user = None default_user = None
def __init__(self, node, argv: str | list[str], executable=None, def __init__(self, node: "Node", argv: str | list[str], executable=None,
stdin=None, stdout=None, stderr=None, stdin=None, stdout=None, stderr=None,
shell=False, cwd=None, env=None, user=None): shell=False, cwd=None, env=None, user=None):
self._slave = node._slave self._slave = node._slave
...@@ -78,7 +81,7 @@ class Subprocess(object): ...@@ -78,7 +81,7 @@ class Subprocess(object):
Exceptions occurred while trying to set up the environment or executing Exceptions occurred while trying to set up the environment or executing
the program are propagated to the parent.""" the program are propagated to the parent."""
if user == None: if user is None:
user = Subprocess.default_user user = Subprocess.default_user
if isinstance(argv, str): if isinstance(argv, str):
...@@ -106,20 +109,20 @@ class Subprocess(object): ...@@ -106,20 +109,20 @@ class Subprocess(object):
def poll(self): def poll(self):
"""Checks status of program, returns exitcode or None if still running. """Checks status of program, returns exitcode or None if still running.
See Popen.poll.""" See Popen.poll."""
if self._returncode == None: if self._returncode is None:
self._returncode = self._slave.poll(self._pid) self._returncode = self._slave.poll(self._pid)
return self.returncode return self.returncode
def wait(self): def wait(self):
"""Waits for program to complete and returns the exitcode. """Waits for program to complete and returns the exitcode.
See Popen.wait""" See Popen.wait"""
if self._returncode == None: if self._returncode is None:
self._returncode = self._slave.wait(self._pid) self._returncode = self._slave.wait(self._pid)
return self.returncode return self.returncode
def signal(self, sig=signal.SIGTERM): def signal(self, sig=signal.SIGTERM):
"""Sends a signal to the process.""" """Sends a signal to the process."""
if self._returncode == None: if self._returncode is None:
self._slave.signal(self._pid, sig) self._slave.signal(self._pid, sig)
@property @property
...@@ -128,7 +131,7 @@ class Subprocess(object): ...@@ -128,7 +131,7 @@ class Subprocess(object):
communicate, wait, or poll), returns the signal that killed the communicate, wait, or poll), returns the signal that killed the
program, if negative; otherwise, it is the exit code of the program. program, if negative; otherwise, it is the exit code of the program.
""" """
if self._returncode == None: if self._returncode is None:
return None return None
if os.WIFSIGNALED(self._returncode): if os.WIFSIGNALED(self._returncode):
return -os.WTERMSIG(self._returncode) return -os.WTERMSIG(self._returncode)
...@@ -140,12 +143,12 @@ class Subprocess(object): ...@@ -140,12 +143,12 @@ class Subprocess(object):
self.destroy() self.destroy()
def destroy(self): def destroy(self):
if self._returncode != None or self._pid == None: if self._returncode is not None or self._pid is None:
return return
self.signal() self.signal()
now = time.time() now = time.time()
while time.time() - now < KILL_WAIT: while time.time() - now < KILL_WAIT:
if self.poll() != None: if self.poll() is not None:
return return
time.sleep(0.1) time.sleep(0.1)
sys.stderr.write("WARNING: killing forcefully process %d.\n" % sys.stderr.write("WARNING: killing forcefully process %d.\n" %
...@@ -179,7 +182,7 @@ class Popen(Subprocess): ...@@ -179,7 +182,7 @@ class Popen(Subprocess):
fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr} fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr}
# if PIPE: all should be closed at the end # if PIPE: all should be closed at the end
for k, v in fdmap.items(): for k, v in fdmap.items():
if v == None: if v is None:
continue continue
if v == PIPE: if v == PIPE:
r, w = compat.pipe() r, w = compat.pipe()
...@@ -206,27 +209,29 @@ class Popen(Subprocess): ...@@ -206,27 +209,29 @@ class Popen(Subprocess):
# Close pipes, they have been dup()ed to the child # Close pipes, they have been dup()ed to the child
for k, v in fdmap.items(): for k, v in fdmap.items():
if getattr(self, k) != None: if getattr(self, k) is not None:
eintr_wrapper(os.close, v) eintr_wrapper(os.close, v)
def communicate(self, input: bytes = None) -> tuple[bytes, bytes]: def communicate(self, input: bytes | str = None) -> tuple[bytes, bytes]:
"""See Popen.communicate.""" """See Popen.communicate."""
# FIXME: almost verbatim from stdlib version, need to be removed or # FIXME: almost verbatim from stdlib version, need to be removed or
# something # something
if type(input) is str:
input = input.encode("utf-8")
wset = [] wset = []
rset = [] rset = []
err = None err = None
out = None out = None
if self.stdin != None: if self.stdin is not None:
self.stdin.flush() self.stdin.flush()
if input: if input:
wset.append(self.stdin) wset.append(self.stdin)
else: else:
self.stdin.close() self.stdin.close()
if self.stdout != None: if self.stdout is not None:
rset.append(self.stdout) rset.append(self.stdout)
out = [] out = []
if self.stderr != None: if self.stderr is not None:
rset.append(self.stderr) rset.append(self.stderr)
err = [] err = []
...@@ -253,9 +258,9 @@ class Popen(Subprocess): ...@@ -253,9 +258,9 @@ class Popen(Subprocess):
else: else:
err.append(d) err.append(d)
if out != None: if out is not None:
out = b''.join(out) out = b''.join(out)
if err != None: if err is not None:
err = b''.join(err) err = b''.join(err)
self.wait() self.wait()
return (out, err) return (out, err)
...@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False, ...@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
is not supported here. Also, the original descriptors are not closed. is not supported here. Also, the original descriptors are not closed.
""" """
userfd = [stdin, stdout, stderr] userfd = [stdin, stdout, stderr]
filtered_userfd = [x for x in userfd if x != None and x >= 0] filtered_userfd = [x for x in userfd if x is not None and x >= 0]
for i in range(3): for i in range(3):
if userfd[i] != None and not isinstance(userfd[i], int): if userfd[i] is not None and not isinstance(userfd[i], int):
userfd[i] = userfd[i].fileno() # pragma: no cover userfd[i] = userfd[i].fileno() # pragma: no cover
# Verify there is no clash # Verify there is no clash
assert not (set([0, 1, 2]) & set(filtered_userfd)) assert not (set([0, 1, 2]) & set(filtered_userfd))
if user != None: if user is not None:
user, uid, gid = get_user(user) user, uid, gid = get_user(user)
home = pwd.getpwuid(uid)[5] home = pwd.getpwuid(uid)[5]
groups = [x[2] for x in grp.getgrall() if user in x[3]] groups = [x[2] for x in grp.getgrall() if user in x[3]]
...@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False, ...@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
try: try:
# Set up stdio piping # Set up stdio piping
for i in range(3): for i in range(3):
if userfd[i] != None and userfd[i] >= 0: if userfd[i] is not None and userfd[i] >= 0:
os.dup2(userfd[i], i) os.dup2(userfd[i], i)
if userfd[i] != i and userfd[i] not in userfd[0:i]: if userfd[i] != i and userfd[i] not in userfd[0:i]:
eintr_wrapper(os.close, userfd[i]) # only in child! eintr_wrapper(os.close, userfd[i]) # only in child!
...@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False, ...@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
# (it is necessary to kill the forked subprocesses) # (it is necessary to kill the forked subprocesses)
os.setpgrp() os.setpgrp()
if user != None: if user is not None:
# Change user # Change user
os.setgid(gid) os.setgid(gid)
os.setgroups(groups) os.setgroups(groups)
os.setuid(uid) os.setuid(uid)
if cwd != None: if cwd is not None:
os.chdir(cwd) os.chdir(cwd)
if not argv: if not argv:
argv = [executable] argv = [executable]
if '/' in executable: # Should not search in PATH if '/' in executable: # Should not search in PATH
if env != None: if env is not None:
os.execve(executable, argv, env) os.execve(executable, argv, env)
else: else:
os.execv(executable, argv) os.execv(executable, argv)
else: # use PATH else: # use PATH
if env != None: if env is not None:
os.execvpe(executable, argv, env) os.execvpe(executable, argv, env)
else: else:
os.execvp(executable, argv) os.execvp(executable, argv)
......
...@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase): ...@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase):
os.write(if1.fd, s) os.write(if1.fd, s)
if not s: if not s:
break break
if subproc.poll() != None: if subproc.poll() is not None:
break break
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
......
...@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase): ...@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase):
def check_ok(self, cmd, func, args): def check_ok(self, cmd, func, args):
s1.write("%s\n" % cmd) s1.write("%s\n" % cmd)
ccmd = " ".join(cmd.upper().split()[0:2]) ccmd = " ".join(cmd.upper().split()[0:2])
if func == None: if func is None:
self.assertEqual(srv.readcmd()[1:3], (ccmd, args)) self.assertEqual(srv.readcmd()[1:3], (ccmd, args))
else: else:
self.assertEqual(srv.readcmd(), (func, ccmd, args)) self.assertEqual(srv.readcmd(), (func, ccmd, args))
......
...@@ -15,7 +15,7 @@ def process_ipcmd(str: str): ...@@ -15,7 +15,7 @@ def process_ipcmd(str: str):
match = re.search(r'^(\d+): ([^@\s]+)(?:@\S+)?: <(\S+)> mtu (\d+) ' match = re.search(r'^(\d+): ([^@\s]+)(?:@\S+)?: <(\S+)> mtu (\d+) '
r'qdisc (\S+)', r'qdisc (\S+)',
line) line)
if match != None: if match is not None:
cur = match.group(2) cur = match.group(2)
out[cur] = { out[cur] = {
'idx': int(match.group(1)), 'idx': int(match.group(1)),
...@@ -27,14 +27,14 @@ def process_ipcmd(str: str): ...@@ -27,14 +27,14 @@ def process_ipcmd(str: str):
out[cur]['up'] = 'UP' in out[cur]['flags'] out[cur]['up'] = 'UP' in out[cur]['flags']
continue continue
# Assume cur is defined # Assume cur is defined
assert cur != None assert cur is not None
match = re.search(r'^\s+link/\S*(?: ([0-9a-f:]+))?(?: |$)', line) match = re.search(r'^\s+link/\S*(?: ([0-9a-f:]+))?(?: |$)', line)
if match != None: if match is not None:
out[cur]['lladdr'] = match.group(1) out[cur]['lladdr'] = match.group(1)
continue continue
match = re.search(r'^\s+inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line) match = re.search(r'^\s+inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
if match != None: if match is not None:
out[cur]['addr'].append({ out[cur]['addr'].append({
'address': match.group(1), 'address': match.group(1),
'prefix_len': int(match.group(2)), 'prefix_len': int(match.group(2)),
...@@ -43,7 +43,7 @@ def process_ipcmd(str: str): ...@@ -43,7 +43,7 @@ def process_ipcmd(str: str):
continue continue
match = re.search(r'^\s+inet6 ([0-9a-f:]+)/(\d+)(?: |$)', line) match = re.search(r'^\s+inet6 ([0-9a-f:]+)/(\d+)(?: |$)', line)
if match != None: if match is not None:
out[cur]['addr'].append({ out[cur]['addr'].append({
'address': match.group(1), 'address': match.group(1),
'prefix_len': int(match.group(2)), 'prefix_len': int(match.group(2)),
...@@ -51,7 +51,7 @@ def process_ipcmd(str: str): ...@@ -51,7 +51,7 @@ def process_ipcmd(str: str):
continue continue
match = re.search(r'^\s{4}', line) match = re.search(r'^\s{4}', line)
assert match != None assert match is not None
return out return out
def get_devs(): def get_devs():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment