Commit 033ce168 authored by Tom Niget's avatar Tom Niget

Add type hints

parent 7f17df26
......@@ -140,7 +140,7 @@ def main():
if not r:
break
out += r
if srv.poll() != None or clt.poll() != None:
if srv.poll() is not None or clt.poll() is not None:
break
if srv.poll():
......
......@@ -15,6 +15,6 @@ setup(
license = 'GPLv2',
platforms = 'Linux',
packages = ['nemu'],
install_requires = ['unshare', 'six'],
install_requires = ['unshare', 'six', 'attrs'],
package_dir = {'': 'src'}
)
......@@ -41,7 +41,7 @@ class _Config(object):
except KeyError:
pass # User not found.
def _set_run_as(self, user):
def _set_run_as(self, user: str | int):
"""Setter for `run_as'."""
if str(user).isdigit():
uid = int(user)
......@@ -61,7 +61,7 @@ class _Config(object):
self._run_as = run_as
return run_as
def _get_run_as(self):
def _get_run_as(self) -> str:
"""Setter for `run_as'."""
return self._run_as
......
......@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]:
os.set_inheritable(b, True)
return a, b
def socket(*args, **kwargs) -> pysocket.socket:
s = pysocket.socket(*args, **kwargs)
s.set_inheritable(True)
return s
def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]:
a, b = pysocket.socketpair(*args, **kwargs)
a.set_inheritable(True)
b.set_inheritable(True)
return a, b
def fromfd(*args, **kwargs) -> pysocket.socket:
s = pysocket.fromfd(*args, **kwargs)
s.set_inheritable(True)
return s
def fdopen(*args, **kwargs) -> pysocket.socket:
s = os.fdopen(*args, **kwargs)
s.set_inheritable(True)
return s
\ No newline at end of file
......@@ -25,7 +25,7 @@ import subprocess
import sys
import syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
from typing import TypeVar, Callable
from typing import TypeVar, Callable, Optional
__all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"]
......@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
__all__ += ["error", "warning", "notice", "info", "debug"]
def find_bin(name, extra_path=None):
def find_bin(name: str, extra_path: Optional[list[str]] = None) -> Optional[str]:
"""Try hard to find the location of needed programs."""
search = []
if "PATH" in os.environ:
......@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None):
return None
def find_bin_or_die(name, extra_path=None):
def find_bin_or_die(name: str, extra_path: Optional[list[str]] = None) -> str:
"""Try hard to find the location of needed programs; raise on failure."""
res = find_bin(name, extra_path)
if not res:
......@@ -156,7 +156,7 @@ _log_syslog_opts = ()
_log_pid = os.getpid()
def set_log_level(level):
def set_log_level(level: int):
"Sets the log level for console messages, does not affect syslog logging."
global _log_level
assert level > LOG_ERR and level <= LOG_DEBUG
......@@ -191,7 +191,7 @@ def _init_log():
info("Syslog logging started")
def logger(priority, message):
def logger(priority: int, message: str):
"Print a log message in syslog, console or both."
if _log_use_syslog:
if os.getpid() != _log_pid:
......
This diff is collapsed.
This diff is collapsed.
......@@ -21,10 +21,13 @@ import os
import socket
import sys
import traceback
from typing import MutableMapping
import unshare
import weakref
import nemu.interface
import nemu.iproute
import nemu.protocol
import nemu.subprocess_
from nemu import compat
......@@ -33,10 +36,11 @@ from nemu.environ import *
__all__ = ['Node', 'get_nodes', 'import_if']
class Node(object):
_nodes = weakref.WeakValueDictionary()
_nodes: MutableMapping[int, "Node"] = weakref.WeakValueDictionary()
_nextnode = 0
_processes: MutableMapping[int, nemu.subprocess_.Subprocess]
@staticmethod
def get_nodes():
def get_nodes() -> list["Node"]:
s = sorted(list(Node._nodes.items()), key = lambda x: x[0])
return [x[1] for x in s]
......@@ -98,7 +102,7 @@ class Node(object):
return self._pid
# Subprocesses
def _add_subprocess(self, subprocess):
def _add_subprocess(self, subprocess: nemu.subprocess_.Subprocess):
self._processes[subprocess.pid] = subprocess
def Subprocess(self, *kargs, **kwargs):
......@@ -188,13 +192,13 @@ class Node(object):
r = self.route(*args, **kwargs)
return self._slave.del_route(r)
def get_routes(self):
def get_routes(self) -> list[route]:
return self._slave.get_route_data()
# Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run.
def _start_child(nonetns) -> (socket.socket, int):
def _start_child(nonetns: bool) -> (socket.socket, int):
# Create socket pair to communicate
(s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop
......
......@@ -21,7 +21,7 @@ import struct
from io import IOBase
def __check_socket(sock: socket.socket | IOBase):
def __check_socket(sock: socket.socket | IOBase) -> socket.socket:
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed")
......@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase):
return sock
def __check_fd(fd):
def __check_fd(fd) -> int:
try:
fd = fd.fileno()
except AttributeError:
......@@ -44,7 +44,7 @@ def __check_fd(fd):
return fd
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]:
size = struct.calcsize("@i")
msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size))
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
......@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
return fd, msg.decode("utf-8")
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE"):
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE") -> int:
return __check_socket(sock).sendmsg(
[message],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))])
\ No newline at end of file
......@@ -29,6 +29,7 @@ import tempfile
import time
import traceback
from pickle import loads, dumps
from typing import Literal
import nemu.iproute
import nemu.subprocess_
......@@ -278,7 +279,7 @@ class Server(object):
self.reply(220, "Hello.");
while not self._closed:
cmd = self.readcmd()
if cmd == None:
if cmd is None:
continue
try:
cmd[0](cmd[1], *cmd[2])
......@@ -422,7 +423,7 @@ class Server(object):
else:
ret = nemu.subprocess_.wait(pid)
if ret != None:
if ret is not None:
self._children.remove(pid)
if pid in self._xauthfiles:
try:
......@@ -449,7 +450,7 @@ class Server(object):
self.reply(200, "Process signalled.")
def do_IF_LIST(self, cmdname, ifnr=None):
if ifnr == None:
if ifnr is None:
ifdata = nemu.iproute.get_if_data()[0]
else:
ifdata = nemu.iproute.get_if(ifnr)
......@@ -479,7 +480,7 @@ class Server(object):
def do_ADDR_LIST(self, cmdname, ifnr=None):
addrdata = nemu.iproute.get_addr_data()[0]
if ifnr != None:
if ifnr is not None:
addrdata = addrdata[ifnr]
self.reply(200, ["# Address data follows.",
_b64(dumps(addrdata, protocol=2))])
......@@ -652,7 +653,7 @@ class Client(object):
stdin/stdout/stderr can only be None or a open file descriptor.
See nemu.subprocess_.spawn for details."""
if executable == None:
if executable is None:
executable = argv[0]
params = ["PROC", "CRTE", _b64(executable)]
for i in argv:
......@@ -663,28 +664,28 @@ class Client(object):
# After this, if we get an error, we have to abort the PROC
try:
if user != None:
if user is not None:
self._send_cmd("PROC", "USER", _b64(user))
self._read_and_check_reply()
if cwd != None:
if cwd is not None:
self._send_cmd("PROC", "CWD", _b64(cwd))
self._read_and_check_reply()
if env != None:
if env is not None:
params = []
for k, v in env.items():
params.extend([_b64(k), _b64(v)])
self._send_cmd("PROC", "ENV", *params)
self._read_and_check_reply()
if stdin != None:
if stdin is not None:
os.set_inheritable(stdin, True)
self._send_fd("SIN", stdin)
if stdout != None:
if stdout is not None:
os.set_inheritable(stdout, True)
self._send_fd("SOUT", stdout)
if stderr != None:
if stderr is not None:
os.set_inheritable(stderr, True)
self._send_fd("SERR", stderr)
except:
......@@ -739,7 +740,7 @@ class Client(object):
cmd = ["IF", "SET", interface.index]
for k in interface.changeable_attributes:
v = getattr(interface, k)
if v != None:
if v is not None:
cmd += [k, str(v)]
self._send_cmd(*cmd)
......@@ -761,7 +762,7 @@ class Client(object):
data = self._read_and_check_reply()
return loads(_db64(data.partition("\n")[2]))
def add_addr(self, ifnr, address):
def add_addr(self, ifnr: int, address: nemu.iproute.address):
if hasattr(address, "broadcast") and address.broadcast:
self._send_cmd("ADDR", "ADD", ifnr, address.address,
address.prefix_len, address.broadcast)
......@@ -770,7 +771,7 @@ class Client(object):
address.prefix_len)
self._read_and_check_reply()
def del_addr(self, ifnr, address):
def del_addr(self, ifnr: int, address: nemu.iproute.address):
self._send_cmd("ADDR", "DEL", ifnr, address.address, address.prefix_len)
self._read_and_check_reply()
......@@ -785,14 +786,14 @@ class Client(object):
def del_route(self, route):
self._add_del_route("DEL", route)
def _add_del_route(self, action, route):
def _add_del_route(self, action: Literal["ADD", "DEL"], route: nemu.iproute.route):
args = ["ROUT", action, _b64(route.tipe), _b64(route.prefix),
route.prefix_len or 0, _b64(route.nexthop),
route.interface or 0, route.metric or 0]
self._send_cmd(*args)
self._read_and_check_reply()
def set_x11(self, protoname, hexkey):
def set_x11(self, protoname: str, hexkey: str) -> socket.socket:
# Returns a socket ready to accept() connections
self._send_cmd("X11", "SET", protoname, hexkey)
self._read_and_check_reply()
......@@ -823,7 +824,7 @@ class Client(object):
def _b64_OLD(text: str | bytes) -> str:
if text == None:
if text is None:
# easier this way
text = ''
if type(text) is str:
......@@ -833,11 +834,12 @@ def _b64_OLD(text: str | bytes) -> str:
else:
btext = text
if len(btext) == 0 or any(x for x in btext if x <= ord(" ") or
x > ord("z") or x == ord("=")):
x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(btext).decode("ascii")
else:
return text
def _b64(text) -> str:
if text is None:
# easier this way
......@@ -848,7 +850,7 @@ def _b64(text) -> str:
else:
enc = str(text).encode("utf-8")
if len(enc) == 0 or any(x for x in enc if x <= ord(" ") or
x > ord("z") or x == ord("=")):
x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(enc).decode("ascii")
else:
return enc.decode("utf-8")
......
......@@ -27,7 +27,10 @@ import signal
import sys
import time
import traceback
import typing
if typing.TYPE_CHECKING:
from nemu import Node
from nemu import compat
from nemu.environ import eintr_wrapper
......@@ -46,7 +49,7 @@ class Subprocess(object):
# FIXME
default_user = None
def __init__(self, node, argv: str | list[str], executable=None,
def __init__(self, node: "Node", argv: str | list[str], executable=None,
stdin=None, stdout=None, stderr=None,
shell=False, cwd=None, env=None, user=None):
self._slave = node._slave
......@@ -78,7 +81,7 @@ class Subprocess(object):
Exceptions occurred while trying to set up the environment or executing
the program are propagated to the parent."""
if user == None:
if user is None:
user = Subprocess.default_user
if isinstance(argv, str):
......@@ -106,20 +109,20 @@ class Subprocess(object):
def poll(self):
"""Checks status of program, returns exitcode or None if still running.
See Popen.poll."""
if self._returncode == None:
if self._returncode is None:
self._returncode = self._slave.poll(self._pid)
return self.returncode
def wait(self):
"""Waits for program to complete and returns the exitcode.
See Popen.wait"""
if self._returncode == None:
if self._returncode is None:
self._returncode = self._slave.wait(self._pid)
return self.returncode
def signal(self, sig=signal.SIGTERM):
"""Sends a signal to the process."""
if self._returncode == None:
if self._returncode is None:
self._slave.signal(self._pid, sig)
@property
......@@ -128,7 +131,7 @@ class Subprocess(object):
communicate, wait, or poll), returns the signal that killed the
program, if negative; otherwise, it is the exit code of the program.
"""
if self._returncode == None:
if self._returncode is None:
return None
if os.WIFSIGNALED(self._returncode):
return -os.WTERMSIG(self._returncode)
......@@ -140,12 +143,12 @@ class Subprocess(object):
self.destroy()
def destroy(self):
if self._returncode != None or self._pid == None:
if self._returncode is not None or self._pid is None:
return
self.signal()
now = time.time()
while time.time() - now < KILL_WAIT:
if self.poll() != None:
if self.poll() is not None:
return
time.sleep(0.1)
sys.stderr.write("WARNING: killing forcefully process %d.\n" %
......@@ -179,7 +182,7 @@ class Popen(Subprocess):
fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr}
# if PIPE: all should be closed at the end
for k, v in fdmap.items():
if v == None:
if v is None:
continue
if v == PIPE:
r, w = compat.pipe()
......@@ -206,27 +209,29 @@ class Popen(Subprocess):
# Close pipes, they have been dup()ed to the child
for k, v in fdmap.items():
if getattr(self, k) != None:
if getattr(self, k) is not None:
eintr_wrapper(os.close, v)
def communicate(self, input: bytes = None) -> tuple[bytes, bytes]:
def communicate(self, input: bytes | str = None) -> tuple[bytes, bytes]:
"""See Popen.communicate."""
# FIXME: almost verbatim from stdlib version, need to be removed or
# something
if type(input) is str:
input = input.encode("utf-8")
wset = []
rset = []
err = None
out = None
if self.stdin != None:
if self.stdin is not None:
self.stdin.flush()
if input:
wset.append(self.stdin)
else:
self.stdin.close()
if self.stdout != None:
if self.stdout is not None:
rset.append(self.stdout)
out = []
if self.stderr != None:
if self.stderr is not None:
rset.append(self.stderr)
err = []
......@@ -253,9 +258,9 @@ class Popen(Subprocess):
else:
err.append(d)
if out != None:
if out is not None:
out = b''.join(out)
if err != None:
if err is not None:
err = b''.join(err)
self.wait()
return (out, err)
......@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
is not supported here. Also, the original descriptors are not closed.
"""
userfd = [stdin, stdout, stderr]
filtered_userfd = [x for x in userfd if x != None and x >= 0]
filtered_userfd = [x for x in userfd if x is not None and x >= 0]
for i in range(3):
if userfd[i] != None and not isinstance(userfd[i], int):
if userfd[i] is not None and not isinstance(userfd[i], int):
userfd[i] = userfd[i].fileno() # pragma: no cover
# Verify there is no clash
assert not (set([0, 1, 2]) & set(filtered_userfd))
if user != None:
if user is not None:
user, uid, gid = get_user(user)
home = pwd.getpwuid(uid)[5]
groups = [x[2] for x in grp.getgrall() if user in x[3]]
......@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
try:
# Set up stdio piping
for i in range(3):
if userfd[i] != None and userfd[i] >= 0:
if userfd[i] is not None and userfd[i] >= 0:
os.dup2(userfd[i], i)
if userfd[i] != i and userfd[i] not in userfd[0:i]:
eintr_wrapper(os.close, userfd[i]) # only in child!
......@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
# (it is necessary to kill the forked subprocesses)
os.setpgrp()
if user != None:
if user is not None:
# Change user
os.setgid(gid)
os.setgroups(groups)
os.setuid(uid)
if cwd != None:
if cwd is not None:
os.chdir(cwd)
if not argv:
argv = [executable]
if '/' in executable: # Should not search in PATH
if env != None:
if env is not None:
os.execve(executable, argv, env)
else:
os.execv(executable, argv)
else: # use PATH
if env != None:
if env is not None:
os.execvpe(executable, argv, env)
else:
os.execvp(executable, argv)
......
......@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase):
os.write(if1.fd, s)
if not s:
break
if subproc.poll() != None:
if subproc.poll() is not None:
break
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
......
......@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase):
def check_ok(self, cmd, func, args):
s1.write("%s\n" % cmd)
ccmd = " ".join(cmd.upper().split()[0:2])
if func == None:
if func is None:
self.assertEqual(srv.readcmd()[1:3], (ccmd, args))
else:
self.assertEqual(srv.readcmd(), (func, ccmd, args))
......
......@@ -15,7 +15,7 @@ def process_ipcmd(str: str):
match = re.search(r'^(\d+): ([^@\s]+)(?:@\S+)?: <(\S+)> mtu (\d+) '
r'qdisc (\S+)',
line)
if match != None:
if match is not None:
cur = match.group(2)
out[cur] = {
'idx': int(match.group(1)),
......@@ -27,14 +27,14 @@ def process_ipcmd(str: str):
out[cur]['up'] = 'UP' in out[cur]['flags']
continue
# Assume cur is defined
assert cur != None
assert cur is not None
match = re.search(r'^\s+link/\S*(?: ([0-9a-f:]+))?(?: |$)', line)
if match != None:
if match is not None:
out[cur]['lladdr'] = match.group(1)
continue
match = re.search(r'^\s+inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
if match != None:
if match is not None:
out[cur]['addr'].append({
'address': match.group(1),
'prefix_len': int(match.group(2)),
......@@ -43,7 +43,7 @@ def process_ipcmd(str: str):
continue
match = re.search(r'^\s+inet6 ([0-9a-f:]+)/(\d+)(?: |$)', line)
if match != None:
if match is not None:
out[cur]['addr'].append({
'address': match.group(1),
'prefix_len': int(match.group(2)),
......@@ -51,7 +51,7 @@ def process_ipcmd(str: str):
continue
match = re.search(r'^\s{4}', line)
assert match != None
assert match is not None
return out
def get_devs():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment