Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
N
nemu3
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
1
Issues
1
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
nexedi
nemu3
Commits
033ce168
Commit
033ce168
authored
Nov 16, 2023
by
Tom Niget
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add type hints
parent
7f17df26
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
412 additions
and
329 deletions
+412
-329
benchmarks/linear-raw-throughput.py
benchmarks/linear-raw-throughput.py
+1
-1
setup.py
setup.py
+1
-1
src/nemu/__init__.py
src/nemu/__init__.py
+2
-2
src/nemu/compat.py
src/nemu/compat.py
+3
-5
src/nemu/environ.py
src/nemu/environ.py
+5
-5
src/nemu/interface.py
src/nemu/interface.py
+104
-70
src/nemu/iproute.py
src/nemu/iproute.py
+225
-185
src/nemu/node.py
src/nemu/node.py
+9
-5
src/nemu/passfd.py
src/nemu/passfd.py
+4
-4
src/nemu/protocol.py
src/nemu/protocol.py
+21
-19
src/nemu/subprocess_.py
src/nemu/subprocess_.py
+29
-24
test/test_core.py
test/test_core.py
+1
-1
test/test_protocol.py
test/test_protocol.py
+1
-1
test/test_util.py
test/test_util.py
+6
-6
No files found.
benchmarks/linear-raw-throughput.py
View file @
033ce168
...
@@ -140,7 +140,7 @@ def main():
...
@@ -140,7 +140,7 @@ def main():
if
not
r
:
if
not
r
:
break
break
out
+=
r
out
+=
r
if
srv
.
poll
()
!=
None
or
clt
.
poll
()
!=
None
:
if
srv
.
poll
()
is
not
None
or
clt
.
poll
()
is
not
None
:
break
break
if
srv
.
poll
():
if
srv
.
poll
():
...
...
setup.py
View file @
033ce168
...
@@ -15,6 +15,6 @@ setup(
...
@@ -15,6 +15,6 @@ setup(
license
=
'GPLv2'
,
license
=
'GPLv2'
,
platforms
=
'Linux'
,
platforms
=
'Linux'
,
packages
=
[
'nemu'
],
packages
=
[
'nemu'
],
install_requires
=
[
'unshare'
,
'six'
],
install_requires
=
[
'unshare'
,
'six'
,
'attrs'
],
package_dir
=
{
''
:
'src'
}
package_dir
=
{
''
:
'src'
}
)
)
src/nemu/__init__.py
View file @
033ce168
...
@@ -41,7 +41,7 @@ class _Config(object):
...
@@ -41,7 +41,7 @@ class _Config(object):
except
KeyError
:
except
KeyError
:
pass
# User not found.
pass
# User not found.
def
_set_run_as
(
self
,
user
):
def
_set_run_as
(
self
,
user
:
str
|
int
):
"""Setter for `run_as'."""
"""Setter for `run_as'."""
if
str
(
user
).
isdigit
():
if
str
(
user
).
isdigit
():
uid
=
int
(
user
)
uid
=
int
(
user
)
...
@@ -61,7 +61,7 @@ class _Config(object):
...
@@ -61,7 +61,7 @@ class _Config(object):
self
.
_run_as
=
run_as
self
.
_run_as
=
run_as
return
run_as
return
run_as
def
_get_run_as
(
self
):
def
_get_run_as
(
self
)
->
str
:
"""Setter for `run_as'."""
"""Setter for `run_as'."""
return
self
.
_run_as
return
self
.
_run_as
...
...
src/nemu/compat.py
View file @
033ce168
...
@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]:
...
@@ -8,23 +8,21 @@ def pipe() -> tuple[int, int]:
os
.
set_inheritable
(
b
,
True
)
os
.
set_inheritable
(
b
,
True
)
return
a
,
b
return
a
,
b
def
socket
(
*
args
,
**
kwargs
)
->
pysocket
.
socket
:
def
socket
(
*
args
,
**
kwargs
)
->
pysocket
.
socket
:
s
=
pysocket
.
socket
(
*
args
,
**
kwargs
)
s
=
pysocket
.
socket
(
*
args
,
**
kwargs
)
s
.
set_inheritable
(
True
)
s
.
set_inheritable
(
True
)
return
s
return
s
def
socketpair
(
*
args
,
**
kwargs
)
->
tuple
[
pysocket
.
socket
,
pysocket
.
socket
]:
def
socketpair
(
*
args
,
**
kwargs
)
->
tuple
[
pysocket
.
socket
,
pysocket
.
socket
]:
a
,
b
=
pysocket
.
socketpair
(
*
args
,
**
kwargs
)
a
,
b
=
pysocket
.
socketpair
(
*
args
,
**
kwargs
)
a
.
set_inheritable
(
True
)
a
.
set_inheritable
(
True
)
b
.
set_inheritable
(
True
)
b
.
set_inheritable
(
True
)
return
a
,
b
return
a
,
b
def
fromfd
(
*
args
,
**
kwargs
)
->
pysocket
.
socket
:
def
fromfd
(
*
args
,
**
kwargs
)
->
pysocket
.
socket
:
s
=
pysocket
.
fromfd
(
*
args
,
**
kwargs
)
s
=
pysocket
.
fromfd
(
*
args
,
**
kwargs
)
s
.
set_inheritable
(
True
)
s
.
set_inheritable
(
True
)
return
s
return
s
def
fdopen
(
*
args
,
**
kwargs
)
->
pysocket
.
socket
:
s
=
os
.
fdopen
(
*
args
,
**
kwargs
)
s
.
set_inheritable
(
True
)
return
s
\ No newline at end of file
src/nemu/environ.py
View file @
033ce168
...
@@ -25,7 +25,7 @@ import subprocess
...
@@ -25,7 +25,7 @@ import subprocess
import
sys
import
sys
import
syslog
import
syslog
from
syslog
import
LOG_ERR
,
LOG_WARNING
,
LOG_NOTICE
,
LOG_INFO
,
LOG_DEBUG
from
syslog
import
LOG_ERR
,
LOG_WARNING
,
LOG_NOTICE
,
LOG_INFO
,
LOG_DEBUG
from
typing
import
TypeVar
,
Callable
from
typing
import
TypeVar
,
Callable
,
Optional
__all__
=
[
"IP_PATH"
,
"TC_PATH"
,
"BRCTL_PATH"
,
"SYSCTL_PATH"
,
"HZ"
]
__all__
=
[
"IP_PATH"
,
"TC_PATH"
,
"BRCTL_PATH"
,
"SYSCTL_PATH"
,
"HZ"
]
...
@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
...
@@ -39,7 +39,7 @@ __all__ += ["set_log_level", "logger"]
__all__
+=
[
"error"
,
"warning"
,
"notice"
,
"info"
,
"debug"
]
__all__
+=
[
"error"
,
"warning"
,
"notice"
,
"info"
,
"debug"
]
def
find_bin
(
name
,
extra_path
=
None
)
:
def
find_bin
(
name
:
str
,
extra_path
:
Optional
[
list
[
str
]]
=
None
)
->
Optional
[
str
]
:
"""Try hard to find the location of needed programs."""
"""Try hard to find the location of needed programs."""
search
=
[]
search
=
[]
if
"PATH"
in
os
.
environ
:
if
"PATH"
in
os
.
environ
:
...
@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None):
...
@@ -57,7 +57,7 @@ def find_bin(name, extra_path=None):
return
None
return
None
def
find_bin_or_die
(
name
,
extra_path
=
None
)
:
def
find_bin_or_die
(
name
:
str
,
extra_path
:
Optional
[
list
[
str
]]
=
None
)
->
str
:
"""Try hard to find the location of needed programs; raise on failure."""
"""Try hard to find the location of needed programs; raise on failure."""
res
=
find_bin
(
name
,
extra_path
)
res
=
find_bin
(
name
,
extra_path
)
if
not
res
:
if
not
res
:
...
@@ -156,7 +156,7 @@ _log_syslog_opts = ()
...
@@ -156,7 +156,7 @@ _log_syslog_opts = ()
_log_pid
=
os
.
getpid
()
_log_pid
=
os
.
getpid
()
def
set_log_level
(
level
):
def
set_log_level
(
level
:
int
):
"Sets the log level for console messages, does not affect syslog logging."
"Sets the log level for console messages, does not affect syslog logging."
global
_log_level
global
_log_level
assert
level
>
LOG_ERR
and
level
<=
LOG_DEBUG
assert
level
>
LOG_ERR
and
level
<=
LOG_DEBUG
...
@@ -191,7 +191,7 @@ def _init_log():
...
@@ -191,7 +191,7 @@ def _init_log():
info
(
"Syslog logging started"
)
info
(
"Syslog logging started"
)
def
logger
(
priority
,
message
):
def
logger
(
priority
:
int
,
message
:
str
):
"Print a log message in syslog, console or both."
"Print a log message in syslog, console or both."
if
_log_use_syslog
:
if
_log_use_syslog
:
if
os
.
getpid
()
!=
_log_pid
:
if
os
.
getpid
()
!=
_log_pid
:
...
...
src/nemu/interface.py
View file @
033ce168
...
@@ -19,33 +19,36 @@
...
@@ -19,33 +19,36 @@
import
os
import
os
import
weakref
import
weakref
from
typing
import
TypedDict
import
nemu.iproute
import
nemu.iproute
from
nemu.environ
import
*
from
nemu.environ
import
*
__all__
=
[
'NodeInterface'
,
'P2PInterface'
,
'ImportedInterface'
,
__all__
=
[
'NodeInterface'
,
'P2PInterface'
,
'ImportedInterface'
,
'ImportedNodeInterface'
,
'Switch'
]
'ImportedNodeInterface'
,
'Switch'
]
class
Interface
(
object
):
class
Interface
(
object
):
"""Just a base class for the *Interface classes: assign names and handle
"""Just a base class for the *Interface classes: assign names and handle
destruction."""
destruction."""
_nextid
=
0
_nextid
=
0
@
staticmethod
@
staticmethod
def
_gen_next_id
():
def
_gen_next_id
()
->
int
:
n
=
Interface
.
_nextid
n
=
Interface
.
_nextid
Interface
.
_nextid
+=
1
Interface
.
_nextid
+=
1
return
n
return
n
@
staticmethod
@
staticmethod
def
_gen_if_name
():
def
_gen_if_name
()
->
str
:
n
=
Interface
.
_gen_next_id
()
n
=
Interface
.
_gen_next_id
()
# Max 15 chars
# Max 15 chars
return
"NETNSif-%.4x%.3x"
%
(
os
.
getpid
()
&
0xffff
,
n
)
return
"NETNSif-%.4x%.3x"
%
(
os
.
getpid
()
&
0xffff
,
n
)
def
__init__
(
self
,
index
):
def
__init__
(
self
,
index
:
int
):
self
.
_idx
=
index
self
.
_idx
=
index
debug
(
"%s(0x%x).__init__(), index = %d"
%
(
self
.
__class__
.
__name__
,
debug
(
"%s(0x%x).__init__(), index = %d"
%
(
self
.
__class__
.
__name__
,
id
(
self
),
index
))
id
(
self
),
index
))
def
__del__
(
self
):
def
__del__
(
self
):
debug
(
"%s(0x%x).__del__()"
%
(
self
.
__class__
.
__name__
,
id
(
self
)))
debug
(
"%s(0x%x).__del__()"
%
(
self
.
__class__
.
__name__
,
id
(
self
)))
...
@@ -55,7 +58,7 @@ class Interface(object):
...
@@ -55,7 +58,7 @@ class Interface(object):
raise
NotImplementedError
raise
NotImplementedError
@
property
@
property
def
index
(
self
):
def
index
(
self
)
->
int
:
"""Interface index as seen by the kernel."""
"""Interface index as seen by the kernel."""
return
self
.
_idx
return
self
.
_idx
...
@@ -65,20 +68,35 @@ class Interface(object):
...
@@ -65,20 +68,35 @@ class Interface(object):
control interfaces can be put into a Switch, for example."""
control interfaces can be put into a Switch, for example."""
return
None
return
None
class
Ipv4Dict
(
TypedDict
):
address
:
str
prefix_len
:
int
broadcast
:
str
family
:
str
class
Ipv6Dict
(
TypedDict
):
address
:
str
prefix_len
:
int
family
:
str
class
NSInterface
(
Interface
):
class
NSInterface
(
Interface
):
"""Add user-facing methods for interfaces that go into a netns."""
"""Add user-facing methods for interfaces that go into a netns."""
def
__init__
(
self
,
node
,
index
):
def
__init__
(
self
,
node
:
"nemu.Node"
,
index
):
super
(
NSInterface
,
self
).
__init__
(
index
)
super
(
NSInterface
,
self
).
__init__
(
index
)
self
.
_slave
=
node
.
_slave
self
.
_slave
=
node
.
_slave
# Disable auto-configuration
# Disable auto-configuration
# you wish: need to take into account the nonetns mode; plus not
# you wish: need to take into account the nonetns mode; plus not
# touching some pre-existing ifaces
# touching some pre-existing ifaces
#node.system([SYSCTL_PATH, '-w', 'net.ipv6.conf.%s.autoconf=0' %
#
node.system([SYSCTL_PATH, '-w', 'net.ipv6.conf.%s.autoconf=0' %
#self.name])
#
self.name])
node
.
_add_interface
(
self
)
node
.
_add_interface
(
self
)
# some black magic to automatically get/set interface attributes
# some black magic to automatically get/set interface attributes
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
:
str
):
# If name starts with _, it must be a normal attr
# If name starts with _, it must be a normal attr
if
name
[
0
]
==
'_'
:
if
name
[
0
]
==
'_'
:
return
super
(
Interface
,
self
).
__getattribute__
(
name
)
return
super
(
Interface
,
self
).
__getattribute__
(
name
)
...
@@ -92,57 +110,59 @@ class NSInterface(Interface):
...
@@ -92,57 +110,59 @@ class NSInterface(Interface):
iface
=
slave
.
get_if_data
(
self
.
index
)
iface
=
slave
.
get_if_data
(
self
.
index
)
return
getattr
(
iface
,
name
)
return
getattr
(
iface
,
name
)
def
__setattr__
(
self
,
name
,
value
):
def
__setattr__
(
self
,
name
:
str
,
value
):
if
name
[
0
]
==
'_'
:
# forbid anything that doesn't start with a _
if
name
[
0
]
==
'_'
:
# forbid anything that doesn't start with a _
super
(
Interface
,
self
).
__setattr__
(
name
,
value
)
super
(
Interface
,
self
).
__setattr__
(
name
,
value
)
return
return
iface
=
nemu
.
iproute
.
interface
(
index
=
self
.
index
)
iface
=
nemu
.
iproute
.
interface
(
index
=
self
.
index
)
setattr
(
iface
,
name
,
value
)
setattr
(
iface
,
name
,
value
)
return
self
.
_slave
.
set_if
(
iface
)
return
self
.
_slave
.
set_if
(
iface
)
def
add_v4_address
(
self
,
address
,
prefix_len
,
broadcast
=
None
):
def
add_v4_address
(
self
,
address
:
str
,
prefix_len
:
int
,
broadcast
=
None
):
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
self
.
_slave
.
add_addr
(
self
.
index
,
addr
)
self
.
_slave
.
add_addr
(
self
.
index
,
addr
)
def
add_v6_address
(
self
,
address
,
prefix_len
):
def
add_v6_address
(
self
,
address
:
str
,
prefix_len
:
int
):
addr
=
nemu
.
iproute
.
ipv6address
(
address
,
prefix_len
)
addr
=
nemu
.
iproute
.
ipv6address
(
address
,
prefix_len
)
self
.
_slave
.
add_addr
(
self
.
index
,
addr
)
self
.
_slave
.
add_addr
(
self
.
index
,
addr
)
def
del_v4_address
(
self
,
address
,
prefix_len
,
broadcast
=
None
):
def
del_v4_address
(
self
,
address
:
str
,
prefix_len
:
int
,
broadcast
=
None
):
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
self
.
_slave
.
del_addr
(
self
.
index
,
addr
)
self
.
_slave
.
del_addr
(
self
.
index
,
addr
)
def
del_v6_address
(
self
,
address
,
prefix_len
):
def
del_v6_address
(
self
,
address
:
str
,
prefix_len
:
int
):
addr
=
nemu
.
iproute
.
ipv6address
(
address
,
prefix_len
)
addr
=
nemu
.
iproute
.
ipv6address
(
address
,
prefix_len
)
self
.
_slave
.
del_addr
(
self
.
index
,
addr
)
self
.
_slave
.
del_addr
(
self
.
index
,
addr
)
def
get_addresses
(
self
):
def
get_addresses
(
self
)
->
list
[
Ipv4Dict
|
Ipv6Dict
]
:
addresses
=
self
.
_slave
.
get_addr_data
(
self
.
index
)
addresses
=
self
.
_slave
.
get_addr_data
(
self
.
index
)
ret
=
[]
ret
=
[]
for
a
in
addresses
:
for
a
in
addresses
:
if
hasattr
(
a
,
'broadcast'
):
if
hasattr
(
a
,
'broadcast'
):
ret
.
append
(
dict
(
ret
.
append
(
dict
(
address
=
a
.
address
,
address
=
a
.
address
,
prefix_len
=
a
.
prefix_len
,
prefix_len
=
a
.
prefix_len
,
broadcast
=
a
.
broadcast
,
broadcast
=
a
.
broadcast
,
family
=
'inet'
))
family
=
'inet'
))
else
:
else
:
ret
.
append
(
dict
(
ret
.
append
(
dict
(
address
=
a
.
address
,
address
=
a
.
address
,
prefix_len
=
a
.
prefix_len
,
prefix_len
=
a
.
prefix_len
,
family
=
'inet6'
))
family
=
'inet6'
))
return
ret
return
ret
class
NodeInterface
(
NSInterface
):
class
NodeInterface
(
NSInterface
):
"""Class to create and handle a virtual interface inside a name space, it
"""Class to create and handle a virtual interface inside a name space, it
can be connected to a Switch object with emulation of link
can be connected to a Switch object with emulation of link
characteristics."""
characteristics."""
def
__init__
(
self
,
node
):
def
__init__
(
self
,
node
:
"nemu.Node"
):
"""Create a new interface. `node' is the name space in which this
"""Create a new interface. `node' is the name space in which this
interface should be put."""
interface should be put."""
self
.
_slave
=
None
self
.
_slave
=
None
if1
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
if1
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
if2
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
if2
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
ctl
,
ns
=
nemu
.
iproute
.
create_if_pair
(
if1
,
if2
)
ctl
,
ns
=
nemu
.
iproute
.
create_if_pair
(
if1
,
if2
)
try
:
try
:
nemu
.
iproute
.
change_netns
(
ns
,
node
.
pid
)
nemu
.
iproute
.
change_netns
(
ns
,
node
.
pid
)
...
@@ -165,18 +185,20 @@ class NodeInterface(NSInterface):
...
@@ -165,18 +185,20 @@ class NodeInterface(NSInterface):
self
.
_slave
.
del_if
(
self
.
index
)
self
.
_slave
.
del_if
(
self
.
index
)
self
.
_slave
=
None
self
.
_slave
=
None
class
P2PInterface
(
NSInterface
):
class
P2PInterface
(
NSInterface
):
"""Class to create and handle point-to-point interfaces between name
"""Class to create and handle point-to-point interfaces between name
spaces, without using Switch objects. Those do not allow any kind of
spaces, without using Switch objects. Those do not allow any kind of
traffic shaping.
traffic shaping.
As two interfaces need to be created, instead of using the class
As two interfaces need to be created, instead of using the class
constructor, use the P2PInterface.create_pair() static method."""
constructor, use the P2PInterface.create_pair() static method."""
@
staticmethod
@
staticmethod
def
create_pair
(
node1
,
node2
):
def
create_pair
(
node1
:
"nemu.Node"
,
node2
:
"nemu.Node"
):
"""Create and return a pair of connected P2PInterface objects,
"""Create and return a pair of connected P2PInterface objects,
assigned to name spaces represented by `node1' and `node2'."""
assigned to name spaces represented by `node1' and `node2'."""
if1
=
nemu
.
iproute
.
interface
(
name
=
P2PInterface
.
_gen_if_name
())
if1
=
nemu
.
iproute
.
interface
(
name
=
P2PInterface
.
_gen_if_name
())
if2
=
nemu
.
iproute
.
interface
(
name
=
P2PInterface
.
_gen_if_name
())
if2
=
nemu
.
iproute
.
interface
(
name
=
P2PInterface
.
_gen_if_name
())
pair
=
nemu
.
iproute
.
create_if_pair
(
if1
,
if2
)
pair
=
nemu
.
iproute
.
create_if_pair
(
if1
,
if2
)
try
:
try
:
nemu
.
iproute
.
change_netns
(
pair
[
0
],
node1
.
pid
)
nemu
.
iproute
.
change_netns
(
pair
[
0
],
node1
.
pid
)
...
@@ -206,6 +228,7 @@ class P2PInterface(NSInterface):
...
@@ -206,6 +228,7 @@ class P2PInterface(NSInterface):
self
.
_slave
.
del_if
(
self
.
index
)
self
.
_slave
.
del_if
(
self
.
index
)
self
.
_slave
=
None
self
.
_slave
=
None
class
ImportedNodeInterface
(
NSInterface
):
class
ImportedNodeInterface
(
NSInterface
):
"""Class to handle already existing interfaces inside a name space:
"""Class to handle already existing interfaces inside a name space:
real devices, tun devices, etc.
real devices, tun devices, etc.
...
@@ -213,7 +236,8 @@ class ImportedNodeInterface(NSInterface):
...
@@ -213,7 +236,8 @@ class ImportedNodeInterface(NSInterface):
to be moved inside the name space.
to be moved inside the name space.
On destruction, the interface will be restored to the original name space
On destruction, the interface will be restored to the original name space
and will try to restore the original state."""
and will try to restore the original state."""
def
__init__
(
self
,
node
,
iface
,
migrate
=
True
):
def
__init__
(
self
,
node
:
"nemu.Node"
,
iface
,
migrate
=
True
):
self
.
_slave
=
None
self
.
_slave
=
None
self
.
_migrate
=
migrate
self
.
_migrate
=
migrate
if
self
.
_migrate
:
if
self
.
_migrate
:
...
@@ -230,7 +254,7 @@ class ImportedNodeInterface(NSInterface):
...
@@ -230,7 +254,7 @@ class ImportedNodeInterface(NSInterface):
super
(
ImportedNodeInterface
,
self
).
__init__
(
node
,
iface
.
index
)
super
(
ImportedNodeInterface
,
self
).
__init__
(
node
,
iface
.
index
)
def
destroy
(
self
):
# override: restore as much as possible
def
destroy
(
self
):
# override: restore as much as possible
if
not
self
.
_slave
:
if
not
self
.
_slave
:
return
return
debug
(
"ImportedNodeInterface(0x%x).destroy()"
%
id
(
self
))
debug
(
"ImportedNodeInterface(0x%x).destroy()"
%
id
(
self
))
...
@@ -244,17 +268,19 @@ class ImportedNodeInterface(NSInterface):
...
@@ -244,17 +268,19 @@ class ImportedNodeInterface(NSInterface):
nemu
.
iproute
.
set_if
(
self
.
_original_state
)
nemu
.
iproute
.
set_if
(
self
.
_original_state
)
self
.
_slave
=
None
self
.
_slave
=
None
class
TapNodeInterface
(
NSInterface
):
class
TapNodeInterface
(
NSInterface
):
"""Class to create a tap interface inside a name space, it
"""Class to create a tap interface inside a name space, it
can be connected to a Switch object with emulation of link
can be connected to a Switch object with emulation of link
characteristics."""
characteristics."""
def
__init__
(
self
,
node
,
use_pi
=
False
):
def
__init__
(
self
,
node
,
use_pi
=
False
):
"""Create a new tap interface. 'node' is the name space in which this
"""Create a new tap interface. 'node' is the name space in which this
interface should be put."""
interface should be put."""
self
.
_fd
=
None
self
.
_fd
=
None
self
.
_slave
=
None
self
.
_slave
=
None
iface
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
iface
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
iface
,
self
.
_fd
=
nemu
.
iproute
.
create_tap
(
iface
,
use_pi
=
use_pi
)
iface
,
self
.
_fd
=
nemu
.
iproute
.
create_tap
(
iface
,
use_pi
=
use_pi
)
nemu
.
iproute
.
change_netns
(
iface
.
name
,
node
.
pid
)
nemu
.
iproute
.
change_netns
(
iface
.
name
,
node
.
pid
)
super
(
TapNodeInterface
,
self
).
__init__
(
node
,
iface
.
index
)
super
(
TapNodeInterface
,
self
).
__init__
(
node
,
iface
.
index
)
...
@@ -271,18 +297,20 @@ class TapNodeInterface(NSInterface):
...
@@ -271,18 +297,20 @@ class TapNodeInterface(NSInterface):
except
:
except
:
pass
pass
class
TunNodeInterface
(
NSInterface
):
class
TunNodeInterface
(
NSInterface
):
"""Class to create a tun interface inside a name space, it
"""Class to create a tun interface inside a name space, it
can be connected to a Switch object with emulation of link
can be connected to a Switch object with emulation of link
characteristics."""
characteristics."""
def
__init__
(
self
,
node
,
use_pi
=
False
):
def
__init__
(
self
,
node
,
use_pi
=
False
):
"""Create a new tap interface. 'node' is the name space in which this
"""Create a new tap interface. 'node' is the name space in which this
interface should be put."""
interface should be put."""
self
.
_fd
=
None
self
.
_fd
=
None
self
.
_slave
=
None
self
.
_slave
=
None
iface
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
iface
=
nemu
.
iproute
.
interface
(
name
=
self
.
_gen_if_name
())
iface
,
self
.
_fd
=
nemu
.
iproute
.
create_tap
(
iface
,
use_pi
=
use_pi
,
iface
,
self
.
_fd
=
nemu
.
iproute
.
create_tap
(
iface
,
use_pi
=
use_pi
,
tun
=
True
)
tun
=
True
)
nemu
.
iproute
.
change_netns
(
iface
.
name
,
node
.
pid
)
nemu
.
iproute
.
change_netns
(
iface
.
name
,
node
.
pid
)
super
(
TunNodeInterface
,
self
).
__init__
(
node
,
iface
.
index
)
super
(
TunNodeInterface
,
self
).
__init__
(
node
,
iface
.
index
)
...
@@ -299,9 +327,11 @@ class TunNodeInterface(NSInterface):
...
@@ -299,9 +327,11 @@ class TunNodeInterface(NSInterface):
except
:
except
:
pass
pass
class
ExternalInterface
(
Interface
):
class
ExternalInterface
(
Interface
):
"""Add user-facing methods for interfaces that run in the main
"""Add user-facing methods for interfaces that run in the main
namespace."""
namespace."""
@
property
@
property
def
control
(
self
):
def
control
(
self
):
# This is *the* control interface
# This is *the* control interface
...
@@ -313,14 +343,14 @@ class ExternalInterface(Interface):
...
@@ -313,14 +343,14 @@ class ExternalInterface(Interface):
return
getattr
(
iface
,
name
)
return
getattr
(
iface
,
name
)
def
__setattr__
(
self
,
name
,
value
):
def
__setattr__
(
self
,
name
,
value
):
if
name
[
0
]
==
'_'
:
# forbid anything that doesn't start with a _
if
name
[
0
]
==
'_'
:
# forbid anything that doesn't start with a _
super
(
ExternalInterface
,
self
).
__setattr__
(
name
,
value
)
super
(
ExternalInterface
,
self
).
__setattr__
(
name
,
value
)
return
return
iface
=
nemu
.
iproute
.
interface
(
index
=
self
.
index
)
iface
=
nemu
.
iproute
.
interface
(
index
=
self
.
index
)
setattr
(
iface
,
name
,
value
)
setattr
(
iface
,
name
,
value
)
return
nemu
.
iproute
.
set_if
(
iface
)
return
nemu
.
iproute
.
set_if
(
iface
)
def
add_v4_address
(
self
,
address
,
prefix_len
,
broadcast
=
None
):
def
add_v4_address
(
self
,
address
,
prefix_len
,
broadcast
=
None
):
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
nemu
.
iproute
.
add_addr
(
self
.
index
,
addr
)
nemu
.
iproute
.
add_addr
(
self
.
index
,
addr
)
...
@@ -328,7 +358,7 @@ class ExternalInterface(Interface):
...
@@ -328,7 +358,7 @@ class ExternalInterface(Interface):
addr
=
nemu
.
iproute
.
ipv6address
(
address
,
prefix_len
)
addr
=
nemu
.
iproute
.
ipv6address
(
address
,
prefix_len
)
nemu
.
iproute
.
add_addr
(
self
.
index
,
addr
)
nemu
.
iproute
.
add_addr
(
self
.
index
,
addr
)
def
del_v4_address
(
self
,
address
,
prefix_len
,
broadcast
=
None
):
def
del_v4_address
(
self
,
address
,
prefix_len
,
broadcast
=
None
):
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
addr
=
nemu
.
iproute
.
ipv4address
(
address
,
prefix_len
,
broadcast
)
nemu
.
iproute
.
del_addr
(
self
.
index
,
addr
)
nemu
.
iproute
.
del_addr
(
self
.
index
,
addr
)
...
@@ -342,23 +372,26 @@ class ExternalInterface(Interface):
...
@@ -342,23 +372,26 @@ class ExternalInterface(Interface):
for
a
in
addresses
:
for
a
in
addresses
:
if
hasattr
(
a
,
'broadcast'
):
if
hasattr
(
a
,
'broadcast'
):
ret
.
append
(
dict
(
ret
.
append
(
dict
(
address
=
a
.
address
,
address
=
a
.
address
,
prefix_len
=
a
.
prefix_len
,
prefix_len
=
a
.
prefix_len
,
broadcast
=
a
.
broadcast
,
broadcast
=
a
.
broadcast
,
family
=
'inet'
))
family
=
'inet'
))
else
:
else
:
ret
.
append
(
dict
(
ret
.
append
(
dict
(
address
=
a
.
address
,
address
=
a
.
address
,
prefix_len
=
a
.
prefix_len
,
prefix_len
=
a
.
prefix_len
,
family
=
'inet6'
))
family
=
'inet6'
))
return
ret
return
ret
class
SlaveInterface
(
ExternalInterface
):
class
SlaveInterface
(
ExternalInterface
):
"""Class to handle the main-name-space-facing half of NodeInterface.
"""Class to handle the main-name-space-facing half of NodeInterface.
Does nothing, just avoids any destroy code."""
Does nothing, just avoids any destroy code."""
def
destroy
(
self
):
def
destroy
(
self
):
pass
pass
class
ImportedInterface
(
ExternalInterface
):
class
ImportedInterface
(
ExternalInterface
):
"""Class to handle already existing interfaces. Analogous to
"""Class to handle already existing interfaces. Analogous to
ImportedNodeInterface, this class only differs in that the interface is
ImportedNodeInterface, this class only differs in that the interface is
...
@@ -366,6 +399,7 @@ class ImportedInterface(ExternalInterface):
...
@@ -366,6 +399,7 @@ class ImportedInterface(ExternalInterface):
connected to Switch objects and not assigned to a name space. On
connected to Switch objects and not assigned to a name space. On
destruction, the code will try to restore the interface to the state it
destruction, the code will try to restore the interface to the state it
was in before being imported into nemu."""
was in before being imported into nemu."""
def
__init__
(
self
,
iface
):
def
__init__
(
self
,
iface
):
self
.
_original_state
=
None
self
.
_original_state
=
None
iface
=
nemu
.
iproute
.
get_if
(
iface
)
iface
=
nemu
.
iproute
.
get_if
(
iface
)
...
@@ -373,12 +407,13 @@ class ImportedInterface(ExternalInterface):
...
@@ -373,12 +407,13 @@ class ImportedInterface(ExternalInterface):
super
(
ImportedInterface
,
self
).
__init__
(
iface
.
index
)
super
(
ImportedInterface
,
self
).
__init__
(
iface
.
index
)
# FIXME: register somewhere for destruction!
# FIXME: register somewhere for destruction!
def
destroy
(
self
):
# override: restore as much as possible
def
destroy
(
self
):
# override: restore as much as possible
if
self
.
_original_state
:
if
self
.
_original_state
:
debug
(
"ImportedInterface(0x%x).destroy()"
%
id
(
self
))
debug
(
"ImportedInterface(0x%x).destroy()"
%
id
(
self
))
nemu
.
iproute
.
set_if
(
self
.
_original_state
)
nemu
.
iproute
.
set_if
(
self
.
_original_state
)
self
.
_original_state
=
None
self
.
_original_state
=
None
# Switch is just another interface type
# Switch is just another interface type
class
Switch
(
ExternalInterface
):
class
Switch
(
ExternalInterface
):
...
@@ -412,7 +447,7 @@ class Switch(ExternalInterface):
...
@@ -412,7 +447,7 @@ class Switch(ExternalInterface):
return
getattr
(
iface
,
name
)
return
getattr
(
iface
,
name
)
def
__setattr__
(
self
,
name
,
value
):
def
__setattr__
(
self
,
name
,
value
):
if
name
[
0
]
==
'_'
:
# forbid anything that doesn't start with a _
if
name
[
0
]
==
'_'
:
# forbid anything that doesn't start with a _
super
(
Switch
,
self
).
__setattr__
(
name
,
value
)
super
(
Switch
,
self
).
__setattr__
(
name
,
value
)
return
return
# Set ports
# Set ports
...
@@ -421,7 +456,7 @@ class Switch(ExternalInterface):
...
@@ -421,7 +456,7 @@ class Switch(ExternalInterface):
if
self
.
_check_port
(
i
.
index
):
if
self
.
_check_port
(
i
.
index
):
setattr
(
i
,
name
,
value
)
setattr
(
i
,
name
,
value
)
# Set bridge
# Set bridge
iface
=
nemu
.
iproute
.
bridge
(
index
=
self
.
index
)
iface
=
nemu
.
iproute
.
bridge
(
index
=
self
.
index
)
setattr
(
iface
,
name
,
value
)
setattr
(
iface
,
name
,
value
)
nemu
.
iproute
.
set_bridge
(
iface
)
nemu
.
iproute
.
set_bridge
(
iface
)
...
@@ -460,7 +495,7 @@ class Switch(ExternalInterface):
...
@@ -460,7 +495,7 @@ class Switch(ExternalInterface):
return
True
return
True
# else
# else
warning
(
"Switch(0x%x): Port (index = %d) went away."
%
(
id
(
self
),
warning
(
"Switch(0x%x): Port (index = %d) went away."
%
(
id
(
self
),
port_index
))
port_index
))
del
self
.
_ports
[
port_index
]
del
self
.
_ports
[
port_index
]
return
False
return
False
...
@@ -472,12 +507,12 @@ class Switch(ExternalInterface):
...
@@ -472,12 +507,12 @@ class Switch(ExternalInterface):
self
.
_apply_parameters
({},
iface
.
control
)
self
.
_apply_parameters
({},
iface
.
control
)
del
self
.
_ports
[
iface
.
control
.
index
]
del
self
.
_ports
[
iface
.
control
.
index
]
def
set_parameters
(
self
,
bandwidth
=
None
,
def
set_parameters
(
self
,
bandwidth
=
None
,
delay
=
None
,
delay_jitter
=
None
,
delay
=
None
,
delay_jitter
=
None
,
delay_correlation
=
None
,
delay_distribution
=
None
,
delay_correlation
=
None
,
delay_distribution
=
None
,
loss
=
None
,
loss_correlation
=
None
,
loss
=
None
,
loss_correlation
=
None
,
dup
=
None
,
dup_correlation
=
None
,
dup
=
None
,
dup_correlation
=
None
,
corrupt
=
None
,
corrupt_correlation
=
None
):
corrupt
=
None
,
corrupt_correlation
=
None
):
"""Set the parameters that control the link characteristics. For the
"""Set the parameters that control the link characteristics. For the
description of each, refer to netem documentation:
description of each, refer to netem documentation:
http://www.linuxfoundation.org/collaborate/workgroups/networking/netem
http://www.linuxfoundation.org/collaborate/workgroups/networking/netem
...
@@ -492,13 +527,13 @@ class Switch(ExternalInterface):
...
@@ -492,13 +527,13 @@ class Switch(ExternalInterface):
`dup_correlation', `corrupt', and `corrupt_correlation' take a
`dup_correlation', `corrupt', and `corrupt_correlation' take a
percentage value in the form of a number between 0 and 1. (50% is
percentage value in the form of a number between 0 and 1. (50% is
passed as 0.5)."""
passed as 0.5)."""
parameters
=
dict
(
bandwidth
=
bandwidth
,
parameters
=
dict
(
bandwidth
=
bandwidth
,
delay
=
delay
,
delay_jitter
=
delay_jitter
,
delay
=
delay
,
delay_jitter
=
delay_jitter
,
delay_correlation
=
delay_correlation
,
delay_correlation
=
delay_correlation
,
delay_distribution
=
delay_distribution
,
delay_distribution
=
delay_distribution
,
loss
=
loss
,
loss_correlation
=
loss_correlation
,
loss
=
loss
,
loss_correlation
=
loss_correlation
,
dup
=
dup
,
dup_correlation
=
dup_correlation
,
dup
=
dup
,
dup_correlation
=
dup_correlation
,
corrupt
=
corrupt
,
corrupt_correlation
=
corrupt_correlation
)
corrupt
=
corrupt
,
corrupt_correlation
=
corrupt_correlation
)
try
:
try
:
self
.
_apply_parameters
(
parameters
)
self
.
_apply_parameters
(
parameters
)
except
:
except
:
...
@@ -506,7 +541,6 @@ class Switch(ExternalInterface):
...
@@ -506,7 +541,6 @@ class Switch(ExternalInterface):
raise
raise
self
.
_parameters
=
parameters
self
.
_parameters
=
parameters
def
_apply_parameters
(
self
,
parameters
,
port
=
None
):
def
_apply_parameters
(
self
,
parameters
,
port
=
None
):
for
i
in
[
port
]
if
port
else
list
(
self
.
_ports
.
values
()):
for
i
in
[
port
]
if
port
else
list
(
self
.
_ports
.
values
()):
nemu
.
iproute
.
set_tc
(
i
.
index
,
**
parameters
)
nemu
.
iproute
.
set_tc
(
i
.
index
,
**
parameters
)
src/nemu/iproute.py
View file @
033ce168
...
@@ -24,7 +24,10 @@ import re
...
@@ -24,7 +24,10 @@ import re
import
socket
import
socket
import
struct
import
struct
import
sys
import
sys
from
typing
import
TypeVar
,
Callable
,
Literal
from
attr
import
evolve
from
attrs
import
define
,
setters
,
field
import
six
import
six
from
nemu.environ
import
*
from
nemu.environ
import
*
...
@@ -46,18 +49,21 @@ def _any_to_bool(any):
...
@@ -46,18 +49,21 @@ def _any_to_bool(any):
return
any
!=
""
return
any
!=
""
return
bool
(
any
)
return
bool
(
any
)
def
_positive
(
val
):
def
_positive
(
val
):
v
=
int
(
val
)
v
=
int
(
val
)
if
v
<=
0
:
if
v
<=
0
:
raise
ValueError
(
"Invalid value: %d"
%
v
)
raise
ValueError
(
"Invalid value: %d"
%
v
)
return
v
return
v
def
_non_empty_str
(
val
):
def
_non_empty_str
(
val
):
if
val
==
""
:
if
val
==
""
:
return
None
return
None
else
:
else
:
return
str
(
val
)
return
str
(
val
)
def
_fix_lladdr
(
addr
):
def
_fix_lladdr
(
addr
):
foo
=
addr
.
lower
()
foo
=
addr
.
lower
()
if
":"
in
addr
:
if
":"
in
addr
:
...
@@ -77,106 +83,100 @@ def _fix_lladdr(addr):
...
@@ -77,106 +83,100 @@ def _fix_lladdr(addr):
# Glue
# Glue
return
":"
.
join
(
m
.
groups
())
return
":"
.
join
(
m
.
groups
())
def
_make_getter
(
attr
,
conv
=
lambda
x
:
x
):
def
_make_getter
(
attr
,
conv
=
lambda
x
:
x
):
def
getter
(
self
):
def
getter
(
self
):
return
conv
(
getattr
(
self
,
attr
))
return
conv
(
getattr
(
self
,
attr
))
return
getter
return
getter
def
_make_setter
(
attr
,
conv
=
lambda
x
:
x
):
def
_make_setter
(
attr
,
conv
=
lambda
x
:
x
):
def
setter
(
self
,
value
):
def
setter
(
self
,
value
):
if
value
==
None
:
if
value
is
None
:
setattr
(
self
,
attr
,
None
)
setattr
(
self
,
attr
,
None
)
else
:
else
:
setattr
(
self
,
attr
,
conv
(
value
))
setattr
(
self
,
attr
,
conv
(
value
))
return
setter
return
setter
T
=
TypeVar
(
"T"
)
U
=
TypeVar
(
"U"
)
def
_if_any
(
conv
:
Callable
[[
T
],
U
]):
def
c
(
val
:
T
)
->
U
:
if
val
is
None
:
return
None
else
:
return
conv
(
val
)
return
c
# classes for internal use
# classes for internal use
class
interface
(
object
):
@
define
(
repr
=
False
)
class
interface
:
"""Class for internal use. It is mostly a data container used to easily
"""Class for internal use. It is mostly a data container used to easily
pass information around; with some convenience methods."""
pass information around; with some convenience methods."""
# information for other parts of the code
# information for other parts of the code
changeable_attributes
=
[
"name"
,
"mtu"
,
"lladdr"
,
"broadcast"
,
"up"
,
changeable_attributes
=
[
"name"
,
"mtu"
,
"lladdr"
,
"broadcast"
,
"up"
,
"multicast"
,
"arp"
]
"multicast"
,
"arp"
]
# Index should be read-only
index
:
int
=
field
(
default
=
None
,
converter
=
_if_any
(
_positive
),
on_setattr
=
setters
.
frozen
)
index
=
property
(
_make_getter
(
"_index"
))
name
:
str
=
field
(
default
=
None
)
up
=
property
(
_make_getter
(
"_up"
),
_make_setter
(
"_up"
,
_any_to_bool
))
up
:
bool
=
field
(
default
=
None
,
converter
=
_if_any
(
_any_to_bool
))
mtu
=
property
(
_make_getter
(
"_mtu"
),
_make_setter
(
"_mtu"
,
_positive
))
mtu
:
int
=
field
(
default
=
None
,
converter
=
_if_any
(
_positive
))
lladdr
=
property
(
_make_getter
(
"_lladdr"
),
lladdr
:
str
=
field
(
default
=
None
,
converter
=
_if_any
(
_fix_lladdr
))
_make_setter
(
"_lladdr"
,
_fix_lladdr
))
broadcast
:
str
=
field
(
default
=
None
)
arp
=
property
(
_make_getter
(
"_arp"
),
_make_setter
(
"_arp"
,
_any_to_bool
))
multicast
:
bool
=
field
(
default
=
None
,
converter
=
_if_any
(
_any_to_bool
))
multicast
=
property
(
_make_getter
(
"_mc"
),
_make_setter
(
"_mc"
,
_any_to_bool
))
arp
:
bool
=
field
(
default
=
None
,
converter
=
_if_any
(
_any_to_bool
))
def
__init__
(
self
,
index
=
None
,
name
=
None
,
up
=
None
,
mtu
=
None
,
lladdr
=
None
,
broadcast
=
None
,
multicast
=
None
,
arp
=
None
):
self
.
_index
=
_positive
(
index
)
if
index
is
not
None
else
None
self
.
name
=
name
self
.
up
=
up
self
.
mtu
=
mtu
self
.
lladdr
=
lladdr
self
.
broadcast
=
broadcast
self
.
multicast
=
multicast
self
.
arp
=
arp
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
s
=
"%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
s
+=
"broadcast = %s, multicast = %s, arp = %s)"
s
+=
"broadcast = %s, multicast = %s, arp = %s)"
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
self
.
index
.
__repr__
(),
self
.
name
.
__repr__
(),
self
.
index
.
__repr__
(),
self
.
name
.
__repr__
(),
self
.
up
.
__repr__
(),
self
.
mtu
.
__repr__
(),
self
.
up
.
__repr__
(),
self
.
mtu
.
__repr__
(),
self
.
lladdr
.
__repr__
(),
self
.
broadcast
.
__repr__
(),
self
.
lladdr
.
__repr__
(),
self
.
broadcast
.
__repr__
(),
self
.
multicast
.
__repr__
(),
self
.
arp
.
__repr__
())
self
.
multicast
.
__repr__
(),
self
.
arp
.
__repr__
())
def
__sub__
(
self
,
o
):
def
__sub__
(
self
,
o
):
"""Compare attributes and return a new object with just the attributes
"""Compare attributes and return a new object with just the attributes
that differ set (with the value they have in the first operand). The
that differ set (with the value they have in the first operand). The
index remains equal to the first operand."""
index remains equal to the first operand."""
name
=
None
if
self
.
name
==
o
.
name
else
self
.
name
name
=
None
if
self
.
name
==
o
.
name
else
self
.
name
up
=
None
if
self
.
up
==
o
.
up
else
self
.
up
up
=
None
if
self
.
up
==
o
.
up
else
self
.
up
mtu
=
None
if
self
.
mtu
==
o
.
mtu
else
self
.
mtu
mtu
=
None
if
self
.
mtu
==
o
.
mtu
else
self
.
mtu
lladdr
=
None
if
self
.
lladdr
==
o
.
lladdr
else
self
.
lladdr
lladdr
=
None
if
self
.
lladdr
==
o
.
lladdr
else
self
.
lladdr
broadcast
=
None
if
self
.
broadcast
==
o
.
broadcast
else
self
.
broadcast
broadcast
=
None
if
self
.
broadcast
==
o
.
broadcast
else
self
.
broadcast
multicast
=
None
if
self
.
multicast
==
o
.
multicast
else
self
.
multicast
multicast
=
None
if
self
.
multicast
==
o
.
multicast
else
self
.
multicast
arp
=
None
if
self
.
arp
==
o
.
arp
else
self
.
arp
arp
=
None
if
self
.
arp
==
o
.
arp
else
self
.
arp
return
self
.
__class__
(
self
.
index
,
name
,
up
,
mtu
,
lladdr
,
broadcast
,
return
interface
(
self
.
index
,
name
,
up
,
mtu
,
lladdr
,
broadcast
,
multicast
,
arp
)
multicast
,
arp
)
def
copy
(
self
):
def
copy
(
self
):
return
copy
.
copy
(
self
)
return
copy
.
copy
(
self
)
@
define
(
repr
=
False
)
class
bridge
(
interface
):
class
bridge
(
interface
):
changeable_attributes
=
interface
.
changeable_attributes
+
[
"stp"
,
changeable_attributes
=
interface
.
changeable_attributes
+
[
"stp"
,
"forward_delay"
,
"hello_time"
,
"ageing_time"
,
"max_age"
]
"forward_delay"
,
"hello_time"
,
"ageing_time"
,
"max_age"
]
# Index should be read-only
stp
:
bool
=
field
(
default
=
None
,
converter
=
_if_any
(
_any_to_bool
))
stp
=
property
(
_make_getter
(
"_stp"
),
_make_setter
(
"_stp"
,
_any_to_bool
))
forward_delay
:
float
=
field
(
default
=
None
,
converter
=
_if_any
(
float
))
forward_delay
=
property
(
_make_getter
(
"_forward_delay"
),
hello_time
:
float
=
field
(
default
=
None
,
converter
=
_if_any
(
float
))
_make_setter
(
"_forward_delay"
,
float
))
ageing_time
:
float
=
field
(
default
=
None
,
converter
=
_if_any
(
float
))
hello_time
=
property
(
_make_getter
(
"_hello_time"
),
max_age
:
float
=
field
(
default
=
None
,
converter
=
_if_any
(
float
))
_make_setter
(
"_hello_time"
,
float
))
ageing_time
=
property
(
_make_getter
(
"_ageing_time"
),
_make_setter
(
"_ageing_time"
,
float
))
max_age
=
property
(
_make_getter
(
"_max_age"
),
_make_setter
(
"_max_age"
,
float
))
@
classmethod
@
classmethod
def
upgrade
(
cls
,
iface
,
*
kargs
,
**
kwargs
):
def
upgrade
(
cls
,
iface
,
*
kargs
,
**
kwargs
):
"""Upgrade a interface to a bridge."""
"""Upgrade a interface to a bridge."""
return
cls
(
iface
.
index
,
iface
.
name
,
iface
.
up
,
iface
.
mtu
,
iface
.
lladdr
,
return
cls
(
iface
.
index
,
iface
.
name
,
iface
.
up
,
iface
.
mtu
,
iface
.
lladdr
,
iface
.
broadcast
,
iface
.
multicast
,
iface
.
arp
,
*
kargs
,
**
kwargs
)
iface
.
broadcast
,
iface
.
multicast
,
iface
.
arp
,
*
kargs
,
**
kwargs
)
def
__init__
(
self
,
index
=
None
,
name
=
None
,
up
=
None
,
mtu
=
None
,
lladdr
=
None
,
broadcast
=
None
,
multicast
=
None
,
arp
=
None
,
stp
=
None
,
forward_delay
=
None
,
hello_time
=
None
,
ageing_time
=
None
,
max_age
=
None
):
super
(
bridge
,
self
).
__init__
(
index
,
name
,
up
,
mtu
,
lladdr
,
broadcast
,
multicast
,
arp
)
self
.
stp
=
stp
self
.
forward_delay
=
forward_delay
self
.
hello_time
=
hello_time
self
.
ageing_time
=
ageing_time
self
.
max_age
=
max_age
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
s
=
"%s.%s(index = %s, name = %s, up = %s, mtu = %s, lladdr = %s, "
...
@@ -184,33 +184,40 @@ class bridge(interface):
...
@@ -184,33 +184,40 @@ class bridge(interface):
s
+=
"forward_delay = %s, hello_time = %s, ageing_time = %s, "
s
+=
"forward_delay = %s, hello_time = %s, ageing_time = %s, "
s
+=
"max_age = %s)"
s
+=
"max_age = %s)"
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
self
.
index
.
__repr__
(),
self
.
name
.
__repr__
(),
self
.
index
.
__repr__
(),
self
.
name
.
__repr__
(),
self
.
up
.
__repr__
(),
self
.
mtu
.
__repr__
(),
self
.
up
.
__repr__
(),
self
.
mtu
.
__repr__
(),
self
.
lladdr
.
__repr__
(),
self
.
broadcast
.
__repr__
(),
self
.
lladdr
.
__repr__
(),
self
.
broadcast
.
__repr__
(),
self
.
multicast
.
__repr__
(),
self
.
arp
.
__repr__
(),
self
.
multicast
.
__repr__
(),
self
.
arp
.
__repr__
(),
self
.
stp
.
__repr__
(),
self
.
forward_delay
.
__repr__
(),
self
.
stp
.
__repr__
(),
self
.
forward_delay
.
__repr__
(),
self
.
hello_time
.
__repr__
(),
self
.
ageing_time
.
__repr__
(),
self
.
hello_time
.
__repr__
(),
self
.
ageing_time
.
__repr__
(),
self
.
max_age
.
__repr__
())
self
.
max_age
.
__repr__
())
def
__sub__
(
self
,
o
):
def
__sub__
(
self
,
o
):
r
=
super
(
bridge
,
self
).
__sub__
(
o
)
r
=
bridge
.
upgrade
(
super
().
__sub__
(
o
)
)
if
type
(
o
)
==
interface
:
if
type
(
o
)
==
interface
:
return
r
return
r
r
.
stp
=
None
if
self
.
stp
==
o
.
stp
else
self
.
stp
r
.
stp
=
None
if
self
.
stp
==
o
.
stp
else
self
.
stp
r
.
hello_time
=
None
if
self
.
hello_time
==
o
.
hello_time
else
\
r
.
hello_time
=
None
if
self
.
hello_time
==
o
.
hello_time
else
\
self
.
hello_time
self
.
hello_time
r
.
forward_delay
=
None
if
self
.
forward_delay
==
o
.
forward_delay
else
\
r
.
forward_delay
=
None
if
self
.
forward_delay
==
o
.
forward_delay
else
\
self
.
forward_delay
self
.
forward_delay
r
.
ageing_time
=
None
if
self
.
ageing_time
==
o
.
ageing_time
else
\
r
.
ageing_time
=
None
if
self
.
ageing_time
==
o
.
ageing_time
else
\
self
.
ageing_time
self
.
ageing_time
r
.
max_age
=
None
if
self
.
max_age
==
o
.
max_age
else
self
.
max_age
r
.
max_age
=
None
if
self
.
max_age
==
o
.
max_age
else
self
.
max_age
return
r
return
r
class
address
(
object
):
class
address
(
object
):
"""Class for internal use. It is mostly a data container used to easily
"""Class for internal use. It is mostly a data container used to easily
pass information around; with some convenience methods. __eq__ and
pass information around; with some convenience methods. __eq__ and
__hash__ are defined just to be able to easily find duplicated
__hash__ are defined just to be able to easily find duplicated
addresses."""
addresses."""
def
__init__
(
self
,
address
:
str
,
prefix_len
:
int
,
family
:
socket
.
AddressFamily
):
self
.
address
=
address
self
.
prefix_len
=
int
(
prefix_len
)
self
.
family
=
family
# broadcast is not taken into account for differentiating addresses
# broadcast is not taken into account for differentiating addresses
def
__eq__
(
self
,
o
):
def
__eq__
(
self
,
o
):
if
not
isinstance
(
o
,
address
):
if
not
isinstance
(
o
,
address
):
...
@@ -220,52 +227,50 @@ class address(object):
...
@@ -220,52 +227,50 @@ class address(object):
def
__hash__
(
self
):
def
__hash__
(
self
):
h
=
(
self
.
address
.
__hash__
()
^
self
.
prefix_len
.
__hash__
()
^
h
=
(
self
.
address
.
__hash__
()
^
self
.
prefix_len
.
__hash__
()
^
self
.
family
.
__hash__
())
self
.
family
.
__hash__
())
return
h
return
h
class
ipv4address
(
address
):
class
ipv4address
(
address
):
def
__init__
(
self
,
address
,
prefix_len
,
broadcast
):
def
__init__
(
self
,
address
:
str
,
prefix_len
:
int
,
broadcast
):
self
.
address
=
address
super
().
__init__
(
address
,
prefix_len
,
socket
.
AF_INET
)
self
.
prefix_len
=
int
(
prefix_len
)
self
.
broadcast
=
broadcast
self
.
broadcast
=
broadcast
self
.
family
=
socket
.
AF_INET
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"%s.%s(address = %s, prefix_len = %d, broadcast = %s)"
s
=
"%s.%s(address = %s, prefix_len = %d, broadcast = %s)"
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
self
.
address
.
__repr__
(),
self
.
prefix_len
,
self
.
address
.
__repr__
(),
self
.
prefix_len
,
self
.
broadcast
.
__repr__
())
self
.
broadcast
.
__repr__
())
class
ipv6address
(
address
):
class
ipv6address
(
address
):
def
__init__
(
self
,
address
,
prefix_len
):
def
__init__
(
self
,
address
:
str
,
prefix_len
:
int
):
self
.
address
=
address
super
().
__init__
(
address
,
prefix_len
,
socket
.
AF_INET6
)
self
.
prefix_len
=
int
(
prefix_len
)
self
.
family
=
socket
.
AF_INET6
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"%s.%s(address = %s, prefix_len = %d)"
s
=
"%s.%s(address = %s, prefix_len = %d)"
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
self
.
address
.
__repr__
(),
self
.
prefix_len
)
self
.
address
.
__repr__
(),
self
.
prefix_len
)
class
route
(
object
):
class
route
(
object
):
tipes
=
[
"unicast"
,
"local"
,
"broadcast"
,
"multicast"
,
"throw"
,
tipes
=
[
"unicast"
,
"local"
,
"broadcast"
,
"multicast"
,
"throw"
,
"unreachable"
,
"prohibit"
,
"blackhole"
,
"nat"
]
"unreachable"
,
"prohibit"
,
"blackhole"
,
"nat"
]
tipe
=
property
(
_make_getter
(
"_tipe"
,
tipes
.
__getitem__
),
tipe
=
property
(
_make_getter
(
"_tipe"
,
tipes
.
__getitem__
),
_make_setter
(
"_tipe"
,
tipes
.
index
))
_make_setter
(
"_tipe"
,
tipes
.
index
))
prefix
=
property
(
_make_getter
(
"_prefix"
),
prefix
=
property
(
_make_getter
(
"_prefix"
),
_make_setter
(
"_prefix"
,
_non_empty_str
))
_make_setter
(
"_prefix"
,
_non_empty_str
))
prefix_len
=
property
(
_make_getter
(
"_plen"
),
prefix_len
=
property
(
_make_getter
(
"_plen"
),
lambda
s
,
v
:
setattr
(
s
,
"_plen"
,
int
(
v
or
0
)))
lambda
s
,
v
:
setattr
(
s
,
"_plen"
,
int
(
v
or
0
)))
nexthop
=
property
(
_make_getter
(
"_nexthop"
),
nexthop
=
property
(
_make_getter
(
"_nexthop"
),
_make_setter
(
"_nexthop"
,
_non_empty_str
))
_make_setter
(
"_nexthop"
,
_non_empty_str
))
interface
=
property
(
_make_getter
(
"_interface"
),
interface
=
property
(
_make_getter
(
"_interface"
),
_make_setter
(
"_interface"
,
_positive
))
_make_setter
(
"_interface"
,
_positive
))
metric
=
property
(
_make_getter
(
"_metric"
),
metric
=
property
(
_make_getter
(
"_metric"
),
lambda
s
,
v
:
setattr
(
s
,
"_metric"
,
int
(
v
or
0
)))
lambda
s
,
v
:
setattr
(
s
,
"_metric"
,
int
(
v
or
0
)))
def
__init__
(
self
,
tipe
=
"unicast"
,
prefix
=
None
,
prefix_len
=
0
,
def
__init__
(
self
,
tipe
=
"unicast"
,
prefix
=
None
,
prefix_len
=
0
,
nexthop
=
None
,
interface
=
None
,
metric
=
0
):
nexthop
=
None
,
interface
=
None
,
metric
=
0
):
self
.
tipe
=
tipe
self
.
tipe
=
tipe
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
prefix_len
=
prefix_len
self
.
prefix_len
=
prefix_len
...
@@ -278,9 +283,9 @@ class route(object):
...
@@ -278,9 +283,9 @@ class route(object):
s
=
"%s.%s(tipe = %s, prefix = %s, prefix_len = %s, nexthop = %s, "
s
=
"%s.%s(tipe = %s, prefix = %s, prefix_len = %s, nexthop = %s, "
s
+=
"interface = %s, metric = %s)"
s
+=
"interface = %s, metric = %s)"
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
return
s
%
(
self
.
__module__
,
self
.
__class__
.
__name__
,
self
.
tipe
.
__repr__
(),
self
.
prefix
.
__repr__
(),
self
.
tipe
.
__repr__
(),
self
.
prefix
.
__repr__
(),
self
.
prefix_len
.
__repr__
(),
self
.
nexthop
.
__repr__
(),
self
.
prefix_len
.
__repr__
(),
self
.
nexthop
.
__repr__
(),
self
.
interface
.
__repr__
(),
self
.
metric
.
__repr__
())
self
.
interface
.
__repr__
(),
self
.
metric
.
__repr__
())
def
__eq__
(
self
,
o
):
def
__eq__
(
self
,
o
):
if
not
isinstance
(
o
,
route
):
if
not
isinstance
(
o
,
route
):
...
@@ -289,20 +294,22 @@ class route(object):
...
@@ -289,20 +294,22 @@ class route(object):
self
.
prefix_len
==
o
.
prefix_len
and
self
.
nexthop
==
o
.
nexthop
self
.
prefix_len
==
o
.
prefix_len
and
self
.
nexthop
==
o
.
nexthop
and
self
.
interface
==
o
.
interface
and
self
.
metric
==
o
.
metric
)
and
self
.
interface
==
o
.
interface
and
self
.
metric
==
o
.
metric
)
# helpers
# helpers
def
_get_if_name
(
iface
):
def
_get_if_name
(
iface
:
interface
|
int
|
str
):
if
isinstance
(
iface
,
interface
):
if
isinstance
(
iface
,
interface
):
if
iface
.
name
!=
None
:
if
iface
.
name
is
not
None
:
return
iface
.
name
return
iface
.
name
if
isinstance
(
iface
,
str
):
if
isinstance
(
iface
,
str
):
return
iface
return
iface
return
get_if
(
iface
).
name
return
get_if
(
iface
).
name
# XXX: ideally this should be replaced by netlink communication
# XXX: ideally this should be replaced by netlink communication
# Interface handling
# Interface handling
# FIXME: try to lower the amount of calls to retrieve data!!
# FIXME: try to lower the amount of calls to retrieve data!!
def
get_if_data
():
def
get_if_data
()
->
tuple
[
dict
[
int
,
interface
],
dict
[
str
,
interface
]]
:
"""Gets current interface information. Returns a tuple (byidx, bynam) in
"""Gets current interface information. Returns a tuple (byidx, bynam) in
which each element is a dictionary with the same data, but using different
which each element is a dictionary with the same data, but using different
keys: interface indexes and interface names.
keys: interface indexes and interface names.
...
@@ -323,21 +330,22 @@ def get_if_data():
...
@@ -323,21 +330,22 @@ def get_if_data():
r'
brd
([
0
-
9
a
-
f
:]
+
))
?
', line)
r'
brd
([
0
-
9
a
-
f
:]
+
))
?
', line)
flags = match.group(3).split(",")
flags = match.group(3).split(",")
i = interface(
i = interface(
index =
match.group(1),
index=
match.group(1),
name =
match.group(2),
name=
match.group(2),
up =
"UP" in flags,
up=
"UP" in flags,
mtu =
match.group(4),
mtu=
match.group(4),
lladdr =
match.group(5),
lladdr=
match.group(5),
arp =
not ("NOARP" in flags),
arp=
not ("NOARP" in flags),
broadcast =
match.group(6),
broadcast=
match.group(6),
multicast =
"MULTICAST" in flags)
multicast=
"MULTICAST" in flags)
byidx[idx] = bynam[i.name] = i
byidx[idx] = bynam[i.name] = i
return byidx, bynam
return byidx, bynam
def get_if(iface):
def get_if(iface: interface | int | str) -> interface:
ifdata = get_if_data()
ifdata = get_if_data()
if isinstance(iface, interface):
if isinstance(iface, interface):
if iface.index
!=
None:
if iface.index
is not
None:
return ifdata[0][iface.index]
return ifdata[0][iface.index]
else:
else:
return ifdata[1][iface.name]
return ifdata[1][iface.name]
...
@@ -345,7 +353,8 @@ def get_if(iface):
...
@@ -345,7 +353,8 @@ def get_if(iface):
return ifdata[0][iface]
return ifdata[0][iface]
return ifdata[1][iface]
return ifdata[1][iface]
def create_if_pair(if1, if2):
def create_if_pair(if1: interface, if2: interface) -> tuple[interface, interface]:
assert if1.name and if2.name
assert if1.name and if2.name
cmd = [[], []]
cmd = [[], []]
...
@@ -375,22 +384,24 @@ def create_if_pair(if1, if2):
...
@@ -375,22 +384,24 @@ def create_if_pair(if1, if2):
interfaces = get_if_data()[1]
interfaces = get_if_data()[1]
return interfaces[if1.name], interfaces[if2.name]
return interfaces[if1.name], interfaces[if2.name]
def del_if(iface):
def del_if(iface):
ifname = _get_if_name(iface)
ifname = _get_if_name(iface)
execute([IP_PATH, "link", "del", ifname])
execute([IP_PATH, "link", "del", ifname])
def set_if(iface, recover = True):
def do_cmds(cmds, orig_iface):
def set_if(iface: interface, recover=True):
def do_cmds(cmds: list[list[str]], orig_iface: interface):
for c in cmds:
for c in cmds:
try:
try:
execute(c)
execute(c)
except:
except:
if recover:
if recover:
set_if(orig_iface, recover
= False)
# rollback
set_if(orig_iface, recover
=False)
# rollback
raise
raise
orig_iface = get_if(iface)
orig_iface = get_if(iface)
diff = iface - orig_iface # Only set what'
s
needed
diff = iface - orig_iface
# Only set what'
s
needed
# Name goes first
# Name goes first
if
diff
.
name
:
if
diff
.
name
:
...
@@ -410,26 +421,28 @@ def set_if(iface, recover = True):
...
@@ -410,26 +421,28 @@ def set_if(iface, recover = True):
# iface needs to be down
# iface needs to be down
cmds
.
append
(
_ils
+
[
"down"
])
cmds
.
append
(
_ils
+
[
"down"
])
cmds
.
append
(
_ils
+
[
"address"
,
diff
.
lladdr
])
cmds
.
append
(
_ils
+
[
"address"
,
diff
.
lladdr
])
if
orig_iface
.
up
and
diff
.
up
==
None
:
if
orig_iface
.
up
and
diff
.
up
is
None
:
# restore if it was up and it's not going to be set later
# restore if it was up and it's not going to be set later
cmds
.
append
(
_ils
+
[
"up"
])
cmds
.
append
(
_ils
+
[
"up"
])
if
diff
.
mtu
:
if
diff
.
mtu
:
cmds
.
append
(
_ils
+
[
"mtu"
,
str
(
diff
.
mtu
)])
cmds
.
append
(
_ils
+
[
"mtu"
,
str
(
diff
.
mtu
)])
if
diff
.
broadcast
:
if
diff
.
broadcast
:
cmds
.
append
(
_ils
+
[
"broadcast"
,
diff
.
broadcast
])
cmds
.
append
(
_ils
+
[
"broadcast"
,
diff
.
broadcast
])
if
diff
.
multicast
!=
None
:
if
diff
.
multicast
is
not
None
:
cmds
.
append
(
_ils
+
[
"multicast"
,
"on"
if
diff
.
multicast
else
"off"
])
cmds
.
append
(
_ils
+
[
"multicast"
,
"on"
if
diff
.
multicast
else
"off"
])
if
diff
.
arp
!=
None
:
if
diff
.
arp
is
not
None
:
cmds
.
append
(
_ils
+
[
"arp"
,
"on"
if
diff
.
arp
else
"off"
])
cmds
.
append
(
_ils
+
[
"arp"
,
"on"
if
diff
.
arp
else
"off"
])
if
diff
.
up
!=
None
:
if
diff
.
up
is
not
None
:
cmds
.
append
(
_ils
+
[
"up"
if
diff
.
up
else
"down"
])
cmds
.
append
(
_ils
+
[
"up"
if
diff
.
up
else
"down"
])
do_cmds
(
cmds
,
orig_iface
)
do_cmds
(
cmds
,
orig_iface
)
def
change_netns
(
iface
,
netns
):
def
change_netns
(
iface
,
netns
):
ifname
=
_get_if_name
(
iface
)
ifname
=
_get_if_name
(
iface
)
execute
([
IP_PATH
,
"link"
,
"set"
,
"dev"
,
ifname
,
"netns"
,
str
(
netns
)])
execute
([
IP_PATH
,
"link"
,
"set"
,
"dev"
,
ifname
,
"netns"
,
str
(
netns
)])
# Address handling
# Address handling
def
get_addr_data
():
def
get_addr_data
():
...
@@ -459,16 +472,16 @@ def get_addr_data():
...
@@ -459,16 +472,16 @@ def get_addr_data():
match
=
re
.
search
(
r'^\
s*i
net ([0-9.]+)/(\
d+)(?:
brd ([0-9.]+))?'
,
line
)
match
=
re
.
search
(
r'^\
s*i
net ([0-9.]+)/(\
d+)(?:
brd ([0-9.]+))?'
,
line
)
if
match
:
if
match
:
bynam
[
current
].
append
(
ipv4address
(
bynam
[
current
].
append
(
ipv4address
(
address
=
match
.
group
(
1
),
address
=
match
.
group
(
1
),
prefix_len
=
match
.
group
(
2
),
prefix_len
=
match
.
group
(
2
),
broadcast
=
match
.
group
(
3
)))
broadcast
=
match
.
group
(
3
)))
continue
continue
match
=
re
.
search
(
r'^\
s*i
net6 ([0-9a-f:]+)/(\
d+)
', line)
match
=
re
.
search
(
r'^\
s*i
net6 ([0-9a-f:]+)/(\
d+)
', line)
if match:
if match:
bynam[current].append(ipv6address(
bynam[current].append(ipv6address(
address
=
match.group(1),
address
=
match.group(1),
prefix_len
=
match.group(2)))
prefix_len
=
match.group(2)))
continue
continue
# Extra info, ignored.
# Extra info, ignored.
...
@@ -476,26 +489,29 @@ def get_addr_data():
...
@@ -476,26 +489,29 @@ def get_addr_data():
return byidx, bynam
return byidx, bynam
def add_addr(iface, address):
def add_addr(iface, address):
ifname = _get_if_name(iface)
ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname]
addresses = get_addr_data()[1][ifname]
assert address not in addresses
assert address not in addresses
cmd = [IP_PATH, "addr", "add", "dev", ifname, "local",
cmd = [IP_PATH, "addr", "add", "dev", ifname, "local",
"%s/%d" % (address.address, int(address.prefix_len))]
"%s/%d" % (address.address, int(address.prefix_len))]
if hasattr(address, "broadcast"):
if hasattr(address, "broadcast"):
cmd += ["broadcast", address.broadcast if address.broadcast else "+"]
cmd += ["broadcast", address.broadcast if address.broadcast else "+"]
execute(cmd)
execute(cmd)
def del_addr(iface, address):
def del_addr(iface, address):
ifname = _get_if_name(iface)
ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname]
addresses = get_addr_data()[1][ifname]
assert address in addresses
assert address in addresses
cmd = [IP_PATH, "addr", "del", "dev", ifname, "local",
cmd = [IP_PATH, "addr", "del", "dev", ifname, "local",
"%s/%d" % (address.address, int(address.prefix_len))]
"%s/%d" % (address.address, int(address.prefix_len))]
execute(cmd)
execute(cmd)
# Bridge handling
# Bridge handling
def _sysfs_read_br(brname):
def _sysfs_read_br(brname):
def readval(fname):
def readval(fname):
...
@@ -509,12 +525,13 @@ def _sysfs_read_br(brname):
...
@@ -509,12 +525,13 @@ def _sysfs_read_br(brname):
except:
except:
return None
return None
return dict(
return dict(
stp = readval(p + "stp_state"),
stp=readval(p + "stp_state"),
forward_delay = float(readval(p + "forward_delay")) / 100,
forward_delay=float(readval(p + "forward_delay")) / 100,
hello_time = float(readval(p + "hello_time")) / 100,
hello_time=float(readval(p + "hello_time")) / 100,
ageing_time = float(readval(p + "ageing_time")) / 100,
ageing_time=float(readval(p + "ageing_time")) / 100,
max_age = float(readval(p + "max_age")) / 100,
max_age=float(readval(p + "max_age")) / 100,
ports = os.listdir(p2))
ports=os.listdir(p2))
def get_bridge_data():
def get_bridge_data():
# brctl stinks too much; it is better to directly use sysfs, it is
# brctl stinks too much; it is better to directly use sysfs, it is
...
@@ -525,24 +542,26 @@ def get_bridge_data():
...
@@ -525,24 +542,26 @@ def get_bridge_data():
ifdata = get_if_data()
ifdata = get_if_data()
for iface in ifdata[0].values():
for iface in ifdata[0].values():
brdata = _sysfs_read_br(iface.name)
brdata = _sysfs_read_br(iface.name)
if brdata
==
None:
if brdata
is
None:
continue
continue
ports[iface.index] = [ifdata[1][x].index for x in brdata["ports"]]
ports[iface.index] = [ifdata[1][x].index for x in brdata["ports"]]
del brdata["ports"]
del brdata["ports"]
bynam[iface.name] = byidx[iface.index] = \
bynam[iface.name] = byidx[iface.index] = \
bridge.upgrade(iface, **brdata)
bridge.upgrade(iface, **brdata)
return byidx, bynam, ports
return byidx, bynam, ports
def get_bridge(br):
def get_bridge(br):
iface = get_if(br)
iface = get_if(br)
brdata = _sysfs_read_br(iface.name)
brdata = _sysfs_read_br(iface.name)
#ports = [ifdata[1][x].index for x in brdata["ports"]]
#
ports = [ifdata[1][x].index for x in brdata["ports"]]
del brdata["ports"]
del brdata["ports"]
return bridge.upgrade(iface, **brdata)
return bridge.upgrade(iface, **brdata)
def create_bridge(br):
def create_bridge(br):
if isinstance(br, str):
if isinstance(br, str):
br = interface(name
=
br)
br = interface(name
=
br)
assert br.name
assert br.name
execute([BRCTL_PATH, "addbr", br.name])
execute([BRCTL_PATH, "addbr", br.name])
try:
try:
...
@@ -556,58 +575,64 @@ def create_bridge(br):
...
@@ -556,58 +575,64 @@ def create_bridge(br):
six.reraise(t, v, bt)
six.reraise(t, v, bt)
return get_if_data()[1][br.name]
return get_if_data()[1][br.name]
def del_bridge(br):
def del_bridge(br):
brname = _get_if_name(br)
brname = _get_if_name(br)
execute([BRCTL_PATH, "delbr", brname])
execute([BRCTL_PATH, "delbr", brname])
def set_bridge(br, recover = True):
def set_bridge(br, recover=True):
def saveval(fname, val):
def saveval(fname, val):
f = open(fname, "w")
f = open(fname, "w")
f.write(str(val))
f.write(str(val))
f.close()
f.close()
def do_cmds(basename, cmds, orig_br):
def do_cmds(basename, cmds, orig_br):
for n, v in cmds:
for n, v in cmds:
try:
try:
saveval(basename + n, v)
saveval(basename + n, v)
except:
except:
if recover:
if recover:
set_bridge(orig_br, recover
= False)
# rollback
set_bridge(orig_br, recover
=False)
# rollback
set_if(orig_br, recover
= False)
# rollback
set_if(orig_br, recover
=False)
# rollback
raise
raise
orig_br = get_bridge(br)
orig_br = get_bridge(br)
diff = br - orig_br # Only set what'
s
needed
diff = br - orig_br
# Only set what'
s
needed
cmds
=
[]
cmds
=
[]
if
diff
.
stp
!=
None
:
if
diff
.
stp
is
not
None
:
cmds
.
append
((
"stp_state"
,
int
(
diff
.
stp
)))
cmds
.
append
((
"stp_state"
,
int
(
diff
.
stp
)))
if
diff
.
forward_delay
!=
None
:
if
diff
.
forward_delay
is
not
None
:
cmds
.
append
((
"forward_delay"
,
int
(
diff
.
forward_delay
)))
cmds
.
append
((
"forward_delay"
,
int
(
diff
.
forward_delay
)))
if
diff
.
hello_time
!=
None
:
if
diff
.
hello_time
is
not
None
:
cmds
.
append
((
"hello_time"
,
int
(
diff
.
hello_time
)))
cmds
.
append
((
"hello_time"
,
int
(
diff
.
hello_time
)))
if
diff
.
ageing_time
!=
None
:
if
diff
.
ageing_time
is
not
None
:
cmds
.
append
((
"ageing_time"
,
int
(
diff
.
ageing_time
)))
cmds
.
append
((
"ageing_time"
,
int
(
diff
.
ageing_time
)))
if
diff
.
max_age
!=
None
:
if
diff
.
max_age
is
not
None
:
cmds
.
append
((
"max_age"
,
int
(
diff
.
max_age
)))
cmds
.
append
((
"max_age"
,
int
(
diff
.
max_age
)))
set_if
(
diff
)
set_if
(
diff
)
name
=
diff
.
name
if
diff
.
name
!=
None
else
orig_br
.
name
name
=
diff
.
name
if
diff
.
name
is
not
None
else
orig_br
.
name
do_cmds
(
"/sys/class/net/%s/bridge/"
%
name
,
cmds
,
orig_br
)
do_cmds
(
"/sys/class/net/%s/bridge/"
%
name
,
cmds
,
orig_br
)
def
add_bridge_port
(
br
,
iface
):
def
add_bridge_port
(
br
,
iface
):
ifname
=
_get_if_name
(
iface
)
ifname
=
_get_if_name
(
iface
)
brname
=
_get_if_name
(
br
)
brname
=
_get_if_name
(
br
)
execute
([
BRCTL_PATH
,
"addif"
,
brname
,
ifname
])
execute
([
BRCTL_PATH
,
"addif"
,
brname
,
ifname
])
def
del_bridge_port
(
br
,
iface
):
def
del_bridge_port
(
br
,
iface
):
ifname
=
_get_if_name
(
iface
)
ifname
=
_get_if_name
(
iface
)
brname
=
_get_if_name
(
br
)
brname
=
_get_if_name
(
br
)
execute
([
BRCTL_PATH
,
"delif"
,
brname
,
ifname
])
execute
([
BRCTL_PATH
,
"delif"
,
brname
,
ifname
])
# Routing
# Routing
def
get_all_route_data
():
def
get_all_route_data
():
ipdata
=
backticks
([
IP_PATH
,
"-o"
,
"route"
,
"list"
])
# "table", "all"
ipdata
=
backticks
([
IP_PATH
,
"-o"
,
"route"
,
"list"
])
# "table", "all"
ipdata
+=
backticks
([
IP_PATH
,
"-o"
,
"-f"
,
"inet6"
,
"route"
,
"list"
])
ipdata
+=
backticks
([
IP_PATH
,
"-o"
,
"-f"
,
"inet6"
,
"route"
,
"list"
])
ifdata
=
get_if_data
()[
1
]
ifdata
=
get_if_data
()[
1
]
...
@@ -616,8 +641,8 @@ def get_all_route_data():
...
@@ -616,8 +641,8 @@ def get_all_route_data():
if
line
==
""
:
if
line
==
""
:
continue
continue
match
=
re
.
match
(
r'(?:(unicast|local|broadcast|multicast|throw|'
+
match
=
re
.
match
(
r'(?:(unicast|local|broadcast|multicast|throw|'
+
r'unreachable|prohibit|blackhole|nat) )?'
+
r'unreachable|prohibit|blackhole|nat) )?'
+
r'(\
S+)(?:
via (\
S+))? de
v (\
S+).*(?: me
tric (\
d+))?
', line)
r'(\
S+)(?:
via (\
S+))? de
v (\
S+).*(?: me
tric (\
d+))?
', line)
if not match:
if not match:
raise RuntimeError("Invalid output from `ip route'
:
`
%
s
'" % line)
raise RuntimeError("Invalid output from `ip route'
:
`
%
s
'" % line)
tipe = match.group(1) or "unicast"
tipe = match.group(1) or "unicast"
...
@@ -633,26 +658,30 @@ def get_all_route_data():
...
@@ -633,26 +658,30 @@ def get_all_route_data():
prefix = match.group(1)
prefix = match.group(1)
prefix_len = int(match.group(2) or 32)
prefix_len = int(match.group(2) or 32)
ret.append(route(tipe, prefix, prefix_len, nexthop, interface.index,
ret.append(route(tipe, prefix, prefix_len, nexthop, interface.index,
metric))
metric))
return ret
return ret
def get_route_data():
def get_route_data() -> list[route]:
# filter out non-unicast routes
# filter out non-unicast routes
return [x for x in get_all_route_data() if x.tipe == "unicast"]
return [x for x in get_all_route_data() if x.tipe == "unicast"]
def add_route(route):
def add_route(route: route):
# Cannot really test this
# Cannot really test this
#if route in get_all_route_data():
#
if route in get_all_route_data():
# raise ValueError("Route already exists")
# raise ValueError("Route already exists")
_add_del_route("add", route)
_add_del_route("add", route)
def del_route(route):
def del_route(route: route):
# Cannot really test this
# Cannot really test this
#if route not in get_all_route_data():
#
if route not in get_all_route_data():
# raise ValueError("Route does not exist")
# raise ValueError("Route does not exist")
_add_del_route("del", route)
_add_del_route("del", route)
def _add_del_route(action, route):
def _add_del_route(action: Literal["add", "del"], route: route):
cmd = [IP_PATH, "route", action]
cmd = [IP_PATH, "route", action]
if route.tipe != "unicast":
if route.tipe != "unicast":
cmd += [route.tipe]
cmd += [route.tipe]
...
@@ -666,6 +695,7 @@ def _add_del_route(action, route):
...
@@ -666,6 +695,7 @@ def _add_del_route(action, route):
cmd += ["dev", _get_if_name(route.interface)]
cmd += ["dev", _get_if_name(route.interface)]
execute(cmd)
execute(cmd)
# TC stuff
# TC stuff
def get_tc_tree():
def get_tc_tree():
...
@@ -676,13 +706,13 @@ def get_tc_tree():
...
@@ -676,13 +706,13 @@ def get_tc_tree():
if line == "":
if line == "":
continue
continue
match = re.match(r'
qdisc
(
\
S
+
)
([
0
-
9
a
-
f
]
+
):[
0
-
9
a
-
f
]
*
dev
(
\
S
+
)
' +
match = re.match(r'
qdisc
(
\
S
+
)
([
0
-
9
a
-
f
]
+
):[
0
-
9
a
-
f
]
*
dev
(
\
S
+
)
' +
r'
(
?
:
parent
([
0
-
9
a
-
f
]
*
):[
0
-
9
a
-
f
]
*|
root
)
\
s
*
(.
*
)
', line)
r'
(
?
:
parent
([
0
-
9
a
-
f
]
*
):[
0
-
9
a
-
f
]
*|
root
)
\
s
*
(.
*
)
', line)
if not match:
if not match:
raise RuntimeError("Invalid output from `tc qdisc'
:
`
%
s
'" % line)
raise RuntimeError("Invalid output from `tc qdisc'
:
`
%
s
'" % line)
qdisc = match.group(1)
qdisc = match.group(1)
handle = match.group(2)
handle = match.group(2)
iface = match.group(3)
iface = match.group(3)
parent = match.group(4) # or None
parent = match.group(4)
# or None
extra = match.group(5)
extra = match.group(5)
if parent == "":
if parent == "":
# XXX: Still not sure what is this, shows in newer kernels for wlan
# XXX: Still not sure what is this, shows in newer kernels for wlan
...
@@ -706,15 +736,19 @@ def get_tc_tree():
...
@@ -706,15 +736,19 @@ def get_tc_tree():
for h in data[data_node[0]]:
for h in data[data_node[0]]:
node["children"].append(gen_tree(data, h))
node["children"].append(gen_tree(data, h))
return node
return node
tree[iface] = gen_tree(data[iface], data[iface][None][0])
tree[iface] = gen_tree(data[iface], data[iface][None][0])
return tree
return tree
_multipliers = {"M": 1000000, "K": 1000}
_multipliers = {"M": 1000000, "K": 1000}
_dividers = {"m": 1000, "u": 1000000}
_dividers = {"m": 1000, "u": 1000000}
def _parse_netem_delay(line):
def _parse_netem_delay(line):
ret = {}
ret = {}
match = re.search(r'
delay
([
\
d
.]
+
)([
mu
]
?
)
s
(
?
:
+
([
\
d
.]
+
)([
mu
]
?
)
s
)
?
' +
match = re.search(r'
delay
([
\
d
.]
+
)([
mu
]
?
)
s
(
?
:
+
([
\
d
.]
+
)([
mu
]
?
)
s
)
?
' +
r'
(
?
:
*
([
\
d
.]
+
)
%
)
?
(
?
:
*
distribution
(
\
S
+
))
?
', line)
r'
(
?
:
*
([
\
d
.]
+
)
%
)
?
(
?
:
*
distribution
(
\
S
+
))
?
', line)
if not match:
if not match:
return ret
return ret
...
@@ -737,6 +771,7 @@ def _parse_netem_delay(line):
...
@@ -737,6 +771,7 @@ def _parse_netem_delay(line):
return ret
return ret
def _parse_netem_loss(line):
def _parse_netem_loss(line):
ret = {}
ret = {}
match = re.search(r'
loss
([
\
d
.]
+
)
%
(
?
:
*
([
\
d
.]
+
)
%
)
?
', line)
match = re.search(r'
loss
([
\
d
.]
+
)
%
(
?
:
*
([
\
d
.]
+
)
%
)
?
', line)
...
@@ -748,6 +783,7 @@ def _parse_netem_loss(line):
...
@@ -748,6 +783,7 @@ def _parse_netem_loss(line):
ret["loss_correlation"] = float(match.group(2)) / 100
ret["loss_correlation"] = float(match.group(2)) / 100
return ret
return ret
def _parse_netem_dup(line):
def _parse_netem_dup(line):
ret = {}
ret = {}
match = re.search(r'
duplicate
([
\
d
.]
+
)
%
(
?
:
*
([
\
d
.]
+
)
%
)
?
', line)
match = re.search(r'
duplicate
([
\
d
.]
+
)
%
(
?
:
*
([
\
d
.]
+
)
%
)
?
', line)
...
@@ -759,6 +795,7 @@ def _parse_netem_dup(line):
...
@@ -759,6 +795,7 @@ def _parse_netem_dup(line):
ret["dup_correlation"] = float(match.group(2)) / 100
ret["dup_correlation"] = float(match.group(2)) / 100
return ret
return ret
def _parse_netem_corrupt(line):
def _parse_netem_corrupt(line):
ret = {}
ret = {}
match = re.search(r'
corrupt
([
\
d
.]
+
)
%
(
?
:
*
([
\
d
.]
+
)
%
)
?
', line)
match = re.search(r'
corrupt
([
\
d
.]
+
)
%
(
?
:
*
([
\
d
.]
+
)
%
)
?
', line)
...
@@ -770,6 +807,7 @@ def _parse_netem_corrupt(line):
...
@@ -770,6 +807,7 @@ def _parse_netem_corrupt(line):
ret["corrupt_correlation"] = float(match.group(2)) / 100
ret["corrupt_correlation"] = float(match.group(2)) / 100
return ret
return ret
def get_tc_data():
def get_tc_data():
tree = get_tc_tree()
tree = get_tc_tree()
ifdata = get_if_data()
ifdata = get_if_data()
...
@@ -802,7 +840,7 @@ def get_tc_data():
...
@@ -802,7 +840,7 @@ def get_tc_data():
continue
continue
tbf = node["extra"], node["handle"]
tbf = node["extra"], node["handle"]
netem = node["children"][0]["extra"],
\
netem = node["children"][0]["extra"],
\
node["children"][0]["handle"]
node["children"][0]["handle"]
if tbf:
if tbf:
ret[i]["qdiscs"]["tbf"] = tbf[1]
ret[i]["qdiscs"]["tbf"] = tbf[1]
...
@@ -823,22 +861,24 @@ def get_tc_data():
...
@@ -823,22 +861,24 @@ def get_tc_data():
ret[i].update(_parse_netem_corrupt(netem[0]))
ret[i].update(_parse_netem_corrupt(netem[0]))
return ret, ifdata[0], ifdata[1]
return ret, ifdata[0], ifdata[1]
def clear_tc(iface):
def clear_tc(iface):
iface = get_if(iface)
iface = get_if(iface)
tcdata = get_tc_data()[0]
tcdata = get_tc_data()[0]
if tcdata[iface.index]
==
None:
if tcdata[iface.index]
is
None:
return
return
# Any other case, we clean
# Any other case, we clean
execute([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
execute([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
delay_correlation = None, delay_distribution = None,
def set_tc(iface, bandwidth=None, delay=None, delay_jitter=None,
loss = None, loss_correlation = None,
delay_correlation=None, delay_distribution=None,
dup = None, dup_correlation = None,
loss=None, loss_correlation=None,
corrupt = None, corrupt_correlation = None):
dup=None, dup_correlation=None,
corrupt=None, corrupt_correlation=None):
use_netem = bool(delay or delay_jitter or delay_correlation or
use_netem = bool(delay or delay_jitter or delay_correlation or
delay_distribution or loss or loss_correlation or dup or
delay_distribution or loss or loss_correlation or dup or
dup_correlation or corrupt or corrupt_correlation)
dup_correlation or corrupt or corrupt_correlation)
iface = get_if(iface)
iface = get_if(iface)
tcdata, ifdata = get_tc_data()[0:2]
tcdata, ifdata = get_tc_data()[0:2]
...
@@ -846,7 +886,7 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
...
@@ -846,7 +886,7 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
if tcdata[iface.index] == '
foreign
':
if tcdata[iface.index] == '
foreign
':
# Avoid the overhead of calling tc+ip again
# Avoid the overhead of calling tc+ip again
commands.append([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
commands.append([TC_PATH, "qdisc", "del", "dev", iface.name, "root"])
tcdata[iface.index] = {'
qdiscs
':
[]}
tcdata[iface.index] = {'
qdiscs
': []}
has_netem = '
netem
' in tcdata[iface.index]['
qdiscs
']
has_netem = '
netem
' in tcdata[iface.index]['
qdiscs
']
has_tbf = '
tbf
' in tcdata[iface.index]['
qdiscs
']
has_tbf = '
tbf
' in tcdata[iface.index]['
qdiscs
']
...
@@ -862,20 +902,20 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
...
@@ -862,20 +902,20 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
# Too much work to do better :)
# Too much work to do better :)
if has_netem or has_tbf:
if has_netem or has_tbf:
commands.append([TC_PATH, "qdisc", "del", "dev", iface.name,
commands.append([TC_PATH, "qdisc", "del", "dev", iface.name,
"root"])
"root"])
cmd = "add"
cmd = "add"
if bandwidth:
if bandwidth:
rate = "%dbit" % int(bandwidth)
rate = "%dbit" % int(bandwidth)
mtu = ifdata[iface.index].mtu
mtu = ifdata[iface.index].mtu
burst = max(mtu, int(bandwidth) // HZ)
burst = max(mtu, int(bandwidth) // HZ)
limit = burst * 2 # FIXME?
limit = burst * 2
# FIXME?
handle = "1:"
handle = "1:"
if cmd == "change":
if cmd == "change":
handle = "%d:" % int(tcdata[iface.index]["qdiscs"]["tbf"])
handle = "%d:" % int(tcdata[iface.index]["qdiscs"]["tbf"])
command = [TC_PATH, "qdisc", cmd, "dev", iface.name, "root", "handle",
command = [TC_PATH, "qdisc", cmd, "dev", iface.name, "root", "handle",
handle, "tbf", "rate", rate, "limit", str(limit), "burst",
handle, "tbf", "rate", rate, "limit", str(limit), "burst",
str(burst)]
str(burst)]
commands.append(command)
commands.append(command)
if use_netem:
if use_netem:
...
@@ -920,16 +960,17 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
...
@@ -920,16 +960,17 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
for c in commands:
for c in commands:
execute(c)
execute(c)
def create_tap(iface, use_pi = False, tun = False):
def create_tap(iface, use_pi=False, tun=False):
"""Creates a tap/tun device and returns the associated file descriptor"""
"""Creates a tap/tun device and returns the associated file descriptor"""
if isinstance(iface, str):
if isinstance(iface, str):
iface = interface(name
=
iface)
iface = interface(name
=
iface)
assert iface.name
assert iface.name
IFF_TUN
= 0x0001
IFF_TUN = 0x0001
IFF_TAP
= 0x0002
IFF_TAP = 0x0002
IFF_NO_PI
= 0x1000
IFF_NO_PI = 0x1000
TUNSETIFF
= 0x400454ca
TUNSETIFF = 0x400454ca
if tun:
if tun:
mode = IFF_TUN
mode = IFF_TUN
else:
else:
...
@@ -952,4 +993,3 @@ def create_tap(iface, use_pi = False, tun = False):
...
@@ -952,4 +993,3 @@ def create_tap(iface, use_pi = False, tun = False):
raise
raise
interfaces = get_if_data()[1]
interfaces = get_if_data()[1]
return interfaces[iface.name], fd
return interfaces[iface.name], fd
src/nemu/node.py
View file @
033ce168
...
@@ -21,10 +21,13 @@ import os
...
@@ -21,10 +21,13 @@ import os
import
socket
import
socket
import
sys
import
sys
import
traceback
import
traceback
from
typing
import
MutableMapping
import
unshare
import
unshare
import
weakref
import
weakref
import
nemu.interface
import
nemu.interface
import
nemu.iproute
import
nemu.protocol
import
nemu.protocol
import
nemu.subprocess_
import
nemu.subprocess_
from
nemu
import
compat
from
nemu
import
compat
...
@@ -33,10 +36,11 @@ from nemu.environ import *
...
@@ -33,10 +36,11 @@ from nemu.environ import *
__all__
=
[
'Node'
,
'get_nodes'
,
'import_if'
]
__all__
=
[
'Node'
,
'get_nodes'
,
'import_if'
]
class
Node
(
object
):
class
Node
(
object
):
_nodes
=
weakref
.
WeakValueDictionary
()
_nodes
:
MutableMapping
[
int
,
"Node"
]
=
weakref
.
WeakValueDictionary
()
_nextnode
=
0
_nextnode
=
0
_processes
:
MutableMapping
[
int
,
nemu
.
subprocess_
.
Subprocess
]
@
staticmethod
@
staticmethod
def
get_nodes
():
def
get_nodes
()
->
list
[
"Node"
]
:
s
=
sorted
(
list
(
Node
.
_nodes
.
items
()),
key
=
lambda
x
:
x
[
0
])
s
=
sorted
(
list
(
Node
.
_nodes
.
items
()),
key
=
lambda
x
:
x
[
0
])
return
[
x
[
1
]
for
x
in
s
]
return
[
x
[
1
]
for
x
in
s
]
...
@@ -98,7 +102,7 @@ class Node(object):
...
@@ -98,7 +102,7 @@ class Node(object):
return
self
.
_pid
return
self
.
_pid
# Subprocesses
# Subprocesses
def
_add_subprocess
(
self
,
subprocess
):
def
_add_subprocess
(
self
,
subprocess
:
nemu
.
subprocess_
.
Subprocess
):
self
.
_processes
[
subprocess
.
pid
]
=
subprocess
self
.
_processes
[
subprocess
.
pid
]
=
subprocess
def
Subprocess
(
self
,
*
kargs
,
**
kwargs
):
def
Subprocess
(
self
,
*
kargs
,
**
kwargs
):
...
@@ -188,13 +192,13 @@ class Node(object):
...
@@ -188,13 +192,13 @@ class Node(object):
r
=
self
.
route
(
*
args
,
**
kwargs
)
r
=
self
.
route
(
*
args
,
**
kwargs
)
return
self
.
_slave
.
del_route
(
r
)
return
self
.
_slave
.
del_route
(
r
)
def
get_routes
(
self
):
def
get_routes
(
self
)
->
list
[
route
]
:
return
self
.
_slave
.
get_route_data
()
return
self
.
_slave
.
get_route_data
()
# Handle the creation of the child; parent gets (fd, pid), child creates and
# Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns.
# runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run.
# Requires CAP_SYS_ADMIN privileges to run.
def
_start_child
(
nonetns
)
->
(
socket
.
socket
,
int
):
def
_start_child
(
nonetns
:
bool
)
->
(
socket
.
socket
,
int
):
# Create socket pair to communicate
# Create socket pair to communicate
(
s0
,
s1
)
=
compat
.
socketpair
(
socket
.
AF_UNIX
,
socket
.
SOCK_STREAM
,
0
)
(
s0
,
s1
)
=
compat
.
socketpair
(
socket
.
AF_UNIX
,
socket
.
SOCK_STREAM
,
0
)
# Spawn a child that will run in a loop
# Spawn a child that will run in a loop
...
...
src/nemu/passfd.py
View file @
033ce168
...
@@ -21,7 +21,7 @@ import struct
...
@@ -21,7 +21,7 @@ import struct
from
io
import
IOBase
from
io
import
IOBase
def
__check_socket
(
sock
:
socket
.
socket
|
IOBase
):
def
__check_socket
(
sock
:
socket
.
socket
|
IOBase
)
->
socket
.
socket
:
if
hasattr
(
sock
,
'family'
)
and
sock
.
family
!=
socket
.
AF_UNIX
:
if
hasattr
(
sock
,
'family'
)
and
sock
.
family
!=
socket
.
AF_UNIX
:
raise
ValueError
(
"Only AF_UNIX sockets are allowed"
)
raise
ValueError
(
"Only AF_UNIX sockets are allowed"
)
...
@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase):
...
@@ -33,7 +33,7 @@ def __check_socket(sock: socket.socket | IOBase):
return
sock
return
sock
def
__check_fd
(
fd
):
def
__check_fd
(
fd
)
->
int
:
try
:
try
:
fd
=
fd
.
fileno
()
fd
=
fd
.
fileno
()
except
AttributeError
:
except
AttributeError
:
...
@@ -44,7 +44,7 @@ def __check_fd(fd):
...
@@ -44,7 +44,7 @@ def __check_fd(fd):
return
fd
return
fd
def
recvfd
(
sock
:
socket
.
socket
|
IOBase
,
msg_buf
:
int
=
4096
):
def
recvfd
(
sock
:
socket
.
socket
|
IOBase
,
msg_buf
:
int
=
4096
)
->
tuple
[
int
,
str
]
:
size
=
struct
.
calcsize
(
"@i"
)
size
=
struct
.
calcsize
(
"@i"
)
msg
,
ancdata
,
flags
,
addr
=
__check_socket
(
sock
).
recvmsg
(
msg_buf
,
socket
.
CMSG_SPACE
(
size
))
msg
,
ancdata
,
flags
,
addr
=
__check_socket
(
sock
).
recvmsg
(
msg_buf
,
socket
.
CMSG_SPACE
(
size
))
cmsg_level
,
cmsg_type
,
cmsg_data
=
ancdata
[
0
]
cmsg_level
,
cmsg_type
,
cmsg_data
=
ancdata
[
0
]
...
@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
...
@@ -59,7 +59,7 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
return
fd
,
msg
.
decode
(
"utf-8"
)
return
fd
,
msg
.
decode
(
"utf-8"
)
def
sendfd
(
sock
:
socket
.
socket
|
IOBase
,
fd
:
int
,
message
:
bytes
=
b"NONE"
):
def
sendfd
(
sock
:
socket
.
socket
|
IOBase
,
fd
:
int
,
message
:
bytes
=
b"NONE"
)
->
int
:
return
__check_socket
(
sock
).
sendmsg
(
return
__check_socket
(
sock
).
sendmsg
(
[
message
],
[
message
],
[(
socket
.
SOL_SOCKET
,
socket
.
SCM_RIGHTS
,
struct
.
pack
(
"@i"
,
fd
))])
[(
socket
.
SOL_SOCKET
,
socket
.
SCM_RIGHTS
,
struct
.
pack
(
"@i"
,
fd
))])
\ No newline at end of file
src/nemu/protocol.py
View file @
033ce168
...
@@ -29,6 +29,7 @@ import tempfile
...
@@ -29,6 +29,7 @@ import tempfile
import
time
import
time
import
traceback
import
traceback
from
pickle
import
loads
,
dumps
from
pickle
import
loads
,
dumps
from
typing
import
Literal
import
nemu.iproute
import
nemu.iproute
import
nemu.subprocess_
import
nemu.subprocess_
...
@@ -278,7 +279,7 @@ class Server(object):
...
@@ -278,7 +279,7 @@ class Server(object):
self
.
reply
(
220
,
"Hello."
);
self
.
reply
(
220
,
"Hello."
);
while
not
self
.
_closed
:
while
not
self
.
_closed
:
cmd
=
self
.
readcmd
()
cmd
=
self
.
readcmd
()
if
cmd
==
None
:
if
cmd
is
None
:
continue
continue
try
:
try
:
cmd
[
0
](
cmd
[
1
],
*
cmd
[
2
])
cmd
[
0
](
cmd
[
1
],
*
cmd
[
2
])
...
@@ -422,7 +423,7 @@ class Server(object):
...
@@ -422,7 +423,7 @@ class Server(object):
else
:
else
:
ret
=
nemu
.
subprocess_
.
wait
(
pid
)
ret
=
nemu
.
subprocess_
.
wait
(
pid
)
if
ret
!=
None
:
if
ret
is
not
None
:
self
.
_children
.
remove
(
pid
)
self
.
_children
.
remove
(
pid
)
if
pid
in
self
.
_xauthfiles
:
if
pid
in
self
.
_xauthfiles
:
try
:
try
:
...
@@ -449,7 +450,7 @@ class Server(object):
...
@@ -449,7 +450,7 @@ class Server(object):
self
.
reply
(
200
,
"Process signalled."
)
self
.
reply
(
200
,
"Process signalled."
)
def
do_IF_LIST
(
self
,
cmdname
,
ifnr
=
None
):
def
do_IF_LIST
(
self
,
cmdname
,
ifnr
=
None
):
if
ifnr
==
None
:
if
ifnr
is
None
:
ifdata
=
nemu
.
iproute
.
get_if_data
()[
0
]
ifdata
=
nemu
.
iproute
.
get_if_data
()[
0
]
else
:
else
:
ifdata
=
nemu
.
iproute
.
get_if
(
ifnr
)
ifdata
=
nemu
.
iproute
.
get_if
(
ifnr
)
...
@@ -479,7 +480,7 @@ class Server(object):
...
@@ -479,7 +480,7 @@ class Server(object):
def
do_ADDR_LIST
(
self
,
cmdname
,
ifnr
=
None
):
def
do_ADDR_LIST
(
self
,
cmdname
,
ifnr
=
None
):
addrdata
=
nemu
.
iproute
.
get_addr_data
()[
0
]
addrdata
=
nemu
.
iproute
.
get_addr_data
()[
0
]
if
ifnr
!=
None
:
if
ifnr
is
not
None
:
addrdata
=
addrdata
[
ifnr
]
addrdata
=
addrdata
[
ifnr
]
self
.
reply
(
200
,
[
"# Address data follows."
,
self
.
reply
(
200
,
[
"# Address data follows."
,
_b64
(
dumps
(
addrdata
,
protocol
=
2
))])
_b64
(
dumps
(
addrdata
,
protocol
=
2
))])
...
@@ -652,7 +653,7 @@ class Client(object):
...
@@ -652,7 +653,7 @@ class Client(object):
stdin/stdout/stderr can only be None or a open file descriptor.
stdin/stdout/stderr can only be None or a open file descriptor.
See nemu.subprocess_.spawn for details."""
See nemu.subprocess_.spawn for details."""
if
executable
==
None
:
if
executable
is
None
:
executable
=
argv
[
0
]
executable
=
argv
[
0
]
params
=
[
"PROC"
,
"CRTE"
,
_b64
(
executable
)]
params
=
[
"PROC"
,
"CRTE"
,
_b64
(
executable
)]
for
i
in
argv
:
for
i
in
argv
:
...
@@ -663,28 +664,28 @@ class Client(object):
...
@@ -663,28 +664,28 @@ class Client(object):
# After this, if we get an error, we have to abort the PROC
# After this, if we get an error, we have to abort the PROC
try
:
try
:
if
user
!=
None
:
if
user
is
not
None
:
self
.
_send_cmd
(
"PROC"
,
"USER"
,
_b64
(
user
))
self
.
_send_cmd
(
"PROC"
,
"USER"
,
_b64
(
user
))
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
if
cwd
!=
None
:
if
cwd
is
not
None
:
self
.
_send_cmd
(
"PROC"
,
"CWD"
,
_b64
(
cwd
))
self
.
_send_cmd
(
"PROC"
,
"CWD"
,
_b64
(
cwd
))
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
if
env
!=
None
:
if
env
is
not
None
:
params
=
[]
params
=
[]
for
k
,
v
in
env
.
items
():
for
k
,
v
in
env
.
items
():
params
.
extend
([
_b64
(
k
),
_b64
(
v
)])
params
.
extend
([
_b64
(
k
),
_b64
(
v
)])
self
.
_send_cmd
(
"PROC"
,
"ENV"
,
*
params
)
self
.
_send_cmd
(
"PROC"
,
"ENV"
,
*
params
)
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
if
stdin
!=
None
:
if
stdin
is
not
None
:
os
.
set_inheritable
(
stdin
,
True
)
os
.
set_inheritable
(
stdin
,
True
)
self
.
_send_fd
(
"SIN"
,
stdin
)
self
.
_send_fd
(
"SIN"
,
stdin
)
if
stdout
!=
None
:
if
stdout
is
not
None
:
os
.
set_inheritable
(
stdout
,
True
)
os
.
set_inheritable
(
stdout
,
True
)
self
.
_send_fd
(
"SOUT"
,
stdout
)
self
.
_send_fd
(
"SOUT"
,
stdout
)
if
stderr
!=
None
:
if
stderr
is
not
None
:
os
.
set_inheritable
(
stderr
,
True
)
os
.
set_inheritable
(
stderr
,
True
)
self
.
_send_fd
(
"SERR"
,
stderr
)
self
.
_send_fd
(
"SERR"
,
stderr
)
except
:
except
:
...
@@ -739,7 +740,7 @@ class Client(object):
...
@@ -739,7 +740,7 @@ class Client(object):
cmd
=
[
"IF"
,
"SET"
,
interface
.
index
]
cmd
=
[
"IF"
,
"SET"
,
interface
.
index
]
for
k
in
interface
.
changeable_attributes
:
for
k
in
interface
.
changeable_attributes
:
v
=
getattr
(
interface
,
k
)
v
=
getattr
(
interface
,
k
)
if
v
!=
None
:
if
v
is
not
None
:
cmd
+=
[
k
,
str
(
v
)]
cmd
+=
[
k
,
str
(
v
)]
self
.
_send_cmd
(
*
cmd
)
self
.
_send_cmd
(
*
cmd
)
...
@@ -761,7 +762,7 @@ class Client(object):
...
@@ -761,7 +762,7 @@ class Client(object):
data
=
self
.
_read_and_check_reply
()
data
=
self
.
_read_and_check_reply
()
return
loads
(
_db64
(
data
.
partition
(
"
\
n
"
)[
2
]))
return
loads
(
_db64
(
data
.
partition
(
"
\
n
"
)[
2
]))
def
add_addr
(
self
,
ifnr
,
address
):
def
add_addr
(
self
,
ifnr
:
int
,
address
:
nemu
.
iproute
.
address
):
if
hasattr
(
address
,
"broadcast"
)
and
address
.
broadcast
:
if
hasattr
(
address
,
"broadcast"
)
and
address
.
broadcast
:
self
.
_send_cmd
(
"ADDR"
,
"ADD"
,
ifnr
,
address
.
address
,
self
.
_send_cmd
(
"ADDR"
,
"ADD"
,
ifnr
,
address
.
address
,
address
.
prefix_len
,
address
.
broadcast
)
address
.
prefix_len
,
address
.
broadcast
)
...
@@ -770,7 +771,7 @@ class Client(object):
...
@@ -770,7 +771,7 @@ class Client(object):
address
.
prefix_len
)
address
.
prefix_len
)
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
def
del_addr
(
self
,
ifnr
,
address
):
def
del_addr
(
self
,
ifnr
:
int
,
address
:
nemu
.
iproute
.
address
):
self
.
_send_cmd
(
"ADDR"
,
"DEL"
,
ifnr
,
address
.
address
,
address
.
prefix_len
)
self
.
_send_cmd
(
"ADDR"
,
"DEL"
,
ifnr
,
address
.
address
,
address
.
prefix_len
)
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
...
@@ -785,14 +786,14 @@ class Client(object):
...
@@ -785,14 +786,14 @@ class Client(object):
def
del_route
(
self
,
route
):
def
del_route
(
self
,
route
):
self
.
_add_del_route
(
"DEL"
,
route
)
self
.
_add_del_route
(
"DEL"
,
route
)
def
_add_del_route
(
self
,
action
,
route
):
def
_add_del_route
(
self
,
action
:
Literal
[
"ADD"
,
"DEL"
],
route
:
nemu
.
iproute
.
route
):
args
=
[
"ROUT"
,
action
,
_b64
(
route
.
tipe
),
_b64
(
route
.
prefix
),
args
=
[
"ROUT"
,
action
,
_b64
(
route
.
tipe
),
_b64
(
route
.
prefix
),
route
.
prefix_len
or
0
,
_b64
(
route
.
nexthop
),
route
.
prefix_len
or
0
,
_b64
(
route
.
nexthop
),
route
.
interface
or
0
,
route
.
metric
or
0
]
route
.
interface
or
0
,
route
.
metric
or
0
]
self
.
_send_cmd
(
*
args
)
self
.
_send_cmd
(
*
args
)
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
def
set_x11
(
self
,
protoname
,
hexkey
)
:
def
set_x11
(
self
,
protoname
:
str
,
hexkey
:
str
)
->
socket
.
socket
:
# Returns a socket ready to accept() connections
# Returns a socket ready to accept() connections
self
.
_send_cmd
(
"X11"
,
"SET"
,
protoname
,
hexkey
)
self
.
_send_cmd
(
"X11"
,
"SET"
,
protoname
,
hexkey
)
self
.
_read_and_check_reply
()
self
.
_read_and_check_reply
()
...
@@ -823,7 +824,7 @@ class Client(object):
...
@@ -823,7 +824,7 @@ class Client(object):
def
_b64_OLD
(
text
:
str
|
bytes
)
->
str
:
def
_b64_OLD
(
text
:
str
|
bytes
)
->
str
:
if
text
==
None
:
if
text
is
None
:
# easier this way
# easier this way
text
=
''
text
=
''
if
type
(
text
)
is
str
:
if
type
(
text
)
is
str
:
...
@@ -833,11 +834,12 @@ def _b64_OLD(text: str | bytes) -> str:
...
@@ -833,11 +834,12 @@ def _b64_OLD(text: str | bytes) -> str:
else
:
else
:
btext
=
text
btext
=
text
if
len
(
btext
)
==
0
or
any
(
x
for
x
in
btext
if
x
<=
ord
(
" "
)
or
if
len
(
btext
)
==
0
or
any
(
x
for
x
in
btext
if
x
<=
ord
(
" "
)
or
x
>
ord
(
"z"
)
or
x
==
ord
(
"="
)):
x
>
ord
(
"z"
)
or
x
==
ord
(
"="
)):
return
"="
+
base64
.
b64encode
(
btext
).
decode
(
"ascii"
)
return
"="
+
base64
.
b64encode
(
btext
).
decode
(
"ascii"
)
else
:
else
:
return
text
return
text
def
_b64
(
text
)
->
str
:
def
_b64
(
text
)
->
str
:
if
text
is
None
:
if
text
is
None
:
# easier this way
# easier this way
...
@@ -848,7 +850,7 @@ def _b64(text) -> str:
...
@@ -848,7 +850,7 @@ def _b64(text) -> str:
else
:
else
:
enc
=
str
(
text
).
encode
(
"utf-8"
)
enc
=
str
(
text
).
encode
(
"utf-8"
)
if
len
(
enc
)
==
0
or
any
(
x
for
x
in
enc
if
x
<=
ord
(
" "
)
or
if
len
(
enc
)
==
0
or
any
(
x
for
x
in
enc
if
x
<=
ord
(
" "
)
or
x
>
ord
(
"z"
)
or
x
==
ord
(
"="
)):
x
>
ord
(
"z"
)
or
x
==
ord
(
"="
)):
return
"="
+
base64
.
b64encode
(
enc
).
decode
(
"ascii"
)
return
"="
+
base64
.
b64encode
(
enc
).
decode
(
"ascii"
)
else
:
else
:
return
enc
.
decode
(
"utf-8"
)
return
enc
.
decode
(
"utf-8"
)
...
...
src/nemu/subprocess_.py
View file @
033ce168
...
@@ -27,7 +27,10 @@ import signal
...
@@ -27,7 +27,10 @@ import signal
import
sys
import
sys
import
time
import
time
import
traceback
import
traceback
import
typing
if
typing
.
TYPE_CHECKING
:
from
nemu
import
Node
from
nemu
import
compat
from
nemu
import
compat
from
nemu.environ
import
eintr_wrapper
from
nemu.environ
import
eintr_wrapper
...
@@ -46,7 +49,7 @@ class Subprocess(object):
...
@@ -46,7 +49,7 @@ class Subprocess(object):
# FIXME
# FIXME
default_user
=
None
default_user
=
None
def
__init__
(
self
,
node
,
argv
:
str
|
list
[
str
],
executable
=
None
,
def
__init__
(
self
,
node
:
"Node"
,
argv
:
str
|
list
[
str
],
executable
=
None
,
stdin
=
None
,
stdout
=
None
,
stderr
=
None
,
stdin
=
None
,
stdout
=
None
,
stderr
=
None
,
shell
=
False
,
cwd
=
None
,
env
=
None
,
user
=
None
):
shell
=
False
,
cwd
=
None
,
env
=
None
,
user
=
None
):
self
.
_slave
=
node
.
_slave
self
.
_slave
=
node
.
_slave
...
@@ -78,7 +81,7 @@ class Subprocess(object):
...
@@ -78,7 +81,7 @@ class Subprocess(object):
Exceptions occurred while trying to set up the environment or executing
Exceptions occurred while trying to set up the environment or executing
the program are propagated to the parent."""
the program are propagated to the parent."""
if
user
==
None
:
if
user
is
None
:
user
=
Subprocess
.
default_user
user
=
Subprocess
.
default_user
if
isinstance
(
argv
,
str
):
if
isinstance
(
argv
,
str
):
...
@@ -106,20 +109,20 @@ class Subprocess(object):
...
@@ -106,20 +109,20 @@ class Subprocess(object):
def
poll
(
self
):
def
poll
(
self
):
"""Checks status of program, returns exitcode or None if still running.
"""Checks status of program, returns exitcode or None if still running.
See Popen.poll."""
See Popen.poll."""
if
self
.
_returncode
==
None
:
if
self
.
_returncode
is
None
:
self
.
_returncode
=
self
.
_slave
.
poll
(
self
.
_pid
)
self
.
_returncode
=
self
.
_slave
.
poll
(
self
.
_pid
)
return
self
.
returncode
return
self
.
returncode
def
wait
(
self
):
def
wait
(
self
):
"""Waits for program to complete and returns the exitcode.
"""Waits for program to complete and returns the exitcode.
See Popen.wait"""
See Popen.wait"""
if
self
.
_returncode
==
None
:
if
self
.
_returncode
is
None
:
self
.
_returncode
=
self
.
_slave
.
wait
(
self
.
_pid
)
self
.
_returncode
=
self
.
_slave
.
wait
(
self
.
_pid
)
return
self
.
returncode
return
self
.
returncode
def
signal
(
self
,
sig
=
signal
.
SIGTERM
):
def
signal
(
self
,
sig
=
signal
.
SIGTERM
):
"""Sends a signal to the process."""
"""Sends a signal to the process."""
if
self
.
_returncode
==
None
:
if
self
.
_returncode
is
None
:
self
.
_slave
.
signal
(
self
.
_pid
,
sig
)
self
.
_slave
.
signal
(
self
.
_pid
,
sig
)
@
property
@
property
...
@@ -128,7 +131,7 @@ class Subprocess(object):
...
@@ -128,7 +131,7 @@ class Subprocess(object):
communicate, wait, or poll), returns the signal that killed the
communicate, wait, or poll), returns the signal that killed the
program, if negative; otherwise, it is the exit code of the program.
program, if negative; otherwise, it is the exit code of the program.
"""
"""
if
self
.
_returncode
==
None
:
if
self
.
_returncode
is
None
:
return
None
return
None
if
os
.
WIFSIGNALED
(
self
.
_returncode
):
if
os
.
WIFSIGNALED
(
self
.
_returncode
):
return
-
os
.
WTERMSIG
(
self
.
_returncode
)
return
-
os
.
WTERMSIG
(
self
.
_returncode
)
...
@@ -140,12 +143,12 @@ class Subprocess(object):
...
@@ -140,12 +143,12 @@ class Subprocess(object):
self
.
destroy
()
self
.
destroy
()
def
destroy
(
self
):
def
destroy
(
self
):
if
self
.
_returncode
!=
None
or
self
.
_pid
==
None
:
if
self
.
_returncode
is
not
None
or
self
.
_pid
is
None
:
return
return
self
.
signal
()
self
.
signal
()
now
=
time
.
time
()
now
=
time
.
time
()
while
time
.
time
()
-
now
<
KILL_WAIT
:
while
time
.
time
()
-
now
<
KILL_WAIT
:
if
self
.
poll
()
!=
None
:
if
self
.
poll
()
is
not
None
:
return
return
time
.
sleep
(
0.1
)
time
.
sleep
(
0.1
)
sys
.
stderr
.
write
(
"WARNING: killing forcefully process %d.
\
n
"
%
sys
.
stderr
.
write
(
"WARNING: killing forcefully process %d.
\
n
"
%
...
@@ -179,7 +182,7 @@ class Popen(Subprocess):
...
@@ -179,7 +182,7 @@ class Popen(Subprocess):
fdmap
=
{
"stdin"
:
stdin
,
"stdout"
:
stdout
,
"stderr"
:
stderr
}
fdmap
=
{
"stdin"
:
stdin
,
"stdout"
:
stdout
,
"stderr"
:
stderr
}
# if PIPE: all should be closed at the end
# if PIPE: all should be closed at the end
for
k
,
v
in
fdmap
.
items
():
for
k
,
v
in
fdmap
.
items
():
if
v
==
None
:
if
v
is
None
:
continue
continue
if
v
==
PIPE
:
if
v
==
PIPE
:
r
,
w
=
compat
.
pipe
()
r
,
w
=
compat
.
pipe
()
...
@@ -206,27 +209,29 @@ class Popen(Subprocess):
...
@@ -206,27 +209,29 @@ class Popen(Subprocess):
# Close pipes, they have been dup()ed to the child
# Close pipes, they have been dup()ed to the child
for
k
,
v
in
fdmap
.
items
():
for
k
,
v
in
fdmap
.
items
():
if
getattr
(
self
,
k
)
!=
None
:
if
getattr
(
self
,
k
)
is
not
None
:
eintr_wrapper
(
os
.
close
,
v
)
eintr_wrapper
(
os
.
close
,
v
)
def
communicate
(
self
,
input
:
bytes
=
None
)
->
tuple
[
bytes
,
bytes
]:
def
communicate
(
self
,
input
:
bytes
|
str
=
None
)
->
tuple
[
bytes
,
bytes
]:
"""See Popen.communicate."""
"""See Popen.communicate."""
# FIXME: almost verbatim from stdlib version, need to be removed or
# FIXME: almost verbatim from stdlib version, need to be removed or
# something
# something
if
type
(
input
)
is
str
:
input
=
input
.
encode
(
"utf-8"
)
wset
=
[]
wset
=
[]
rset
=
[]
rset
=
[]
err
=
None
err
=
None
out
=
None
out
=
None
if
self
.
stdin
!=
None
:
if
self
.
stdin
is
not
None
:
self
.
stdin
.
flush
()
self
.
stdin
.
flush
()
if
input
:
if
input
:
wset
.
append
(
self
.
stdin
)
wset
.
append
(
self
.
stdin
)
else
:
else
:
self
.
stdin
.
close
()
self
.
stdin
.
close
()
if
self
.
stdout
!=
None
:
if
self
.
stdout
is
not
None
:
rset
.
append
(
self
.
stdout
)
rset
.
append
(
self
.
stdout
)
out
=
[]
out
=
[]
if
self
.
stderr
!=
None
:
if
self
.
stderr
is
not
None
:
rset
.
append
(
self
.
stderr
)
rset
.
append
(
self
.
stderr
)
err
=
[]
err
=
[]
...
@@ -253,9 +258,9 @@ class Popen(Subprocess):
...
@@ -253,9 +258,9 @@ class Popen(Subprocess):
else
:
else
:
err
.
append
(
d
)
err
.
append
(
d
)
if
out
!=
None
:
if
out
is
not
None
:
out
=
b''
.
join
(
out
)
out
=
b''
.
join
(
out
)
if
err
!=
None
:
if
err
is
not
None
:
err
=
b''
.
join
(
err
)
err
=
b''
.
join
(
err
)
self
.
wait
()
self
.
wait
()
return
(
out
,
err
)
return
(
out
,
err
)
...
@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
...
@@ -313,15 +318,15 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
is not supported here. Also, the original descriptors are not closed.
is not supported here. Also, the original descriptors are not closed.
"""
"""
userfd
=
[
stdin
,
stdout
,
stderr
]
userfd
=
[
stdin
,
stdout
,
stderr
]
filtered_userfd
=
[
x
for
x
in
userfd
if
x
!=
None
and
x
>=
0
]
filtered_userfd
=
[
x
for
x
in
userfd
if
x
is
not
None
and
x
>=
0
]
for
i
in
range
(
3
):
for
i
in
range
(
3
):
if
userfd
[
i
]
!=
None
and
not
isinstance
(
userfd
[
i
],
int
):
if
userfd
[
i
]
is
not
None
and
not
isinstance
(
userfd
[
i
],
int
):
userfd
[
i
]
=
userfd
[
i
].
fileno
()
# pragma: no cover
userfd
[
i
]
=
userfd
[
i
].
fileno
()
# pragma: no cover
# Verify there is no clash
# Verify there is no clash
assert
not
(
set
([
0
,
1
,
2
])
&
set
(
filtered_userfd
))
assert
not
(
set
([
0
,
1
,
2
])
&
set
(
filtered_userfd
))
if
user
!=
None
:
if
user
is
not
None
:
user
,
uid
,
gid
=
get_user
(
user
)
user
,
uid
,
gid
=
get_user
(
user
)
home
=
pwd
.
getpwuid
(
uid
)[
5
]
home
=
pwd
.
getpwuid
(
uid
)[
5
]
groups
=
[
x
[
2
]
for
x
in
grp
.
getgrall
()
if
user
in
x
[
3
]]
groups
=
[
x
[
2
]
for
x
in
grp
.
getgrall
()
if
user
in
x
[
3
]]
...
@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
...
@@ -337,7 +342,7 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
try
:
try
:
# Set up stdio piping
# Set up stdio piping
for
i
in
range
(
3
):
for
i
in
range
(
3
):
if
userfd
[
i
]
!=
None
and
userfd
[
i
]
>=
0
:
if
userfd
[
i
]
is
not
None
and
userfd
[
i
]
>=
0
:
os
.
dup2
(
userfd
[
i
],
i
)
os
.
dup2
(
userfd
[
i
],
i
)
if
userfd
[
i
]
!=
i
and
userfd
[
i
]
not
in
userfd
[
0
:
i
]:
if
userfd
[
i
]
!=
i
and
userfd
[
i
]
not
in
userfd
[
0
:
i
]:
eintr_wrapper
(
os
.
close
,
userfd
[
i
])
# only in child!
eintr_wrapper
(
os
.
close
,
userfd
[
i
])
# only in child!
...
@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
...
@@ -362,22 +367,22 @@ def spawn(executable, argv=None, cwd=None, env=None, close_fds=False,
# (it is necessary to kill the forked subprocesses)
# (it is necessary to kill the forked subprocesses)
os
.
setpgrp
()
os
.
setpgrp
()
if
user
!=
None
:
if
user
is
not
None
:
# Change user
# Change user
os
.
setgid
(
gid
)
os
.
setgid
(
gid
)
os
.
setgroups
(
groups
)
os
.
setgroups
(
groups
)
os
.
setuid
(
uid
)
os
.
setuid
(
uid
)
if
cwd
!=
None
:
if
cwd
is
not
None
:
os
.
chdir
(
cwd
)
os
.
chdir
(
cwd
)
if
not
argv
:
if
not
argv
:
argv
=
[
executable
]
argv
=
[
executable
]
if
'/'
in
executable
:
# Should not search in PATH
if
'/'
in
executable
:
# Should not search in PATH
if
env
!=
None
:
if
env
is
not
None
:
os
.
execve
(
executable
,
argv
,
env
)
os
.
execve
(
executable
,
argv
,
env
)
else
:
else
:
os
.
execv
(
executable
,
argv
)
os
.
execv
(
executable
,
argv
)
else
:
# use PATH
else
:
# use PATH
if
env
!=
None
:
if
env
is
not
None
:
os
.
execvpe
(
executable
,
argv
,
env
)
os
.
execvpe
(
executable
,
argv
,
env
)
else
:
else
:
os
.
execvp
(
executable
,
argv
)
os
.
execvp
(
executable
,
argv
)
...
...
test/test_core.py
View file @
033ce168
...
@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase):
...
@@ -136,7 +136,7 @@ class TestGlobal(unittest.TestCase):
os
.
write
(
if1
.
fd
,
s
)
os
.
write
(
if1
.
fd
,
s
)
if
not
s
:
if
not
s
:
break
break
if
subproc
.
poll
()
!=
None
:
if
subproc
.
poll
()
is
not
None
:
break
break
@
test_util
.
skipUnless
(
os
.
getuid
()
==
0
,
"Test requires root privileges"
)
@
test_util
.
skipUnless
(
os
.
getuid
()
==
0
,
"Test requires root privileges"
)
...
...
test/test_protocol.py
View file @
033ce168
...
@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase):
...
@@ -107,7 +107,7 @@ class TestServer(unittest.TestCase):
def
check_ok
(
self
,
cmd
,
func
,
args
):
def
check_ok
(
self
,
cmd
,
func
,
args
):
s1
.
write
(
"%s
\
n
"
%
cmd
)
s1
.
write
(
"%s
\
n
"
%
cmd
)
ccmd
=
" "
.
join
(
cmd
.
upper
().
split
()[
0
:
2
])
ccmd
=
" "
.
join
(
cmd
.
upper
().
split
()[
0
:
2
])
if
func
==
None
:
if
func
is
None
:
self
.
assertEqual
(
srv
.
readcmd
()[
1
:
3
],
(
ccmd
,
args
))
self
.
assertEqual
(
srv
.
readcmd
()[
1
:
3
],
(
ccmd
,
args
))
else
:
else
:
self
.
assertEqual
(
srv
.
readcmd
(),
(
func
,
ccmd
,
args
))
self
.
assertEqual
(
srv
.
readcmd
(),
(
func
,
ccmd
,
args
))
...
...
test/test_util.py
View file @
033ce168
...
@@ -15,7 +15,7 @@ def process_ipcmd(str: str):
...
@@ -15,7 +15,7 @@ def process_ipcmd(str: str):
match
=
re
.
search
(
r'^(\
d+): ([^@
\s]+)(?:@\
S+)?: <(
\S+)> mtu (\
d+)
'
match
=
re
.
search
(
r'^(\
d+): ([^@
\s]+)(?:@\
S+)?: <(
\S+)> mtu (\
d+)
'
r'
qdisc
(
\
S
+
)
',
r'
qdisc
(
\
S
+
)
',
line)
line)
if match
!=
None:
if match
is not
None:
cur = match.group(2)
cur = match.group(2)
out[cur] = {
out[cur] = {
'
idx
': int(match.group(1)),
'
idx
': int(match.group(1)),
...
@@ -27,14 +27,14 @@ def process_ipcmd(str: str):
...
@@ -27,14 +27,14 @@ def process_ipcmd(str: str):
out[cur]['
up
'] = '
UP
' in out[cur]['
flags
']
out[cur]['
up
'] = '
UP
' in out[cur]['
flags
']
continue
continue
# Assume cur is defined
# Assume cur is defined
assert cur
!=
None
assert cur
is not
None
match = re.search(r'
^
\
s
+
link
/
\
S
*
(
?
:
([
0
-
9
a
-
f
:]
+
))
?
(
?
:
|
$
)
', line)
match = re.search(r'
^
\
s
+
link
/
\
S
*
(
?
:
([
0
-
9
a
-
f
:]
+
))
?
(
?
:
|
$
)
', line)
if match
!=
None:
if match
is not
None:
out[cur]['
lladdr
'] = match.group(1)
out[cur]['
lladdr
'] = match.group(1)
continue
continue
match = re.search(r'
^
\
s
+
inet
([
0
-
9.
]
+
)
/
(
\
d
+
)(
?
:
brd
([
0
-
9.
]
+
))
?
', line)
match = re.search(r'
^
\
s
+
inet
([
0
-
9.
]
+
)
/
(
\
d
+
)(
?
:
brd
([
0
-
9.
]
+
))
?
', line)
if match
!=
None:
if match
is not
None:
out[cur]['
addr
'].append({
out[cur]['
addr
'].append({
'
address
': match.group(1),
'
address
': match.group(1),
'
prefix_len
': int(match.group(2)),
'
prefix_len
': int(match.group(2)),
...
@@ -43,7 +43,7 @@ def process_ipcmd(str: str):
...
@@ -43,7 +43,7 @@ def process_ipcmd(str: str):
continue
continue
match = re.search(r'
^
\
s
+
inet6
([
0
-
9
a
-
f
:]
+
)
/
(
\
d
+
)(
?
:
|
$
)
', line)
match = re.search(r'
^
\
s
+
inet6
([
0
-
9
a
-
f
:]
+
)
/
(
\
d
+
)(
?
:
|
$
)
', line)
if match
!=
None:
if match
is not
None:
out[cur]['
addr
'].append({
out[cur]['
addr
'].append({
'
address
': match.group(1),
'
address
': match.group(1),
'
prefix_len
': int(match.group(2)),
'
prefix_len
': int(match.group(2)),
...
@@ -51,7 +51,7 @@ def process_ipcmd(str: str):
...
@@ -51,7 +51,7 @@ def process_ipcmd(str: str):
continue
continue
match = re.search(r'
^
\
s
{
4
}
', line)
match = re.search(r'
^
\
s
{
4
}
', line)
assert match
!=
None
assert match
is not
None
return out
return out
def get_devs():
def get_devs():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment