Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
R
re6stnet
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Milestones
Merge Requests
4
Merge Requests
4
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
nexedi
re6stnet
Commits
fd5bda0a
Commit
fd5bda0a
authored
Jul 17, 2024
by
Tom Niget
Committed by
Tom Niget
Oct 02, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add some type annotations
parent
191b0781
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
166 additions
and
125 deletions
+166
-125
demo/demo
demo/demo
+9
-5
re6st/cache.py
re6st/cache.py
+16
-14
re6st/cli/conf.py
re6st/cli/conf.py
+1
-1
re6st/cli/node.py
re6st/cli/node.py
+1
-1
re6st/ctl.py
re6st/ctl.py
+6
-6
re6st/debug.py
re6st/debug.py
+3
-3
re6st/plib.py
re6st/plib.py
+12
-7
re6st/registry.py
re6st/registry.py
+39
-27
re6st/tests/test_unit/test_registry.py
re6st/tests/test_unit/test_registry.py
+8
-5
re6st/tests/tools.py
re6st/tests/tools.py
+3
-3
re6st/tunnel.py
re6st/tunnel.py
+23
-14
re6st/upnpigd.py
re6st/upnpigd.py
+2
-2
re6st/utils.py
re6st/utils.py
+15
-13
re6st/x509.py
re6st/x509.py
+28
-24
No files found.
demo/demo
View file @
fd5bda0a
...
@@ -4,6 +4,7 @@ import socket, sqlite3, subprocess, sys, time, weakref
...
@@ -4,6 +4,7 @@ import socket, sqlite3, subprocess, sys, time, weakref
from collections import defaultdict
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager
from threading import Thread
from threading import Thread
from typing import Optional
IPTABLES = 'iptables'
IPTABLES = 'iptables'
SCREEN = 'screen'
SCREEN = 'screen'
...
@@ -242,7 +243,7 @@ gateway1.screen(['miniupnpd', '-d', '-f', 'miniupnpd.conf', '-P',
...
@@ -242,7 +243,7 @@ gateway1.screen(['miniupnpd', '-d', '-f', 'miniupnpd.conf', '-P',
'miniupnpd.pid', '-a', g1_if_1.name, '-i', g1_if_0_name])
'miniupnpd.pid', '-a', g1_if_1.name, '-i', g1_if_0_name])
@contextmanager
@contextmanager
def new_network(registry
, reg_addr, serial, ca
):
def new_network(registry
: Re6stNode, reg_addr: str, serial: str, ca: str
):
from OpenSSL import crypto
from OpenSSL import crypto
import hashlib, sqlite3
import hashlib, sqlite3
os.path.exists(ca) or subprocess.check_call(
os.path.exists(ca) or subprocess.check_call(
...
@@ -272,7 +273,9 @@ def new_network(registry, reg_addr, serial, ca):
...
@@ -272,7 +273,9 @@ def new_network(registry, reg_addr, serial, ca):
time.sleep(.1)
time.sleep(.1)
""")).wait()
""")).wait()
db = sqlite3.connect(db_path, isolation_level=None)
db = sqlite3.connect(db_path, isolation_level=None)
def new_node(node, folder, args='', prefix_len=None, registry=registry_url):
def new_node(node: Re6stNode, folder: str, args: list[str]=[],
prefix_len: Optional[int] = None, registry=registry_url):
nodes.append(node)
nodes.append(node)
if not os.path.exists(folder + '/cert.crt'):
if not os.path.exists(folder + '/cert.crt'):
dh_path = folder + '/dh2048.pem'
dh_path = folder + '/dh2048.pem'
...
@@ -382,8 +385,9 @@ if args.hmac:
...
@@ -382,8 +385,9 @@ if args.hmac:
t.start()
t.start()
del t
del t
_ll = {}
_ll: dict[str, tuple[Re6stNode, bool]] = {}
def node_by_ll(addr):
def node_by_ll(addr: str) -> tuple[Re6stNode, bool]:
try:
try:
return _ll[addr]
return _ll[addr]
except KeyError:
except KeyError:
...
@@ -414,7 +418,7 @@ def node_by_ll(addr):
...
@@ -414,7 +418,7 @@ def node_by_ll(addr):
def route_svg(ipv4, z=4):
def route_svg(ipv4, z=4):
graph = {}
graph
: dict[Re6stNode, dict[tuple[Re6stNode, bool], list[Re6stNode]]]
= {}
for n in nodes:
for n in nodes:
g = graph[n] = defaultdict(list)
g = graph[n] = defaultdict(list)
for r in n.get_routes():
for r in n.get_routes():
...
...
re6st/cache.py
View file @
fd5bda0a
...
@@ -5,7 +5,7 @@ from . import utils, version, x509
...
@@ -5,7 +5,7 @@ from . import utils, version, x509
class
Cache
:
class
Cache
:
def
__init__
(
self
,
db_path
,
registry
,
c
ert
,
db_size
=
200
):
def
__init__
(
self
,
db_path
:
str
,
registry
,
cert
:
x509
.
C
ert
,
db_size
=
200
):
self
.
_prefix
=
cert
.
prefix
self
.
_prefix
=
cert
.
prefix
self
.
_db_size
=
db_size
self
.
_db_size
=
db_size
self
.
_decrypt
=
cert
.
decrypt
self
.
_decrypt
=
cert
.
decrypt
...
@@ -50,7 +50,7 @@ class Cache:
...
@@ -50,7 +50,7 @@ class Cache:
self
.
warnProtocol
()
self
.
warnProtocol
()
logging
.
info
(
"Cache initialized."
)
logging
.
info
(
"Cache initialized."
)
def
_open
(
self
,
path
)
:
def
_open
(
self
,
path
:
str
)
->
sqlite3
.
Connection
:
db
=
sqlite3
.
connect
(
path
,
isolation_level
=
None
)
db
=
sqlite3
.
connect
(
path
,
isolation_level
=
None
)
db
.
text_factory
=
str
db
.
text_factory
=
str
db
.
execute
(
"PRAGMA synchronous = OFF"
)
db
.
execute
(
"PRAGMA synchronous = OFF"
)
...
@@ -141,7 +141,7 @@ class Cache:
...
@@ -141,7 +141,7 @@ class Cache:
logging
.
warning
(
"There's a new version of re6stnet:"
logging
.
warning
(
"There's a new version of re6stnet:"
" you should update."
)
" you should update."
)
def
getDh
(
self
,
path
):
def
getDh
(
self
,
path
:
str
):
# We'd like to do a full check here but
# We'd like to do a full check here but
# from OpenSSL import SSL
# from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
...
@@ -173,11 +173,11 @@ class Cache:
...
@@ -173,11 +173,11 @@ class Cache:
logging
.
trace
(
"- %s: %s%s"
,
prefix
,
address
,
logging
.
trace
(
"- %s: %s%s"
,
prefix
,
address
,
' (blacklisted)'
if
_try
else
''
)
' (blacklisted)'
if
_try
else
''
)
def
cacheMinimize
(
self
,
size
):
def
cacheMinimize
(
self
,
size
:
int
):
with
self
.
_db
:
with
self
.
_db
:
self
.
_cacheMinimize
(
size
)
self
.
_cacheMinimize
(
size
)
def
_cacheMinimize
(
self
,
size
):
def
_cacheMinimize
(
self
,
size
:
int
):
a
=
self
.
_db
.
execute
(
a
=
self
.
_db
.
execute
(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1"
,
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1"
,
(
size
,)).
fetchall
()
(
size
,)).
fetchall
()
...
@@ -186,26 +186,26 @@ class Cache:
...
@@ -186,26 +186,26 @@ class Cache:
q
(
"DELETE FROM peer WHERE prefix IN (?)"
,
a
)
q
(
"DELETE FROM peer WHERE prefix IN (?)"
,
a
)
q
(
"DELETE FROM volatile.stat WHERE peer IN (?)"
,
a
)
q
(
"DELETE FROM volatile.stat WHERE peer IN (?)"
,
a
)
def
connecting
(
self
,
prefix
,
connecting
):
def
connecting
(
self
,
prefix
:
str
,
connecting
:
bool
):
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=? WHERE peer=?"
,
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=? WHERE peer=?"
,
(
connecting
,
prefix
))
(
connecting
,
prefix
))
def
resetConnecting
(
self
):
def
resetConnecting
(
self
):
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=0"
)
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=0"
)
def
getAddress
(
self
,
prefix
)
:
def
getAddress
(
self
,
prefix
:
str
)
->
bool
:
r
=
self
.
_db
.
execute
(
"SELECT address FROM peer, volatile.stat"
r
=
self
.
_db
.
execute
(
"SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0"
,
" WHERE prefix=? AND prefix=peer AND try=0"
,
(
prefix
,)).
fetchone
()
(
prefix
,)).
fetchone
()
return
r
and
r
[
0
]
return
r
and
r
[
0
]
@
property
@
property
def
my_address
(
self
):
def
my_address
(
self
)
->
str
:
for
x
,
in
self
.
_db
.
execute
(
"SELECT address FROM peer WHERE prefix=''"
):
for
x
,
in
self
.
_db
.
execute
(
"SELECT address FROM peer WHERE prefix=''"
):
return
x
return
x
@
my_address
.
setter
@
my_address
.
setter
def
my_address
(
self
,
value
):
def
my_address
(
self
,
value
:
str
):
if
value
:
if
value
:
with
self
.
_db
as
db
:
with
self
.
_db
as
db
:
db
.
execute
(
"INSERT OR REPLACE INTO peer VALUES ('', ?)"
,
db
.
execute
(
"INSERT OR REPLACE INTO peer VALUES ('', ?)"
,
...
@@ -223,13 +223,15 @@ class Cache:
...
@@ -223,13 +223,15 @@ class Cache:
# IOW, one should probably always put our own address there.
# IOW, one should probably always put our own address there.
_get_peer_sql
=
"SELECT %s FROM peer, volatile.stat"
\
_get_peer_sql
=
"SELECT %s FROM peer, volatile.stat"
\
" WHERE prefix=peer AND prefix!=? AND try=?"
" WHERE prefix=peer AND prefix!=? AND try=?"
def
getPeerList
(
self
,
failed
=
0
,
__sql
=
_get_peer_sql
%
"prefix, address"
def
getPeerList
(
self
,
failed
=
False
,
__sql
=
_get_peer_sql
%
"prefix, address"
+
" ORDER BY RANDOM()"
):
+
" ORDER BY RANDOM()"
):
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
))
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
))
def
getPeerCount
(
self
,
failed
=
0
,
__sql
=
_get_peer_sql
%
"COUNT(*)"
):
def
getPeerCount
(
self
,
failed
=
False
,
__sql
=
_get_peer_sql
%
"COUNT(*)"
)
\
->
int
:
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
)).
next
()[
0
]
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
)).
next
()[
0
]
def
getBootstrapPeer
(
self
):
def
getBootstrapPeer
(
self
)
->
tuple
[
str
,
str
]
:
logging
.
info
(
'Getting Boot peer...'
)
logging
.
info
(
'Getting Boot peer...'
)
try
:
try
:
bootpeer
=
self
.
_registry
.
getBootstrapPeer
(
self
.
_prefix
)
bootpeer
=
self
.
_registry
.
getBootstrapPeer
(
self
.
_prefix
)
...
@@ -243,7 +245,7 @@ class Cache:
...
@@ -243,7 +245,7 @@ class Cache:
return
prefix
,
address
return
prefix
,
address
logging
.
warning
(
'Buggy registry sent us our own address'
)
logging
.
warning
(
'Buggy registry sent us our own address'
)
def
addPeer
(
self
,
prefix
,
address
,
set_preferred
=
False
):
def
addPeer
(
self
,
prefix
:
str
,
address
:
str
,
set_preferred
=
False
):
logging
.
debug
(
'Adding peer %s: %s'
,
prefix
,
address
)
logging
.
debug
(
'Adding peer %s: %s'
,
prefix
,
address
)
with
self
.
_db
:
with
self
.
_db
:
q
=
self
.
_db
.
execute
q
=
self
.
_db
.
execute
...
@@ -267,7 +269,7 @@ class Cache:
...
@@ -267,7 +269,7 @@ class Cache:
q
(
"INSERT OR REPLACE INTO peer VALUES (?,?)"
,
(
prefix
,
address
))
q
(
"INSERT OR REPLACE INTO peer VALUES (?,?)"
,
(
prefix
,
address
))
q
(
"INSERT OR REPLACE INTO volatile.stat VALUES (?,0)"
,
(
prefix
,))
q
(
"INSERT OR REPLACE INTO volatile.stat VALUES (?,0)"
,
(
prefix
,))
def
getCountry
(
self
,
ip
)
:
def
getCountry
(
self
,
ip
:
str
)
->
str
:
try
:
try
:
return
self
.
_registry
.
getCountry
(
self
.
_prefix
,
ip
).
decode
()
return
self
.
_registry
.
getCountry
(
self
.
_prefix
,
ip
).
decode
()
except
socket
.
error
as
e
:
except
socket
.
error
as
e
:
...
...
re6st/cli/conf.py
View file @
fd5bda0a
...
@@ -13,7 +13,7 @@ def create(path, text=None, mode=0o666):
...
@@ -13,7 +13,7 @@ def create(path, text=None, mode=0o666):
finally
:
finally
:
os
.
close
(
fd
)
os
.
close
(
fd
)
def
loadCert
(
pem
):
def
loadCert
(
pem
:
bytes
):
return
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
pem
)
return
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
pem
)
def
main
():
def
main
():
...
...
re6st/cli/node.py
View file @
fd5bda0a
...
@@ -272,7 +272,7 @@ def main():
...
@@ -272,7 +272,7 @@ def main():
call
(
args
)
call
(
args
)
args
[
3
]
=
'del'
args
[
3
]
=
'del'
cleanup
.
append
(
lambda
:
subprocess
.
call
(
args
))
cleanup
.
append
(
lambda
:
subprocess
.
call
(
args
))
def
ip
(
object
,
*
args
):
def
ip
(
object
:
str
,
*
args
):
args
=
[
'ip'
,
'-6'
,
object
,
'add'
]
+
list
(
args
)
args
=
[
'ip'
,
'-6'
,
object
,
'add'
]
+
list
(
args
)
call
(
args
)
call
(
args
)
args
[
3
]
=
'del'
args
[
3
]
=
'del'
...
...
re6st/ctl.py
View file @
fd5bda0a
...
@@ -34,13 +34,13 @@ class Array:
...
@@ -34,13 +34,13 @@ class Array:
def
__init__
(
self
,
item
):
def
__init__
(
self
,
item
):
self
.
_item
=
item
self
.
_item
=
item
def
encode
(
self
,
buffer
,
value
):
def
encode
(
self
,
buffer
:
bytes
,
value
:
list
):
buffer
+=
uint16
.
pack
(
len
(
value
))
buffer
+=
uint16
.
pack
(
len
(
value
))
encode
=
self
.
_item
.
encode
encode
=
self
.
_item
.
encode
for
value
in
value
:
for
value
in
value
:
encode
(
buffer
,
value
)
encode
(
buffer
,
value
)
def
decode
(
self
,
buffer
,
offset
=
0
)
:
def
decode
(
self
,
buffer
:
bytes
,
offset
=
0
)
->
tuple
[
int
,
list
]
:
r
=
[]
r
=
[]
o
=
offset
+
2
o
=
offset
+
2
decode
=
self
.
_item
.
decode
decode
=
self
.
_item
.
decode
...
@@ -52,11 +52,11 @@ class Array:
...
@@ -52,11 +52,11 @@ class Array:
class
String
:
class
String
:
@
staticmethod
@
staticmethod
def
encode
(
buffer
,
value
):
def
encode
(
buffer
:
bytes
,
value
:
str
):
buffer
+=
value
.
encode
(
"utf-8"
)
+
b'
\
0
'
buffer
+=
value
.
encode
(
"utf-8"
)
+
b'
\
0
'
@
staticmethod
@
staticmethod
def
decode
(
buffer
,
offset
=
0
)
:
def
decode
(
buffer
:
bytes
,
offset
=
0
)
->
tuple
[
int
,
str
]
:
i
=
buffer
.
index
(
0
,
offset
)
i
=
buffer
.
index
(
0
,
offset
)
return
i
+
1
,
buffer
[
offset
:
i
].
decode
(
"utf-8"
)
return
i
+
1
,
buffer
[
offset
:
i
].
decode
(
"utf-8"
)
...
@@ -171,7 +171,7 @@ class Babel:
...
@@ -171,7 +171,7 @@ class Babel:
_decode
=
None
_decode
=
None
def
__init__
(
self
,
socket_path
,
handler
,
network
):
def
__init__
(
self
,
socket_path
:
str
,
handler
,
network
:
str
):
self
.
socket_path
=
socket_path
self
.
socket_path
=
socket_path
self
.
handler
=
handler
self
.
handler
=
handler
self
.
network
=
network
self
.
network
=
network
...
@@ -304,7 +304,7 @@ class iterRoutes:
...
@@ -304,7 +304,7 @@ class iterRoutes:
_waiting
=
True
_waiting
=
True
def
__new__
(
cls
,
control_socket
,
network
):
def
__new__
(
cls
,
control_socket
:
str
,
network
:
str
):
self
=
object
.
__new__
(
cls
)
self
=
object
.
__new__
(
cls
)
c
=
Babel
(
control_socket
,
self
,
network
)
c
=
Babel
(
control_socket
,
self
,
network
)
c
.
request_dump
()
c
.
request_dump
()
...
...
re6st/debug.py
View file @
fd5bda0a
...
@@ -3,7 +3,7 @@ import errno, os, socket, stat, threading
...
@@ -3,7 +3,7 @@ import errno, os, socket, stat, threading
class
Socket
:
class
Socket
:
def
__init__
(
self
,
socket
):
def
__init__
(
self
,
socket
:
socket
.
socket
):
# In case that the default timeout is not None.
# In case that the default timeout is not None.
socket
.
settimeout
(
None
)
socket
.
settimeout
(
None
)
self
.
_socket
=
socket
self
.
_socket
=
socket
...
@@ -12,10 +12,10 @@ class Socket:
...
@@ -12,10 +12,10 @@ class Socket:
def
close
(
self
):
def
close
(
self
):
self
.
_socket
.
close
()
self
.
_socket
.
close
()
def
write
(
self
,
data
):
def
write
(
self
,
data
:
bytes
):
self
.
_socket
.
send
(
data
)
self
.
_socket
.
send
(
data
)
def
readline
(
self
):
def
readline
(
self
)
->
bytes
:
recv
=
self
.
_socket
.
recv
recv
=
self
.
_socket
.
recv
data
=
self
.
_buf
data
=
self
.
_buf
while
True
:
while
True
:
...
...
re6st/plib.py
View file @
fd5bda0a
import
binascii
import
binascii
import
logging
,
errno
,
os
import
logging
,
errno
,
os
from
typing
import
Optional
from
.
import
utils
from
.
import
utils
here
=
os
.
path
.
realpath
(
os
.
path
.
dirname
(
__file__
))
here
=
os
.
path
.
realpath
(
os
.
path
.
dirname
(
__file__
))
ovpn_server
=
os
.
path
.
join
(
here
,
'ovpn-server'
)
ovpn_server
=
os
.
path
.
join
(
here
,
'ovpn-server'
)
ovpn_client
=
os
.
path
.
join
(
here
,
'ovpn-client'
)
ovpn_client
=
os
.
path
.
join
(
here
,
'ovpn-client'
)
ovpn_log
=
None
ovpn_log
:
Optional
[
str
]
=
None
def
openvpn
(
iface
,
encrypt
,
*
args
,
**
kw
)
:
def
openvpn
(
iface
:
str
,
encrypt
,
*
args
,
**
kw
)
->
utils
.
Popen
:
args
=
[
'openvpn'
,
args
=
[
'openvpn'
,
'--dev-type'
,
'tap'
,
'--dev-type'
,
'tap'
,
'--dev'
,
iface
,
'--dev'
,
iface
,
...
@@ -28,7 +29,8 @@ def openvpn(iface, encrypt, *args, **kw):
...
@@ -28,7 +29,8 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict
=
{
'udp4'
:
1432
,
'udp6'
:
1450
}
ovpn_link_mtu_dict
=
{
'udp4'
:
1432
,
'udp6'
:
1450
}
def
server
(
iface
,
max_clients
,
dh_path
,
fd
,
port
,
proto
,
encrypt
,
*
args
,
**
kw
):
def
server
(
iface
:
str
,
max_clients
:
int
,
dh_path
:
str
,
fd
:
int
,
port
:
int
,
proto
:
str
,
encrypt
:
bool
,
*
args
,
**
kw
)
->
utils
.
Popen
:
if
proto
==
'udp'
:
if
proto
==
'udp'
:
proto
=
'udp4'
proto
=
'udp4'
client_script
=
'%s %s'
%
(
ovpn_server
,
fd
)
client_script
=
'%s %s'
%
(
ovpn_server
,
fd
)
...
@@ -49,7 +51,8 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
...
@@ -49,7 +51,8 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
*
args
,
pass_fds
=
[
fd
],
**
kw
)
*
args
,
pass_fds
=
[
fd
],
**
kw
)
def
client
(
iface
,
address_list
,
encrypt
,
*
args
,
**
kw
):
def
client
(
iface
:
str
,
address_list
:
list
[
tuple
[
str
,
int
,
str
]],
encrypt
:
bool
,
*
args
,
**
kw
)
->
utils
.
Popen
:
remote
=
[
'--nobind'
,
'--client'
]
remote
=
[
'--nobind'
,
'--client'
]
# XXX: We'd like to pass <connection> sections at command-line.
# XXX: We'd like to pass <connection> sections at command-line.
link_mtu
=
set
()
link_mtu
=
set
()
...
@@ -65,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
...
@@ -65,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
return
openvpn
(
iface
,
encrypt
,
*
remote
,
**
kw
)
return
openvpn
(
iface
,
encrypt
,
*
remote
,
**
kw
)
def
router
(
ip
,
ip4
,
rt6
,
hello_interval
,
log_path
,
state_path
,
pidfile
,
def
router
(
ip
:
tuple
[
str
,
int
],
ip4
,
rt6
:
tuple
[
str
,
bool
,
bool
],
control_socket
,
default
,
hmac
,
*
args
,
**
kw
):
hello_interval
:
int
,
log_path
:
str
,
state_path
:
str
,
pidfile
:
str
,
control_socket
:
str
,
default
:
str
,
hmac
:
tuple
[
bytes
|
None
,
bytes
|
None
],
*
args
,
**
kw
)
->
utils
.
Popen
:
network
,
gateway
,
has_ipv6_subtrees
=
rt6
network
,
gateway
,
has_ipv6_subtrees
=
rt6
network_mask
=
int
(
network
[
network
.
index
(
'/'
)
+
1
:])
network_mask
=
int
(
network
[
network
.
index
(
'/'
)
+
1
:])
ip
,
n
=
ip
ip
,
n
=
ip
...
@@ -83,7 +88,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
...
@@ -83,7 +88,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C'
,
'redistribute local deny'
,
'-C'
,
'redistribute local deny'
,
'-C'
,
'redistribute ip %s/%s eq %s'
%
(
ip
,
n
,
n
)]
'-C'
,
'redistribute ip %s/%s eq %s'
%
(
ip
,
n
,
n
)]
if
hmac_sign
:
if
hmac_sign
:
def
key
(
cmd
,
id
,
value
):
def
key
(
cmd
:
list
[
str
],
id
:
str
,
value
:
bytes
):
cmd
+=
'-C'
,
(
'key type blake2s128 id %s value %s'
%
cmd
+=
'-C'
,
(
'key type blake2s128 id %s value %s'
%
(
id
,
binascii
.
hexlify
(
value
).
decode
()))
(
id
,
binascii
.
hexlify
(
value
).
decode
()))
key
(
cmd
,
'sign'
,
hmac_sign
)
key
(
cmd
,
'sign'
,
hmac_sign
)
...
...
re6st/registry.py
View file @
fd5bda0a
...
@@ -22,10 +22,13 @@ import base64, hmac, hashlib, http.client, inspect, json, logging
...
@@ -22,10 +22,13 @@ import base64, hmac, hashlib, http.client, inspect, json, logging
import
mailbox
,
os
,
platform
,
random
,
select
,
smtplib
,
socket
,
sqlite3
import
mailbox
,
os
,
platform
,
random
,
select
,
smtplib
,
socket
,
sqlite3
import
string
,
sys
,
threading
,
time
,
weakref
,
zlib
import
string
,
sys
,
threading
,
time
,
weakref
,
zlib
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Iterator
from
datetime
import
datetime
from
datetime
import
datetime
from
http.server
import
HTTPServer
,
BaseHTTPRequestHandler
from
http.server
import
HTTPServer
,
BaseHTTPRequestHandler
from
email.mime.text
import
MIMEText
from
email.mime.text
import
MIMEText
from
operator
import
itemgetter
from
operator
import
itemgetter
from
typing
import
Tuple
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
urllib.parse
import
urlparse
,
unquote
,
urlencode
from
urllib.parse
import
urlparse
,
unquote
,
urlencode
from
.
import
ctl
,
tunnel
,
utils
,
version
,
x509
from
.
import
ctl
,
tunnel
,
utils
,
version
,
x509
...
@@ -58,6 +61,8 @@ class RegistryServer:
...
@@ -58,6 +61,8 @@ class RegistryServer:
peers
=
0
,
()
peers
=
0
,
()
cert_duration
=
365
*
86400
cert_duration
=
365
*
86400
sessions
:
dict
[
str
,
list
[
tuple
[
bytes
,
int
]]]
def
_geoiplookup
(
self
,
ip
):
def
_geoiplookup
(
self
,
ip
):
raise
HTTPError
(
http
.
client
.
BAD_REQUEST
)
raise
HTTPError
(
http
.
client
.
BAD_REQUEST
)
...
@@ -140,7 +145,7 @@ class RegistryServer:
...
@@ -140,7 +145,7 @@ class RegistryServer:
if
self
.
geoip_db
:
if
self
.
geoip_db
:
from
geoip2
import
database
,
errors
from
geoip2
import
database
,
errors
country
=
database
.
Reader
(
self
.
geoip_db
).
country
country
=
database
.
Reader
(
self
.
geoip_db
).
country
def
geoiplookup
(
ip
)
:
def
geoiplookup
(
ip
:
str
)
->
Tuple
[
str
,
str
]
:
try
:
try
:
req
=
country
(
ip
)
req
=
country
(
ip
)
return
req
.
country
.
iso_code
,
req
.
continent
.
code
return
req
.
country
.
iso_code
,
req
.
continent
.
code
...
@@ -206,7 +211,7 @@ class RegistryServer:
...
@@ -206,7 +211,7 @@ class RegistryServer:
self
.
sock
.
sendto
(
prefix
.
encode
()
+
bytes
((
0
,
code
)),
self
.
sock
.
sendto
(
prefix
.
encode
()
+
bytes
((
0
,
code
)),
(
'::1'
,
tunnel
.
PORT
))
(
'::1'
,
tunnel
.
PORT
))
def
recv
(
self
,
code
)
:
def
recv
(
self
,
code
:
int
)
->
tuple
[
str
,
str
]
|
tuple
[
None
,
None
]
:
try
:
try
:
prefix
,
msg
=
self
.
sock
.
recv
(
1
<<
16
).
split
(
b'
\
0
'
,
1
)
prefix
,
msg
=
self
.
sock
.
recv
(
1
<<
16
).
split
(
b'
\
0
'
,
1
)
int
(
prefix
,
2
)
int
(
prefix
,
2
)
...
@@ -241,7 +246,7 @@ class RegistryServer:
...
@@ -241,7 +246,7 @@ class RegistryServer:
def
babel_dump
(
self
):
def
babel_dump
(
self
):
self
.
_wait_dump
=
False
self
.
_wait_dump
=
False
def
iterCert
(
self
):
def
iterCert
(
self
)
->
Iterator
[
Tuple
[
crypto
.
X509
,
str
,
str
]]
:
for
prefix
,
email
,
cert
in
self
.
db
.
execute
(
for
prefix
,
email
,
cert
in
self
.
db
.
execute
(
"SELECT * FROM cert WHERE cert IS NOT NULL"
):
"SELECT * FROM cert WHERE cert IS NOT NULL"
):
try
:
try
:
...
@@ -341,12 +346,12 @@ class RegistryServer:
...
@@ -341,12 +346,12 @@ class RegistryServer:
if
result
:
if
result
:
request
.
wfile
.
write
(
result
)
request
.
wfile
.
write
(
result
)
def
getPeerProtocol
(
self
,
cn
)
:
def
getPeerProtocol
(
self
,
cn
:
str
)
->
int
:
session
,
=
self
.
sessions
[
cn
]
session
,
=
self
.
sessions
[
cn
]
return
session
[
1
]
return
session
[
1
]
@
rpc
@
rpc
def
hello
(
self
,
client_prefix
,
protocol
=
'1'
)
:
def
hello
(
self
,
client_prefix
:
str
,
protocol
=
'1'
)
->
bytes
:
with
self
.
lock
:
with
self
.
lock
:
cert
=
self
.
getCert
(
client_prefix
)
cert
=
self
.
getCert
(
client_prefix
)
key
=
utils
.
newHmacSecret
()
key
=
utils
.
newHmacSecret
()
...
@@ -356,7 +361,7 @@ class RegistryServer:
...
@@ -356,7 +361,7 @@ class RegistryServer:
assert
len
(
key
)
==
len
(
sign
)
assert
len
(
key
)
==
len
(
sign
)
return
key
+
sign
return
key
+
sign
def
getCert
(
self
,
client_prefix
)
:
def
getCert
(
self
,
client_prefix
:
str
)
->
bytes
:
assert
self
.
lock
.
locked
()
assert
self
.
lock
.
locked
()
cert
=
self
.
db
.
execute
(
"SELECT cert FROM cert"
cert
=
self
.
db
.
execute
(
"SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL"
,
" WHERE prefix=? AND cert IS NOT NULL"
,
...
@@ -366,19 +371,19 @@ class RegistryServer:
...
@@ -366,19 +371,19 @@ class RegistryServer:
return
cert
[
0
]
return
cert
[
0
]
@
rpc_private
@
rpc_private
def
isToken
(
self
,
token
):
def
isToken
(
self
,
token
:
str
):
with
self
.
lock
:
with
self
.
lock
:
if
self
.
db
.
execute
(
"SELECT 1 FROM token WHERE token = ?"
,
if
self
.
db
.
execute
(
"SELECT 1 FROM token WHERE token = ?"
,
(
token
,)).
fetchone
():
(
token
,)).
fetchone
():
return
b"1"
return
b"1"
@
rpc_private
@
rpc_private
def
deleteToken
(
self
,
token
):
def
deleteToken
(
self
,
token
:
str
):
with
self
.
lock
:
with
self
.
lock
:
self
.
db
.
execute
(
"DELETE FROM token WHERE token = ?"
,
(
token
,))
self
.
db
.
execute
(
"DELETE FROM token WHERE token = ?"
,
(
token
,))
@
rpc_private
@
rpc_private
def
addToken
(
self
,
email
,
token
)
:
def
addToken
(
self
,
email
:
str
,
token
:
str
|
None
)
->
str
:
prefix_len
=
self
.
config
.
prefix_length
prefix_len
=
self
.
config
.
prefix_length
if
not
prefix_len
:
if
not
prefix_len
:
raise
HTTPError
(
http
.
client
.
FORBIDDEN
)
raise
HTTPError
(
http
.
client
.
FORBIDDEN
)
...
@@ -506,7 +511,8 @@ class RegistryServer:
...
@@ -506,7 +511,8 @@ class RegistryServer:
q
(
"UPDATE cert SET cert = 'reserved' WHERE prefix = ?"
,
(
prefix
,))
q
(
"UPDATE cert SET cert = 'reserved' WHERE prefix = ?"
,
(
prefix
,))
@
rpc
@
rpc
def
requestCertificate
(
self
,
token
,
req
,
location
=
''
,
ip
=
''
):
def
requestCertificate
(
self
,
token
:
str
|
None
,
req
:
bytes
,
location
:
str
=
''
,
ip
:
str
=
''
):
logging
.
debug
(
"Requesting certificate with token %s"
,
token
)
logging
.
debug
(
"Requesting certificate with token %s"
,
token
)
req
=
crypto
.
load_certificate_request
(
crypto
.
FILETYPE_PEM
,
req
)
req
=
crypto
.
load_certificate_request
(
crypto
.
FILETYPE_PEM
,
req
)
with
self
.
lock
:
with
self
.
lock
:
...
@@ -580,7 +586,7 @@ class RegistryServer:
...
@@ -580,7 +586,7 @@ class RegistryServer:
return
cert
return
cert
@
rpc
@
rpc
def
renewCertificate
(
self
,
cn
)
:
def
renewCertificate
(
self
,
cn
:
str
)
->
bytes
:
with
self
.
lock
:
with
self
.
lock
:
with
self
.
db
as
db
:
with
self
.
db
as
db
:
pem
=
self
.
getCert
(
cn
)
pem
=
self
.
getCert
(
cn
)
...
@@ -596,16 +602,16 @@ class RegistryServer:
...
@@ -596,16 +602,16 @@ class RegistryServer:
cert
.
get_subject
(),
cert
.
get_pubkey
(),
not_after
)
cert
.
get_subject
(),
cert
.
get_pubkey
(),
not_after
)
@
rpc
@
rpc
def
getCa
(
self
):
def
getCa
(
self
)
->
bytes
:
return
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
.
ca
)
return
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
.
ca
)
@
rpc
@
rpc
def
getDh
(
self
,
cn
)
:
def
getDh
(
self
,
cn
:
str
)
->
bytes
:
with
open
(
self
.
config
.
dh
,
"rb"
)
as
f
:
with
open
(
self
.
config
.
dh
,
"rb"
)
as
f
:
return
f
.
read
()
return
f
.
read
()
@
rpc
@
rpc
def
getNetworkConfig
(
self
,
cn
)
:
def
getNetworkConfig
(
self
,
cn
:
str
)
->
bytes
:
with
self
.
lock
:
with
self
.
lock
:
cert
=
self
.
getCert
(
cn
)
cert
=
self
.
getCert
(
cn
)
config
=
self
.
network_config
.
copy
()
config
=
self
.
network_config
.
copy
()
...
@@ -615,7 +621,7 @@ class RegistryServer:
...
@@ -615,7 +621,7 @@ class RegistryServer:
v
and
base64
.
b64encode
(
x509
.
encrypt
(
cert
,
v
)).
decode
()
v
and
base64
.
b64encode
(
x509
.
encrypt
(
cert
,
v
)).
decode
()
return
zlib
.
compress
(
json
.
dumps
(
config
).
encode
(
"utf-8"
))
return
zlib
.
compress
(
json
.
dumps
(
config
).
encode
(
"utf-8"
))
def
_queryAddress
(
self
,
peer
)
:
def
_queryAddress
(
self
,
peer
:
str
)
->
str
:
logging
.
info
(
"Querying address for %s/%s %r"
,
logging
.
info
(
"Querying address for %s/%s %r"
,
int
(
peer
,
2
),
len
(
peer
),
peer
)
int
(
peer
,
2
),
len
(
peer
),
peer
)
self
.
sendto
(
peer
,
1
)
self
.
sendto
(
peer
,
1
)
...
@@ -633,12 +639,12 @@ class RegistryServer:
...
@@ -633,12 +639,12 @@ class RegistryServer:
int
(
peer
,
2
),
len
(
peer
))
int
(
peer
,
2
),
len
(
peer
))
@
rpc
@
rpc
def
getCountry
(
self
,
cn
,
address
)
:
def
getCountry
(
self
,
cn
:
str
,
address
:
str
)
->
str
|
None
:
country
=
self
.
_geoiplookup
(
address
)[
0
]
country
=
self
.
_geoiplookup
(
address
)[
0
]
return
None
if
country
==
'*'
else
country
return
None
if
country
==
'*'
else
country
@
rpc
@
rpc
def
getBootstrapPeer
(
self
,
cn
)
:
def
getBootstrapPeer
(
self
,
cn
:
str
)
->
bytes
|
None
:
logging
.
info
(
"Answering bootstrap peer for %s"
,
cn
)
logging
.
info
(
"Answering bootstrap peer for %s"
,
cn
)
with
self
.
peers_lock
:
with
self
.
peers_lock
:
age
,
peers
=
self
.
peers
age
,
peers
=
self
.
peers
...
@@ -673,7 +679,7 @@ class RegistryServer:
...
@@ -673,7 +679,7 @@ class RegistryServer:
return
x509
.
encrypt
(
cert
,
msg
.
encode
())
return
x509
.
encrypt
(
cert
,
msg
.
encode
())
@
rpc_private
@
rpc_private
def
revoke
(
self
,
cn_or_serial
):
def
revoke
(
self
,
cn_or_serial
:
int
|
str
):
with
self
.
lock
,
self
.
db
:
with
self
.
lock
,
self
.
db
:
q
=
self
.
db
.
execute
q
=
self
.
db
.
execute
try
:
try
:
...
@@ -694,12 +700,12 @@ class RegistryServer:
...
@@ -694,12 +700,12 @@ class RegistryServer:
q
(
"INSERT INTO crl VALUES (?,?)"
,
(
serial
,
not_after
))
q
(
"INSERT INTO crl VALUES (?,?)"
,
(
serial
,
not_after
))
self
.
updateNetworkConfig
()
self
.
updateNetworkConfig
()
def
newHMAC
(
self
,
i
,
key
=
None
):
def
newHMAC
(
self
,
i
:
int
,
key
:
bytes
=
None
):
if
key
is
None
:
if
key
is
None
:
key
=
os
.
urandom
(
16
)
key
=
os
.
urandom
(
16
)
self
.
setConfig
(
BABEL_HMAC
[
i
],
key
)
self
.
setConfig
(
BABEL_HMAC
[
i
],
key
)
def
delHMAC
(
self
,
i
):
def
delHMAC
(
self
,
i
:
int
):
self
.
db
.
execute
(
"DELETE FROM config WHERE name=?"
,
(
BABEL_HMAC
[
i
],))
self
.
db
.
execute
(
"DELETE FROM config WHERE name=?"
,
(
BABEL_HMAC
[
i
],))
@
rpc_private
@
rpc_private
...
@@ -726,7 +732,7 @@ class RegistryServer:
...
@@ -726,7 +732,7 @@ class RegistryServer:
self
.
sendto
(
self
.
prefix
,
0
)
self
.
sendto
(
self
.
prefix
,
0
)
@
rpc_private
@
rpc_private
def
getNodePrefix
(
self
,
email
)
:
def
getNodePrefix
(
self
,
email
:
str
)
->
str
|
None
:
with
self
.
lock
,
self
.
db
:
with
self
.
lock
,
self
.
db
:
try
:
try
:
cert
,
=
next
(
cert
,
=
next
(
...
@@ -738,7 +744,7 @@ class RegistryServer:
...
@@ -738,7 +744,7 @@ class RegistryServer:
return
x509
.
subnetFromCert
(
certificate
)
return
x509
.
subnetFromCert
(
certificate
)
@
rpc_private
@
rpc_private
def
getIPv6Address
(
self
,
email
)
:
def
getIPv6Address
(
self
,
email
:
str
)
->
str
:
cn
=
self
.
getNodePrefix
(
email
)
cn
=
self
.
getNodePrefix
(
email
)
if
cn
:
if
cn
:
return
utils
.
ipFromBin
(
return
utils
.
ipFromBin
(
...
@@ -746,7 +752,7 @@ class RegistryServer:
...
@@ -746,7 +752,7 @@ class RegistryServer:
+
utils
.
binFromSubnet
(
cn
))
+
utils
.
binFromSubnet
(
cn
))
@
rpc_private
@
rpc_private
def
getIPv4Information
(
self
,
email
)
:
def
getIPv4Information
(
self
,
email
:
str
)
->
str
|
None
:
peer
=
self
.
getNodePrefix
(
email
)
peer
=
self
.
getNodePrefix
(
email
)
if
peer
:
if
peer
:
peer
=
utils
.
binFromSubnet
(
peer
)
peer
=
utils
.
binFromSubnet
(
peer
)
...
@@ -765,7 +771,7 @@ class RegistryServer:
...
@@ -765,7 +771,7 @@ class RegistryServer:
return
msg
.
split
(
','
)[
0
]
return
msg
.
split
(
','
)[
0
]
@
rpc_private
@
rpc_private
def
versions
(
self
):
def
versions
(
self
)
->
str
:
with
self
.
peers_lock
:
with
self
.
peers_lock
:
self
.
request_dump
()
self
.
request_dump
()
peers
=
{
prefix
peers
=
{
prefix
...
@@ -791,7 +797,7 @@ class RegistryServer:
...
@@ -791,7 +797,7 @@ class RegistryServer:
return
json
.
dumps
(
peer_dict
)
return
json
.
dumps
(
peer_dict
)
@
rpc_private
@
rpc_private
def
topology
(
self
):
def
topology
(
self
)
->
str
:
logging
.
debug
(
"Computing topology"
)
logging
.
debug
(
"Computing topology"
)
p
=
lambda
p
:
'%s/%s'
%
(
int
(
p
,
2
),
len
(
p
))
p
=
lambda
p
:
'%s/%s'
%
(
int
(
p
,
2
),
len
(
p
))
peers
=
deque
((
p
(
self
.
prefix
),))
peers
=
deque
((
p
(
self
.
prefix
),))
...
@@ -827,11 +833,17 @@ class RegistryServer:
...
@@ -827,11 +833,17 @@ class RegistryServer:
class
RegistryClient
:
class
RegistryClient
:
"""
Client for the re6st registry.
Method calls are forwarded to the registry server.
String results are always returned as bytes.
"""
_hmac
=
None
_hmac
=
None
user_agent
=
"re6stnet/%s, %s"
%
(
version
.
version
,
platform
.
platform
())
user_agent
=
"re6stnet/%s, %s"
%
(
version
.
version
,
platform
.
platform
())
def
__init__
(
self
,
url
,
c
ert
=
None
,
auto_close
=
True
):
def
__init__
(
self
,
url
:
str
,
cert
:
x509
.
C
ert
=
None
,
auto_close
=
True
):
self
.
cert
=
cert
self
.
cert
=
cert
self
.
auto_close
=
auto_close
self
.
auto_close
=
auto_close
url_parsed
=
urlparse
(
url
)
url_parsed
=
urlparse
(
url
)
...
@@ -843,7 +855,7 @@ class RegistryClient:
...
@@ -843,7 +855,7 @@ class RegistryClient:
)[
scheme
](
unquote
(
host
),
timeout
=
60
)
)[
scheme
](
unquote
(
host
),
timeout
=
60
)
self
.
_path
=
path
.
rstrip
(
'/'
)
self
.
_path
=
path
.
rstrip
(
'/'
)
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
:
str
):
getcallargs
=
getattr
(
RegistryServer
,
name
).
getcallargs
getcallargs
=
getattr
(
RegistryServer
,
name
).
getcallargs
def
rpc
(
*
args
,
**
kw
)
->
bytes
:
def
rpc
(
*
args
,
**
kw
)
->
bytes
:
kw
=
getcallargs
(
*
args
,
**
kw
)
kw
=
getcallargs
(
*
args
,
**
kw
)
...
...
re6st/tests/test_unit/test_registry.py
View file @
fd5bda0a
...
@@ -11,11 +11,13 @@ import hashlib
...
@@ -11,11 +11,13 @@ import hashlib
import
time
import
time
import
tempfile
import
tempfile
from
argparse
import
Namespace
from
argparse
import
Namespace
from
sqlite3
import
Cursor
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
mock
import
Mock
,
patch
from
mock
import
Mock
,
patch
from
pathlib
import
Path
from
pathlib
import
Path
from
re6st
import
registry
from
re6st
import
registry
,
x509
from
re6st.tests.tools
import
*
from
re6st.tests.tools
import
*
from
re6st.tests
import
DEMO_PATH
from
re6st.tests
import
DEMO_PATH
...
@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH
...
@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
# getIPV4Information, versions
def
load_config
(
filename
=
"registry.json"
)
:
def
load_config
(
filename
:
str
=
"registry.json"
)
->
Namespace
:
with
open
(
filename
)
as
f
:
with
open
(
filename
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
config
[
"dh"
]
=
DEMO_PATH
/
"dh2048.pem"
config
[
"dh"
]
=
DEMO_PATH
/
"dh2048.pem"
...
@@ -37,13 +39,14 @@ def load_config(filename="registry.json"):
...
@@ -37,13 +39,14 @@ def load_config(filename="registry.json"):
return
Namespace
(
**
config
)
return
Namespace
(
**
config
)
def
get_cert
(
cur
,
prefix
):
def
get_cert
(
cur
:
Cursor
,
prefix
:
str
):
res
=
cur
.
execute
(
res
=
cur
.
execute
(
"SELECT cert FROM cert WHERE prefix=?"
,
(
prefix
,)).
fetchone
()
"SELECT cert FROM cert WHERE prefix=?"
,
(
prefix
,)).
fetchone
()
return
res
[
0
]
return
res
[
0
]
def
insert_cert
(
cur
,
ca
,
prefix
,
not_after
=
None
,
email
=
None
):
def
insert_cert
(
cur
:
Cursor
,
ca
:
x509
.
Cert
,
prefix
:
str
,
not_after
=
None
,
email
=
None
):
key
,
csr
=
generate_csr
()
key
,
csr
=
generate_csr
()
cert
=
generate_cert
(
ca
.
ca
,
ca
.
key
,
csr
,
prefix
,
insert_cert
.
serial
,
not_after
)
cert
=
generate_cert
(
ca
.
ca
,
ca
.
key
,
csr
,
prefix
,
insert_cert
.
serial
,
not_after
)
cur
.
execute
(
"INSERT INTO cert VALUES (?,?,?)"
,
(
prefix
,
email
,
cert
))
cur
.
execute
(
"INSERT INTO cert VALUES (?,?,?)"
,
(
prefix
,
email
,
cert
))
...
@@ -54,7 +57,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
...
@@ -54,7 +57,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert
.
serial
=
0
insert_cert
.
serial
=
0
def
delete_cert
(
cur
,
prefix
):
def
delete_cert
(
cur
:
Cursor
,
prefix
:
str
):
cur
.
execute
(
"DELETE FROM cert WHERE prefix = ?"
,
(
prefix
,))
cur
.
execute
(
"DELETE FROM cert WHERE prefix = ?"
,
(
prefix
,))
...
...
re6st/tests/tools.py
View file @
fd5bda0a
...
@@ -92,14 +92,14 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
...
@@ -92,14 +92,14 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
return
key
,
cert
return
key
,
cert
def
prefix2cn
(
prefix
)
:
def
prefix2cn
(
prefix
:
str
)
->
str
:
return
"%u/%u"
%
(
int
(
prefix
,
2
),
len
(
prefix
))
return
"%u/%u"
%
(
int
(
prefix
,
2
),
len
(
prefix
))
def
serial2prefix
(
serial
)
:
def
serial2prefix
(
serial
:
int
)
->
str
:
return
bin
(
serial
)[
2
:].
rjust
(
16
,
'0'
)
return
bin
(
serial
)[
2
:].
rjust
(
16
,
'0'
)
# pkey: private key
# pkey: private key
def
decrypt
(
pkey
,
incontent
)
:
def
decrypt
(
pkey
:
bytes
,
incontent
:
bytes
)
->
bytes
:
with
open
(
"node.key"
,
'wb'
)
as
f
:
with
open
(
"node.key"
,
'wb'
)
as
f
:
f
.
write
(
pkey
)
f
.
write
(
pkey
)
args
=
"openssl rsautl -decrypt -inkey node.key"
.
split
()
args
=
"openssl rsautl -decrypt -inkey node.key"
.
split
()
...
...
re6st/tunnel.py
View file @
fd5bda0a
...
@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
...
@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
import
subprocess
,
struct
,
sys
,
time
,
weakref
import
subprocess
,
struct
,
sys
,
time
,
weakref
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
bisect
import
bisect
,
insort
from
bisect
import
bisect
,
insort
from
collections.abc
import
Iterator
,
Sequence
from
typing
import
Callable
,
TYPE_CHECKING
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
.
import
ctl
,
plib
,
utils
,
version
,
x509
from
.
import
ctl
,
plib
,
utils
,
version
,
x509
if
TYPE_CHECKING
:
from
.
import
cache
PORT
=
326
PORT
=
326
...
@@ -21,7 +26,8 @@ proto_dict = {
...
@@ -21,7 +26,8 @@ proto_dict = {
proto_dict
[
'tcp'
]
=
proto_dict
[
'tcp4'
]
proto_dict
[
'tcp'
]
=
proto_dict
[
'tcp4'
]
proto_dict
[
'udp'
]
=
proto_dict
[
'udp4'
]
proto_dict
[
'udp'
]
=
proto_dict
[
'udp4'
]
def
resolve
(
ip
,
port
,
proto
):
def
resolve
(
ip
,
port
,
proto
:
str
)
\
->
tuple
[
socket
.
AddressFamily
|
None
,
Iterator
[
str
]]:
try
:
try
:
family
,
proto
=
proto_dict
[
proto
]
family
,
proto
=
proto_dict
[
proto
]
except
KeyError
:
except
KeyError
:
...
@@ -31,16 +37,16 @@ def resolve(ip, port, proto):
...
@@ -31,16 +37,16 @@ def resolve(ip, port, proto):
class
MultiGatewayManager
(
dict
):
class
MultiGatewayManager
(
dict
):
def
__init__
(
self
,
gateway
):
def
__init__
(
self
,
gateway
:
Callable
[[
str
],
str
]
):
self
.
_gw
=
gateway
self
.
_gw
=
gateway
def
_route
(
self
,
cmd
,
dest
,
gw
):
def
_route
(
self
,
cmd
:
str
,
dest
:
str
,
gw
:
str
):
if
gw
:
if
gw
:
cmd
=
'ip'
,
'-4'
,
'route'
,
cmd
,
'%s/32'
%
dest
,
'via'
,
gw
cmd
=
'ip'
,
'-4'
,
'route'
,
cmd
,
'%s/32'
%
dest
,
'via'
,
gw
logging
.
trace
(
'%r'
,
cmd
)
logging
.
trace
(
'%r'
,
cmd
)
subprocess
.
check_call
(
cmd
)
subprocess
.
check_call
(
cmd
)
def
add
(
self
,
dest
,
route
):
def
add
(
self
,
dest
:
str
,
route
:
bool
):
try
:
try
:
self
[
dest
][
1
]
+=
1
self
[
dest
][
1
]
+=
1
except
KeyError
:
except
KeyError
:
...
@@ -48,7 +54,7 @@ class MultiGatewayManager(dict):
...
@@ -48,7 +54,7 @@ class MultiGatewayManager(dict):
self
[
dest
]
=
[
gw
,
0
]
self
[
dest
]
=
[
gw
,
0
]
self
.
_route
(
'add'
,
dest
,
gw
)
self
.
_route
(
'add'
,
dest
,
gw
)
def
remove
(
self
,
dest
):
def
remove
(
self
,
dest
:
str
):
gw
,
count
=
self
[
dest
]
gw
,
count
=
self
[
dest
]
if
count
:
if
count
:
self
[
dest
][
1
]
=
count
-
1
self
[
dest
][
1
]
=
count
-
1
...
@@ -65,7 +71,8 @@ class Connection:
...
@@ -65,7 +71,8 @@ class Connection:
serial
=
None
serial
=
None
time
=
float
(
'inf'
)
time
=
float
(
'inf'
)
def
__init__
(
self
,
tunnel_manager
,
address_list
,
iface
,
prefix
):
def
__init__
(
self
,
tunnel_manager
:
"TunnelManager"
,
address_list
,
iface
,
prefix
):
self
.
tunnel_manager
=
tunnel_manager
self
.
tunnel_manager
=
tunnel_manager
self
.
address_list
=
address_list
self
.
address_list
=
address_list
self
.
iface
=
iface
self
.
iface
=
iface
...
@@ -109,7 +116,7 @@ class Connection:
...
@@ -109,7 +116,7 @@ class Connection:
if
i
:
if
i
:
cache
.
addPeer
(
self
.
_prefix
,
','
.
join
(
self
.
address_list
[
i
]),
True
)
cache
.
addPeer
(
self
.
_prefix
,
','
.
join
(
self
.
address_list
[
i
]),
True
)
else
:
else
:
cache
.
connecting
(
self
.
_prefix
,
0
)
cache
.
connecting
(
self
.
_prefix
,
False
)
def
close
(
self
):
def
close
(
self
):
try
:
try
:
...
@@ -198,7 +205,8 @@ class BaseTunnelManager:
...
@@ -198,7 +205,8 @@ class BaseTunnelManager:
_geoiplookup
=
None
_geoiplookup
=
None
_forward
=
None
_forward
=
None
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
conf_country
,
address
=
()):
def
__init__
(
self
,
control_socket
,
cache
:
"cache.Cache"
,
cert
:
x509
.
Cert
,
conf_country
,
address
=
()):
self
.
cert
=
cert
self
.
cert
=
cert
self
.
_network
=
cert
.
network
self
.
_network
=
cert
.
network
self
.
_prefix
=
cert
.
prefix
self
.
_prefix
=
cert
.
prefix
...
@@ -329,7 +337,7 @@ class BaseTunnelManager:
...
@@ -329,7 +337,7 @@ class BaseTunnelManager:
def
_getPeer
(
self
,
prefix
):
def
_getPeer
(
self
,
prefix
):
return
self
.
_peers
[
bisect
(
self
.
_peers
,
prefix
)
-
1
]
return
self
.
_peers
[
bisect
(
self
.
_peers
,
prefix
)
-
1
]
def
sendto
(
self
,
prefix
,
msg
):
def
sendto
(
self
,
prefix
:
str
,
msg
):
to
=
utils
.
ipFromBin
(
self
.
_network
+
prefix
),
PORT
to
=
utils
.
ipFromBin
(
self
.
_network
+
prefix
),
PORT
peer
=
self
.
_getPeer
(
prefix
)
peer
=
self
.
_getPeer
(
prefix
)
if
peer
.
prefix
!=
prefix
:
if
peer
.
prefix
!=
prefix
:
...
@@ -451,7 +459,7 @@ class BaseTunnelManager:
...
@@ -451,7 +459,7 @@ class BaseTunnelManager:
peer
)
peer
)
def
_processPacket
(
self
,
msg
,
pee
r
=
None
):
def
_processPacket
(
self
,
msg
:
bytes
,
peer
:
x509
.
Peer
|
st
r
=
None
):
c
=
msg
[
0
]
c
=
msg
[
0
]
msg
=
msg
[
1
:]
msg
=
msg
[
1
:]
code
=
c
&
0x7f
code
=
c
&
0x7f
...
@@ -565,7 +573,7 @@ class BaseTunnelManager:
...
@@ -565,7 +573,7 @@ class BaseTunnelManager:
self
.
selectTimeout
(
time
.
time
()
+
1
+
self
.
cache
.
delay_restart
,
self
.
selectTimeout
(
time
.
time
()
+
1
+
self
.
cache
.
delay_restart
,
self
.
_restart
)
self
.
_restart
)
def
handleServerEvent
(
self
,
sock
):
def
handleServerEvent
(
self
,
sock
:
socket
.
socket
):
event
,
args
=
eval
(
sock
.
recv
(
65536
))
event
,
args
=
eval
(
sock
.
recv
(
65536
))
logging
.
debug
(
"%s%r"
,
event
,
args
)
logging
.
debug
(
"%s%r"
,
event
,
args
)
r
=
getattr
(
self
,
'_ovpn_'
+
event
.
replace
(
'-'
,
'_'
))(
*
args
)
r
=
getattr
(
self
,
'_ovpn_'
+
event
.
replace
(
'-'
,
'_'
))(
*
args
)
...
@@ -582,7 +590,7 @@ class BaseTunnelManager:
...
@@ -582,7 +590,7 @@ class BaseTunnelManager:
self
.
_gateway_manager
.
add
(
trusted_ip
,
False
)
self
.
_gateway_manager
.
add
(
trusted_ip
,
False
)
if
prefix
in
self
.
_connection_dict
and
self
.
_prefix
<
prefix
:
if
prefix
in
self
.
_connection_dict
and
self
.
_prefix
<
prefix
:
self
.
_kill
(
prefix
)
self
.
_kill
(
prefix
)
self
.
cache
.
connecting
(
prefix
,
0
)
self
.
cache
.
connecting
(
prefix
,
False
)
return
True
return
True
def
_ovpn_client_disconnect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
def
_ovpn_client_disconnect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
...
@@ -666,7 +674,8 @@ class TunnelManager(BaseTunnelManager):
...
@@ -666,7 +674,8 @@ class TunnelManager(BaseTunnelManager):
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
openvpn_args
,
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
openvpn_args
,
timeout
,
client_count
,
iface_list
,
conf_country
,
address
,
timeout
,
client_count
,
iface_list
,
conf_country
,
address
,
ip_changed
,
remote_gateway
,
disable_proto
,
neighbour_list
=
()):
ip_changed
,
remote_gateway
:
Callable
[[
str
],
str
],
disable_proto
:
Sequence
[
str
],
neighbour_list
=
()):
super
(
TunnelManager
,
self
).
__init__
(
control_socket
,
super
(
TunnelManager
,
self
).
__init__
(
control_socket
,
cache
,
cert
,
conf_country
,
address
)
cache
,
cert
,
conf_country
,
address
)
self
.
ovpn_args
=
openvpn_args
self
.
ovpn_args
=
openvpn_args
...
@@ -878,7 +887,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -878,7 +887,7 @@ class TunnelManager(BaseTunnelManager):
address_list
.
append
((
ip
,
x
[
1
],
x
[
2
]))
address_list
.
append
((
ip
,
x
[
1
],
x
[
2
]))
continue
continue
address_list
.
append
(
x
[:
3
])
address_list
.
append
(
x
[:
3
])
self
.
cache
.
connecting
(
prefix
,
1
)
self
.
cache
.
connecting
(
prefix
,
True
)
if
not
address_list
:
if
not
address_list
:
return
False
return
False
logging
.
info
(
'Establishing a connection with %u/%u'
,
logging
.
info
(
'Establishing a connection with %u/%u'
,
...
...
re6st/upnpigd.py
View file @
fd5bda0a
...
@@ -17,7 +17,7 @@ class Forwarder:
...
@@ -17,7 +17,7 @@ class Forwarder:
_lcg_n
=
0
_lcg_n
=
0
@
classmethod
@
classmethod
def
_getExternalPort
(
cls
):
def
_getExternalPort
(
cls
)
->
int
:
# Since _refresh() does not test all ports in a row, we prefer to
# Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port.
# return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without
# A linear congruential generator should be random enough, without
...
@@ -35,7 +35,7 @@ class Forwarder:
...
@@ -35,7 +35,7 @@ class Forwarder:
self
.
_u
.
discoverdelay
=
200
self
.
_u
.
discoverdelay
=
200
self
.
_rules
=
[]
self
.
_rules
=
[]
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
:
str
):
wrapped
=
getattr
(
self
.
_u
,
name
)
wrapped
=
getattr
(
self
.
_u
,
name
)
def
wrapper
(
*
args
,
**
kw
):
def
wrapper
(
*
args
,
**
kw
):
try
:
try
:
...
...
re6st/utils.py
View file @
fd5bda0a
import
argparse
,
errno
,
fcntl
,
hashlib
,
logging
,
os
,
select
as
_select
import
argparse
,
errno
,
fcntl
,
hashlib
,
logging
,
os
,
select
as
_select
import
shlex
,
signal
,
socket
,
sqlite3
,
struct
,
subprocess
import
shlex
,
signal
,
socket
,
sqlite3
,
struct
,
subprocess
import
sys
,
textwrap
,
threading
,
time
,
traceback
import
sys
,
textwrap
,
threading
,
time
,
traceback
from
collections.abc
import
Iterator
,
Mapping
HMAC_LEN
=
len
(
hashlib
.
sha1
(
b''
).
digest
())
HMAC_LEN
=
len
(
hashlib
.
sha1
(
b''
).
digest
())
...
@@ -40,7 +40,7 @@ class FileHandler(logging.FileHandler):
...
@@ -40,7 +40,7 @@ class FileHandler(logging.FileHandler):
if
self
.
lock
.
acquire
(
False
):
if
self
.
lock
.
acquire
(
False
):
self
.
release
()
self
.
release
()
def
setupLog
(
log_level
,
filenam
e
=
None
,
**
kw
):
def
setupLog
(
log_level
:
int
,
filename
:
str
|
Non
e
=
None
,
**
kw
):
if
log_level
and
filename
:
if
log_level
and
filename
:
makedirs
(
os
.
path
.
dirname
(
filename
))
makedirs
(
os
.
path
.
dirname
(
filename
))
handler
=
FileHandler
(
filename
)
handler
=
FileHandler
(
filename
)
...
@@ -184,7 +184,7 @@ def setCloexec(fd):
...
@@ -184,7 +184,7 @@ def setCloexec(fd):
flags
=
fcntl
.
fcntl
(
fd
,
fcntl
.
F_GETFD
)
flags
=
fcntl
.
fcntl
(
fd
,
fcntl
.
F_GETFD
)
fcntl
.
fcntl
(
fd
,
fcntl
.
F_SETFD
,
flags
|
fcntl
.
FD_CLOEXEC
)
fcntl
.
fcntl
(
fd
,
fcntl
.
F_SETFD
,
flags
|
fcntl
.
FD_CLOEXEC
)
def
select
(
R
,
W
,
T
):
def
select
(
R
:
Mapping
,
W
:
Mapping
,
T
):
try
:
try
:
r
,
w
,
_
=
_select
.
select
(
R
,
W
,
(),
r
,
w
,
_
=
_select
.
select
(
R
,
W
,
(),
max
(
0
,
min
(
T
)[
0
]
-
time
.
time
())
if
T
else
None
)
max
(
0
,
min
(
T
)[
0
]
-
time
.
time
())
if
T
else
None
)
...
@@ -208,15 +208,15 @@ def makedirs(*args):
...
@@ -208,15 +208,15 @@ def makedirs(*args):
if
e
.
errno
!=
errno
.
EEXIST
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
raise
def
binFromIp
(
ip
)
:
def
binFromIp
(
ip
:
str
)
->
str
:
return
binFromRawIp
(
socket
.
inet_pton
(
socket
.
AF_INET6
,
ip
))
return
binFromRawIp
(
socket
.
inet_pton
(
socket
.
AF_INET6
,
ip
))
def
binFromRawIp
(
ip
)
:
def
binFromRawIp
(
ip
:
bytes
)
->
str
:
ip1
,
ip2
=
struct
.
unpack
(
'>QQ'
,
ip
)
ip1
,
ip2
=
struct
.
unpack
(
'>QQ'
,
ip
)
return
bin
(
ip1
)[
2
:].
rjust
(
64
,
'0'
)
+
bin
(
ip2
)[
2
:].
rjust
(
64
,
'0'
)
return
bin
(
ip1
)[
2
:].
rjust
(
64
,
'0'
)
+
bin
(
ip2
)[
2
:].
rjust
(
64
,
'0'
)
def
ipFromBin
(
ip
,
suffix
=
''
)
:
def
ipFromBin
(
ip
:
str
,
suffix
=
''
)
->
str
:
suffix_len
=
128
-
len
(
ip
)
suffix_len
=
128
-
len
(
ip
)
if
suffix_len
>
0
:
if
suffix_len
>
0
:
ip
+=
suffix
.
rjust
(
suffix_len
,
'0'
)
ip
+=
suffix
.
rjust
(
suffix_len
,
'0'
)
...
@@ -225,11 +225,11 @@ def ipFromBin(ip, suffix=''):
...
@@ -225,11 +225,11 @@ def ipFromBin(ip, suffix=''):
return
socket
.
inet_ntop
(
socket
.
AF_INET6
,
return
socket
.
inet_ntop
(
socket
.
AF_INET6
,
struct
.
pack
(
'>QQ'
,
int
(
ip
[:
64
],
2
),
int
(
ip
[
64
:],
2
)))
struct
.
pack
(
'>QQ'
,
int
(
ip
[:
64
],
2
),
int
(
ip
[
64
:],
2
)))
def
dump_address
(
address
)
:
def
dump_address
(
address
:
str
)
->
str
:
return
';'
.
join
(
map
(
','
.
join
,
address
))
return
';'
.
join
(
map
(
','
.
join
,
address
))
# Yield ip, port, protocol, and country if it is in the address
# Yield ip, port, protocol, and country if it is in the address
def
parse_address
(
address_list
)
:
def
parse_address
(
address_list
:
str
)
->
Iterator
[
tuple
[
str
,
str
,
str
,
str
]]
:
for
address
in
address_list
.
split
(
';'
):
for
address
in
address_list
.
split
(
';'
):
try
:
try
:
a
=
address
.
split
(
','
)
a
=
address
.
split
(
','
)
...
@@ -239,16 +239,18 @@ def parse_address(address_list):
...
@@ -239,16 +239,18 @@ def parse_address(address_list):
logging
.
warning
(
"Failed to parse node address %r (%s)"
,
logging
.
warning
(
"Failed to parse node address %r (%s)"
,
address
,
e
)
address
,
e
)
def
binFromSubnet
(
subnet
)
:
def
binFromSubnet
(
subnet
:
str
)
->
str
:
p
,
l
=
subnet
.
split
(
'/'
)
p
,
l
=
subnet
.
split
(
'/'
)
return
bin
(
int
(
p
))[
2
:].
rjust
(
int
(
l
),
'0'
)
return
bin
(
int
(
p
))[
2
:].
rjust
(
int
(
l
),
'0'
)
def
newHmacSecret
():
def
_
newHmacSecret
():
from
random
import
getrandbits
as
g
from
random
import
getrandbits
as
g
pack
=
struct
.
Struct
(
">QQI"
).
pack
pack
=
struct
.
Struct
(
">QQI"
).
pack
assert
len
(
pack
(
0
,
0
,
0
))
==
HMAC_LEN
assert
len
(
pack
(
0
,
0
,
0
))
==
HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return
lambda
x
=
None
:
pack
(
g
(
64
)
if
x
is
None
else
x
,
g
(
64
),
g
(
32
))
return
lambda
x
=
None
:
pack
(
g
(
64
)
if
x
is
None
else
x
,
g
(
64
),
g
(
32
))
newHmacSecret
=
newHmacSecret
()
newHmacSecret
=
_newHmacSecret
()
# https://github.com/python/mypy/issues/1174
### Integer serialization
### Integer serialization
# - supports values from 0 to 0x202020202020201f
# - supports values from 0 to 0x202020202020201f
...
@@ -256,7 +258,7 @@ newHmacSecret = newHmacSecret()
...
@@ -256,7 +258,7 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
# - the 3 first bits code the number of bytes
def
packInteger
(
i
)
:
def
packInteger
(
i
:
int
)
->
bytes
:
for
n
in
range
(
8
):
for
n
in
range
(
8
):
x
=
32
<<
8
*
n
x
=
32
<<
8
*
n
if
i
<
x
:
if
i
<
x
:
...
@@ -264,7 +266,7 @@ def packInteger(i):
...
@@ -264,7 +266,7 @@ def packInteger(i):
i
-=
x
i
-=
x
raise
OverflowError
raise
OverflowError
def
unpackInteger
(
x
)
:
def
unpackInteger
(
x
:
bytes
)
->
tuple
[
int
,
int
]
|
None
:
n
=
x
[
0
]
>>
5
n
=
x
[
0
]
>>
5
try
:
try
:
i
,
=
struct
.
unpack
(
"!Q"
,
b'
\
0
'
*
(
7
-
n
)
+
x
[:
n
+
1
])
i
,
=
struct
.
unpack
(
"!Q"
,
b'
\
0
'
*
(
7
-
n
)
+
x
[:
n
+
1
])
...
...
re6st/x509.py
View file @
fd5bda0a
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
calendar
,
hashlib
,
hmac
,
logging
,
os
,
struct
,
subprocess
,
threading
,
time
import
calendar
,
hashlib
,
hmac
,
logging
,
os
,
struct
,
subprocess
,
threading
,
time
from
typing
import
Callable
,
Any
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
cryptography.hazmat.primitives
import
hashes
from
cryptography.hazmat.primitives
import
hashes
from
cryptography.hazmat.primitives.asymmetric
import
padding
from
cryptography.hazmat.primitives.asymmetric
import
padding
...
@@ -9,14 +11,14 @@ from cryptography.x509 import load_pem_x509_certificate
...
@@ -9,14 +11,14 @@ from cryptography.x509 import load_pem_x509_certificate
from
.
import
utils
from
.
import
utils
from
.version
import
protocol
from
.version
import
protocol
def
newHmacSecret
():
def
newHmacSecret
()
->
bytes
:
return
utils
.
newHmacSecret
(
int
(
time
.
time
()
*
1000000
))
return
utils
.
newHmacSecret
(
int
(
time
.
time
()
*
1000000
))
def
networkFromCa
(
ca
)
:
def
networkFromCa
(
ca
:
crypto
.
X509
)
->
str
:
# TODO: will be ca.serial_number after migration to cryptography
# TODO: will be ca.serial_number after migration to cryptography
return
bin
(
ca
.
get_serial_number
())[
3
:]
return
bin
(
ca
.
get_serial_number
())[
3
:]
def
subnetFromCert
(
cert
)
:
def
subnetFromCert
(
cert
:
crypto
.
X509
)
->
str
:
return
cert
.
get_subject
().
CN
return
cert
.
get_subject
().
CN
def
notBefore
(
cert
:
crypto
.
X509
)
->
int
:
def
notBefore
(
cert
:
crypto
.
X509
)
->
int
:
...
@@ -27,13 +29,13 @@ def notAfter(cert: crypto.X509) -> int:
...
@@ -27,13 +29,13 @@ def notAfter(cert: crypto.X509) -> int:
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notAfter
().
decode
(),
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notAfter
().
decode
(),
'%Y%m%d%H%M%SZ'
))
'%Y%m%d%H%M%SZ'
))
def
openssl
(
*
args
,
fds
=
[])
:
def
openssl
(
*
args
:
str
,
fds
=
[])
->
utils
.
Popen
:
return
utils
.
Popen
((
'openssl'
,)
+
args
,
return
utils
.
Popen
((
'openssl'
,)
+
args
,
stdin
=
subprocess
.
PIPE
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
pass_fds
=
fds
)
stderr
=
subprocess
.
PIPE
,
pass_fds
=
fds
)
def
encrypt
(
cert
,
data
)
:
def
encrypt
(
cert
:
bytes
,
data
:
bytes
)
->
bytes
:
r
,
w
=
os
.
pipe
()
r
,
w
=
os
.
pipe
()
try
:
try
:
threading
.
Thread
(
target
=
os
.
write
,
args
=
(
w
,
cert
)).
start
()
threading
.
Thread
(
target
=
os
.
write
,
args
=
(
w
,
cert
)).
start
()
...
@@ -47,10 +49,12 @@ def encrypt(cert, data):
...
@@ -47,10 +49,12 @@ def encrypt(cert, data):
raise
subprocess
.
CalledProcessError
(
p
.
returncode
,
'openssl'
,
err
)
raise
subprocess
.
CalledProcessError
(
p
.
returncode
,
'openssl'
,
err
)
return
out
return
out
def
fingerprint
(
cert
,
alg
=
'sha1'
):
def
fingerprint
(
cert
:
crypto
.
X509
,
alg
=
'sha1'
):
return
hashlib
.
new
(
alg
,
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
))
return
hashlib
.
new
(
alg
,
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
))
def
maybe_renew
(
path
,
cert
,
info
,
renew
,
force
=
False
):
def
maybe_renew
(
path
:
str
,
cert
:
crypto
.
X509
,
info
:
str
,
renew
:
Callable
[[],
bytes
],
force
=
False
)
->
tuple
[
crypto
.
X509
,
int
]:
from
.registry
import
RENEW_PERIOD
from
.registry
import
RENEW_PERIOD
while
True
:
while
True
:
if
force
:
if
force
:
...
@@ -94,7 +98,7 @@ class NewSessionError(Exception):
...
@@ -94,7 +98,7 @@ class NewSessionError(Exception):
class
Cert
:
class
Cert
:
def
__init__
(
self
,
ca
,
key
,
cert
=
None
):
def
__init__
(
self
,
ca
:
str
,
key
:
str
,
cert
:
str
|
None
=
None
):
self
.
ca_path
=
ca
self
.
ca_path
=
ca
self
.
cert_path
=
cert
self
.
cert_path
=
cert
self
.
key_path
=
key
self
.
key_path
=
key
...
@@ -112,24 +116,24 @@ class Cert:
...
@@ -112,24 +116,24 @@ class Cert:
self
.
cert
=
self
.
loadVerify
(
f
.
read
().
encode
())
self
.
cert
=
self
.
loadVerify
(
f
.
read
().
encode
())
@
property
@
property
def
prefix
(
self
):
def
prefix
(
self
)
->
str
:
return
utils
.
binFromSubnet
(
subnetFromCert
(
self
.
cert
))
return
utils
.
binFromSubnet
(
subnetFromCert
(
self
.
cert
))
@
property
@
property
def
network
(
self
):
def
network
(
self
)
->
str
:
return
networkFromCa
(
self
.
ca
)
return
networkFromCa
(
self
.
ca
)
@
property
@
property
def
subject_serial
(
self
):
def
subject_serial
(
self
)
->
int
:
return
int
(
self
.
cert
.
get_subject
().
serialNumber
)
return
int
(
self
.
cert
.
get_subject
().
serialNumber
)
@
property
@
property
def
openvpn_args
(
self
):
def
openvpn_args
(
self
)
->
tuple
[
str
,
...]
:
return
(
'--ca'
,
self
.
ca_path
,
return
(
'--ca'
,
self
.
ca_path
,
'--cert'
,
self
.
cert_path
,
'--cert'
,
self
.
cert_path
,
'--key'
,
self
.
key_path
)
'--key'
,
self
.
key_path
)
def
maybeRenew
(
self
,
registry
,
crl
):
def
maybeRenew
(
self
,
registry
,
crl
)
->
int
:
self
.
cert
,
next_renew
=
maybe_renew
(
self
.
cert_path
,
self
.
cert
,
self
.
cert
,
next_renew
=
maybe_renew
(
self
.
cert_path
,
self
.
cert
,
"Certificate"
,
lambda
:
registry
.
renewCertificate
(
self
.
prefix
),
"Certificate"
,
lambda
:
registry
.
renewCertificate
(
self
.
prefix
),
self
.
cert
.
get_serial_number
()
in
crl
)
self
.
cert
.
get_serial_number
()
in
crl
)
...
@@ -165,7 +169,6 @@ class Cert:
...
@@ -165,7 +169,6 @@ class Cert:
return
r
return
r
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
assert
isinstance
(
data
,
bytes
)
pub_key
=
self
.
ca_crypto
.
public_key
()
pub_key
=
self
.
ca_crypto
.
public_key
()
pub_key
.
verify
(
pub_key
.
verify
(
sign
,
sign
,
...
@@ -175,14 +178,13 @@ class Cert:
...
@@ -175,14 +178,13 @@ class Cert:
)
)
def
sign
(
self
,
data
:
bytes
)
->
bytes
:
def
sign
(
self
,
data
:
bytes
)
->
bytes
:
assert
isinstance
(
data
,
bytes
)
return
self
.
key_crypto
.
sign
(
return
self
.
key_crypto
.
sign
(
data
,
data
,
padding
.
PKCS1v15
(),
padding
.
PKCS1v15
(),
hashes
.
SHA512
()
hashes
.
SHA512
()
)
)
def
decrypt
(
self
,
data
)
:
def
decrypt
(
self
,
data
:
bytes
)
->
bytes
:
p
=
openssl
(
'rsautl'
,
'-decrypt'
,
'-inkey'
,
self
.
key_path
)
p
=
openssl
(
'rsautl'
,
'-decrypt'
,
'-inkey'
,
self
.
key_path
)
out
,
err
=
p
.
communicate
(
data
)
out
,
err
=
p
.
communicate
(
data
)
if
p
.
returncode
:
if
p
.
returncode
:
...
@@ -232,8 +234,9 @@ class Peer:
...
@@ -232,8 +234,9 @@ class Peer:
serial
=
None
serial
=
None
stop_date
=
float
(
'inf'
)
stop_date
=
float
(
'inf'
)
version
=
b''
version
=
b''
cert
:
crypto
.
X509
def
__init__
(
self
,
prefix
):
def
__init__
(
self
,
prefix
:
str
):
self
.
prefix
=
prefix
self
.
prefix
=
prefix
@
property
@
property
...
@@ -249,7 +252,7 @@ class Peer:
...
@@ -249,7 +252,7 @@ class Peer:
def
__lt__
(
self
,
other
):
def
__lt__
(
self
,
other
):
return
self
.
prefix
<
(
other
if
type
(
other
)
is
str
else
other
.
prefix
)
return
self
.
prefix
<
(
other
if
type
(
other
)
is
str
else
other
.
prefix
)
def
hello0
(
self
,
cert
)
:
def
hello0
(
self
,
cert
:
crypto
.
X509
)
->
bytes
:
if
self
.
_hello
<
time
.
time
():
if
self
.
_hello
<
time
.
time
():
try
:
try
:
# Always assume peer is not old, in case it has just upgraded,
# Always assume peer is not old, in case it has just upgraded,
...
@@ -264,7 +267,7 @@ class Peer:
...
@@ -264,7 +267,7 @@ class Peer:
def
hello0Sent
(
self
):
def
hello0Sent
(
self
):
self
.
_hello
=
time
.
time
()
+
60
self
.
_hello
=
time
.
time
()
+
60
def
hello
(
self
,
cert
,
protocol
)
:
def
hello
(
self
,
cert
:
Cert
,
protocol
:
int
)
->
bytes
:
key
=
self
.
_key
=
newHmacSecret
()
key
=
self
.
_key
=
newHmacSecret
()
h
=
encrypt
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
),
h
=
encrypt
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
),
key
)
key
)
...
@@ -274,10 +277,10 @@ class Peer:
...
@@ -274,10 +277,10 @@ class Peer:
return
b''
.
join
((
b'
\
0
\
0
\
0
\
2
'
,
PACKED_PROTOCOL
if
protocol
else
b''
,
return
b''
.
join
((
b'
\
0
\
0
\
0
\
2
'
,
PACKED_PROTOCOL
if
protocol
else
b''
,
h
,
cert
.
sign
(
h
)))
h
,
cert
.
sign
(
h
)))
def
_hmac
(
self
,
msg
)
:
def
_hmac
(
self
,
msg
:
bytes
)
->
bytes
:
return
hmac
.
HMAC
(
self
.
_key
,
msg
,
hashlib
.
sha1
).
digest
()
return
hmac
.
HMAC
(
self
.
_key
,
msg
,
hashlib
.
sha1
).
digest
()
def
newSession
(
self
,
key
,
protocol
):
def
newSession
(
self
,
key
:
bytes
,
protocol
:
int
):
if
key
<=
self
.
_key
:
if
key
<=
self
.
_key
:
raise
NewSessionError
(
self
.
_key
,
key
)
raise
NewSessionError
(
self
.
_key
,
key
)
self
.
_key
=
key
self
.
_key
=
key
...
@@ -285,12 +288,13 @@ class Peer:
...
@@ -285,12 +288,13 @@ class Peer:
self
.
_last
=
None
self
.
_last
=
None
self
.
protocol
=
protocol
self
.
protocol
=
protocol
def
verify
(
self
,
sign
,
data
):
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
crypto
.
verify
(
self
.
cert
,
sign
,
data
,
'sha512'
)
crypto
.
verify
(
self
.
cert
,
sign
,
data
,
'sha512'
)
seqno_struct
=
struct
.
Struct
(
"!L"
)
seqno_struct
=
struct
.
Struct
(
"!L"
)
def
decode
(
self
,
msg
,
_unpack
=
seqno_struct
.
unpack
):
def
decode
(
self
,
msg
:
bytes
,
_unpack
=
seqno_struct
.
unpack
)
\
->
tuple
[
int
,
bytes
,
int
|
None
]
|
bytes
:
seqno
,
=
_unpack
(
msg
[:
4
])
seqno
,
=
_unpack
(
msg
[:
4
])
if
seqno
<=
2
:
if
seqno
<=
2
:
msg
=
msg
[
4
:]
msg
=
msg
[
4
:]
...
@@ -306,7 +310,7 @@ class Peer:
...
@@ -306,7 +310,7 @@ class Peer:
self
.
_i
=
seqno
self
.
_i
=
seqno
return
msg
[
4
:
i
]
return
msg
[
4
:
i
]
def
encode
(
self
,
msg
,
_pack
=
seqno_struct
.
pack
)
:
def
encode
(
self
,
msg
:
str
|
bytes
,
_pack
=
seqno_struct
.
pack
)
->
bytes
:
self
.
_j
+=
1
self
.
_j
+=
1
if
type
(
msg
)
is
str
:
if
type
(
msg
)
is
str
:
msg
=
msg
.
encode
()
msg
=
msg
.
encode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment