Commit ba619ec3 authored by Tom Niget's avatar Tom Niget

nemu works in python 3

parent f8914e80
import os
import socket as pysocket
def pipe() -> tuple[int, int]:
a, b = os.pipe2(0)
os.set_inheritable(a, True)
os.set_inheritable(b, True)
return a, b
def socket(*args, **kwargs) -> pysocket.socket:
s = pysocket.socket(*args, **kwargs)
s.set_inheritable(True)
return s
def socketpair(*args, **kwargs) -> tuple[pysocket.socket, pysocket.socket]:
a, b = pysocket.socketpair(*args, **kwargs)
a.set_inheritable(True)
b.set_inheritable(True)
return a, b
def fromfd(*args, **kwargs) -> pysocket.socket:
s = pysocket.fromfd(*args, **kwargs)
s.set_inheritable(True)
return s
def fdopen(*args, **kwargs) -> pysocket.socket:
s = os.fdopen(*args, **kwargs)
s.set_inheritable(True)
return s
\ No newline at end of file
...@@ -25,8 +25,12 @@ import subprocess ...@@ -25,8 +25,12 @@ import subprocess
import sys import sys
import syslog import syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
from typing import TypeVar, Callable
__all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"] __all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"]
from nemu import compat
__all__ += ["TCPDUMP_PATH", "NETPERF_PATH", "XAUTH_PATH", "XDPYINFO_PATH"] __all__ += ["TCPDUMP_PATH", "NETPERF_PATH", "XAUTH_PATH", "XDPYINFO_PATH"]
__all__ += ["execute", "backticks", "eintr_wrapper"] __all__ += ["execute", "backticks", "eintr_wrapper"]
__all__ += ["find_listen_port"] __all__ += ["find_listen_port"]
...@@ -35,7 +39,7 @@ __all__ += ["set_log_level", "logger"] ...@@ -35,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
__all__ += ["error", "warning", "notice", "info", "debug"] __all__ += ["error", "warning", "notice", "info", "debug"]
def find_bin(name, extra_path = None): def find_bin(name, extra_path=None):
"""Try hard to find the location of needed programs.""" """Try hard to find the location of needed programs."""
search = [] search = []
if "PATH" in os.environ: if "PATH" in os.environ:
...@@ -52,13 +56,15 @@ def find_bin(name, extra_path = None): ...@@ -52,13 +56,15 @@ def find_bin(name, extra_path = None):
return path return path
return None return None
def find_bin_or_die(name, extra_path = None):
def find_bin_or_die(name, extra_path=None):
"""Try hard to find the location of needed programs; raise on failure.""" """Try hard to find the location of needed programs; raise on failure."""
res = find_bin(name, extra_path) res = find_bin(name, extra_path)
if not res: if not res:
raise RuntimeError("Cannot find `%s', impossible to continue." % name) raise RuntimeError("Cannot find `%s', impossible to continue." % name)
return res return res
IP_PATH = find_bin_or_die("ip") IP_PATH = find_bin_or_die("ip")
TC_PATH = find_bin_or_die("tc") TC_PATH = find_bin_or_die("tc")
BRCTL_PATH = find_bin_or_die("brctl") BRCTL_PATH = find_bin_or_die("brctl")
...@@ -80,19 +86,21 @@ except: ...@@ -80,19 +86,21 @@ except:
raise RuntimeError("Sysfs does not seem to be mounted, impossible to " + raise RuntimeError("Sysfs does not seem to be mounted, impossible to " +
"continue.") "continue.")
def execute(cmd):
def execute(cmd: list[str]):
"""Execute a command, if the return value is non-zero, raise an exception. """Execute a command, if the return value is non-zero, raise an exception.
Raises: Raises:
RuntimeError: the command was unsuccessful (return code != 0). RuntimeError: the command was unsuccessful (return code != 0).
""" """
debug("execute(%s)" % cmd) debug("execute(%s)" % cmd)
proc = subprocess.Popen(cmd, stdout = subprocess.DEVNULL, stderr = subprocess.PIPE) proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
_, err = proc.communicate() _, err = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err)) raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
def backticks(cmd):
def backticks(cmd: list[str]) -> str:
"""Execute a command and capture its output. """Execute a command and capture its output.
If the return value is non-zero, raise an exception. If the return value is non-zero, raise an exception.
...@@ -102,14 +110,18 @@ def backticks(cmd): ...@@ -102,14 +110,18 @@ def backticks(cmd):
RuntimeError: the command was unsuccessful (return code != 0). RuntimeError: the command was unsuccessful (return code != 0).
""" """
debug("backticks(%s)" % cmd) debug("backticks(%s)" % cmd)
proc = subprocess.Popen(cmd, stdout = subprocess.PIPE, proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr = subprocess.PIPE) stderr=subprocess.PIPE)
out, err = proc.communicate() out, err = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err)) raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
return out.decode("utf-8") return out.decode("utf-8")
def eintr_wrapper(func, *args):
T = TypeVar("T")
def eintr_wrapper(func: Callable[..., T], *args) -> T:
"Wraps some callable with a loop that retries on EINTR." "Wraps some callable with a loop that retries on EINTR."
while True: while True:
try: try:
...@@ -123,9 +135,10 @@ def eintr_wrapper(func, *args): ...@@ -123,9 +135,10 @@ def eintr_wrapper(func, *args):
continue continue
raise raise
def find_listen_port(family = socket.AF_INET, type = socket.SOCK_STREAM,
proto = 0, addr = "127.0.0.1", min_port = 1, max_port = 65535): def find_listen_port(family=socket.AF_INET, type=socket.SOCK_STREAM,
sock = socket.socket(family, type, proto) proto=0, addr="127.0.0.1", min_port=1, max_port=65535):
sock = compat.socket(family, type, proto)
for port in range(min_port, max_port + 1): for port in range(min_port, max_port + 1):
try: try:
sock.bind((addr, port)) sock.bind((addr, port))
...@@ -134,44 +147,50 @@ def find_listen_port(family = socket.AF_INET, type = socket.SOCK_STREAM, ...@@ -134,44 +147,50 @@ def find_listen_port(family = socket.AF_INET, type = socket.SOCK_STREAM,
pass pass
raise RuntimeError("Cannot find an usable port in the range specified") raise RuntimeError("Cannot find an usable port in the range specified")
# Logging # Logging
_log_level = LOG_DEBUG _log_level = LOG_WARNING
_log_use_syslog = False _log_use_syslog = False
_log_stream = sys.stderr _log_stream = sys.stderr
_log_syslog_opts = () _log_syslog_opts = ()
_log_pid = os.getpid() _log_pid = os.getpid()
def set_log_level(level): def set_log_level(level):
"Sets the log level for console messages, does not affect syslog logging." "Sets the log level for console messages, does not affect syslog logging."
global _log_level global _log_level
assert level > LOG_ERR and level <= LOG_DEBUG assert level > LOG_ERR and level <= LOG_DEBUG
_log_level = level _log_level = level
def set_log_output(stream): def set_log_output(stream):
"Redirect console messages to the provided stream." "Redirect console messages to the provided stream."
global _log_stream global _log_stream
assert hasattr(stream, "write") and hasattr(stream, "flush") assert hasattr(stream, "write") and hasattr(stream, "flush")
_log_stream = stream _log_stream = stream
def log_use_syslog(use = True, ident = None, logopt = 0,
facility = syslog.LOG_USER): def log_use_syslog(use=True, ident=None, logopt=0,
facility=syslog.LOG_USER):
"Enable or disable the use of syslog for logging messages." "Enable or disable the use of syslog for logging messages."
global _log_use_syslog, _log_syslog_opts global _log_use_syslog, _log_syslog_opts
_log_syslog_opts = (ident, logopt, facility) _log_syslog_opts = (ident, logopt, facility)
_log_use_syslog = use _log_use_syslog = use
_init_log() _init_log()
def _init_log(): def _init_log():
if not _log_use_syslog: if not _log_use_syslog:
syslog.closelog() syslog.closelog()
return return
(ident, logopt, facility) = _log_syslog_opts (ident, logopt, facility) = _log_syslog_opts
if not ident: if not ident:
#ident = os.path.basename(sys.argv[0]) # ident = os.path.basename(sys.argv[0])
ident = "nemu" ident = "nemu"
syslog.openlog("%s[%d]" % (ident, os.getpid()), logopt, facility) syslog.openlog("%s[%d]" % (ident, os.getpid()), logopt, facility)
info("Syslog logging started") info("Syslog logging started")
def logger(priority, message): def logger(priority, message):
"Print a log message in syslog, console or both." "Print a log message in syslog, console or both."
if _log_use_syslog: if _log_use_syslog:
...@@ -186,17 +205,27 @@ def logger(priority, message): ...@@ -186,17 +205,27 @@ def logger(priority, message):
"[%d] %s\n" % (os.getpid(), message.rstrip())) "[%d] %s\n" % (os.getpid(), message.rstrip()))
_log_stream.flush() _log_stream.flush()
def error(message): def error(message):
logger(LOG_ERR, message) logger(LOG_ERR, message)
def warning(message): def warning(message):
logger(LOG_WARNING, message) logger(LOG_WARNING, message)
def notice(message): def notice(message):
logger(LOG_NOTICE, message) logger(LOG_NOTICE, message)
def info(message): def info(message):
logger(LOG_INFO, message) logger(LOG_INFO, message)
def debug(message): def debug(message):
logger(LOG_DEBUG, message) logger(LOG_DEBUG, message)
def _custom_hook(tipe, value, traceback): # pragma: no cover def _custom_hook(tipe, value, traceback): # pragma: no cover
"""Custom exception hook, to print nested exceptions information.""" """Custom exception hook, to print nested exceptions information."""
if hasattr(value, "child_traceback"): if hasattr(value, "child_traceback"):
...@@ -205,5 +234,5 @@ def _custom_hook(tipe, value, traceback): # pragma: no cover ...@@ -205,5 +234,5 @@ def _custom_hook(tipe, value, traceback): # pragma: no cover
sys.stderr.write(value.child_traceback + ("-" * 70) + "\n") sys.stderr.write(value.child_traceback + ("-" * 70) + "\n")
sys.__excepthook__(tipe, value, traceback) sys.__excepthook__(tipe, value, traceback)
sys.excepthook = _custom_hook
sys.excepthook = _custom_hook
...@@ -40,7 +40,7 @@ class Interface(object): ...@@ -40,7 +40,7 @@ class Interface(object):
def _gen_if_name(): def _gen_if_name():
n = Interface._gen_next_id() n = Interface._gen_next_id()
# Max 15 chars # Max 15 chars
return "NETNSif-%.4x%.3x" % (os.getpid() % 0xffff, n) return "NETNSif-%.4x%.3x" % (os.getpid() & 0xffff, n)
def __init__(self, index): def __init__(self, index):
self._idx = index self._idx = index
...@@ -386,7 +386,7 @@ class Switch(ExternalInterface): ...@@ -386,7 +386,7 @@ class Switch(ExternalInterface):
def _gen_br_name(): def _gen_br_name():
n = Switch._gen_next_id() n = Switch._gen_next_id()
# Max 15 chars # Max 15 chars
return "NETNSbr-%.4x%.3x" % (os.getpid() % 0xffff, n) return "NETNSbr-%.4x%.3x" % (os.getpid() & 0xffff, n)
def __init__(self, **args): def __init__(self, **args):
"""Creates a new Switch object, which models a linux bridge device. """Creates a new Switch object, which models a linux bridge device.
......
...@@ -27,6 +27,7 @@ import weakref ...@@ -27,6 +27,7 @@ import weakref
import nemu.interface import nemu.interface
import nemu.protocol import nemu.protocol
import nemu.subprocess_ import nemu.subprocess_
from nemu import compat
from nemu.environ import * from nemu.environ import *
__all__ = ['Node', 'get_nodes', 'import_if'] __all__ = ['Node', 'get_nodes', 'import_if']
...@@ -195,7 +196,7 @@ class Node(object): ...@@ -195,7 +196,7 @@ class Node(object):
# Requires CAP_SYS_ADMIN privileges to run. # Requires CAP_SYS_ADMIN privileges to run.
def _start_child(nonetns) -> (socket.socket, int): def _start_child(nonetns) -> (socket.socket, int):
# Create socket pair to communicate # Create socket pair to communicate
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop # Spawn a child that will run in a loop
pid = os.fork() pid = os.fork()
if pid: if pid:
......
...@@ -33,6 +33,7 @@ from pickle import loads, dumps ...@@ -33,6 +33,7 @@ from pickle import loads, dumps
import nemu.iproute import nemu.iproute
import nemu.subprocess_ import nemu.subprocess_
from nemu import compat
from nemu.environ import * from nemu.environ import *
# ============================================================================ # ============================================================================
...@@ -47,8 +48,8 @@ from nemu.environ import * ...@@ -47,8 +48,8 @@ from nemu.environ import *
# The format string is a chain of "s" for string and "i" for integer # The format string is a chain of "s" for string and "i" for integer
_proto_commands = { _proto_commands = {
"QUIT": { None: ("", "") }, "QUIT": {None: ("", "")},
"HELP": { None: ("", "") }, "HELP": {None: ("", "")},
"X11": { "X11": {
"SET": ("ss", ""), "SET": ("ss", ""),
"SOCK": ("", "") "SOCK": ("", "")
...@@ -75,11 +76,11 @@ _proto_commands = { ...@@ -75,11 +76,11 @@ _proto_commands = {
"WAIT": ("i", ""), "WAIT": ("i", ""),
"KILL": ("i", "i") "KILL": ("i", "i")
}, },
} }
# Commands valid only after PROC CRTE # Commands valid only after PROC CRTE
_proc_commands = { _proc_commands = {
"HELP": { None: ("", "") }, "HELP": {None: ("", "")},
"QUIT": { None: ("", "") }, "QUIT": {None: ("", "")},
"PROC": { "PROC": {
"USER": ("b", ""), "USER": ("b", ""),
"CWD": ("b", ""), "CWD": ("b", ""),
...@@ -90,14 +91,16 @@ _proc_commands = { ...@@ -90,14 +91,16 @@ _proc_commands = {
"RUN": ("", ""), "RUN": ("", ""),
"ABRT": ("", ""), "ABRT": ("", ""),
} }
} }
KILL_WAIT = 3 # seconds KILL_WAIT = 3 # seconds
class Server(object): class Server(object):
"""Class that implements the communication protocol and dispatches calls """Class that implements the communication protocol and dispatches calls
to the required functions. Also works as the main loop for the slave to the required functions. Also works as the main loop for the slave
process.""" process."""
def __init__(self, rfd: socket.socket, wfd: socket.socket): def __init__(self, rfd: socket.socket, wfd: socket.socket):
debug("Server(0x%x).__init__()" % id(self)) debug("Server(0x%x).__init__()" % id(self))
# Dictionary of valid commands # Dictionary of valid commands
...@@ -160,7 +163,7 @@ class Server(object): ...@@ -160,7 +163,7 @@ class Server(object):
def reply(self, code, text): def reply(self, code, text):
"Send back a reply to the client; handle multiline messages" "Send back a reply to the client; handle multiline messages"
if type(text) != list: if type(text) != list:
text = [ text ] text = [text]
clean = [] clean = []
# Split lines with embedded \n # Split lines with embedded \n
for i in text: for i in text:
...@@ -212,7 +215,7 @@ class Server(object): ...@@ -212,7 +215,7 @@ class Server(object):
cmd2 = None cmd2 = None
subcommands = self._commands[cmd1] subcommands = self._commands[cmd1]
if list(subcommands.keys()) != [ None ]: if list(subcommands.keys()) != [None]:
if len(args) < 1: if len(args) < 1:
self.reply(500, "Incomplete command.") self.reply(500, "Incomplete command.")
return None return None
...@@ -285,7 +288,7 @@ class Server(object): ...@@ -285,7 +288,7 @@ class Server(object):
v.child_traceback = "".join( v.child_traceback = "".join(
traceback.format_exception(t, v, tb)) traceback.format_exception(t, v, tb))
self.reply(550, ["# Exception data follows:", self.reply(550, ["# Exception data follows:",
_b64(dumps(v, protocol = 2))]) _b64(dumps(v, protocol=2))])
try: try:
self._rfd.close() self._rfd.close()
self._wfd.close() self._wfd.close()
...@@ -312,7 +315,7 @@ class Server(object): ...@@ -312,7 +315,7 @@ class Server(object):
self._closed = True self._closed = True
def do_PROC_CRTE(self, cmdname, executable, *argv): def do_PROC_CRTE(self, cmdname, executable, *argv):
self._proc = { 'executable': executable, 'argv': argv } self._proc = {'executable': executable, 'argv': argv}
self._commands = _proc_commands self._commands = _proc_commands
self.reply(200, "Entering PROC mode.") self.reply(200, "Entering PROC mode.")
...@@ -330,7 +333,7 @@ class Server(object): ...@@ -330,7 +333,7 @@ class Server(object):
"Invalid number of arguments for PROC ENV: must be even.") "Invalid number of arguments for PROC ENV: must be even.")
return return
self._proc['env'] = {} self._proc['env'] = {}
for i in range(len(env)//2): for i in range(len(env) // 2):
self._proc['env'][env[i * 2]] = env[i * 2 + 1] self._proc['env'][env[i * 2]] = env[i * 2 + 1]
self.reply(200, "%d environment definition(s) read." % (len(env) // 2)) self.reply(200, "%d environment definition(s) read." % (len(env) // 2))
...@@ -446,13 +449,13 @@ class Server(object): ...@@ -446,13 +449,13 @@ class Server(object):
os.kill(-pid, signal.SIGTERM) os.kill(-pid, signal.SIGTERM)
self.reply(200, "Process signalled.") self.reply(200, "Process signalled.")
def do_IF_LIST(self, cmdname, ifnr = None): def do_IF_LIST(self, cmdname, ifnr=None):
if ifnr == None: if ifnr == None:
ifdata = nemu.iproute.get_if_data()[0] ifdata = nemu.iproute.get_if_data()[0]
else: else:
ifdata = nemu.iproute.get_if(ifnr) ifdata = nemu.iproute.get_if(ifnr)
self.reply(200, ["# Interface data follows.", self.reply(200, ["# Interface data follows.",
_b64(dumps(ifdata, protocol = 2))]) _b64(dumps(ifdata, protocol=2))])
def do_IF_SET(self, cmdname, ifnr, *args): def do_IF_SET(self, cmdname, ifnr, *args):
if len(args) % 2: if len(args) % 2:
...@@ -475,14 +478,14 @@ class Server(object): ...@@ -475,14 +478,14 @@ class Server(object):
nemu.iproute.del_if(ifnr) nemu.iproute.del_if(ifnr)
self.reply(200, "Done.") self.reply(200, "Done.")
def do_ADDR_LIST(self, cmdname, ifnr = None): def do_ADDR_LIST(self, cmdname, ifnr=None):
addrdata = nemu.iproute.get_addr_data()[0] addrdata = nemu.iproute.get_addr_data()[0]
if ifnr != None: if ifnr != None:
addrdata = addrdata[ifnr] addrdata = addrdata[ifnr]
self.reply(200, ["# Address data follows.", self.reply(200, ["# Address data follows.",
_b64(dumps(addrdata, protocol = 2))]) _b64(dumps(addrdata, protocol=2))])
def do_ADDR_ADD(self, cmdname, ifnr, address, prefixlen, broadcast = None): def do_ADDR_ADD(self, cmdname, ifnr, address, prefixlen, broadcast=None):
if address.find(":") < 0: # crude, I know if address.find(":") < 0: # crude, I know
a = nemu.iproute.ipv4address(address, prefixlen, broadcast) a = nemu.iproute.ipv4address(address, prefixlen, broadcast)
else: else:
...@@ -501,7 +504,7 @@ class Server(object): ...@@ -501,7 +504,7 @@ class Server(object):
def do_ROUT_LIST(self, cmdname): def do_ROUT_LIST(self, cmdname):
rdata = nemu.iproute.get_route_data() rdata = nemu.iproute.get_route_data()
self.reply(200, ["# Routing data follows.", self.reply(200, ["# Routing data follows.",
_b64(dumps(rdata, protocol = 2))]) _b64(dumps(rdata, protocol=2))])
def do_ROUT_ADD(self, cmdname, tipe, prefix, prefixlen, nexthop, ifnr, def do_ROUT_ADD(self, cmdname, tipe, prefix, prefixlen, nexthop, ifnr,
metric): metric):
...@@ -521,7 +524,7 @@ class Server(object): ...@@ -521,7 +524,7 @@ class Server(object):
return return
skt, port = None, None skt, port = None, None
try: try:
skt, port = find_listen_port(min_port = 6010, max_port = 6099) skt, port = find_listen_port(min_port=6010, max_port=6099)
except: except:
self.reply(500, "Cannot allocate a port for X forwarding.") self.reply(500, "Cannot allocate a port for X forwarding.")
return return
...@@ -548,6 +551,7 @@ class Server(object): ...@@ -548,6 +551,7 @@ class Server(object):
self._xsock = None self._xsock = None
self.reply(200, "Will set up X forwarding.") self.reply(200, "Will set up X forwarding.")
# ============================================================================ # ============================================================================
# #
# Client-side protocol implementation. # Client-side protocol implementation.
...@@ -555,6 +559,7 @@ class Server(object): ...@@ -555,6 +559,7 @@ class Server(object):
class Client(object): class Client(object):
"""Client-side implementation of the communication protocol. Acts as a RPC """Client-side implementation of the communication protocol. Acts as a RPC
service.""" service."""
def __init__(self, rfd: socket.socket, wfd: socket.socket): def __init__(self, rfd: socket.socket, wfd: socket.socket):
debug("Client(0x%x).__init__()" % id(self)) debug("Client(0x%x).__init__()" % id(self))
self._rfd_socket = rfd self._rfd_socket = rfd
...@@ -569,7 +574,7 @@ class Client(object): ...@@ -569,7 +574,7 @@ class Client(object):
debug("Client(0x%x).__del__()" % id(self)) debug("Client(0x%x).__del__()" % id(self))
self.shutdown() self.shutdown()
def _send_cmd(self, *args: str): def _send_cmd(self, *args: str | int):
if not self._wfd: if not self._wfd:
raise RuntimeError("Client already shut down.") raise RuntimeError("Client already shut down.")
s = " ".join(map(str, args)) + "\n" s = " ".join(map(str, args)) + "\n"
...@@ -595,7 +600,7 @@ class Client(object): ...@@ -595,7 +600,7 @@ class Client(object):
break break
return (int(status), "\n".join(text)) return (int(status), "\n".join(text))
def _read_and_check_reply(self, expected = 2): def _read_and_check_reply(self, expected=2):
"""Reads a response and raises an exception if the first digit of the """Reads a response and raises an exception if the first digit of the
code is not the expected value. If expected is not specified, it code is not the expected value. If expected is not specified, it
defaults to 2.""" defaults to 2."""
...@@ -620,7 +625,7 @@ class Client(object): ...@@ -620,7 +625,7 @@ class Client(object):
self._rfd_socket.close() self._rfd_socket.close()
self._rfd = None self._rfd = None
self._wfd.close() self._wfd.close()
self._rfd_socket.close() self._wfd_socket.close()
self._wfd = None self._wfd = None
if self._forwarder: if self._forwarder:
os.kill(self._forwarder, signal.SIGTERM) os.kill(self._forwarder, signal.SIGTERM)
...@@ -640,9 +645,9 @@ class Client(object): ...@@ -640,9 +645,9 @@ class Client(object):
raise raise
self._read_and_check_reply() self._read_and_check_reply()
def spawn(self, argv, executable = None, def spawn(self, argv, executable=None,
stdin = None, stdout = None, stderr = None, stdin=None, stdout=None, stderr=None,
cwd = None, env = None, user = None): cwd=None, env=None, user=None):
"""Start a subprocess in the slave; the interface resembles """Start a subprocess in the slave; the interface resembles
subprocess.Popen, but with less functionality. In particular subprocess.Popen, but with less functionality. In particular
stdin/stdout/stderr can only be None or a open file descriptor. stdin/stdout/stderr can only be None or a open file descriptor.
...@@ -675,10 +680,13 @@ class Client(object): ...@@ -675,10 +680,13 @@ class Client(object):
self._read_and_check_reply() self._read_and_check_reply()
if stdin != None: if stdin != None:
os.set_inheritable(stdin, True)
self._send_fd("SIN", stdin) self._send_fd("SIN", stdin)
if stdout != None: if stdout != None:
os.set_inheritable(stdout, True)
self._send_fd("SOUT", stdout) self._send_fd("SOUT", stdout)
if stderr != None: if stderr != None:
os.set_inheritable(stderr, True)
self._send_fd("SERR", stderr) self._send_fd("SERR", stderr)
except: except:
self._send_cmd("PROC", "ABRT") self._send_cmd("PROC", "ABRT")
...@@ -711,16 +719,16 @@ class Client(object): ...@@ -711,16 +719,16 @@ class Client(object):
exitcode = int(text.split()[0]) exitcode = int(text.split()[0])
return exitcode return exitcode
def signal(self, pid, sig = signal.SIGTERM): def signal(self, pid, sig=signal.SIGTERM):
"""Equivalent to Popen.send_signal(). Sends a signal to the child """Equivalent to Popen.send_signal(). Sends a signal to the child
process; signal defaults to SIGTERM.""" process; signal defaults to SIGTERM."""
if sig: if sig:
self._send_cmd("PROC", "KILL", pid, sig) self._send_cmd("PROC", "KILL", pid, int(sig))
else: else:
self._send_cmd("PROC", "KILL", pid) self._send_cmd("PROC", "KILL", pid)
self._read_and_check_reply() self._read_and_check_reply()
def get_if_data(self, ifnr = None): def get_if_data(self, ifnr=None):
if ifnr: if ifnr:
self._send_cmd("IF", "LIST", ifnr) self._send_cmd("IF", "LIST", ifnr)
else: else:
...@@ -746,7 +754,7 @@ class Client(object): ...@@ -746,7 +754,7 @@ class Client(object):
self._send_cmd("IF", "RTRN", ifnr, netns) self._send_cmd("IF", "RTRN", ifnr, netns)
self._read_and_check_reply() self._read_and_check_reply()
def get_addr_data(self, ifnr = None): def get_addr_data(self, ifnr=None):
if ifnr: if ifnr:
self._send_cmd("ADDR", "LIST", ifnr) self._send_cmd("ADDR", "LIST", ifnr)
else: else:
...@@ -793,7 +801,7 @@ class Client(object): ...@@ -793,7 +801,7 @@ class Client(object):
self._send_cmd("X11", "SOCK") self._send_cmd("X11", "SOCK")
fd, payload = passfd.recvfd(self._rfd, 1) fd, payload = passfd.recvfd(self._rfd, 1)
self._read_and_check_reply() self._read_and_check_reply()
skt = socket.fromfd(fd, socket.AF_INET, socket.SOCK_DGRAM) skt = compat.fromfd(fd, socket.AF_INET, socket.SOCK_DGRAM)
os.close(fd) # fromfd dup()'s os.close(fd) # fromfd dup()'s
return skt return skt
...@@ -814,25 +822,45 @@ class Client(object): ...@@ -814,25 +822,45 @@ class Client(object):
server = self.set_x11(protoname, hexkey) server = self.set_x11(protoname, hexkey)
self._forwarder = _spawn_x11_forwarder(server, sock, addr) self._forwarder = _spawn_x11_forwarder(server, sock, addr)
def _b64(text: str | bytes) -> str:
def _b64_OLD(text: str | bytes) -> str:
if text == None: if text == None:
# easier this way # easier this way
text = '' text = ''
if type(text) is str: if type(text) is str:
btext = text.encode("utf-8") btext = text.encode("utf-8")
elif type(text) is bytes:
btext = text
else: else:
btext = text btext = text
if len(text) == 0 or any(x for x in btext if x <= ord(" ") or if len(btext) == 0 or any(x for x in btext if x <= ord(" ") or
x > ord("z") or x == ord("=")): x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(btext).decode("ascii") return "=" + base64.b64encode(btext).decode("ascii")
else: else:
return text return text
def _b64(text) -> str:
if text is None:
# easier this way
return "="
if type(text) is bytes:
enc = text
else:
enc = str(text).encode("utf-8")
if len(enc) == 0 or any(x for x in enc if x <= ord(" ") or
x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(enc).decode("ascii")
else:
return enc.decode("utf-8")
def _db64(text: str) -> bytes: def _db64(text: str) -> bytes:
if not text or text[0] != '=': if not text or text[0] != '=':
return text.encode("utf-8") return text.encode("utf-8")
return base64.b64decode(text[1:]) return base64.b64decode(text[1:])
def _get_file(fd, mode): def _get_file(fd, mode):
# Since fdopen insists on closing the fd on destruction, I need to dup() # Since fdopen insists on closing the fd on destruction, I need to dup()
if hasattr(fd, "fileno"): if hasattr(fd, "fileno"):
...@@ -841,6 +869,7 @@ def _get_file(fd, mode): ...@@ -841,6 +869,7 @@ def _get_file(fd, mode):
nfd = os.dup(fd) nfd = os.dup(fd)
return os.fdopen(nfd, mode, 1) return os.fdopen(nfd, mode, 1)
def _parse_display(): def _parse_display():
if "DISPLAY" not in os.environ: if "DISPLAY" not in os.environ:
return None return None
...@@ -863,6 +892,7 @@ def _parse_display(): ...@@ -863,6 +892,7 @@ def _parse_display():
xauthdpy = "unix:%s" % number xauthdpy = "unix:%s" % number
return xauthdpy, sock, addr return xauthdpy, sock, addr
def _spawn_x11_forwarder(server, xsock, xaddr): def _spawn_x11_forwarder(server, xsock, xaddr):
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.listen(10) # arbitrary server.listen(10) # arbitrary
...@@ -876,6 +906,7 @@ def _spawn_x11_forwarder(server, xsock, xaddr): ...@@ -876,6 +906,7 @@ def _spawn_x11_forwarder(server, xsock, xaddr):
traceback.print_exc(file=sys.stderr) traceback.print_exc(file=sys.stderr)
os._exit(1) os._exit(1)
def _x11_forwarder(server, xsock, xaddr): def _x11_forwarder(server, xsock, xaddr):
def clean(idx, toread, fd): def clean(idx, toread, fd):
# silently discards any buffer! # silently discards any buffer!
...@@ -896,14 +927,14 @@ def _x11_forwarder(server, xsock, xaddr): ...@@ -896,14 +927,14 @@ def _x11_forwarder(server, xsock, xaddr):
if fd2 in toread: if fd2 in toread:
toread.remove(fd2) toread.remove(fd2)
toread = set([server]) toread = {server}
idx = {} idx = {}
while(True): while True:
towrite = [x["wr"] for x in idx.values() if x["buf"]] towrite = [x["wr"] for x in idx.values() if x["buf"]]
(rr, wr, er) = select.select(toread, towrite, []) (rr, wr, er) = select.select(toread, towrite, [])
if server in rr: if server in rr:
xconn = socket.socket(*xsock) xconn = compat.socket(*xsock)
xconn.connect(xaddr) xconn.connect(xaddr)
client, addr = server.accept() client, addr = server.accept()
toread.add(client) toread.add(client)
......
...@@ -28,24 +28,27 @@ import sys ...@@ -28,24 +28,27 @@ import sys
import time import time
import traceback import traceback
from nemu import compat
from nemu.environ import eintr_wrapper from nemu.environ import eintr_wrapper
__all__ = [ 'PIPE', 'STDOUT', 'Popen', 'Subprocess', 'spawn', 'wait', 'poll', __all__ = ['PIPE', 'STDOUT', 'Popen', 'Subprocess', 'spawn', 'wait', 'poll',
'get_user', 'system', 'backticks', 'backticks_raise' ] 'get_user', 'system', 'backticks', 'backticks_raise']
# User-facing interfaces # User-facing interfaces
KILL_WAIT = 3 # seconds KILL_WAIT = 3 # seconds
class Subprocess(object): class Subprocess(object):
"""Class that allows the execution of programs inside a nemu Node. This is """Class that allows the execution of programs inside a nemu Node. This is
the base class for all process operations, Popen provides a more high level the base class for all process operations, Popen provides a more high level
interface.""" interface."""
# FIXME # FIXME
default_user = None default_user = None
def __init__(self, node, argv, executable = None,
stdin = None, stdout = None, stderr = None, def __init__(self, node, argv: str | list[str], executable=None,
shell = False, cwd = None, env = None, user = None): stdin=None, stdout=None, stderr=None,
shell=False, cwd=None, env=None, user=None):
self._slave = node._slave self._slave = node._slave
"""Forks and execs a program, with stdio redirection and user """Forks and execs a program, with stdio redirection and user
switching. switching.
...@@ -79,9 +82,9 @@ class Subprocess(object): ...@@ -79,9 +82,9 @@ class Subprocess(object):
user = Subprocess.default_user user = Subprocess.default_user
if isinstance(argv, str): if isinstance(argv, str):
argv = [ argv ] argv = [argv]
if shell: if shell:
argv = [ '/bin/sh', '-c' ] + argv argv = ['/bin/sh', '-c'] + argv
# Initialize attributes that would be used by the destructor if spawn # Initialize attributes that would be used by the destructor if spawn
# fails # fails
...@@ -89,9 +92,9 @@ class Subprocess(object): ...@@ -89,9 +92,9 @@ class Subprocess(object):
# confusingly enough, to go to the function at the top of this file, # confusingly enough, to go to the function at the top of this file,
# I need to call it thru the communications protocol: remember that # I need to call it thru the communications protocol: remember that
# happens in another process! # happens in another process!
self._pid = self._slave.spawn(argv, executable = executable, self._pid = self._slave.spawn(argv, executable=executable,
stdin = stdin, stdout = stdout, stderr = stderr, stdin=stdin, stdout=stdout, stderr=stderr,
cwd = cwd, env = env, user = user) cwd=cwd, env=env, user=user)
node._add_subprocess(self) node._add_subprocess(self)
...@@ -114,7 +117,7 @@ class Subprocess(object): ...@@ -114,7 +117,7 @@ class Subprocess(object):
self._returncode = self._slave.wait(self._pid) self._returncode = self._slave.wait(self._pid)
return self.returncode return self.returncode
def signal(self, sig = signal.SIGTERM): def signal(self, sig=signal.SIGTERM):
"""Sends a signal to the process.""" """Sends a signal to the process."""
if self._returncode == None: if self._returncode == None:
self._slave.signal(self._pid, sig) self._slave.signal(self._pid, sig)
...@@ -135,6 +138,7 @@ class Subprocess(object): ...@@ -135,6 +138,7 @@ class Subprocess(object):
def __del__(self): def __del__(self):
self.destroy() self.destroy()
def destroy(self): def destroy(self):
if self._returncode != None or self._pid == None: if self._returncode != None or self._pid == None:
return return
...@@ -149,15 +153,19 @@ class Subprocess(object): ...@@ -149,15 +153,19 @@ class Subprocess(object):
self.signal(signal.SIGKILL) self.signal(signal.SIGKILL)
self.wait() self.wait()
PIPE = -1 PIPE = -1
STDOUT = -2 STDOUT = -2
DEVNULL = -3
class Popen(Subprocess): class Popen(Subprocess):
"""Higher-level interface for executing processes, that tries to emulate """Higher-level interface for executing processes, that tries to emulate
the stdlib's subprocess.Popen as much as possible.""" the stdlib's subprocess.Popen as much as possible."""
def __init__(self, node, argv, executable = None, def __init__(self, node, argv, executable=None,
stdin = None, stdout = None, stderr = None, bufsize = 0, stdin=None, stdout=None, stderr=None, bufsize=0,
shell = False, cwd = None, env = None, user = None): shell=False, cwd=None, env=None, user=None):
"""As in Subprocess, `node' specifies the nemu Node to run in. """As in Subprocess, `node' specifies the nemu Node to run in.
The `stdin', `stdout', and `stderr' parameters also accept the special The `stdin', `stdout', and `stderr' parameters also accept the special
...@@ -168,30 +176,33 @@ class Popen(Subprocess): ...@@ -168,30 +176,33 @@ class Popen(Subprocess):
self.stdin = self.stdout = self.stderr = None self.stdin = self.stdout = self.stderr = None
self._pid = self._returncode = None self._pid = self._returncode = None
fdmap = { "stdin": stdin, "stdout": stdout, "stderr": stderr } fdmap = {"stdin": stdin, "stdout": stdout, "stderr": stderr}
# if PIPE: all should be closed at the end # if PIPE: all should be closed at the end
for k, v in fdmap.items(): for k, v in fdmap.items():
if v == None: if v == None:
continue continue
if v == PIPE: if v == PIPE:
r, w = os.pipe() r, w = compat.pipe()
if k == "stdin": if k == "stdin":
self.stdin = os.fdopen(w, 'wb', bufsize) self.stdin = os.fdopen(w, 'wb', bufsize)
fdmap[k] = r fdmap[k] = r
else: else:
setattr(self, k, os.fdopen(r, 'rb', bufsize)) setattr(self, k, os.fdopen(r, 'rb', bufsize))
fdmap[k] = w fdmap[k] = w
elif v == DEVNULL:
fdmap[k] = os.open(os.devnull, os.O_RDWR)
elif isinstance(v, int): elif isinstance(v, int):
pass pass
else: else:
fdmap[k] = v.fileno() fdmap[k] = v.fileno()
os.set_inheritable(fdmap[k], True)
if stderr == STDOUT: if stderr == STDOUT:
fdmap['stderr'] = fdmap['stdout'] fdmap['stderr'] = fdmap['stdout']
super(Popen, self).__init__(node, argv, executable = executable, super(Popen, self).__init__(node, argv, executable=executable,
stdin = fdmap['stdin'], stdout = fdmap['stdout'], stdin=fdmap['stdin'], stdout=fdmap['stdout'],
stderr = fdmap['stderr'], stderr=fdmap['stderr'],
shell = shell, cwd = cwd, env = env, user = user) shell=shell, cwd=cwd, env=env, user=user)
# Close pipes, they have been dup()ed to the child # Close pipes, they have been dup()ed to the child
for k, v in fdmap.items(): for k, v in fdmap.items():
...@@ -224,8 +235,8 @@ class Popen(Subprocess): ...@@ -224,8 +235,8 @@ class Popen(Subprocess):
r, w, x = select.select(rset, wset, []) r, w, x = select.select(rset, wset, [])
if self.stdin in w: if self.stdin in w:
wrote = os.write(self.stdin.fileno(), wrote = os.write(self.stdin.fileno(),
#buffer(input, offset, select.PIPE_BUF)) # buffer(input, offset, select.PIPE_BUF))
input[offset:offset+512]) # XXX: py2.7 input[offset:offset + 512]) # XXX: py2.7
offset += wrote offset += wrote
if offset >= len(input): if offset >= len(input):
self.stdin.close() self.stdin.close()
...@@ -233,7 +244,7 @@ class Popen(Subprocess): ...@@ -233,7 +244,7 @@ class Popen(Subprocess):
for i in self.stdout, self.stderr: for i in self.stdout, self.stderr:
if i in r: if i in r:
d = os.read(i.fileno(), 1024) # No need for eintr wrapper d = os.read(i.fileno(), 1024) # No need for eintr wrapper
if d == "": if d == b"":
i.close() i.close()
rset.remove(i) rset.remove(i)
else: else:
...@@ -249,38 +260,42 @@ class Popen(Subprocess): ...@@ -249,38 +260,42 @@ class Popen(Subprocess):
self.wait() self.wait()
return (out, err) return (out, err)
def system(node, args): def system(node, args):
"""Emulates system() function, if `args' is an string, it uses `/bin/sh' to """Emulates system() function, if `args' is an string, it uses `/bin/sh' to
exexecute it, otherwise is interpreted as the argv array to call execve.""" exexecute it, otherwise is interpreted as the argv array to call execve."""
shell = isinstance(args, str) shell = isinstance(args, str)
return Popen(node, args, shell = shell).wait() return Popen(node, args, shell=shell).wait()
def backticks(node, args): def backticks(node, args):
"""Emulates shell backticks, if `args' is an string, it uses `/bin/sh' to """Emulates shell backticks, if `args' is an string, it uses `/bin/sh' to
exexecute it, otherwise is interpreted as the argv array to call execve.""" exexecute it, otherwise is interpreted as the argv array to call execve."""
shell = isinstance(args, str) shell = isinstance(args, str)
return Popen(node, args, shell = shell, stdout = PIPE).communicate()[0] return Popen(node, args, shell=shell, stdout=PIPE).communicate()[0].decode("utf-8")
def backticks_raise(node, args):
def backticks_raise(node, args: str | list[str]) -> str:
"""Emulates shell backticks, if `args' is an string, it uses `/bin/sh' to """Emulates shell backticks, if `args' is an string, it uses `/bin/sh' to
exexecute it, otherwise is interpreted as the argv array to call execve. exexecute it, otherwise is interpreted as the argv array to call execve.
Raises an RuntimeError if the return value is not 0.""" Raises an RuntimeError if the return value is not 0."""
shell = isinstance(args, str) shell = isinstance(args, str)
p = Popen(node, args, shell = shell, stdout = PIPE) p = Popen(node, args, shell=shell, stdout=PIPE)
out = p.communicate()[0] out = p.communicate()[0]
ret = p.returncode ret = p.returncode
if ret > 0: if ret > 0:
raise RuntimeError("Command failed with return code %d." % ret) raise RuntimeError("Command failed with return code %d." % ret)
if ret < 0: if ret < 0:
raise RuntimeError("Command killed by signal %d." % -ret) raise RuntimeError("Command killed by signal %d." % -ret)
return out return out.decode("utf-8")
# ======================================================================= # =======================================================================
# #
# Server-side code, called from nemu.protocol.Server # Server-side code, called from nemu.protocol.Server
def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
stdin = None, stdout = None, stderr = None, user = None): stdin=None, stdout=None, stderr=None, user=None):
"""Internal function that performs all the dirty work for Subprocess, Popen """Internal function that performs all the dirty work for Subprocess, Popen
and friends. This is executed in the slave process, directly from the and friends. This is executed in the slave process, directly from the
protocol.Server class. protocol.Server class.
...@@ -315,7 +330,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -315,7 +330,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
env['HOME'] = home env['HOME'] = home
env['USER'] = user env['USER'] = user
(r, w) = os.pipe() (r, w) = compat.pipe()
pid = os.fork() pid = os.fork()
if pid == 0: # pragma: no cover if pid == 0: # pragma: no cover
# coverage doesn't seem to understand fork # coverage doesn't seem to understand fork
...@@ -355,7 +370,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -355,7 +370,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
if cwd != None: if cwd != None:
os.chdir(cwd) os.chdir(cwd)
if not argv: if not argv:
argv = [ executable ] argv = [executable]
if '/' in executable: # Should not search in PATH if '/' in executable: # Should not search in PATH
if env != None: if env != None:
os.execve(executable, argv, env) os.execve(executable, argv, env)
...@@ -376,7 +391,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -376,7 +391,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
traceback.format_exception(t, v, tb)) traceback.format_exception(t, v, tb))
eintr_wrapper(os.write, w, pickle.dumps(v)) eintr_wrapper(os.write, w, pickle.dumps(v))
eintr_wrapper(os.close, w) eintr_wrapper(os.close, w)
#traceback.print_exc() # traceback.print_exc()
except: except:
traceback.print_exc() traceback.print_exc()
os._exit(1) os._exit(1)
...@@ -387,7 +402,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -387,7 +402,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
s = b"" s = b""
while True: while True:
s1 = eintr_wrapper(os.read, r, 4096) s1 = eintr_wrapper(os.read, r, 4096)
if s1 == "": if s1 == b"":
break break
s += s1 s += s1
eintr_wrapper(os.close, r) eintr_wrapper(os.close, r)
...@@ -399,9 +414,10 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -399,9 +414,10 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
eintr_wrapper(os.waitpid, pid, 0) eintr_wrapper(os.waitpid, pid, 0)
exc = pickle.loads(s) exc = pickle.loads(s)
# XXX: sys.excepthook # XXX: sys.excepthook
#print exc.child_traceback # print exc.child_traceback
raise exc raise exc
def poll(pid): def poll(pid):
"""Check if the process already died. Returns the exit code or None if """Check if the process already died. Returns the exit code or None if
the process is still alive.""" the process is still alive."""
...@@ -410,10 +426,12 @@ def poll(pid): ...@@ -410,10 +426,12 @@ def poll(pid):
return r[1] return r[1]
return None return None
def wait(pid): def wait(pid):
"""Wait for process to die and return the exit code.""" """Wait for process to die and return the exit code."""
return eintr_wrapper(os.waitpid, pid, 0)[1] return eintr_wrapper(os.waitpid, pid, 0)[1]
def get_user(user): def get_user(user):
"Take either an username or an uid, and return a tuple (user, uid, gid)." "Take either an username or an uid, and return a tuple (user, uid, gid)."
if str(user).isdigit(): if str(user).isdigit():
...@@ -430,11 +448,10 @@ def get_user(user): ...@@ -430,11 +448,10 @@ def get_user(user):
gid = pwd.getpwuid(uid)[3] gid = pwd.getpwuid(uid)[3]
return user, uid, gid return user, uid, gid
# internal stuff, do not look! # internal stuff, do not look!
try: try:
MAXFD = os.sysconf("SC_OPEN_MAX") MAXFD = os.sysconf("SC_OPEN_MAX")
except: # pragma: no cover except: # pragma: no cover
MAXFD = 256 MAXFD = 256
...@@ -6,14 +6,16 @@ import nemu.protocol ...@@ -6,14 +6,16 @@ import nemu.protocol
import os, socket, sys, threading, unittest import os, socket, sys, threading, unittest
import test_util import test_util
from nemu import compat
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
@test_util.skip("python 3 can't makefile a socket in r+")
def test_server_startup(self): def test_server_startup(self):
# Test the creation of the server object with different ways of passing # Test the creation of the server object with different ways of passing
# the file descriptor; and check the banner. # the file descriptor; and check the banner.
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
(s2, s3) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s2, s3) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
def test_help(fd): def test_help(fd):
fd.write("HELP\n") fd.write("HELP\n")
...@@ -34,13 +36,13 @@ class TestServer(unittest.TestCase): ...@@ -34,13 +36,13 @@ class TestServer(unittest.TestCase):
t = threading.Thread(target = run_server) t = threading.Thread(target = run_server)
t.start() t.start()
s = os.fdopen(s1.fileno(), "r+", 1) s = os.fdopen(s1.fileno(), "r", 1)
self.assertEqual(s.readline()[0:4], "220 ") self.assertEqual(s.readline()[0:4], "220 ")
test_help(s) test_help(s)
s.close() s.close()
s0.close() s0.close()
s = os.fdopen(s3.fileno(), "r+", 1) s = os.fdopen(s3.fileno(), "r", 1)
self.assertEqual(s.readline()[0:4], "220 ") self.assertEqual(s.readline()[0:4], "220 ")
test_help(s) test_help(s)
s.close() s.close()
...@@ -48,7 +50,7 @@ class TestServer(unittest.TestCase): ...@@ -48,7 +50,7 @@ class TestServer(unittest.TestCase):
t.join() t.join()
def test_server_clean(self): def test_server_clean(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
def run_server(): def run_server():
nemu.protocol.Server(s0, s0).run() nemu.protocol.Server(s0, s0).run()
...@@ -57,7 +59,8 @@ class TestServer(unittest.TestCase): ...@@ -57,7 +59,8 @@ class TestServer(unittest.TestCase):
cli = nemu.protocol.Client(s1, s1) cli = nemu.protocol.Client(s1, s1)
argv = [ '/bin/sh', '-c', 'yes' ] argv = [ '/bin/sh', '-c', 'yes' ]
pid = cli.spawn(argv, stdout = subprocess.DEVNULL) nullfd = open("/dev/null", "wb")
pid = cli.spawn(argv, stdout = nullfd.fileno())
self.assertTrue(os.path.exists("/proc/%d" % pid)) self.assertTrue(os.path.exists("/proc/%d" % pid))
# try to exit while there are still processes running # try to exit while there are still processes running
cli.shutdown() cli.shutdown()
...@@ -68,7 +71,7 @@ class TestServer(unittest.TestCase): ...@@ -68,7 +71,7 @@ class TestServer(unittest.TestCase):
self.assertFalse(os.path.exists("/proc/%d" % pid)) self.assertFalse(os.path.exists("/proc/%d" % pid))
def test_spawn_recovery(self): def test_spawn_recovery(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
def run_server(): def run_server():
nemu.protocol.Server(s0, s0).run() nemu.protocol.Server(s0, s0).run()
...@@ -93,7 +96,7 @@ class TestServer(unittest.TestCase): ...@@ -93,7 +96,7 @@ class TestServer(unittest.TestCase):
@test_util.skip("python 3 can't makefile a socket in r+") @test_util.skip("python 3 can't makefile a socket in r+")
def test_basic_stuff(self): def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = nemu.protocol.Server(s0, s0) srv = nemu.protocol.Server(s0, s0)
s1 = s1.makefile("r+", 1) s1 = s1.makefile("r+", 1)
......
...@@ -6,6 +6,9 @@ import nemu, test_util ...@@ -6,6 +6,9 @@ import nemu, test_util
import nemu.subprocess_ as sp import nemu.subprocess_ as sp
import grp, os, pwd, signal, socket, sys, time, unittest import grp, os, pwd, signal, socket, sys, time, unittest
from nemu import compat
def _stat(path): def _stat(path):
try: try:
return os.stat(path) return os.stat(path)
...@@ -104,7 +107,7 @@ class TestSubprocess(unittest.TestCase): ...@@ -104,7 +107,7 @@ class TestSubprocess(unittest.TestCase):
# uses a default search path # uses a default search path
self.assertRaises(OSError, sp.spawn, 'sleep', env = {'PATH': ''}) self.assertRaises(OSError, sp.spawn, 'sleep', env = {'PATH': ''})
r, w = os.pipe() r, w = compat.pipe()
p = sp.spawn('/bin/echo', ['echo', 'hello world'], stdout = w) p = sp.spawn('/bin/echo', ['echo', 'hello world'], stdout = w)
os.close(w) os.close(w)
self.assertEqual(_readall(r), b"hello world\n") self.assertEqual(_readall(r), b"hello world\n")
...@@ -120,8 +123,8 @@ class TestSubprocess(unittest.TestCase): ...@@ -120,8 +123,8 @@ class TestSubprocess(unittest.TestCase):
# It cannot be wait()ed again. # It cannot be wait()ed again.
self.assertRaises(OSError, sp.wait, p) self.assertRaises(OSError, sp.wait, p)
r0, w0 = os.pipe() r0, w0 = compat.pipe()
r1, w1 = os.pipe() r1, w1 = compat.pipe()
p = sp.spawn('/bin/cat', stdout = w0, stdin = r1, close_fds = [r0, w1]) p = sp.spawn('/bin/cat', stdout = w0, stdin = r1, close_fds = [r0, w1])
os.close(w0) os.close(w0)
os.close(r1) os.close(r1)
...@@ -140,25 +143,25 @@ class TestSubprocess(unittest.TestCase): ...@@ -140,25 +143,25 @@ class TestSubprocess(unittest.TestCase):
self.assertRaises(ValueError, node.Subprocess, self.assertRaises(ValueError, node.Subprocess,
['/bin/sleep', '1000'], user = self.nouid) ['/bin/sleep', '1000'], user = self.nouid)
# Invalid CWD: it is a file # Invalid CWD: it is a file
self.assertRaises(OSError, node.Subprocess, self.assertRaises(NotADirectoryError, node.Subprocess,
'/bin/sleep', cwd = '/bin/sleep') '/bin/sleep', cwd = '/bin/sleep')
# Invalid CWD: does not exist # Invalid CWD: does not exist
self.assertRaises(OSError, node.Subprocess, self.assertRaises(FileNotFoundError, node.Subprocess,
'/bin/sleep', cwd = self.nofile) '/bin/sleep', cwd = self.nofile)
# Exec failure # Exec failure
self.assertRaises(OSError, node.Subprocess, self.nofile) self.assertRaises(FileNotFoundError, node.Subprocess, self.nofile)
# Test that the environment is cleared: sleep should not be found # Test that the environment is cleared: sleep should not be found
self.assertRaises(OSError, node.Subprocess, self.assertRaises(FileNotFoundError, node.Subprocess,
'sleep', env = {'PATH': ''}) 'sleep', env = {'PATH': ''})
# Argv # Argv
self.assertRaises(OSError, node.Subprocess, 'true; false') self.assertRaises(FileNotFoundError, node.Subprocess, 'true; false')
self.assertEqual(node.Subprocess('true').wait(), 0) self.assertEqual(node.Subprocess('true').wait(), 0)
self.assertEqual(node.Subprocess('true; false', shell = True).wait(), self.assertEqual(node.Subprocess('true; false', shell = True).wait(),
1) 1)
# Piping # Piping
r, w = os.pipe() r, w = compat.pipe()
p = node.Subprocess(['echo', 'hello world'], stdout = w) p = node.Subprocess(['echo', 'hello world'], stdout = w)
os.close(w) os.close(w)
self.assertEqual(_readall(r), b"hello world\n") self.assertEqual(_readall(r), b"hello world\n")
...@@ -166,7 +169,7 @@ class TestSubprocess(unittest.TestCase): ...@@ -166,7 +169,7 @@ class TestSubprocess(unittest.TestCase):
p.wait() p.wait()
# cwd # cwd
r, w = os.pipe() r, w = compat.pipe()
p = node.Subprocess('/bin/pwd', stdout = w, cwd = "/") p = node.Subprocess('/bin/pwd', stdout = w, cwd = "/")
os.close(w) os.close(w)
self.assertEqual(_readall(r), b"/\n") self.assertEqual(_readall(r), b"/\n")
...@@ -194,7 +197,7 @@ class TestSubprocess(unittest.TestCase): ...@@ -194,7 +197,7 @@ class TestSubprocess(unittest.TestCase):
# closing stdout (so _readall finishes) # closing stdout (so _readall finishes)
cmd = 'trap "" TERM; echo; exec sleep 100 > /dev/null' cmd = 'trap "" TERM; echo; exec sleep 100 > /dev/null'
r, w = os.pipe() r, w = compat.pipe()
p = node.Subprocess(cmd, shell = True, stdout = w) p = node.Subprocess(cmd, shell = True, stdout = w)
os.close(w) os.close(w)
self.assertEqual(_readall(r), b"\n") # wait for trap to be installed self.assertEqual(_readall(r), b"\n") # wait for trap to be installed
...@@ -202,9 +205,10 @@ class TestSubprocess(unittest.TestCase): ...@@ -202,9 +205,10 @@ class TestSubprocess(unittest.TestCase):
pid = p.pid pid = p.pid
os.kill(pid, 0) # verify process still there os.kill(pid, 0) # verify process still there
# Avoid the warning about the process being killed # Avoid the warning about the process being killed
old_err = sys.stderr
with open("/dev/null", "w") as sys.stderr: with open("/dev/null", "w") as sys.stderr:
p.destroy() p.destroy()
sys.stderr = sys.__stderr__ sys.stderr = old_err
self.assertRaises(OSError, os.kill, pid, 0) # should be dead by now self.assertRaises(OSError, os.kill, pid, 0) # should be dead by now
p = node.Subprocess(['sleep', '100']) p = node.Subprocess(['sleep', '100'])
...@@ -217,8 +221,8 @@ class TestSubprocess(unittest.TestCase): ...@@ -217,8 +221,8 @@ class TestSubprocess(unittest.TestCase):
node = nemu.Node(nonetns = True) node = nemu.Node(nonetns = True)
# repeat test with Popen interface # repeat test with Popen interface
r0, w0 = os.pipe() r0, w0 = compat.pipe()
r1, w1 = os.pipe() r1, w1 = compat.pipe()
p = node.Popen('cat', stdout = w0, stdin = r1) p = node.Popen('cat', stdout = w0, stdin = r1)
os.close(w0) os.close(w0)
os.close(r1) os.close(r1)
...@@ -228,7 +232,7 @@ class TestSubprocess(unittest.TestCase): ...@@ -228,7 +232,7 @@ class TestSubprocess(unittest.TestCase):
os.close(r0) os.close(r0)
# now with a socketpair, not using integers # now with a socketpair, not using integers
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = compat.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
p = node.Popen('cat', stdout = s0, stdin = s0) p = node.Popen('cat', stdout = s0, stdin = s0)
s0.close() s0.close()
s1.send(b"hello world\n") s1.send(b"hello world\n")
...@@ -255,9 +259,9 @@ class TestSubprocess(unittest.TestCase): ...@@ -255,9 +259,9 @@ class TestSubprocess(unittest.TestCase):
p = node.Popen('cat >&2', shell = True, stdin = sp.PIPE, p = node.Popen('cat >&2', shell = True, stdin = sp.PIPE,
stderr = sp.PIPE) stderr = sp.PIPE)
p.stdin.write("hello world\n") p.stdin.write(b"hello world\n")
p.stdin.close() p.stdin.close()
self.assertEqual(p.stderr.readlines(), ["hello world\n"]) self.assertEqual(p.stderr.readlines(), [b"hello world\n"])
self.assertEqual(p.stdout, None) self.assertEqual(p.stdout, None)
self.assertEqual(p.wait(), 0) self.assertEqual(p.wait(), 0)
...@@ -268,9 +272,9 @@ class TestSubprocess(unittest.TestCase): ...@@ -268,9 +272,9 @@ class TestSubprocess(unittest.TestCase):
# #
p = node.Popen(['sh', '-c', 'cat >&2'], p = node.Popen(['sh', '-c', 'cat >&2'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT)
p.stdin.write("hello world\n") p.stdin.write(b"hello world\n")
p.stdin.close() p.stdin.close()
self.assertEqual(p.stdout.readlines(), ["hello world\n"]) self.assertEqual(p.stdout.readlines(), [b"hello world\n"])
self.assertEqual(p.stderr, None) self.assertEqual(p.stderr, None)
self.assertEqual(p.wait(), 0) self.assertEqual(p.wait(), 0)
...@@ -281,9 +285,9 @@ class TestSubprocess(unittest.TestCase): ...@@ -281,9 +285,9 @@ class TestSubprocess(unittest.TestCase):
# #
p = node.Popen(['tee', '/dev/stderr'], p = node.Popen(['tee', '/dev/stderr'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT)
p.stdin.write("hello world\n") p.stdin.write(b"hello world\n")
p.stdin.close() p.stdin.close()
self.assertEqual(p.stdout.readlines(), ["hello world\n"] * 2) self.assertEqual(p.stdout.readlines(), [b"hello world\n"] * 2)
self.assertEqual(p.stderr, None) self.assertEqual(p.stderr, None)
self.assertEqual(p.wait(), 0) self.assertEqual(p.wait(), 0)
...@@ -295,10 +299,10 @@ class TestSubprocess(unittest.TestCase): ...@@ -295,10 +299,10 @@ class TestSubprocess(unittest.TestCase):
# #
p = node.Popen(['tee', '/dev/stderr'], p = node.Popen(['tee', '/dev/stderr'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.PIPE) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.PIPE)
p.stdin.write("hello world\n") p.stdin.write(b"hello world\n")
p.stdin.close() p.stdin.close()
self.assertEqual(p.stdout.readlines(), ["hello world\n"]) self.assertEqual(p.stdout.readlines(), [b"hello world\n"])
self.assertEqual(p.stderr.readlines(), ["hello world\n"]) self.assertEqual(p.stderr.readlines(), [b"hello world\n"])
self.assertEqual(p.wait(), 0) self.assertEqual(p.wait(), 0)
p = node.Popen(['tee', '/dev/stderr'], p = node.Popen(['tee', '/dev/stderr'],
......
...@@ -5,7 +5,7 @@ import os, re, subprocess, sys ...@@ -5,7 +5,7 @@ import os, re, subprocess, sys
import nemu.subprocess_ import nemu.subprocess_
from nemu.environ import * from nemu.environ import *
def process_ipcmd(str): def process_ipcmd(str: str):
cur = None cur = None
out = {} out = {}
for line in str.split("\n"): for line in str.split("\n"):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment