Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
R
re6stnet
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Milestones
Merge Requests
4
Merge Requests
4
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
nexedi
re6stnet
Commits
829117fd
Commit
829117fd
authored
Jul 17, 2024
by
Tom Niget
3
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
python3: add type annotations
parent
8931044f
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
195 additions
and
148 deletions
+195
-148
demo/demo
demo/demo
+9
-5
draft/re6st-cn
draft/re6st-cn
+1
-1
re6st/cache.py
re6st/cache.py
+16
-14
re6st/cli/conf.py
re6st/cli/conf.py
+1
-1
re6st/cli/node.py
re6st/cli/node.py
+1
-1
re6st/ctl.py
re6st/ctl.py
+11
-6
re6st/debug.py
re6st/debug.py
+6
-6
re6st/plib.py
re6st/plib.py
+12
-7
re6st/registry.py
re6st/registry.py
+52
-36
re6st/tests/test_unit/test_registry.py
re6st/tests/test_unit/test_registry.py
+11
-7
re6st/tests/tools.py
re6st/tests/tools.py
+6
-9
re6st/tunnel.py
re6st/tunnel.py
+23
-15
re6st/upnpigd.py
re6st/upnpigd.py
+2
-2
re6st/utils.py
re6st/utils.py
+15
-13
re6st/x509.py
re6st/x509.py
+29
-25
No files found.
demo/demo
View file @
829117fd
...
...
@@ -4,6 +4,7 @@ import socket, sqlite3, subprocess, sys, time, weakref
from collections import defaultdict
from contextlib import contextmanager
from threading import Thread
from typing import Optional
IPTABLES = 'iptables'
SCREEN = 'screen'
...
...
@@ -242,7 +243,7 @@ gateway1.screen(['miniupnpd', '-d', '-f', 'miniupnpd.conf', '-P',
'miniupnpd.pid', '-a', g1_if_1.name, '-i', g1_if_0_name])
@contextmanager
def new_network(registry
, reg_addr, serial, ca
):
def new_network(registry
: Re6stNode, reg_addr: str, serial: str, ca: str
):
from OpenSSL import crypto
import hashlib, sqlite3
os.path.exists(ca) or subprocess.check_call(
...
...
@@ -272,7 +273,9 @@ def new_network(registry, reg_addr, serial, ca):
time.sleep(.1)
""")).wait()
db = sqlite3.connect(db_path, isolation_level=None)
def new_node(node, folder, args='', prefix_len=None, registry=registry_url):
def new_node(node: Re6stNode, folder: str, args: list[str]=[],
prefix_len: Optional[int] = None, registry=registry_url):
nodes.append(node)
if not os.path.exists(folder + '/cert.crt'):
dh_path = folder + '/dh2048.pem'
...
...
@@ -378,8 +381,9 @@ if args.hmac:
t.start()
del t
_ll = {}
def node_by_ll(addr):
_ll: dict[str, tuple[Re6stNode, bool]] = {}
def node_by_ll(addr: str) -> tuple[Re6stNode, bool]:
try:
return _ll[addr]
except KeyError:
...
...
@@ -410,7 +414,7 @@ def node_by_ll(addr):
def route_svg(ipv4, z=4):
graph = {}
graph
: dict[Re6stNode, dict[tuple[Re6stNode, bool], list[Re6stNode]]]
= {}
for n in nodes:
g = graph[n] = defaultdict(list)
for r in n.get_routes():
...
...
draft/re6st-cn
View file @
829117fd
...
...
@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
from
re6st
import
utils
,
x509
from
OpenSSL
import
crypto
with
open
(
"/etc/re6stnet/ca.crt"
)
as
f
:
with
open
(
"/etc/re6stnet/ca.crt"
,
"rb"
)
as
f
:
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
network
=
x509
.
networkFromCa
(
ca
)
...
...
re6st/cache.py
View file @
829117fd
...
...
@@ -5,7 +5,7 @@ from . import utils, version, x509
class
Cache
:
def
__init__
(
self
,
db_path
,
registry
,
c
ert
,
db_size
=
200
):
def
__init__
(
self
,
db_path
:
str
,
registry
,
cert
:
x509
.
C
ert
,
db_size
=
200
):
self
.
_prefix
=
cert
.
prefix
self
.
_db_size
=
db_size
self
.
_decrypt
=
cert
.
decrypt
...
...
@@ -50,7 +50,7 @@ class Cache:
self
.
warnProtocol
()
logging
.
info
(
"Cache initialized."
)
def
_open
(
self
,
path
)
:
def
_open
(
self
,
path
:
str
)
->
sqlite3
.
Connection
:
db
=
sqlite3
.
connect
(
path
,
isolation_level
=
None
)
db
.
text_factory
=
str
db
.
execute
(
"PRAGMA synchronous = OFF"
)
...
...
@@ -146,7 +146,7 @@ class Cache:
logging
.
warning
(
"There's a new version of re6stnet:"
" you should update."
)
def
getDh
(
self
,
path
):
def
getDh
(
self
,
path
:
str
):
# We'd like to do a full check here but
# from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
...
...
@@ -178,11 +178,11 @@ class Cache:
logging
.
trace
(
"- %s: %s%s"
,
prefix
,
address
,
' (blacklisted)'
if
_try
else
''
)
def
cacheMinimize
(
self
,
size
):
def
cacheMinimize
(
self
,
size
:
int
):
with
self
.
_db
:
self
.
_cacheMinimize
(
size
)
def
_cacheMinimize
(
self
,
size
):
def
_cacheMinimize
(
self
,
size
:
int
):
a
=
self
.
_db
.
execute
(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1"
,
(
size
,)).
fetchall
()
...
...
@@ -191,26 +191,26 @@ class Cache:
q
(
"DELETE FROM peer WHERE prefix IN (?)"
,
a
)
q
(
"DELETE FROM volatile.stat WHERE peer IN (?)"
,
a
)
def
connecting
(
self
,
prefix
,
connecting
):
def
connecting
(
self
,
prefix
:
str
,
connecting
:
bool
):
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=? WHERE peer=?"
,
(
connecting
,
prefix
))
def
resetConnecting
(
self
):
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=0"
)
def
getAddress
(
self
,
prefix
)
:
def
getAddress
(
self
,
prefix
:
str
)
->
bool
:
r
=
self
.
_db
.
execute
(
"SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0"
,
(
prefix
,)).
fetchone
()
return
r
and
r
[
0
]
@
property
def
my_address
(
self
):
def
my_address
(
self
)
->
str
:
for
x
,
in
self
.
_db
.
execute
(
"SELECT address FROM peer WHERE prefix=''"
):
return
x
@
my_address
.
setter
def
my_address
(
self
,
value
):
def
my_address
(
self
,
value
:
str
):
if
value
:
with
self
.
_db
as
db
:
db
.
execute
(
"INSERT OR REPLACE INTO peer VALUES ('', ?)"
,
...
...
@@ -228,13 +228,15 @@ class Cache:
# IOW, one should probably always put our own address there.
_get_peer_sql
=
"SELECT %s FROM peer, volatile.stat"
\
" WHERE prefix=peer AND prefix!=? AND try=?"
def
getPeerList
(
self
,
failed
=
0
,
__sql
=
_get_peer_sql
%
"prefix, address"
def
getPeerList
(
self
,
failed
=
False
,
__sql
=
_get_peer_sql
%
"prefix, address"
+
" ORDER BY RANDOM()"
):
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
))
def
getPeerCount
(
self
,
failed
=
0
,
__sql
=
_get_peer_sql
%
"COUNT(*)"
):
def
getPeerCount
(
self
,
failed
=
False
,
__sql
=
_get_peer_sql
%
"COUNT(*)"
)
\
->
int
:
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
)).
next
()[
0
]
def
getBootstrapPeer
(
self
):
def
getBootstrapPeer
(
self
)
->
tuple
[
str
,
str
]
:
logging
.
info
(
'Getting Boot peer...'
)
try
:
bootpeer
=
self
.
_registry
.
getBootstrapPeer
(
self
.
_prefix
)
...
...
@@ -248,7 +250,7 @@ class Cache:
return
prefix
,
address
logging
.
warning
(
'Buggy registry sent us our own address'
)
def
addPeer
(
self
,
prefix
,
address
,
set_preferred
=
False
):
def
addPeer
(
self
,
prefix
:
str
,
address
:
str
,
set_preferred
=
False
):
logging
.
debug
(
'Adding peer %s: %s'
,
prefix
,
address
)
with
self
.
_db
:
q
=
self
.
_db
.
execute
...
...
@@ -272,7 +274,7 @@ class Cache:
q
(
"INSERT OR REPLACE INTO peer VALUES (?,?)"
,
(
prefix
,
address
))
q
(
"INSERT OR REPLACE INTO volatile.stat VALUES (?,0)"
,
(
prefix
,))
def
getCountry
(
self
,
ip
)
:
def
getCountry
(
self
,
ip
:
str
)
->
str
:
try
:
return
self
.
_registry
.
getCountry
(
self
.
_prefix
,
ip
).
decode
()
except
socket
.
error
as
e
:
...
...
re6st/cli/conf.py
View file @
829117fd
...
...
@@ -13,7 +13,7 @@ def create(path, text=None, mode=0o666):
finally
:
os
.
close
(
fd
)
def
loadCert
(
pem
):
def
loadCert
(
pem
:
bytes
):
return
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
pem
)
def
main
():
...
...
re6st/cli/node.py
View file @
829117fd
...
...
@@ -272,7 +272,7 @@ def main():
call
(
args
)
args
[
3
]
=
'del'
cleanup
.
append
(
lambda
:
subprocess
.
call
(
args
))
def
ip
(
object
,
*
args
):
def
ip
(
object
:
str
,
*
args
):
args
=
[
'ip'
,
'-6'
,
object
,
'add'
]
+
list
(
args
)
call
(
args
)
args
[
3
]
=
'del'
...
...
re6st/ctl.py
View file @
829117fd
...
...
@@ -34,13 +34,13 @@ class Array:
def
__init__
(
self
,
item
):
self
.
_item
=
item
def
encode
(
self
,
buffer
,
value
):
def
encode
(
self
,
buffer
:
bytes
,
value
:
list
):
buffer
+=
uint16
.
pack
(
len
(
value
))
encode
=
self
.
_item
.
encode
for
value
in
value
:
encode
(
buffer
,
value
)
def
decode
(
self
,
buffer
,
offset
=
0
)
:
def
decode
(
self
,
buffer
:
bytes
,
offset
=
0
)
->
tuple
[
int
,
list
]
:
r
=
[]
o
=
offset
+
2
decode
=
self
.
_item
.
decode
...
...
@@ -52,11 +52,11 @@ class Array:
class
String
:
@
staticmethod
def
encode
(
buffer
,
value
):
def
encode
(
buffer
:
bytes
,
value
:
str
):
buffer
+=
value
.
encode
(
"utf-8"
)
+
b'
\
x00
'
@
staticmethod
def
decode
(
buffer
,
offset
=
0
)
:
def
decode
(
buffer
:
bytes
,
offset
=
0
)
->
tuple
[
int
,
str
]
:
i
=
buffer
.
index
(
0
,
offset
)
return
i
+
1
,
buffer
[
offset
:
i
].
decode
(
"utf-8"
)
...
...
@@ -171,7 +171,7 @@ class Babel:
_decode
=
None
def
__init__
(
self
,
socket_path
,
handler
,
network
):
def
__init__
(
self
,
socket_path
:
str
,
handler
,
network
:
str
):
self
.
socket_path
=
socket_path
self
.
handler
=
handler
self
.
network
=
network
...
...
@@ -252,15 +252,18 @@ class Babel:
unidentified
=
set
(
n
)
self
.
neighbours
=
neighbours
=
{}
a
=
len
(
self
.
network
)
logging
.
info
(
"Routes: %r"
,
routes
)
for
route
in
routes
:
assert
route
.
flags
&
1
,
route
# installed
if
route
.
prefix
.
startswith
(
b'
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
xff
\
xff
'
):
logging
.
warning
(
"Ignoring IPv4 route: %r"
,
route
)
continue
assert
route
.
neigh_address
==
route
.
nexthop
,
route
address
=
route
.
neigh_address
,
route
.
ifindex
neigh_routes
=
n
[
address
]
ip
=
utils
.
binFromRawIp
(
route
.
prefix
)
if
ip
[:
a
]
==
self
.
network
:
logging
.
debug
(
"Route is on the network: %r"
,
route
)
prefix
=
ip
[
a
:
route
.
plen
]
if
prefix
and
not
route
.
refmetric
:
neighbours
[
prefix
]
=
neigh_routes
...
...
@@ -275,7 +278,9 @@ class Babel:
socket
.
inet_ntop
(
socket
.
AF_INET6
,
route
.
prefix
),
route
.
plen
)
else
:
logging
.
debug
(
"Route is not on the network: %r"
,
route
)
prefix
=
None
logging
.
debug
(
"Adding route %r to %r"
,
route
,
neigh_routes
)
neigh_routes
[
1
][
prefix
]
=
route
self
.
locked
.
clear
()
if
unidentified
:
...
...
@@ -299,7 +304,7 @@ class iterRoutes:
_waiting
=
True
def
__new__
(
cls
,
control_socket
,
network
):
def
__new__
(
cls
,
control_socket
:
str
,
network
:
str
):
self
=
object
.
__new__
(
cls
)
c
=
Babel
(
control_socket
,
self
,
network
)
c
.
request_dump
()
...
...
re6st/debug.py
View file @
829117fd
...
...
@@ -3,30 +3,30 @@ import errno, os, socket, stat, threading
class
Socket
:
def
__init__
(
self
,
socket
):
def
__init__
(
self
,
socket
:
socket
.
socket
):
# In case that the default timeout is not None.
socket
.
settimeout
(
None
)
self
.
_socket
=
socket
self
.
_buf
=
''
self
.
_buf
=
b
''
def
close
(
self
):
self
.
_socket
.
close
()
def
write
(
self
,
data
):
def
write
(
self
,
data
:
bytes
):
self
.
_socket
.
send
(
data
)
def
readline
(
self
):
def
readline
(
self
)
->
bytes
:
recv
=
self
.
_socket
.
recv
data
=
self
.
_buf
while
True
:
i
=
1
+
data
.
find
(
'
\
n
'
)
i
=
1
+
data
.
find
(
b
'
\
n
'
)
if
i
:
self
.
_buf
=
data
[
i
:]
return
data
[:
i
]
d
=
recv
(
4096
)
data
+=
d
if
not
d
:
self
.
_buf
=
''
self
.
_buf
=
b
''
return
data
def
flush
(
self
):
...
...
re6st/plib.py
View file @
829117fd
import
binascii
import
logging
,
errno
,
os
from
typing
import
Optional
from
.
import
utils
here
=
os
.
path
.
realpath
(
os
.
path
.
dirname
(
__file__
))
ovpn_server
=
os
.
path
.
join
(
here
,
'ovpn-server'
)
ovpn_client
=
os
.
path
.
join
(
here
,
'ovpn-client'
)
ovpn_log
=
None
ovpn_log
:
Optional
[
str
]
=
None
def
openvpn
(
iface
,
encrypt
,
*
args
,
**
kw
)
:
def
openvpn
(
iface
:
str
,
encrypt
,
*
args
,
**
kw
)
->
utils
.
Popen
:
args
=
[
'openvpn'
,
'--dev-type'
,
'tap'
,
'--dev'
,
iface
,
...
...
@@ -28,7 +29,8 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict
=
{
'udp4'
:
1432
,
'udp6'
:
1450
}
def
server
(
iface
,
max_clients
,
dh_path
,
fd
,
port
,
proto
,
encrypt
,
*
args
,
**
kw
):
def
server
(
iface
:
str
,
max_clients
:
int
,
dh_path
:
str
,
fd
:
int
,
port
:
int
,
proto
:
str
,
encrypt
:
bool
,
*
args
,
**
kw
)
->
utils
.
Popen
:
if
proto
==
'udp'
:
proto
=
'udp4'
client_script
=
'%s %s'
%
(
ovpn_server
,
fd
)
...
...
@@ -49,7 +51,8 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
*
args
,
pass_fds
=
[
fd
],
**
kw
)
def
client
(
iface
,
address_list
,
encrypt
,
*
args
,
**
kw
):
def
client
(
iface
:
str
,
address_list
:
list
[
tuple
[
str
,
int
,
str
]],
encrypt
:
bool
,
*
args
,
**
kw
)
->
utils
.
Popen
:
remote
=
[
'--nobind'
,
'--client'
]
# XXX: We'd like to pass <connection> sections at command-line.
link_mtu
=
set
()
...
...
@@ -65,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
return
openvpn
(
iface
,
encrypt
,
*
remote
,
**
kw
)
def
router
(
ip
,
ip4
,
rt6
,
hello_interval
,
log_path
,
state_path
,
pidfile
,
control_socket
,
default
,
hmac
,
*
args
,
**
kw
):
def
router
(
ip
:
tuple
[
str
,
int
],
ip4
,
rt6
:
tuple
[
str
,
bool
,
bool
],
hello_interval
:
int
,
log_path
:
str
,
state_path
:
str
,
pidfile
:
str
,
control_socket
:
str
,
default
:
str
,
hmac
:
tuple
[
bytes
|
None
,
bytes
|
None
],
*
args
,
**
kw
)
->
utils
.
Popen
:
network
,
gateway
,
has_ipv6_subtrees
=
rt6
network_mask
=
int
(
network
[
network
.
index
(
'/'
)
+
1
:])
ip
,
n
=
ip
...
...
@@ -83,7 +88,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C'
,
'redistribute local deny'
,
'-C'
,
'redistribute ip %s/%s eq %s'
%
(
ip
,
n
,
n
)]
if
hmac_sign
:
def
key
(
cmd
,
id
,
value
):
def
key
(
cmd
:
list
[
str
],
id
:
str
,
value
:
bytes
):
cmd
+=
'-C'
,
(
'key type blake2s128 id %s value %s'
%
(
id
,
binascii
.
hexlify
(
value
).
decode
()))
key
(
cmd
,
'sign'
,
hmac_sign
)
...
...
re6st/registry.py
View file @
829117fd
This diff is collapsed.
Click to expand it.
re6st/tests/test_unit/test_registry.py
View file @
829117fd
...
...
@@ -11,11 +11,13 @@ import hashlib
import
time
import
tempfile
from
argparse
import
Namespace
from
sqlite3
import
Cursor
from
OpenSSL
import
crypto
from
mock
import
Mock
,
patch
from
pathlib
import
Path
from
re6st
import
registry
from
re6st
import
registry
,
x509
from
re6st.tests.tools
import
*
from
re6st.tests
import
DEMO_PATH
...
...
@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
def
load_config
(
filename
=
"registry.json"
)
:
def
load_config
(
filename
:
str
=
"registry.json"
)
->
Namespace
:
with
open
(
filename
)
as
f
:
config
=
json
.
load
(
f
)
config
[
"dh"
]
=
DEMO_PATH
/
"dh2048.pem"
...
...
@@ -37,13 +39,14 @@ def load_config(filename="registry.json"):
return
Namespace
(
**
config
)
def
get_cert
(
cur
,
prefix
):
def
get_cert
(
cur
:
Cursor
,
prefix
:
str
):
res
=
cur
.
execute
(
"SELECT cert FROM cert WHERE prefix=?"
,
(
prefix
,)).
fetchone
()
return
res
[
0
]
def
insert_cert
(
cur
,
ca
,
prefix
,
not_after
=
None
,
email
=
None
):
def
insert_cert
(
cur
:
Cursor
,
ca
:
x509
.
Cert
,
prefix
:
str
,
not_after
=
None
,
email
=
None
):
key
,
csr
=
generate_csr
()
cert
=
generate_cert
(
ca
.
ca
,
ca
.
key
,
csr
,
prefix
,
insert_cert
.
serial
,
not_after
)
cur
.
execute
(
"INSERT INTO cert VALUES (?,?,?)"
,
(
prefix
,
email
,
cert
))
...
...
@@ -54,7 +57,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert
.
serial
=
0
def
delete_cert
(
cur
,
prefix
):
def
delete_cert
(
cur
:
Cursor
,
prefix
:
str
):
cur
.
execute
(
"DELETE FROM cert WHERE prefix = ?"
,
(
prefix
,))
...
...
@@ -80,7 +83,7 @@ class TestRegistryServer(unittest.TestCase):
def
setUp
(
self
):
self
.
email
=
''
.
join
(
random
.
sample
(
string
.
ascii_lowercase
,
4
))
\
+
"@mail.com"
+
"@mail.com"
def
test_recv
(
self
):
side_effect
=
iter
([
...
...
@@ -178,7 +181,8 @@ class TestRegistryServer(unittest.TestCase):
request
.
send_header
.
assert_any_call
(
"Content-Length"
,
str
(
len
(
result
)))
request
.
send_header
.
assert_any_call
(
registry
.
HMAC_HEADER
,
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
result
,
hashlib
.
sha1
).
digest
()).
decode
(
"ascii"
))
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
result
,
hashlib
.
sha1
).
digest
())
.
decode
(
"ascii"
))
request
.
wfile
.
write
.
assert_called_once_with
(
result
)
# remove the create session \n
...
...
re6st/tests/tools.py
View file @
829117fd
...
...
@@ -92,18 +92,15 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
return
key
,
cert
def
prefix2cn
(
prefix
)
:
def
prefix2cn
(
prefix
:
str
)
->
str
:
return
"%u/%u"
%
(
int
(
prefix
,
2
),
len
(
prefix
))
def
serial2prefix
(
serial
)
:
def
serial2prefix
(
serial
:
int
)
->
str
:
return
bin
(
serial
)[
2
:].
rjust
(
16
,
'0'
)
# pkey: private key
def
decrypt
(
pkey
,
incontent
)
:
with
open
(
"node.key"
,
'w'
)
as
f
:
f
.
write
(
pkey
.
decode
()
)
def
decrypt
(
pkey
:
bytes
,
incontent
:
bytes
)
->
bytes
:
with
open
(
"node.key"
,
'w
b
'
)
as
f
:
f
.
write
(
pkey
)
args
=
"openssl rsautl -decrypt -inkey node.key"
.
split
()
with
subprocess
.
Popen
(
args
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
as
p
:
outcontent
,
err
=
p
.
communicate
(
incontent
)
return
outcontent
return
subprocess
.
run
(
args
,
input
=
incontent
,
stdout
=
subprocess
.
PIPE
).
stdout
re6st/tunnel.py
View file @
829117fd
...
...
@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
import
subprocess
,
struct
,
sys
,
time
,
weakref
from
collections
import
defaultdict
,
deque
from
bisect
import
bisect
,
insort
from
collections.abc
import
Iterator
,
Sequence
from
typing
import
Callable
,
TYPE_CHECKING
from
OpenSSL
import
crypto
from
.
import
ctl
,
plib
,
utils
,
version
,
x509
if
TYPE_CHECKING
:
from
.
import
cache
PORT
=
326
...
...
@@ -21,7 +26,8 @@ proto_dict = {
proto_dict
[
'tcp'
]
=
proto_dict
[
'tcp4'
]
proto_dict
[
'udp'
]
=
proto_dict
[
'udp4'
]
def
resolve
(
ip
,
port
,
proto
):
def
resolve
(
ip
,
port
,
proto
:
str
)
\
->
tuple
[
socket
.
AddressFamily
|
None
,
Iterator
[
str
]]:
try
:
family
,
proto
=
proto_dict
[
proto
]
except
KeyError
:
...
...
@@ -31,16 +37,16 @@ def resolve(ip, port, proto):
class
MultiGatewayManager
(
dict
):
def
__init__
(
self
,
gateway
):
def
__init__
(
self
,
gateway
:
Callable
[[
str
],
str
]
):
self
.
_gw
=
gateway
def
_route
(
self
,
cmd
,
dest
,
gw
):
def
_route
(
self
,
cmd
:
str
,
dest
:
str
,
gw
:
str
):
if
gw
:
cmd
=
'ip'
,
'-4'
,
'route'
,
cmd
,
'%s/32'
%
dest
,
'via'
,
gw
logging
.
trace
(
'%r'
,
cmd
)
subprocess
.
check_call
(
cmd
)
def
add
(
self
,
dest
,
route
):
def
add
(
self
,
dest
:
str
,
route
:
bool
):
try
:
self
[
dest
][
1
]
+=
1
except
KeyError
:
...
...
@@ -48,7 +54,7 @@ class MultiGatewayManager(dict):
self
[
dest
]
=
[
gw
,
0
]
self
.
_route
(
'add'
,
dest
,
gw
)
def
remove
(
self
,
dest
):
def
remove
(
self
,
dest
:
str
):
gw
,
count
=
self
[
dest
]
if
count
:
self
[
dest
][
1
]
=
count
-
1
...
...
@@ -65,7 +71,7 @@ class Connection:
serial
=
None
time
=
float
(
'inf'
)
def
__init__
(
self
,
tunnel_manager
,
address_list
,
iface
,
prefix
):
def
__init__
(
self
,
tunnel_manager
:
"TunnelManager"
,
address_list
,
iface
,
prefix
):
self
.
tunnel_manager
=
tunnel_manager
self
.
address_list
=
address_list
self
.
iface
=
iface
...
...
@@ -109,7 +115,7 @@ class Connection:
if
i
:
cache
.
addPeer
(
self
.
_prefix
,
','
.
join
(
self
.
address_list
[
i
]),
True
)
else
:
cache
.
connecting
(
self
.
_prefix
,
0
)
cache
.
connecting
(
self
.
_prefix
,
False
)
def
close
(
self
):
try
:
...
...
@@ -198,7 +204,8 @@ class BaseTunnelManager:
_geoiplookup
=
None
_forward
=
None
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
conf_country
,
address
=
()):
def
__init__
(
self
,
control_socket
,
cache
:
"cache.Cache"
,
cert
:
x509
.
Cert
,
conf_country
,
address
=
()):
self
.
cert
=
cert
self
.
_network
=
cert
.
network
self
.
_prefix
=
cert
.
prefix
...
...
@@ -329,7 +336,7 @@ class BaseTunnelManager:
def
_getPeer
(
self
,
prefix
):
return
self
.
_peers
[
bisect
(
self
.
_peers
,
prefix
)
-
1
]
def
sendto
(
self
,
prefix
,
msg
):
def
sendto
(
self
,
prefix
:
str
,
msg
):
to
=
utils
.
ipFromBin
(
self
.
_network
+
prefix
),
PORT
peer
=
self
.
_getPeer
(
prefix
)
if
peer
.
prefix
!=
prefix
:
...
...
@@ -451,7 +458,7 @@ class BaseTunnelManager:
peer
)
def
_processPacket
(
self
,
msg
,
pee
r
=
None
):
def
_processPacket
(
self
,
msg
:
bytes
,
peer
:
x509
.
Peer
|
st
r
=
None
):
c
=
msg
[
0
]
msg
=
msg
[
1
:]
code
=
c
&
0x7f
...
...
@@ -565,12 +572,12 @@ class BaseTunnelManager:
self
.
selectTimeout
(
time
.
time
()
+
1
+
self
.
cache
.
delay_restart
,
self
.
_restart
)
def
handleServerEvent
(
self
,
sock
):
def
handleServerEvent
(
self
,
sock
:
socket
.
socket
):
event
,
args
=
eval
(
sock
.
recv
(
65536
))
logging
.
debug
(
"%s%r"
,
event
,
args
)
r
=
getattr
(
self
,
'_ovpn_'
+
event
.
replace
(
'-'
,
'_'
))(
*
args
)
if
r
is
not
None
:
sock
.
send
(
chr
(
r
))
sock
.
send
(
bytes
([
r
]
))
def
_ovpn_client_connect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
if
serial
in
self
.
cache
.
crl
:
...
...
@@ -582,7 +589,7 @@ class BaseTunnelManager:
self
.
_gateway_manager
.
add
(
trusted_ip
,
False
)
if
prefix
in
self
.
_connection_dict
and
self
.
_prefix
<
prefix
:
self
.
_kill
(
prefix
)
self
.
cache
.
connecting
(
prefix
,
0
)
self
.
cache
.
connecting
(
prefix
,
False
)
return
True
def
_ovpn_client_disconnect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
...
...
@@ -666,7 +673,8 @@ class TunnelManager(BaseTunnelManager):
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
openvpn_args
,
timeout
,
client_count
,
iface_list
,
conf_country
,
address
,
ip_changed
,
remote_gateway
,
disable_proto
,
neighbour_list
=
()):
ip_changed
,
remote_gateway
:
Callable
[[
str
],
str
],
disable_proto
:
Sequence
[
str
],
neighbour_list
=
()):
super
(
TunnelManager
,
self
).
__init__
(
control_socket
,
cache
,
cert
,
conf_country
,
address
)
self
.
ovpn_args
=
openvpn_args
...
...
@@ -878,7 +886,7 @@ class TunnelManager(BaseTunnelManager):
address_list
.
append
((
ip
,
x
[
1
],
x
[
2
]))
continue
address_list
.
append
(
x
[:
3
])
self
.
cache
.
connecting
(
prefix
,
1
)
self
.
cache
.
connecting
(
prefix
,
True
)
if
not
address_list
:
return
False
logging
.
info
(
'Establishing a connection with %u/%u'
,
...
...
re6st/upnpigd.py
View file @
829117fd
...
...
@@ -17,7 +17,7 @@ class Forwarder:
_lcg_n
=
0
@
classmethod
def
_getExternalPort
(
cls
):
def
_getExternalPort
(
cls
)
->
int
:
# Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without
...
...
@@ -35,7 +35,7 @@ class Forwarder:
self
.
_u
.
discoverdelay
=
200
self
.
_rules
=
[]
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
:
str
):
wrapped
=
getattr
(
self
.
_u
,
name
)
def
wrapper
(
*
args
,
**
kw
):
try
:
...
...
re6st/utils.py
View file @
829117fd
import
argparse
,
errno
,
fcntl
,
hashlib
,
logging
,
os
,
select
as
_select
import
shlex
,
signal
,
socket
,
sqlite3
,
struct
,
subprocess
import
sys
,
textwrap
,
threading
,
time
,
traceback
from
collections.abc
import
Iterator
,
Mapping
HMAC_LEN
=
len
(
hashlib
.
sha1
(
b''
).
digest
())
...
...
@@ -40,7 +40,7 @@ class FileHandler(logging.FileHandler):
if
self
.
lock
.
acquire
(
False
):
self
.
release
()
def
setupLog
(
log_level
,
filenam
e
=
None
,
**
kw
):
def
setupLog
(
log_level
:
int
,
filename
:
str
|
Non
e
=
None
,
**
kw
):
if
log_level
and
filename
:
makedirs
(
os
.
path
.
dirname
(
filename
))
handler
=
FileHandler
(
filename
)
...
...
@@ -184,7 +184,7 @@ def setCloexec(fd):
flags
=
fcntl
.
fcntl
(
fd
,
fcntl
.
F_GETFD
)
fcntl
.
fcntl
(
fd
,
fcntl
.
F_SETFD
,
flags
|
fcntl
.
FD_CLOEXEC
)
def
select
(
R
,
W
,
T
):
def
select
(
R
:
Mapping
,
W
:
Mapping
,
T
):
try
:
r
,
w
,
_
=
_select
.
select
(
R
,
W
,
(),
max
(
0
,
min
(
T
)[
0
]
-
time
.
time
())
if
T
else
None
)
...
...
@@ -208,15 +208,15 @@ def makedirs(*args):
if
e
.
errno
!=
errno
.
EEXIST
:
raise
def
binFromIp
(
ip
)
:
def
binFromIp
(
ip
:
str
)
->
str
:
return
binFromRawIp
(
socket
.
inet_pton
(
socket
.
AF_INET6
,
ip
))
def
binFromRawIp
(
ip
)
:
def
binFromRawIp
(
ip
:
bytes
)
->
str
:
ip1
,
ip2
=
struct
.
unpack
(
'>QQ'
,
ip
)
return
bin
(
ip1
)[
2
:].
rjust
(
64
,
'0'
)
+
bin
(
ip2
)[
2
:].
rjust
(
64
,
'0'
)
def
ipFromBin
(
ip
,
suffix
=
''
)
:
def
ipFromBin
(
ip
:
str
,
suffix
=
''
)
->
str
:
suffix_len
=
128
-
len
(
ip
)
if
suffix_len
>
0
:
ip
+=
suffix
.
rjust
(
suffix_len
,
'0'
)
...
...
@@ -225,11 +225,11 @@ def ipFromBin(ip, suffix=''):
return
socket
.
inet_ntop
(
socket
.
AF_INET6
,
struct
.
pack
(
'>QQ'
,
int
(
ip
[:
64
],
2
),
int
(
ip
[
64
:],
2
)))
def
dump_address
(
address
)
:
def
dump_address
(
address
:
str
)
->
str
:
return
';'
.
join
(
map
(
','
.
join
,
address
))
# Yield ip, port, protocol, and country if it is in the address
def
parse_address
(
address_list
)
:
def
parse_address
(
address_list
:
str
)
->
Iterator
[
tuple
[
str
,
str
,
str
,
str
]]
:
for
address
in
address_list
.
split
(
';'
):
try
:
a
=
address
.
split
(
','
)
...
...
@@ -239,16 +239,18 @@ def parse_address(address_list):
logging
.
warning
(
"Failed to parse node address %r (%s)"
,
address
,
e
)
def
binFromSubnet
(
subnet
)
:
def
binFromSubnet
(
subnet
:
str
)
->
str
:
p
,
l
=
subnet
.
split
(
'/'
)
return
bin
(
int
(
p
))[
2
:].
rjust
(
int
(
l
),
'0'
)
def
newHmacSecret
():
def
_
newHmacSecret
():
from
random
import
getrandbits
as
g
pack
=
struct
.
Struct
(
">QQI"
).
pack
assert
len
(
pack
(
0
,
0
,
0
))
==
HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return
lambda
x
=
None
:
pack
(
g
(
64
)
if
x
is
None
else
x
,
g
(
64
),
g
(
32
))
newHmacSecret
=
newHmacSecret
()
newHmacSecret
=
_newHmacSecret
()
# https://github.com/python/mypy/issues/1174
### Integer serialization
# - supports values from 0 to 0x202020202020201f
...
...
@@ -256,7 +258,7 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
def
packInteger
(
i
)
:
def
packInteger
(
i
:
int
)
->
bytes
:
for
n
in
range
(
8
):
x
=
32
<<
8
*
n
if
i
<
x
:
...
...
@@ -264,7 +266,7 @@ def packInteger(i):
i
-=
x
raise
OverflowError
def
unpackInteger
(
x
)
:
def
unpackInteger
(
x
:
bytes
)
->
tuple
[
int
,
int
]
|
None
:
n
=
x
[
0
]
>>
5
try
:
i
,
=
struct
.
unpack
(
"!Q"
,
b'
\
0
'
*
(
7
-
n
)
+
x
[:
n
+
1
])
...
...
re6st/x509.py
View file @
829117fd
# -*- coding: utf-8 -*-
import
calendar
,
hashlib
,
hmac
,
logging
,
os
,
struct
,
subprocess
,
threading
,
time
from
typing
import
Callable
,
Any
from
OpenSSL
import
crypto
from
cryptography.hazmat.primitives
import
hashes
from
cryptography.hazmat.primitives.asymmetric
import
padding
...
...
@@ -9,14 +11,14 @@ from cryptography.x509 import load_pem_x509_certificate
from
.
import
utils
from
.version
import
protocol
def
newHmacSecret
():
def
newHmacSecret
()
->
bytes
:
return
utils
.
newHmacSecret
(
int
(
time
.
time
()
*
1000000
))
def
networkFromCa
(
ca
)
:
def
networkFromCa
(
ca
:
crypto
.
X509
)
->
str
:
# TODO: will be ca.serial_number after migration to cryptography
return
bin
(
ca
.
get_serial_number
())[
3
:]
def
subnetFromCert
(
cert
)
:
def
subnetFromCert
(
cert
:
crypto
.
X509
)
->
str
:
return
cert
.
get_subject
().
CN
def
notBefore
(
cert
:
crypto
.
X509
)
->
int
:
...
...
@@ -27,13 +29,13 @@ def notAfter(cert: crypto.X509) -> int:
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notAfter
().
decode
(),
'%Y%m%d%H%M%SZ'
))
def
openssl
(
*
args
,
fds
=
[])
:
def
openssl
(
*
args
:
str
,
fds
=
[])
->
utils
.
Popen
:
return
utils
.
Popen
((
'openssl'
,)
+
args
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
pass_fds
=
fds
)
def
encrypt
(
cert
,
data
)
:
def
encrypt
(
cert
:
bytes
,
data
:
bytes
)
->
bytes
:
r
,
w
=
os
.
pipe
()
try
:
threading
.
Thread
(
target
=
os
.
write
,
args
=
(
w
,
cert
)).
start
()
...
...
@@ -47,10 +49,12 @@ def encrypt(cert, data):
raise
subprocess
.
CalledProcessError
(
p
.
returncode
,
'openssl'
,
err
)
return
out
def
fingerprint
(
cert
,
alg
=
'sha1'
):
def
fingerprint
(
cert
:
crypto
.
X509
,
alg
=
'sha1'
):
return
hashlib
.
new
(
alg
,
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
))
def
maybe_renew
(
path
,
cert
,
info
,
renew
,
force
=
False
):
def
maybe_renew
(
path
:
str
,
cert
:
crypto
.
X509
,
info
:
str
,
renew
:
Callable
[[],
bytes
],
force
=
False
)
->
tuple
[
crypto
.
X509
,
int
]:
from
.registry
import
RENEW_PERIOD
while
True
:
if
force
:
...
...
@@ -94,7 +98,7 @@ class NewSessionError(Exception):
class
Cert
:
def
__init__
(
self
,
ca
,
key
,
cert
=
None
):
def
__init__
(
self
,
ca
:
str
,
key
:
str
,
cert
:
str
|
None
=
None
):
self
.
ca_path
=
ca
self
.
cert_path
=
cert
self
.
key_path
=
key
...
...
@@ -112,24 +116,24 @@ class Cert:
self
.
cert
=
self
.
loadVerify
(
f
.
read
().
encode
())
@
property
def
prefix
(
self
):
def
prefix
(
self
)
->
str
:
return
utils
.
binFromSubnet
(
subnetFromCert
(
self
.
cert
))
@
property
def
network
(
self
):
def
network
(
self
)
->
str
:
return
networkFromCa
(
self
.
ca
)
@
property
def
subject_serial
(
self
):
def
subject_serial
(
self
)
->
int
:
return
int
(
self
.
cert
.
get_subject
().
serialNumber
)
@
property
def
openvpn_args
(
self
):
def
openvpn_args
(
self
)
->
tuple
[
str
,
...]
:
return
(
'--ca'
,
self
.
ca_path
,
'--cert'
,
self
.
cert_path
,
'--key'
,
self
.
key_path
)
def
maybeRenew
(
self
,
registry
,
crl
):
def
maybeRenew
(
self
,
registry
,
crl
)
->
int
:
self
.
cert
,
next_renew
=
maybe_renew
(
self
.
cert_path
,
self
.
cert
,
"Certificate"
,
lambda
:
registry
.
renewCertificate
(
self
.
prefix
),
self
.
cert
.
get_serial_number
()
in
crl
)
...
...
@@ -165,7 +169,6 @@ class Cert:
return
r
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
assert
isinstance
(
data
,
bytes
)
pub_key
=
self
.
ca_crypto
.
public_key
()
pub_key
.
verify
(
sign
,
...
...
@@ -175,14 +178,13 @@ class Cert:
)
def
sign
(
self
,
data
:
bytes
)
->
bytes
:
assert
isinstance
(
data
,
bytes
)
return
self
.
key_crypto
.
sign
(
data
,
padding
.
PKCS1v15
(),
hashes
.
SHA512
()
)
def
decrypt
(
self
,
data
)
:
def
decrypt
(
self
,
data
:
bytes
)
->
bytes
:
p
=
openssl
(
'rsautl'
,
'-decrypt'
,
'-inkey'
,
self
.
key_path
)
out
,
err
=
p
.
communicate
(
data
)
if
p
.
returncode
:
...
...
@@ -232,8 +234,9 @@ class Peer:
serial
=
None
stop_date
=
float
(
'inf'
)
version
=
b''
cert
:
crypto
.
X509
def
__init__
(
self
,
prefix
):
def
__init__
(
self
,
prefix
:
str
):
self
.
prefix
=
prefix
@
property
...
...
@@ -249,7 +252,7 @@ class Peer:
def
__lt__
(
self
,
other
):
return
self
.
prefix
<
(
other
if
type
(
other
)
is
str
else
other
.
prefix
)
def
hello0
(
self
,
cert
)
:
def
hello0
(
self
,
cert
:
crypto
.
X509
)
->
bytes
:
if
self
.
_hello
<
time
.
time
():
try
:
# Always assume peer is not old, in case it has just upgraded,
...
...
@@ -264,7 +267,7 @@ class Peer:
def
hello0Sent
(
self
):
self
.
_hello
=
time
.
time
()
+
60
def
hello
(
self
,
cert
,
protocol
)
:
def
hello
(
self
,
cert
:
Cert
,
protocol
:
int
)
->
bytes
:
key
=
self
.
_key
=
newHmacSecret
()
h
=
encrypt
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
),
key
)
...
...
@@ -274,10 +277,10 @@ class Peer:
return
b''
.
join
((
b'
\
0
\
0
\
0
\
2
'
,
PACKED_PROTOCOL
if
protocol
else
b''
,
h
,
cert
.
sign
(
h
)))
def
_hmac
(
self
,
msg
)
:
def
_hmac
(
self
,
msg
:
bytes
)
->
bytes
:
return
hmac
.
HMAC
(
self
.
_key
,
msg
,
hashlib
.
sha1
).
digest
()
def
newSession
(
self
,
key
,
protocol
):
def
newSession
(
self
,
key
:
bytes
,
protocol
:
int
):
if
key
<=
self
.
_key
:
raise
NewSessionError
(
self
.
_key
,
key
)
self
.
_key
=
key
...
...
@@ -285,12 +288,13 @@ class Peer:
self
.
_last
=
None
self
.
protocol
=
protocol
def
verify
(
self
,
sign
,
data
):
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
crypto
.
verify
(
self
.
cert
,
sign
,
data
,
'sha512'
)
seqno_struct
=
struct
.
Struct
(
"!L"
)
def
decode
(
self
,
msg
,
_unpack
=
seqno_struct
.
unpack
):
def
decode
(
self
,
msg
:
bytes
,
_unpack
=
seqno_struct
.
unpack
)
\
->
tuple
[
int
,
bytes
,
int
|
None
]
|
bytes
:
seqno
,
=
_unpack
(
msg
[:
4
])
if
seqno
<=
2
:
msg
=
msg
[
4
:]
...
...
@@ -304,9 +308,9 @@ class Peer:
if
self
.
_hmac
(
msg
[:
i
])
==
msg
[
i
:]
and
self
.
_i
<
seqno
:
self
.
_last
=
None
self
.
_i
=
seqno
return
msg
[
4
:
i
]
.
decode
()
return
msg
[
4
:
i
]
def
encode
(
self
,
msg
,
_pack
=
seqno_struct
.
pack
)
:
def
encode
(
self
,
msg
:
str
|
bytes
,
_pack
=
seqno_struct
.
pack
)
->
bytes
:
self
.
_j
+=
1
if
type
(
msg
)
is
str
:
msg
=
msg
.
encode
()
...
...
Julien Muchembled
@jm
mentioned in merge request
!46 (merged)
·
Oct 01, 2024
mentioned in merge request
!46 (merged)
mentioned in merge request !46
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment