Commit 033ce168 authored by Tom Niget's avatar Tom Niget

Add type hints

parent 7f17df26
...@@ -140,7 +140,7 @@ def main(): ...@@ -140,7 +140,7 @@ def main():
if not r: if not r:
break break
out += r out += r
if srv.poll() != None or clt.poll() != None: if srv.poll() is not None or clt.poll() is not None:
break break
if srv.poll(): if srv.poll():
......
...@@ -15,6 +15,6 @@ setup( ...@@ -15,6 +15,6 @@ setup(
license = 'GPLv2', license = 'GPLv2',
platforms = 'Linux', platforms = 'Linux',
packages = ['nemu'], packages = ['nemu'],
install_requires = ['unshare', 'six'], install_requires = ['unshare', 'six', 'attrs'],
package_dir = {'': 'src'} package_dir = {'': 'src'}
) )
...@@ -41,7 +41,7 @@ class _Config(object): ...@@ -41,7 +41,7 @@ class _Config(object):
except KeyError: except KeyError:
pass # User not found. pass # User not found.
def _set_run_as(self, user): def _set_run_as(self, user: str | int):
"""Setter for `run_as'.""" """Setter for `run_as'."""
if str(user).isdigit(): if str(user).isdigit():
uid = int(user) uid = int(user)
...@@ -61,7 +61,7 @@ class _Config(object): ...@@ -61,7 +61,7 @@ class _Config(object):
self._run_as = run_as self._run_as = run_as
return run_as return run_as
def _get_run_as(self): def _get_run_as(self) -> str:
"""Setter for `run_as'.""" """Setter for `run_as'."""
return self._run_as return self._run_as
......
...@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]: ...@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]:
os.set_inheritable(b, True) os.set_inheritable(b, True)
return a, b return a, b
def socket(*args, **kwargs) -> pysocket.socket: def socket(*args, **kwargs) -> pysocket.socket:
s = pysocket.socket(*args, **kwargs) s = pysocket.socket(*args, **kwargs)
s.set_inheritable(True) s.set_inheritable(True)
return s return s
def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]: def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]:
a, b = pysocket.socketpair(*args, **kwargs) a, b = pysocket.socketpair(*args, **kwargs)
a.set_inheritable(True) a.set_inheritable(True)
b.set_inheritable(True) b.set_inheritable(True)
return a, b return a, b
def fromfd(*args, **kwargs) -> pysocket.socket: def fromfd(*args, **kwargs) -> pysocket.socket:
s = pysocket.fromfd(*args, **kwargs) s = pysocket.fromfd(*args, **kwargs)
s.set_inheritable(True) s.set_inheritable(True)
return s return s
def fdopen(*args, **kwargs) -> pysocket.socket:
s = os.fdopen(*args, **kwargs)
s.set_inheritable(True)
return s
\ No newline at end of file
...@@ -25,7 +25,7 @@ import subprocess ...@@ -25,7 +25,7 @@ import subprocess
import sys import sys
import syslog import syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
from typing import TypeVar, Callable from typing import TypeVar, Callable, Optional
__all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"] __all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"]
...@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"] ...@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
__all__ += ["error", "warning", "notice", "info", "debug"] __all__ += ["error", "warning", "notice", "info", "debug"]
def find_bin(name, extra_path=None): def find_bin(name: str, extra_path: Optional[list[str]] = None) -> Optional[str]:
"""Try hard to find the location of needed programs.""" """Try hard to find the location of needed programs."""
search = [] search = []
if "PATH" in os.environ: if "PATH" in os.environ:
...@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None): ...@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None):
return None return None
def find_bin_or_die(name, extra_path=None): def find_bin_or_die(name: str, extra_path: Optional[list[str]] = None) -> str:
"""Try hard to find the location of needed programs; raise on failure.""" """Try hard to find the location of needed programs; raise on failure."""
res = find_bin(name, extra_path) res = find_bin(name, extra_path)
if not res: if not res:
...@@ -156,7 +156,7 @@ _log_syslog_opts = () ...@@ -156,7 +156,7 @@ _log_syslog_opts = ()
_log_pid = os.getpid() _log_pid = os.getpid()
def set_log_level(level): def set_log_level(level: int):
"Sets the log level for console messages, does not affect syslog logging." "Sets the log level for console messages, does not affect syslog logging."
global _log_level global _log_level
assert level > LOG_ERR and level <= LOG_DEBUG assert level > LOG_ERR and level <= LOG_DEBUG
...@@ -191,7 +191,7 @@ def _init_log(): ...@@ -191,7 +191,7 @@ def _init_log():
info("Syslog logging started") info("Syslog logging started")
def logger(priority, message): def logger(priority: int, message: str):
"Print a log message in syslog, console or both." "Print a log message in syslog, console or both."
if _log_use_syslog: if _log_use_syslog:
if os.getpid() != _log_pid: if os.getpid() != _log_pid:
......
This diff is collapsed.
This diff is collapsed.
...@@ -21,10 +21,13 @@ import os ...@@ -21,10 +21,13 @@ import os
import socket import socket
import sys import sys
import traceback import traceback
from typing import MutableMapping
import unshare import unshare
import weakref import weakref
import nemu.interface import nemu.interface
import nemu.iproute
import nemu.protocol import nemu.protocol
import nemu.subprocess_ import nemu.subprocess_
from nemu import compat from nemu import compat
...@@ -33,10 +36,11 @@ from nemu.environ import * ...@@ -33,10 +36,11 @@ from nemu.environ import *
__all__ = ['Node', 'get_nodes', 'import_if'] __all__ = ['Node', 'get_nodes', 'import_if']
class Node(object): class Node(object):
_nodes = weakref.WeakValueDictionary() _nodes: MutableMapping[int, "Node"] = weakref.WeakValueDictionary()
_nextnode = 0 _nextnode = 0
_processes: MutableMapping[int, nemu.subprocess_.Subprocess]
@staticmethod @staticmethod
def get_nodes(): def get_nodes() -> list["Node"]:
s = sorted(list(Node._nodes.items()), key = lambda x: x[0]) s = sorted(list(Node._nodes.items()), key = lambda x: x[0])
return [x[1] for x in s] return [x[1] for x in s]
...@@ -98,7 +102,7 @@ class Node(object): ...@@ -98,7 +102,7 @@ class Node(object):
return self._pid return self._pid
# Subprocesses # Subprocesses
def _add_subprocess(self, subprocess): def _add_subprocess(self, subprocess: nemu.subprocess_.Subprocess):
self._processes[subprocess.pid] = subprocess self._processes[subprocess.pid] = subprocess
def Subprocess(self, *kargs, **kwargs): def Subprocess(self, *kargs, **kwargs):
...@@ -188,13 +192,13 @@ class Node(object): ...@@ -188,13 +192,13 @@ class Node(object):
r = self.route(*args, **kwargs) r = self.route(*args, **kwargs)
return self._slave.del_route(r) return self._slave.del_route(r)
def get_routes(self): def get_routes(self) -> list[route]:
return self._slave.get_route_data() return self._slave.get_route_data()
# Handle the creation of the child; parent gets (fd, pid), child creates and # Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns. # runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run. # Requires CAP_SYS_ADMIN privileges to run.
def _start_child(nonetns) -> (socket.socket, int): def _start_child(nonetns: bool) -> (socket.socket, int):
# Create socket pair to communicate # Create socket pair to communicate
(s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop # Spawn a child that will run in a loop
......
...@@ -21,7 +21,7 @@ import struct ...@@ -21,7 +21,7 @@ import struct
from io import IOBase from io import IOBase
def __check_socket(sock: socket.socket | IOBase): def __check_socket(sock: socket.socket | IOBase) -> socket.socket:
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX: if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed") raise ValueError("Only AF_UNIX sockets are allowed")
...@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase): ...@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase):
return sock return sock
def __check_fd(fd): def __check_fd(fd) -> int:
try: try:
fd = fd.fileno() fd = fd.fileno()
except AttributeError: except AttributeError:
...@@ -44,7 +44,7 @@ def __check_fd(fd): ...@@ -44,7 +44,7 @@ def __check_fd(fd):
return fd return fd
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096): def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]:
size = struct.calcsize("@i") size = struct.calcsize("@i")
msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size)) msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size))
cmsg_level, cmsg_type, cmsg_data = ancdata[0] cmsg_level, cmsg_type, cmsg_data = ancdata[0]
...@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096): ...@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
return fd, msg.decode("utf-8") return fd, msg.decode("utf-8")
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE"): def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE") -> int:
return __check_socket(sock).sendmsg( return __check_socket(sock).sendmsg(
[message], [message],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))]) [(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))])
\ No newline at end of file
...@@ -29,6 +29,7 @@ import tempfile ...@@ -29,6 +29,7 @@ import tempfile
import time import time
import traceback import traceback
from pickle import loads, dumps from pickle import loads, dumps
from typing import Literal
import nemu.iproute import nemu.iproute
import nemu.subprocess_ import nemu.subprocess_
...@@ -278,7 +279,7 @@ class Server(object): ...@@ -278,7 +279,7 @@ class Server(object):
self.reply(220, "Hello."); self.reply(220, "Hello.");
while not self._closed: while not self._closed:
cmd = self.readcmd() cmd = self.readcmd()
if cmd == None: if cmd is None:
continue continue
try: try:
cmd[0](cmd[1], *cmd[2]) cmd[0](cmd[1], *cmd[2])
...@@ -422,7 +423,7 @@ class Server(object): ...@@ -422,7 +423,7 @@ class Server(object):
else: else:
ret = nemu.subprocess_.wait(pid) ret = nemu.subprocess_.wait(pid)
if ret != None: if ret is not None:
self._children.remove(pid) self._children.remove(pid)
if pid in self._xauthfiles: if pid in self._xauthfiles:
try: try:
...@@ -449,7 +450,7 @@ class Server(object): ...@@ -449,7 +450,7 @@ class Server(object):
self.reply(200, "Process signalled.") self.reply(200, "Process signalled.")
def do_IF_LIST(self, cmdname, ifnr=None): def do_IF_LIST(self, cmdname, ifnr=None):
if ifnr == None: if ifnr is None:
ifdata = nemu.iproute.get_if_data()[0] ifdata = nemu.iproute.get_if_data()[0]
else: else:
ifdata = nemu.iproute.get_if(ifnr) ifdata = nemu.iproute.get_if(ifnr)
...@@ -479,7 +480,7 @@ class Server(object): ...@@ -479,7 +480,7 @@ class Server(object):
def do_ADDR_LIST(self, cmdname, ifnr=None): def do_ADDR_LIST(self, cmdname, ifnr=None):
addrdata = nemu.iproute.get_addr_data()[0] addrdata = nemu.iproute.get_addr_data()[0]
if ifnr != None: if ifnr is not None:
addrdata = addrdata[ifnr] addrdata = addrdata[ifnr]
self.reply(200, ["# Address data follows.", self.reply(200, ["# Address data follows.",
_b64(dumps(addrdata, protocol=2))]) _b64(dumps(addrdata, protocol=2))])
...@@ -652,7 +653,7 @@ class Client(object): ...@@ -652,7 +653,7 @@ class Client(object):
stdin/stdout/stderr can only be None or a open file descriptor. stdin/stdout/stderr can only be None or a open file descriptor.
See nemu.subprocess_.spawn for details.""" See nemu.subprocess_.spawn for details."""
if executable == None: if executable is None:
executable = argv[0] executable = argv[0]
params = ["PROC", "CRTE", _b64(executable)] params = ["PROC", "CRTE", _b64(executable)]
for i in argv: for i in argv:
...@@ -663,28 +664,28 @@ class Client(object): ...@@ -663,28 +664,28 @@ class Client(object):
# After this, if we get an error, we have to abort the PROC # After this, if we get an error, we have to abort the PROC
try: try:
if user != None: if user is not None:
self._send_cmd("PROC", "USER", _b64(user)) self._send_cmd("PROC", "USER", _b64(user))
self._read_and_check_reply() self._read_and_check_reply()
if cwd != None: if cwd is not None:
self._send_cmd("PROC", "CWD", _b64(cwd)) self._send_cmd("PROC", "CWD", _b64(cwd))
self._read_and_check_reply() self._read_and_check_reply()
if env != None: if env is not None:
params = [] params = []
for k, v in env.items(): for k, v in env.items():
params.extend([_b64(k), _b64(v)]) params.extend([_b64(k), _b64(v)])
self._send_cmd("PROC", "ENV", *params) self._send_cmd("PROC", "ENV", *params)
self._read_and_check_reply() self._read_and_check_reply()
if stdin != None: if stdin is not None:
os.set_inheritable(stdin, True) os.set_inheritable(stdin, True)
self._send_fd("SIN", stdin) self._send_fd("SIN", stdin)
if stdout != None: if stdout is not None:
os.set_inheritable(stdout, True) os.set_inheritable(stdout, True)
self._send_fd("SOUT", stdout) self._send_fd("SOUT", stdout)
if stderr != None: if stderr is not None:
os.set_inheritable(stderr, True) os.set_inheritable(stderr, True)
self._send_fd("SERR", stderr) self._send_fd("SERR", stderr)
except: except:
...@@ -739,7 +740,7 @@ class Client(object): ...@@ -739,7 +740,7 @@ class Client(object):
cmd = ["IF", "SET", interface.index] cmd = ["IF", "SET", interface.index]
for k in interface.changeable_attributes: for k in interface.changeable_attributes:
v = getattr(interface, k) v = getattr(interface, k)
if v != None: if v is not None:
cmd += [k, str(v)] cmd += [k, str(v)]
self._send_cmd(*cmd) self._send_cmd(*cmd)
...@@ -761,7 +762,7 @@ class Client(object): ...@@ -761,7 +762,7 @@ class Client(object):
data = self._read_and_check_reply() data = self._read_and_check_reply()
return loads(_db64(data.partition("\n")[2])) return loads(_db64(data.partition("\n")[2]))
def add_addr(self, ifnr, address): def add_addr(self, ifnr: int, address: nemu.iproute.address):
if hasattr(address, "broadcast") and address.broadcast: if hasattr(address, "broadcast") and address.broadcast:
self._send_cmd("ADDR", "ADD", ifnr, address.address, self._send_cmd("ADDR", "ADD", ifnr, address.address,
address.prefix_len, address.broadcast) address.prefix_len, address.broadcast)
...@@ -770,7 +771,7 @@ class Client(object): ...@@ -770,7 +771,7 @@ class Client(object):
address.prefix_len) address.prefix_len)
self._read_and_check_reply() self._read_and_check_reply()
def del_addr(self, ifnr, address): def del_addr(self, ifnr: int, address: nemu.iproute.address):
self._send_cmd("ADDR", "DEL", ifnr, address.address, address.prefix_len) self._send_cmd("ADDR", "DEL", ifnr, address.address, address.prefix_len)
self._read_and_check_reply() self._read_and_check_reply()
...@@ -785,14 +786,14 @@ class Client(object): ...@@ -785,14 +786,14 @@ class Client(object):
def del_route(self, route): def del_route(self, route):
self._add_del_route("DEL", route) self._add_del_route("DEL", route)
def _add_del_route(self, action, route): def _add_del_route(self, action: Literal["ADD", "DEL"], route: nemu.iproute.route):
args = ["ROUT", action, _b64(route.tipe), _b64(route.prefix), args = ["ROUT", action, _b64(route.tipe), _b64(route.prefix),
route.prefix_len or 0, _b64(route.nexthop), route.prefix_len or 0, _b64(route.nexthop),
route.interface or 0, route.metric or 0] route.interface or 0, route.metric or 0]
self._send_cmd(*args) self._send_cmd(*args)
self._read_and_check_reply() self._read_and_check_reply()
def set_x11(self, protoname, hexkey): def set_x11(self, protoname: str, hexkey: str) -> socket.socket:
# Returns a socket ready to accept() connections # Returns a socket ready to accept() connections
self._send_cmd("X11", "SET", protoname, hexkey) self._send_cmd("X11", "SET", protoname, hexkey)
self._read_and_check_reply() self._read_and_check_reply()
...@@ -823,7 +824,7 @@ class Client(object): ...@@ -823,7 +824,7 @@ class Client(object):
def _b64_OLD(text: str | bytes) -> str: def _b64_OLD(text: str | bytes) -> str:
if text == None: if text is None:
# easier this way # easier this way
text = '' text = ''
if type(text) is str: if type(text) is str:
...@@ -833,11 +834,12 @@ def _b64_OLD(text: str | bytes) -> str: ...@@ -833,11 +834,12 @@ def _b64_OLD(text: str | bytes) -> str:
else: else:
btext = text btext = text
if len(btext) == 0 or any(x for x in btext if x <= ord(" ") or if len(btext) == 0 or any(x for x in btext if x <= ord(" ") or
x > ord("z") or x == ord("=")): x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(btext).decode("ascii") return "=" + base64.b64encode(btext).decode("ascii")
else: else:
return text return text
def _b64(text) -> str: def _b64(text) -> str:
if text is None: if text is None:
# easier this way # easier this way
...@@ -848,7 +850,7 @@ def _b64(text) -> str: ...@@ -848,7 +850,7 @@ def _b64(text) -> str:
else: else:
enc = str(text).encode("utf-8") enc = str(text).encode("utf-8")
if len(enc) == 0 or any(x for x in enc if x <= ord(" ") or if len(enc) == 0 or any(x for x in enc if x <= ord(" ") or
x > ord("z") or x == ord("=")): x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(enc).decode("ascii") return "=" + base64.b64encode(enc).decode("ascii")
else: else:
return enc.decode("utf-8") return enc.decode("utf-8")
......
...@@ -27,7 +27,10 @@ import signal ...@@ -27,7 +27,10 @@ import signal
import sys import sys
import time import time
import traceback import traceback
import typing
if typing.TYPE_CHECKING:
from nemu import Node
from nemu import compat from nemu import compat
from nemu.environ import eintr_wrapper from nemu.environ import eintr_wrapper
...@@ -46,7 +49,7 @@ class Subprocess(object): ...@@ -46,7 +49,7 @@ class Subprocess(object):
# FIXME # FIXME
default_user = None default_user = None
def __init__(self, node, argv: str | list[str], executable=None, def __init__(self, node: "Node", argv: str | list[str], executable=None,
stdin=None, stdout=None, stderr=None, stdin=None, stdout=None, stderr=None,
shell=False, cwd=None, env=None, user=None): shell=False, cwd=None, env=None, user=None):
self._slave = node._slave self._slave = node._slave
...@@ -78,7 +81,7 @@ class Subprocess(object): ...@@ -78,7 +81,7 @@ class Subprocess(object):
Exceptions occurred while trying to set up the environment or executing Exceptions occurred while trying to set up the environment or executing
the program are propagated to the parent.""" the program are propagated to the parent."""
if user == None: if user is None:
user = Subprocess.default_user user = Subprocess.default_user
if isinstance(argv, str): if isinstance(argv, str):
...@@ -106,20 +109,20 @@ class Subprocess(object): ...@@ -106,20 +109,20 @@ class Subprocess(object):
def poll(self): def poll(self):
"""Checks status of program, returns exitcode or None if still running. """Checks status of program, returns exitcode or None if still running.
See Popen.poll.""" See Popen.poll."""
if self._returncode == None: if self._returncode is None:
self._returncode = self._slave.poll(self._pid) self._returncode = self._slave.poll(self._pid)
return self.returncode return self.returncode
def wait(self): def wait(self):
"""Waits for program to complete and returns the exitcode. """Waits for program to complete and returns the exitcode.
See Popen.wait""" See Popen.wait"""
if self._returncode == None: if self._returncode is None:
self._returncode = self._slave.wait(self._pid) self._returncode = self._slave.wait(self._pid)
return self.returncode return self.returncode
def signal(self, sig=signal.SIGTERM): def signal(self, sig=signal.SIGTERM):
"""Sends a signal to the process.""" """Sends a signal to the process."""
if self._returncode == None: if self._returncode is None:
self._slave.signal(self._pid, sig) self._slave.signal(self._pid, sig)
@property @property
...@@ -128,7 +131,7 @@ class Subprocess(object): ...@@ -128,7 +131,7 @@ class Subprocess(object):
communicate, wait, or poll), returns the signal that killed the communicate, wait, or poll), returns the signal that killed the
program, if negative; otherwise, it is the exit code of the program. program, if negative; otherwise, it is the exit code of the program.
""" """
if self._returncode == None: if self._returncode is None:
return None return None
if os.WIFSIGNALED(self._returncode): if os.WIFSIGNALED(self._returncode):
return -os.WTERMSIG(self._returncode) return -os.WTERMSIG(self._returncode)
...@@ -140,12 +143,12 @@ class Subprocess(object): ...@@ -140,12 +143,12 @@ class Subprocess(object):
self.destroy() self.destroy()
def destroy(self): def destroy(self):
if self._returncode != None or self._pid == None: if self._returncode is not None or self._pid is None:
return return
self.signal() self.signal()
now = time.time() now = time.time()
while time.time() - now < KILL_WAIT: while time.time() - now < KILL_WAIT:
if self.poll() != None: if self.poll() is not None:
return return
time.sleep(0.1) time.sleep(0.1)
sys.stderr.write("WARNING: killing forcefully process %d.\n" % sys.stderr.write("WARNING: killing forcefully process %d.\n" %
...@@ -179,7 +182,7 @@ class Popen(Subprocess): ...@@ -179,7 +182,7 @@ class Popen(Subprocess):
fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr} fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr}
# if PIPE: all should be closed at the end # if PIPE: all should be closed at the end
for k, v in fdmap.items(): for k, v in fdmap.items():
if v == None: if v is None:
continue continue
if v == PIPE: if v == PIPE:
r, w = compat.pipe() r, w = compat.pipe()
...@@ -206,27 +209,29 @@ class Popen(Subprocess): ...@@ -206,27 +209,29 @@ class Popen(Subprocess):
# Close pipes, they have been dup()ed to the child # Close pipes, they have been dup()ed to the child
for k, v in fdmap.items(): for k, v in fdmap.items():
if getattr(self, k) != None: if getattr(self, k) is not None:
eintr_wrapper(os.close, v) eintr_wrapper(os.close, v)
def communicate(self, input: bytes = None) -> tuple[bytes, bytes]: def communicate(self, input: bytes | str = None) -> tuple[bytes, bytes]:
"""See Popen.communicate.""" """See Popen.communicate."""
# FIXME: almost verbatim from stdlib version, need to be removed or # FIXME: almost verbatim from stdlib version, need to be removed or
# something # something
if type(input) is str:
input = input.encode("utf-8")
wset = [] wset = []
rset = [] rset = []
err = None err = None
out = None out = None
if self.stdin != None: if self.stdin is not None:
self.stdin.flush() self.stdin.flush()
if input: if input:
wset.append(self.stdin) wset.append(self.stdin)
else: else:
self.stdin.close() self.stdin.close()
if self.stdout != None: if self.stdout is not None:
rset.append(self.stdout) rset.append(self.stdout)
out = [] out = []
if self.stderr != None: if self.stderr is not None:
rset.append(self.stderr) rset.append(self.stderr)
err = [] err = []
...@@ -253,9 +258,9 @@ class Popen(Subprocess): ...@@ -253,9 +258,9 @@ class Popen(Subprocess):
else: else:
err.append(d) err.append(d)
if out != None: if out is not None:
out = b''.join(out) out = b''.join(out)
if err != None: if err is not None:
err = b''.join(err) err = b''.join(err)
self.wait() self.wait()
return (out, err) return (out, err)
...@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False, ...@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
is not supported here. Also, the original descriptors are not closed. is not supported here. Also, the original descriptors are not closed.
""" """
userfd = [stdin, stdout, stderr] userfd = [stdin, stdout, stderr]
filtered_userfd = [x for x in userfd if x != None and x >= 0] filtered_userfd = [x for x in userfd if x is not None and x >= 0]
for i in range(3): for i in range(3):
if userfd[i] != None and not isinstance(userfd[i], int): if userfd[i] is not None and not isinstance(userfd[i], int):
userfd[i] = userfd[i].fileno() # pragma: no cover userfd[i] = userfd[i].fileno() # pragma: no cover
# Verify there is no clash # Verify there is no clash
assert not (set([0, 1, 2]) & set(filtered_userfd)) assert not (set([0, 1, 2]) & set(filtered_userfd))
if user != None: if user is not None:
user, uid, gid = get_user(user) user, uid, gid = get_user(user)
home = pwd.getpwuid(uid)[5] home = pwd.getpwuid(uid)[5]
groups = [x[2] for x in grp.getgrall() if user in x[3]] groups = [x[2] for x in grp.getgrall() if user in x[3]]
...@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False, ...@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
try: try:
# Set up stdio piping # Set up stdio piping
for i in range(3): for i in range(3):
if userfd[i] != None and userfd[i] >= 0: if userfd[i] is not None and userfd[i] >= 0:
os.dup2(userfd[i], i) os.dup2(userfd[i], i)
if userfd[i] != i and userfd[i] not in userfd[0:i]: if userfd[i] != i and userfd[i] not in userfd[0:i]:
eintr_wrapper(os.close, userfd[i]) # only in child! eintr_wrapper(os.close, userfd[i]) # only in child!
...@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False, ...@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
# (it is necessary to kill the forked subprocesses) # (it is necessary to kill the forked subprocesses)
os.setpgrp() os.setpgrp()
if user != None: if user is not None:
# Change user # Change user
os.setgid(gid) os.setgid(gid)
os.setgroups(groups) os.setgroups(groups)
os.setuid(uid) os.setuid(uid)
if cwd != None: if cwd is not None:
os.chdir(cwd) os.chdir(cwd)
if not argv: if not argv:
argv = [executable] argv = [executable]
if '/' in executable: # Should not search in PATH if '/' in executable: # Should not search in PATH
if env != None: if env is not None:
os.execve(executable, argv, env) os.execve(executable, argv, env)
else: else:
os.execv(executable, argv) os.execv(executable, argv)
else: # use PATH else: # use PATH
if env != None: if env is not None:
os.execvpe(executable, argv, env) os.execvpe(executable, argv, env)
else: else:
os.execvp(executable, argv) os.execvp(executable, argv)
......
...@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase): ...@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase):
os.write(if1.fd, s) os.write(if1.fd, s)
if not s: if not s:
break break
if subproc.poll() != None: if subproc.poll() is not None:
break break
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
......
...@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase): ...@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase):
def check_ok(self, cmd, func, args): def check_ok(self, cmd, func, args):
s1.write("%s\n" % cmd) s1.write("%s\n" % cmd)
ccmd = " ".join(cmd.upper().split()[0:2]) ccmd = " ".join(cmd.upper().split()[0:2])
if func == None: if func is None:
self.assertEqual(srv.readcmd()[1:3], (ccmd, args)) self.assertEqual(srv.readcmd()[1:3], (ccmd, args))
else: else:
self.assertEqual(srv.readcmd(), (func, ccmd, args)) self.assertEqual(srv.readcmd(), (func, ccmd, args))
......
...@@ -15,7 +15,7 @@ def process_ipcmd(str: str): ...@@ -15,7 +15,7 @@ def process_ipcmd(str: str):
match = re.search(r'^(\d+): ([^@\s]+)(?:@\S+)?: <(\S+)> mtu (\d+) ' match = re.search(r'^(\d+): ([^@\s]+)(?:@\S+)?: <(\S+)> mtu (\d+) '
r'qdisc (\S+)', r'qdisc (\S+)',
line) line)
if match != None: if match is not None:
cur = match.group(2) cur = match.group(2)
out[cur] = { out[cur] = {
'idx': int(match.group(1)), 'idx': int(match.group(1)),
...@@ -27,14 +27,14 @@ def process_ipcmd(str: str): ...@@ -27,14 +27,14 @@ def process_ipcmd(str: str):
out[cur]['up'] = 'UP' in out[cur]['flags'] out[cur]['up'] = 'UP' in out[cur]['flags']
continue continue
# Assume cur is defined # Assume cur is defined
assert cur != None assert cur is not None
match = re.search(r'^\s+link/\S*(?: ([0-9a-f:]+))?(?: |$)', line) match = re.search(r'^\s+link/\S*(?: ([0-9a-f:]+))?(?: |$)', line)
if match != None: if match is not None:
out[cur]['lladdr'] = match.group(1) out[cur]['lladdr'] = match.group(1)
continue continue
match = re.search(r'^\s+inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line) match = re.search(r'^\s+inet ([0-9.]+)/(\d+)(?: brd ([0-9.]+))?', line)
if match != None: if match is not None:
out[cur]['addr'].append({ out[cur]['addr'].append({
'address': match.group(1), 'address': match.group(1),
'prefix_len': int(match.group(2)), 'prefix_len': int(match.group(2)),
...@@ -43,7 +43,7 @@ def process_ipcmd(str: str): ...@@ -43,7 +43,7 @@ def process_ipcmd(str: str):
continue continue
match = re.search(r'^\s+inet6 ([0-9a-f:]+)/(\d+)(?: |$)', line) match = re.search(r'^\s+inet6 ([0-9a-f:]+)/(\d+)(?: |$)', line)
if match != None: if match is not None:
out[cur]['addr'].append({ out[cur]['addr'].append({
'address': match.group(1), 'address': match.group(1),
'prefix_len': int(match.group(2)), 'prefix_len': int(match.group(2)),
...@@ -51,7 +51,7 @@ def process_ipcmd(str: str): ...@@ -51,7 +51,7 @@ def process_ipcmd(str: str):
continue continue
match = re.search(r'^\s{4}', line) match = re.search(r'^\s{4}', line)
assert match != None assert match is not None
return out return out
def get_devs(): def get_devs():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment