Commit 16ea35a7 authored by Tom Niget's avatar Tom Niget Committed by Tom Niget

python3: migrate to python3

style: stop using deprecated getargspec for rpc


python3: migrate setup.py


style: remove new-style class code from the Python2 days


demo: properly escape screen command parameters



bug: fix wrong sending of HMAC header causing test failure


python3: revert changes from 90a624f0 in debian/control and re6stnet.spec

setup: indicate minimum Python versions

doc: update dependencies in readme


style: revert pep8 changes

Fix permission for ovpn-client and ovpn-server

vcs: add test-related generated files to gitignore
parent f2fd7247
...@@ -5,3 +5,12 @@ ...@@ -5,3 +5,12 @@
/build/ /build/
/dist/ /dist/
/re6stnet.egg-info/ /re6stnet.egg-info/
.idea
*.log
*.pid
*.db
*.state
*.crt
*.pem
demo/mbox
*.sock
...@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes. ...@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes.
Requirements Requirements
============ ============
- Python 2.7 - Python 3.11
- OpenSSL binary and development libraries - OpenSSL binary and development libraries
- OpenVPN 2.4.* - OpenVPN 2.4.*
- Babel_ (with Nexedi patches) - Babel_ (with Nexedi patches)
......
This diff is collapsed.
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
import re import re
import os import os
from new import function
from nemu.iproute import backticks, get_if_data, route, \ from nemu.iproute import backticks, get_if_data, route, \
get_addr_data, get_all_route_data, interface get_addr_data, get_all_route_data, interface
from nemu.interface import Switch, Interface from nemu.interface import Switch, Interface
from types import FunctionType
def _get_all_route_data(): def _get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all" ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
...@@ -56,7 +56,7 @@ def _get_all_route_data(): ...@@ -56,7 +56,7 @@ def _get_all_route_data():
metric)) metric))
return ret return ret
get_all_route_data.func_code = _get_all_route_data.func_code get_all_route_data.__code__ = _get_all_route_data.__code__
interface__init__ = interface.__init__ interface__init__ = interface.__init__
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
...@@ -65,12 +65,12 @@ def __init__(self, *args, **kw): ...@@ -65,12 +65,12 @@ def __init__(self, *args, **kw):
self.name = self.name.split('@',1)[0] self.name = self.name.split('@',1)[0]
interface.__init__ = __init__ interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.func_code, get_addr_data.orig = FunctionType(get_addr_data.__code__,
get_addr_data.func_globals) get_addr_data.__globals__)
def _get_addr_data(): def _get_addr_data():
byidx, bynam = get_addr_data.orig() byidx, bynam = get_addr_data.orig()
return byidx, {name.split('@',1)[0]: a for name, a in bynam.iteritems()} return byidx, {name.split('@',1)[0]: a for name, a in bynam.items()}
get_addr_data.func_code = _get_addr_data.func_code get_addr_data.__code__ = _get_addr_data.__code__
@staticmethod @staticmethod
def _gen_if_name(): def _gen_if_name():
......
#!/usr/bin/env python #!/usr/bin/env python3
def __file__(): def __file__():
import argparse, os, sys import argparse, os, sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
...@@ -30,4 +30,5 @@ def __file__(): ...@@ -30,4 +30,5 @@ def __file__():
return os.path.join(sys.path[0], sys.argv[1]) return os.path.join(sys.path[0], sys.argv[1])
__file__ = __file__() __file__ = __file__()
execfile(__file__) with open(__file__) as f:
exec(compile(f.read(), __file__, 'exec'))
...@@ -34,7 +34,7 @@ def checkHMAC(db, machines): ...@@ -34,7 +34,7 @@ def checkHMAC(db, machines):
else: else:
i = 0 if hmac[0] else 1 i = 0 if hmac[0] else 1
if hmac[i] != sign or hmac[i+1] != accept: if hmac[i] != sign or hmac[i+1] != accept:
print 'HMAC config wrong for in %s' % args print('HMAC config wrong for in %s' % args)
rc = False rc = False
if rc: if rc:
print('All nodes use Babel with the correct HMAC configuration') print('All nodes use Babel with the correct HMAC configuration')
......
...@@ -11,7 +11,7 @@ from re6st.registry import RegistryServer ...@@ -11,7 +11,7 @@ from re6st.registry import RegistryServer
@apply @apply
class proxy(object): class proxy:
def __init__(self): def __init__(self):
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
......
import json, logging, os, sqlite3, socket, subprocess, sys, time, zlib import base64, json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
from itertools import chain from itertools import chain
from .registry import RegistryClient from .registry import RegistryClient
from . import utils, version, x509 from . import utils, version, x509
class Cache(object): class Cache:
def __init__(self, db_path, registry, cert, db_size=200): def __init__(self, db_path, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix self._prefix = cert.prefix
self._db_size = db_size self._db_size = db_size
self._decrypt = cert.decrypt self._decrypt = cert.decrypt
...@@ -65,7 +65,7 @@ class Cache(object): ...@@ -65,7 +65,7 @@ class Cache(object):
@staticmethod @staticmethod
def _selectConfig(execute): # BBB: blob def _selectConfig(execute): # BBB: blob
return ((k, str(v) if type(v) is buffer else v) return ((k, str(v) if type(v) is memoryview else v)
for k, v in execute("SELECT * FROM config")) for k, v in execute("SELECT * FROM config"))
def _loadConfig(self, config): def _loadConfig(self, config):
...@@ -89,24 +89,24 @@ class Cache(object): ...@@ -89,24 +89,24 @@ class Cache(object):
logging.info("Getting new network parameters from registry...") logging.info("Getting new network parameters from registry...")
try: try:
# TODO: When possible, the registry should be queried via the re6st. # TODO: When possible, the registry should be queried via the re6st.
network_config = self._registry.getNetworkConfig(self._prefix)
logging.debug('getNetworkConfig result: %r', network_config)
x = json.loads(zlib.decompress( x = json.loads(zlib.decompress(
self._registry.getNetworkConfig(self._prefix))) network_config))
base64 = x.pop('', ()) base64_list = x.pop('', ())
config = {} config = {}
for k, v in x.iteritems(): for k, v in x.items():
k = str(k) k = str(k)
if k.startswith('babel_hmac'): if k.startswith('babel_hmac'):
if v: if v:
v = self._decrypt(v.decode('base64')) v = self._decrypt(base64.b64decode(v))
elif k in base64: elif k in base64_list:
v = v.decode('base64') v = base64.b64decode(v)
elif type(v) is unicode:
v = str(v)
elif isinstance(v, (list, dict)): elif isinstance(v, (list, dict)):
k += ':json' k += ':json'
v = json.dumps(v) v = json.dumps(v)
config[k] = v config[k] = v
except socket.error, e: except socket.error as e:
logging.warning(e) logging.warning(e)
return return
except Exception: except Exception:
...@@ -133,13 +133,13 @@ class Cache(object): ...@@ -133,13 +133,13 @@ class Cache(object):
# BBB: Use buffer because of http://bugs.python.org/issue13676 # BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6 # on Python 2.6
db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)", db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
((k, buffer(v) if k in base64 or ((k, memoryview(v) if k in base64_list or
k.startswith('babel_hmac') else v) k.startswith('babel_hmac') else v)
for k, v in config.iteritems())) for k, v in config.items()))
self._loadConfig(config.iteritems()) self._loadConfig(config.items())
return [k[:-5] if k.endswith(':json') else k return [k[:-5] if k.endswith(':json') else k
for k in chain(remove, (k for k in chain(remove, (k
for k, v in config.iteritems() for k, v in config.items()
if k not in old or old[k] != v))] if k not in old or old[k] != v))]
def warnProtocol(self): def warnProtocol(self):
...@@ -232,15 +232,16 @@ class Cache(object): ...@@ -232,15 +232,16 @@ class Cache(object):
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address" def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"): + " ORDER BY RANDOM()"):
return self._db.execute(__sql, (self._prefix, failed)) return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0] return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self): def getBootstrapPeer(self):
logging.info('Getting Boot peer...') logging.info('Getting Boot peer...')
try: try:
bootpeer = self._registry.getBootstrapPeer(self._prefix) bootpeer = self._registry.getBootstrapPeer(self._prefix)
prefix, address = self._decrypt(bootpeer).split() prefix, address = self._decrypt(bootpeer).decode().split()
except (socket.error, subprocess.CalledProcessError, ValueError), e: except (socket.error, subprocess.CalledProcessError, ValueError) as e:
logging.warning('Failed to bootstrap (%s)', logging.warning('Failed to bootstrap (%s)',
e if bootpeer else 'no peer returned') e if bootpeer else 'no peer returned')
else: else:
...@@ -275,6 +276,6 @@ class Cache(object): ...@@ -275,6 +276,6 @@ class Cache(object):
def getCountry(self, ip): def getCountry(self, ip):
try: try:
return self._registry.getCountry(self._prefix, ip) return self._registry.getCountry(self._prefix, ip).decode()
except socket.error, e: except socket.error as e:
logging.warning('Failed to get country (%s)', ip) logging.warning('Failed to get country (%s)', ip)
#!/usr/bin/python2 #!/usr/bin/env python3
import argparse, atexit, binascii, errno, hashlib import argparse, atexit, binascii, errno, hashlib
import os, subprocess, sqlite3, sys, time import os, subprocess, sqlite3, sys, time
from OpenSSL import crypto from OpenSSL import crypto
...@@ -6,14 +6,14 @@ if 're6st' not in sys.modules: ...@@ -6,14 +6,14 @@ if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0])) sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, x509 from re6st import registry, utils, x509
def create(path, text=None, mode=0666): def create(path, text=None, mode=0o666):
fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, mode) fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, mode)
try: try:
os.write(fd, text) os.write(fd, text)
finally: finally:
os.close(fd) os.close(fd)
def loadCert(pem): def loadCert(pem: bytes):
return crypto.load_certificate(crypto.FILETYPE_PEM, pem) return crypto.load_certificate(crypto.FILETYPE_PEM, pem)
def main(): def main():
...@@ -68,12 +68,12 @@ def main(): ...@@ -68,12 +68,12 @@ def main():
fingerprint = binascii.a2b_hex(fingerprint) fingerprint = binascii.a2b_hex(fingerprint)
if hashlib.new(alg).digest_size != len(fingerprint): if hashlib.new(alg).digest_size != len(fingerprint):
raise ValueError("wrong size") raise ValueError("wrong size")
except StandardError, e: except Exception as e:
parser.error("invalid fingerprint: %s" % e) parser.error("invalid fingerprint: %s" % e)
if x509.fingerprint(ca, alg).digest() != fingerprint: if x509.fingerprint(ca, alg).digest() != fingerprint:
sys.exit("CA fingerprint doesn't match") sys.exit("CA fingerprint doesn't match")
else: else:
print "WARNING: it is strongly recommended to use --fingerprint option." print("WARNING: it is strongly recommended to use --fingerprint option.")
network = x509.networkFromCa(ca) network = x509.networkFromCa(ca)
if config.is_needed: if config.is_needed:
route, err = subprocess.Popen(('ip', '-6', '-o', 'route', 'get', route, err = subprocess.Popen(('ip', '-6', '-o', 'route', 'get',
...@@ -91,17 +91,17 @@ def main(): ...@@ -91,17 +91,17 @@ def main():
try: try:
with open(cert_path) as f: with open(cert_path) as f:
cert = loadCert(f.read()) cert = loadCert(f.read())
components = dict(cert.get_subject().get_components()) components = {k.decode(): v for k, v in cert.get_subject().get_components()}
for k in reserved: for k in reserved:
components.pop(k, None) components.pop(k, None)
except IOError, e: except IOError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
components = {} components = {}
if config.req: if config.req:
components.update(config.req) components.update(config.req)
subj = req.get_subject() subj = req.get_subject()
for k, v in components.iteritems(): for k, v in components.items():
if k in reserved: if k in reserved:
sys.exit(k + " field is reserved.") sys.exit(k + " field is reserved.")
if v: if v:
...@@ -116,35 +116,35 @@ def main(): ...@@ -116,35 +116,35 @@ def main():
token = '' token = ''
elif not token: elif not token:
if not config.email: if not config.email:
config.email = raw_input('Please enter your email address: ') config.email = input('Please enter your email address: ')
s.requestToken(config.email) s.requestToken(config.email)
token_advice = "Use --token to retry without asking a new token\n" token_advice = "Use --token to retry without asking a new token\n"
while not token: while not token:
token = raw_input('Please enter your token: ') token = input('Please enter your token: ')
try: try:
with open(key_path) as f: with open(key_path) as f:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
key = None key = None
print "Reusing existing key." print("Reusing existing key.")
except IOError, e: except IOError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
bits = ca.get_pubkey().bits() bits = ca.get_pubkey().bits()
print "Generating %s-bit key ..." % bits print("Generating %s-bit key ..." % bits)
pkey = crypto.PKey() pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, bits) pkey.generate_key(crypto.TYPE_RSA, bits)
key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey) key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
create(key_path, key, 0600) create(key_path, key, 0o600)
req.set_pubkey(pkey) req.set_pubkey(pkey)
req.sign(pkey, 'sha512') req.sign(pkey, 'sha512')
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req) req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode()
# First make sure we can open certificate file for writing, # First make sure we can open certificate file for writing,
# to avoid using our token for nothing. # to avoid using our token for nothing.
cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0666) cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0o666)
print "Requesting certificate ..." print("Requesting certificate ...")
if config.location: if config.location:
cert = s.requestCertificate(token, req, location=config.location) cert = s.requestCertificate(token, req, location=config.location)
else: else:
...@@ -173,7 +173,7 @@ def main(): ...@@ -173,7 +173,7 @@ def main():
key_path)) key_path))
if not os.path.lexists(conf_path): if not os.path.lexists(conf_path):
create(conf_path, """\ create(conf_path, ("""\
registry %s registry %s
ca %s ca %s
cert %s cert %s
...@@ -187,14 +187,14 @@ key %s ...@@ -187,14 +187,14 @@ key %s
#O--verb #O--verb
#O3 #O3
""" % (config.registry, ca_path, cert_path, key_path, """ % (config.registry, ca_path, cert_path, key_path,
('country ' + config.location.split(',', 1)[0]) \ ('country ' + config.location.split(',', 1)[0])
if config.location else '')) if config.location else '')).encode())
print "Sample configuration file created." print("Sample configuration file created.")
cn = x509.subnetFromCert(cert) cn = x509.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn) subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \ print("Your subnet: %s/%u (CN=%s)"
% (utils.ipFromBin(subnet), len(subnet), cn) % (utils.ipFromBin(subnet), len(subnet), cn))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
#!/usr/bin/python2 #!/usr/bin/env python3
import atexit, errno, logging, os, shutil, signal import atexit, errno, logging, os, shutil, signal
import socket, struct, subprocess, sys import socket, struct, subprocess, sys
from collections import deque from collections import deque
...@@ -246,7 +246,7 @@ def main(): ...@@ -246,7 +246,7 @@ def main():
try: try:
from re6st.upnpigd import Forwarder from re6st.upnpigd import Forwarder
forwarder = Forwarder('re6stnet openvpn server') forwarder = Forwarder('re6stnet openvpn server')
except Exception, e: except Exception as e:
if ipv4: if ipv4:
raise raise
logging.info("%s: assume we are not NATed", e) logging.info("%s: assume we are not NATed", e)
...@@ -299,7 +299,7 @@ def main(): ...@@ -299,7 +299,7 @@ def main():
timeout = 4 * cache.hello timeout = 4 * cache.hello
cleanup = [lambda: cache.cacheMinimize(config.client_count), cleanup = [lambda: cache.cacheMinimize(config.client_count),
lambda: shutil.rmtree(config.run, True)] lambda: shutil.rmtree(config.run, True)]
utils.makedirs(config.run, 0700) utils.makedirs(config.run, 0o700)
control_socket = os.path.join(config.run, 'babeld.sock') control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client: if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket, tunnel_manager = tunnel.TunnelManager(control_socket,
...@@ -362,7 +362,7 @@ def main(): ...@@ -362,7 +362,7 @@ def main():
if not dh: if not dh:
dh = os.path.join(config.state, "dh.pem") dh = os.path.join(config.state, "dh.pem")
cache.getDh(dh) cache.getDh(dh)
for iface, (port, proto) in server_tunnels.iteritems(): for iface, (port, proto) in server_tunnels.items():
r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM) r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
utils.setCloexec(r) utils.setCloexec(r)
cleanup.append(plib.server(iface, config.max_clients, cleanup.append(plib.server(iface, config.max_clients,
...@@ -442,7 +442,7 @@ def main(): ...@@ -442,7 +442,7 @@ def main():
except: except:
pass pass
exit.release() exit.release()
except ReexecException, e: except ReexecException as e:
logging.info(e) logging.info(e)
except Exception: except Exception:
utils.log_exception() utils.log_exception()
...@@ -455,7 +455,7 @@ def main(): ...@@ -455,7 +455,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
try: try:
main() main()
except SystemExit, e: except SystemExit as e:
if type(e.code) is str: if type(e.code) is str:
if hasattr(logging, 'trace'): # utils.setupLog called if hasattr(logging, 'trace'): # utils.setupLog called
logging.critical(e.code) logging.critical(e.code)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import httplib, logging, os, socket, sys import http.client, logging, os, socket, sys
from BaseHTTPServer import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer from socketserver import ThreadingTCPServer
from urlparse import parse_qsl from urllib.parse import parse_qsl
if 're6st' not in sys.modules: if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0])) sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, version from re6st import registry, utils, version
...@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler): ...@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler):
path = self.path path = self.path
query = {} query = {}
else: else:
query = dict(parse_qsl(query, keep_blank_values=1, query = dict(parse_qsl(query, keep_blank_values=True,
strict_parsing=1)) strict_parsing=True))
_, path = path.split('/') _, path = path.split('/')
if not _: if not _:
return self.server.handle_request(self, path, query) return self.server.handle_request(self, path, query)
except Exception: except Exception:
logging.info(self.requestline, exc_info=1) logging.info(self.requestline, exc_info=True)
self.send_error(httplib.BAD_REQUEST) self.send_error(http.client.BAD_REQUEST)
def log_error(*args): def log_error(*args):
pass pass
......
...@@ -5,7 +5,7 @@ from . import utils ...@@ -5,7 +5,7 @@ from . import utils
uint16 = struct.Struct("!H") uint16 = struct.Struct("!H")
header = struct.Struct("!HI") header = struct.Struct("!HI")
class Struct(object): class Struct:
def __init__(self, format, *args): def __init__(self, format, *args):
if args: if args:
...@@ -29,39 +29,39 @@ class Struct(object): ...@@ -29,39 +29,39 @@ class Struct(object):
self.encode = encode self.encode = encode
self.decode = decode self.decode = decode
class Array(object): class Array:
def __init__(self, item): def __init__(self, item):
self._item = item self._item = item
def encode(self, buffer, value): def encode(self, buffer: bytes, value: list):
buffer += uint16.pack(len(value)) buffer += uint16.pack(len(value))
encode = self._item.encode encode = self._item.encode
for value in value: for value in value:
encode(buffer, value) encode(buffer, value)
def decode(self, buffer, offset=0): def decode(self, buffer: bytes, offset=0) -> tuple[int, list]:
r = [] r = []
o = offset + 2 o = offset + 2
decode = self._item.decode decode = self._item.decode
for i in xrange(*uint16.unpack_from(buffer, offset)): for i in range(*uint16.unpack_from(buffer, offset)):
o, x = decode(buffer, o) o, x = decode(buffer, o)
r.append(x) r.append(x)
return o, r return o, r
class String(object): class String:
@staticmethod @staticmethod
def encode(buffer, value): def encode(buffer: bytes, value: str):
buffer += value + "\0" buffer += value.encode("utf-8") + b'\x00'
@staticmethod @staticmethod
def decode(buffer, offset=0): def decode(buffer: bytes, offset=0) -> tuple[int, str]:
i = buffer.index("\0", offset) i = buffer.index(0, offset)
return i + 1, buffer[offset:i] return i + 1, buffer[offset:i].decode("utf-8")
class Buffer(object): class Buffer:
def __init__(self): def __init__(self):
self._buf = bytearray() self._buf = bytearray()
...@@ -104,21 +104,6 @@ class Buffer(object): ...@@ -104,21 +104,6 @@ class Buffer(object):
self._seek(r) self._seek(r)
return value return value
try: # BBB: Python < 2.7.4 (http://bugs.python.org/issue10212)
uint16.unpack_from(bytearray(uint16.size))
except TypeError:
def unpack_from(self, struct):
r = self._r
x = r + struct.size
value = struct.unpack(buffer(self._buf)[r:x])
self._seek(x)
return value
def decode(self, decode):
r = self._r
size, value = decode(buffer(self._buf)[r:])
self._seek(r + size)
return value
# writing # writing
def send(self, socket, *args): def send(self, socket, *args):
...@@ -129,7 +114,7 @@ class Buffer(object): ...@@ -129,7 +114,7 @@ class Buffer(object):
struct.pack_into(self._buf, offset, *args) struct.pack_into(self._buf, offset, *args)
class Packet(object): class Packet:
response_dict = {} response_dict = {}
...@@ -149,7 +134,7 @@ class Packet(object): ...@@ -149,7 +134,7 @@ class Packet(object):
logging.trace('send %s%r', self.__class__.__name__, logging.trace('send %s%r', self.__class__.__name__,
(self.id,) + self.args) (self.id,) + self.args)
offset = len(buffer) offset = len(buffer)
buffer += '\0' * header.size buffer += b'\x00' * header.size
r = self.request r = self.request
if isinstance(r, Struct): if isinstance(r, Struct):
r.encode(buffer, self.args) r.encode(buffer, self.args)
...@@ -182,7 +167,7 @@ class ConnectionClosed(BabelException): ...@@ -182,7 +167,7 @@ class ConnectionClosed(BabelException):
return "connection to babeld closed (%s)" % self.args return "connection to babeld closed (%s)" % self.args
class Babel(object): class Babel:
_decode = None _decode = None
...@@ -206,11 +191,11 @@ class Babel(object): ...@@ -206,11 +191,11 @@ class Babel(object):
def select(*args): def select(*args):
try: try:
s.connect(self.socket_path) s.connect(self.socket_path)
except socket.error, e: except socket.error as e:
logging.debug("Can't connect to %r (%r)", self.socket_path, e) logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e return e
s.send("\1") s.send(b'\x01')
s.setblocking(0) s.setblocking(False)
del self.select del self.select
self.socket = s self.socket = s
return self.select(*args) return self.select(*args)
...@@ -269,7 +254,7 @@ class Babel(object): ...@@ -269,7 +254,7 @@ class Babel(object):
a = len(self.network) a = len(self.network)
for route in routes: for route in routes:
assert route.flags & 1, route # installed assert route.flags & 1, route # installed
if route.prefix.startswith('\0\0\0\0\0\0\0\0\0\0\xff\xff'): if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
continue continue
assert route.neigh_address == route.nexthop, route assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex address = route.neigh_address, route.ifindex
...@@ -310,7 +295,7 @@ class Babel(object): ...@@ -310,7 +295,7 @@ class Babel(object):
pass pass
class iterRoutes(object): class iterRoutes:
_waiting = True _waiting = True
...@@ -323,7 +308,7 @@ class iterRoutes(object): ...@@ -323,7 +308,7 @@ class iterRoutes(object):
c.select(*args) c.select(*args)
utils.select(*args) utils.select(*args)
return (prefix return (prefix
for neigh_routes in c.neighbours.itervalues() for neigh_routes in c.neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix) if prefix)
......
import errno, os, socket, stat, threading import errno, os, socket, stat, threading
class Socket(object): class Socket:
def __init__(self, socket): def __init__(self, socket):
# In case that the default timeout is not None. # In case that the default timeout is not None.
...@@ -37,14 +37,14 @@ class Socket(object): ...@@ -37,14 +37,14 @@ class Socket(object):
try: try:
self._socket.recv(0) self._socket.recv(0)
return True return True
except socket.error, (err, _): except socket.error as e:
if err != errno.EAGAIN: if e.errno != errno.EAGAIN:
raise raise
self._socket.setblocking(1) self._socket.setblocking(1)
return False return False
class Console(object): class Console:
def __init__(self, path, pdb): def __init__(self, path, pdb):
self.path = path self.path = path
...@@ -52,7 +52,7 @@ class Console(object): ...@@ -52,7 +52,7 @@ class Console(object):
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
try: try:
self._removeSocket() self._removeSocket()
except OSError, e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
s.bind(path) s.bind(path)
......
...@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs ...@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs
freeifaddrs.restype = None freeifaddrs.restype = None
freeifaddrs.argtypes = [POINTER(struct_ifaddrs)] freeifaddrs.argtypes = [POINTER(struct_ifaddrs)]
class unpacker(object): class unpacker:
def __init__(self, buf): def __init__(self, buf):
self._buf = buf self._buf = buf
...@@ -55,7 +55,7 @@ class unpacker(object): ...@@ -55,7 +55,7 @@ class unpacker(object):
self._offset += s.size self._offset += s.size
return result return result
class PimDm(object): class PimDm:
def __init__(self): def __init__(self):
s_netlink = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) s_netlink = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)
......
#!/usr/bin/python2 -S #!/usr/bin/env -S python3 -S
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
...@@ -14,4 +14,4 @@ if script_type == 'up': ...@@ -14,4 +14,4 @@ if script_type == 'up':
if script_type == 'route-up': if script_type == 'route-up':
import time import time
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(), os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip']))) int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])).encode())
#!/usr/bin/python2 -S #!/usr/bin/env -S python3 -S
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
...@@ -7,10 +7,10 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6'] ...@@ -7,10 +7,10 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events # Write into pipe connect/disconnect events
fd = int(sys.argv[1]) fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'], os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip)))) int(os.environ['tls_serial_0']), external_ip))).encode("utf-8"))
if script_type == 'client-connect': if script_type == 'client-connect':
if os.read(fd, 1) == '\0': if os.read(fd, 1) == b'\x00':
sys.exit(1) sys.exit(1)
# Send client its external ip address # Send client its external ip address
with open(sys.argv[2], 'w') as f: with open(sys.argv[2], 'w') as f:
......
import binascii
import logging, errno, os import logging, errno, os
from typing import Optional
from . import utils from . import utils
here = os.path.realpath(os.path.dirname(__file__)) here = os.path.realpath(os.path.dirname(__file__))
ovpn_server = os.path.join(here, 'ovpn-server') ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client') ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log = None ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw): def openvpn(iface, encrypt, *args, **kw):
args = ['openvpn', args = ['openvpn',
...@@ -43,7 +45,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw): ...@@ -43,7 +45,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
'--max-clients', str(max_clients), '--max-clients', str(max_clients),
'--port', str(port), '--port', str(port),
'--proto', proto, '--proto', proto,
*args, **kw) *args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw): def client(iface, address_list, encrypt, *args, **kw):
...@@ -80,9 +82,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -80,9 +82,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny', '-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)] '-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign: if hmac_sign:
def key(cmd, id, value): def key(cmd, id: str, value):
cmd += '-C', ('key type blake2s128 id %s value %s' % cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, value.encode('hex'))) (id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign) key(cmd, 'sign', hmac_sign)
default += ' key sign' default += ' key sign'
if hmac_accept is not None: if hmac_accept is not None:
...@@ -132,7 +134,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -132,7 +134,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists # WKRD: babeld fails to start if pidfile already exists
try: try:
os.remove(pidfile) os.remove(pidfile)
except OSError, e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
logging.info('%r', cmd) logging.info('%r', cmd)
......
This diff is collapsed.
from pathlib2 import Path from pathlib import Path
DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo" DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo"
...@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH ...@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH
DH_FILE = DEMO_PATH / "dh2048.pem" DH_FILE = DEMO_PATH / "dh2048.pem"
class DummyNode(object): class DummyNode:
"""fake node to reuse Re6stRegistry """fake node to reuse Re6stRegistry
error: node.Popen has destory method which not in subprocess.Popen error: node.Popen has destory method which not in subprocess.Popen
...@@ -60,7 +60,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db # read token from db
db = sqlite3.connect(str(self.server.db), isolation_level=None) db = sqlite3.connect(str(self.server.db), isolation_level=None)
token = None token = None
for _ in xrange(100): for _ in range(100):
time.sleep(.1) time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?", token = db.execute("SELECT token FROM token WHERE email=?",
(email,)).fetchone() (email,)).fetchone()
...@@ -70,7 +70,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestRegistryClientInteract(unittest.TestCase):
self.fail("Request token failed, no token in database") self.fail("Request token failed, no token in database")
# token: tuple[unicode,] # token: tuple[unicode,]
token = str(token[0]) token = str(token[0])
self.assertEqual(client.isToken(token), "1") self.assertEqual(client.isToken(token).decode(), "1")
# request ca # request ca
ca = client.getCa() ca = client.getCa()
...@@ -78,7 +78,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# request a cert and get cn # request a cert and get cn
key, csr = tools.generate_csr() key, csr = tools.generate_csr()
cert = client.requestCertificate(token, csr) cert = client.requestCertificate(token, csr)
self.assertEqual(client.isToken(token), '', "token should be deleted") self.assertEqual(client.isToken(token).decode(), '', "token should be deleted")
# creat x509.cert object # creat x509.cert object
def write_to_temp(text): def write_to_temp(text):
...@@ -97,18 +97,19 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -97,18 +97,19 @@ class TestRegistryClientInteract(unittest.TestCase):
# verfiy cn and prefix # verfiy cn and prefix
prefix = client.cert.prefix prefix = client.cert.prefix
cn = client.getNodePrefix(email) cn = client.getNodePrefix(email).decode()
self.assertEqual(tools.prefix2cn(prefix), cn) self.assertEqual(tools.prefix2cn(prefix), cn)
# simulate the process in cache # simulate the process in cache
# just prove works # just prove works
net_config = client.getNetworkConfig(prefix) net_config = client.getNetworkConfig(prefix)
self.assertIsNotNone(net_config)
net_config = json.loads(zlib.decompress(net_config)) net_config = json.loads(zlib.decompress(net_config))
self.assertEqual(net_config[u'max_clients'], self.max_clients) self.assertEqual(net_config[u'max_clients'], self.max_clients)
# no re6stnet, empty result # no re6stnet, empty result
bootpeer = client.getBootstrapPeer(prefix) bootpeer = client.getBootstrapPeer(prefix)
self.assertEqual(bootpeer, "") self.assertEqual(bootpeer.decode(), "")
# server should not die # server should not die
self.assertIsNone(self.server.proc.poll()) self.assertIsNone(self.server.proc.poll())
......
...@@ -4,7 +4,7 @@ import nemu ...@@ -4,7 +4,7 @@ import nemu
import time import time
import weakref import weakref
from subprocess import PIPE from subprocess import PIPE
from pathlib2 import Path from pathlib import Path
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
...@@ -50,7 +50,7 @@ class Node(nemu.Node): ...@@ -50,7 +50,7 @@ class Node(nemu.Node):
if_s.add_v4_address(ip, prefix_len=prefix_len) if_s.add_v4_address(ip, prefix_len=prefix_len)
return if_s return if_s
class NetManager(object): class NetManager:
"""contain all the nemu object created, so they can live more time""" """contain all the nemu object created, so they can live more time"""
def __init__(self): def __init__(self):
self.object = [] self.object = []
...@@ -60,7 +60,7 @@ class NetManager(object): ...@@ -60,7 +60,7 @@ class NetManager(object):
Raise: Raise:
AssertionError AssertionError
""" """
for reg, nodes in self.registries.iteritems(): for reg, nodes in self.registries.items():
for node in nodes: for node in nodes:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE) app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE)
ret = app0.wait() ret = app0.wait()
......
...@@ -6,13 +6,15 @@ import ipaddress ...@@ -6,13 +6,15 @@ import ipaddress
import json import json
import logging import logging
import re import re
import shlex
import shutil import shutil
import sqlite3 import sqlite3
import sys
import tempfile import tempfile
import time import time
import weakref import weakref
from subprocess import PIPE from subprocess import PIPE
from pathlib2 import Path from pathlib import Path
from re6st.tests import tools from re6st.tests import tools
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
...@@ -20,9 +22,10 @@ from re6st.tests import DEMO_PATH ...@@ -20,9 +22,10 @@ from re6st.tests import DEMO_PATH
WORK_DIR = Path(__file__).parent / "temp_net_test" WORK_DIR = Path(__file__).parent / "temp_net_test"
DH_FILE = DEMO_PATH / "dh2048.pem" DH_FILE = DEMO_PATH / "dh2048.pem"
RE6STNET = "python -m re6st.cli.node" PYTHON = shlex.quote(sys.executable)
RE6ST_REGISTRY = "python -m re6st.cli.registry" RE6STNET = PYTHON + " -m re6st.cli.node"
RE6ST_CONF = "python -m re6st.cli.conf" RE6ST_REGISTRY = PYTHON + " -m re6st.cli.registry"
RE6ST_CONF = PYTHON + " -m re6st.cli.conf"
def initial(): def initial():
"""create the workplace""" """create the workplace"""
...@@ -36,7 +39,7 @@ def ip_to_serial(ip6): ...@@ -36,7 +39,7 @@ def ip_to_serial(ip6):
return int(ip6, 16) return int(ip6, 16)
class Re6stRegistry(object): class Re6stRegistry:
"""class run a re6st-registry service on a namespace""" """class run a re6st-registry service on a namespace"""
registry_seq = 0 registry_seq = 0
...@@ -72,7 +75,7 @@ class Re6stRegistry(object): ...@@ -72,7 +75,7 @@ class Re6stRegistry(object):
self.run() self.run()
# wait the servcice started # wait the servcice started
p = self.node.Popen(['python', '-c', """if 1: p = self.node.Popen([sys.executable, '-c', """if 1:
import socket, time import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True: while True:
...@@ -115,7 +118,7 @@ class Re6stRegistry(object): ...@@ -115,7 +118,7 @@ class Re6stRegistry(object):
'--client-count', (self.client_number+1)//2, '--port', self.port] '--client-count', (self.client_number+1)//2, '--port', self.port]
#PY3: convert PosixPath to str, can be remove in Python 3 #PY3: convert PosixPath to str, can be remove in Python 3
cmd = map(str, cmd) cmd = list(map(str, cmd))
cmd[:0] = RE6ST_REGISTRY.split() cmd[:0] = RE6ST_REGISTRY.split()
...@@ -139,7 +142,7 @@ class Re6stRegistry(object): ...@@ -139,7 +142,7 @@ class Re6stRegistry(object):
pass pass
class Re6stNode(object): class Re6stNode:
"""class run a re6stnet service on a namespace""" """class run a re6stnet service on a namespace"""
node_seq = 0 node_seq = 0
...@@ -210,7 +213,7 @@ class Re6stNode(object): ...@@ -210,7 +213,7 @@ class Re6stNode(object):
# read token # read token
db = sqlite3.connect(str(self.registry.db), isolation_level=None) db = sqlite3.connect(str(self.registry.db), isolation_level=None)
token = None token = None
for _ in xrange(100): for _ in range(100):
time.sleep(.1) time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?", token = db.execute("SELECT token FROM token WHERE email=?",
(self.email,)).fetchone() (self.email,)).fetchone()
...@@ -223,7 +226,7 @@ class Re6stNode(object): ...@@ -223,7 +226,7 @@ class Re6stNode(object):
out, _ = p.communicate(str(token[0])) out, _ = p.communicate(str(token[0]))
# logging.debug("re6st-conf output: {}".format(out)) # logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node # find the ipv6 subnet of node
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out).group(0) self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out.decode("utf-8")).group(0)
data = {'ip6': self.ip6, 'hash': self.registry.ident} data = {'ip6': self.ip6, 'hash': self.registry.ident}
with open(str(self.data_file), 'w') as f: with open(str(self.data_file), 'w') as f:
json.dump(data, f) json.dump(data, f)
...@@ -236,7 +239,7 @@ class Re6stNode(object): ...@@ -236,7 +239,7 @@ class Re6stNode(object):
'--key', self.key, '-v4', '--registry', self.registry.url, '--key', self.key, '-v4', '--registry', self.registry.url,
'--console', self.console] '--console', self.console]
#PY3: same as for Re6stRegistry.run #PY3: same as for Re6stRegistry.run
cmd = map(str, cmd) cmd = list(map(str, cmd))
cmd[:0] = RE6STNET.split() cmd[:0] = RE6STNET.split()
cmd += args cmd += args
......
"""contain ping-test for re6set net""" """contain ping-test for re6set net"""
import os import os
import sys
import unittest import unittest
import time import time
import psutil import psutil
import logging import logging
import random import random
from pathlib2 import Path from pathlib import Path
import network_build from . import network_build, re6st_wrap
import re6st_wrap
PING_PATH = str(Path(__file__).parent.resolve() / "ping.py") PING_PATH = str(Path(__file__).parent.resolve() / "ping.py")
...@@ -47,12 +47,12 @@ def wait_stable(nodes, timeout=240): ...@@ -47,12 +47,12 @@ def wait_stable(nodes, timeout=240):
for node in nodes: for node in nodes:
sub_ips = set(ips) - {node.ip6} sub_ips = set(ips) - {node.ip6}
node.ping_proc = node.node.Popen( node.ping_proc = node.node.Popen(
["python", PING_PATH, '--retry', '-a'] + list(sub_ips)) [sys.executable, PING_PATH, '--retry', '-a'] + list(sub_ips), env=os.environ)
# check all the node network can ping each other, in order reverse # check all the node network can ping each other, in order reverse
unfinished = list(nodes) unfinished = list(nodes)
while unfinished: while unfinished:
for i in xrange(len(unfinished)-1, -1, -1): for i in range(len(unfinished)-1, -1, -1):
node = unfinished[i] node = unfinished[i]
if node.ping_proc.poll() is not None: if node.ping_proc.poll() is not None:
logging.debug("%s 's network is stable", node.name) logging.debug("%s 's network is stable", node.name)
......
#!/usr/bin/python2 #!/usr/bin/env python3
""" unit test for re6st-conf """ unit test for re6st-conf
""" """
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import sys import sys
import unittest import unittest
from shutil import rmtree from shutil import rmtree
from StringIO import StringIO from io import StringIO
from mock import patch from mock import patch
from OpenSSL import crypto from OpenSSL import crypto
...@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey # mocked server cert and pkey
cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull) cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull)
cls.fingerprint = "".join( cls.cert.digest("sha1").split(":")) cls.fingerprint = "".join( cls.cert.digest("sha1").decode().split(":"))
# client.getCa should return a string form cert # client.getCa should return a string form cert
cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert) cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert)
...@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase):
# go back to original dir # go back to original dir
os.chdir(self.origin_dir) os.chdir(self.origin_dir)
@patch("__builtin__.raw_input") @patch("builtins.input")
def test_basic(self, mock_raw_input): def test_basic(self, mock_raw_input):
""" go through all the step """ go through all the step
getCa, requestToken, requestCertificate getCa, requestToken, requestCertificate
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import random import random
import string import string
import json import json
import httplib import http.client
import base64 import base64
import unittest import unittest
import hmac import hmac
...@@ -13,12 +13,13 @@ import tempfile ...@@ -13,12 +13,13 @@ import tempfile
from argparse import Namespace from argparse import Namespace
from OpenSSL import crypto from OpenSSL import crypto
from mock import Mock, patch from mock import Mock, patch
from pathlib2 import Path from pathlib import Path
from re6st import registry from re6st import registry
from re6st.tests.tools import * from re6st.tests.tools import *
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer # TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions # getIPV4Information, versions
...@@ -49,6 +50,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None): ...@@ -49,6 +50,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial += 1 insert_cert.serial += 1
return key, cert return key, cert
insert_cert.serial = 0 insert_cert.serial = 0
...@@ -77,17 +79,26 @@ class TestRegistryServer(unittest.TestCase): ...@@ -77,17 +79,26 @@ class TestRegistryServer(unittest.TestCase):
def setUp(self): def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \ self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
+ "@mail.com" + "@mail.com"
def test_recv(self): def test_recv(self):
recv = self.server.sock.recv = Mock() side_effect = iter([
recv.side_effect = [
"0001001001001a_msg", "0001001001001a_msg",
"0001001001002\0001dqdq", "0001001001002\0001dqdq",
"0001001001001\000a_msg", "0001001001001\000a_msg",
"0001001001001\000\4a_msg", "0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null "0000000000000\0" # ERROR, IndexError: msg is null
] ])
class SocketProxy:
def __init__(self, wrappee):
self.wrappee = wrappee
self.recv = lambda _: next(side_effect)
def __getattr__(self, attr):
return getattr(self.wrappee, attr)
self.server.sock = SocketProxy(self.server.sock)
try: try:
res1 = self.server.recv(4) res1 = self.server.recv(4)
...@@ -115,7 +126,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -115,7 +126,7 @@ class TestRegistryServer(unittest.TestCase):
now = int(time.time()) - self.config.grace_period + 20 now = int(time.time()) - self.config.grace_period + 20
# makeup data # makeup data
insert_cert(cur, self.server.cert, prefix_old, 1) insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1) insert_cert(cur, self.server.cert, prefix, now - 1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)", cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2)) (token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)", cur.execute("INSERT INTO token VALUES (?,?,?,?)",
...@@ -143,16 +154,16 @@ class TestRegistryServer(unittest.TestCase): ...@@ -143,16 +154,16 @@ class TestRegistryServer(unittest.TestCase):
prefix = "0000000011111111" prefix = "0000000011111111"
method = "func" method = "func"
protocol = 7 protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2} params = {"cn": prefix, "a": 1, "b": 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
del func._private del func._private
func.return_value = result = "this_is_a_result" func.return_value = result = b"this_is_a_result"
key = "this_is_a_key" key = b"this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)] self.server.sessions[prefix] = [(key, protocol)]
request = Mock() request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111" request.path = "/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode( request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())} hmac.HMAC(key, request.path.encode(), hashlib.sha1).digest())}
self.server.handle_request(request, method, params) self.server.handle_request(request, method, params)
...@@ -162,11 +173,11 @@ class TestRegistryServer(unittest.TestCase): ...@@ -162,11 +173,11 @@ class TestRegistryServer(unittest.TestCase):
[(hashlib.sha1(key).digest(), protocol)]) [(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params) func.assert_called_once_with(**params)
# http response check # http response check
request.send_response.assert_called_once_with(httplib.OK) request.send_response.assert_called_once_with(http.client.OK)
request.send_header.assert_any_call("Content-Length", str(len(result))) request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call( request.send_header.assert_any_call(
registry.HMAC_HEADER, registry.HMAC_HEADER,
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest())) base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest()).decode("ascii"))
request.wfile.write.assert_called_once_with(result) request.wfile.write.assert_called_once_with(result)
# remove the create session \n # remove the create session \n
...@@ -176,12 +187,12 @@ class TestRegistryServer(unittest.TestCase): ...@@ -176,12 +187,12 @@ class TestRegistryServer(unittest.TestCase):
def test_handle_request_private(self, func): def test_handle_request_private(self, func):
"""case request with _private attr""" """case request with _private attr"""
method = "func" method = "func"
params = {"a" : 1, "b" : 2} params = {"a": 1, "b": 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
func.return_value = None func.return_value = None
request_good = Mock() request_good = Mock()
request_good.client_address = self.config.authorized_origin request_good.client_address = self.config.authorized_origin
request_good.headers = {'X-Forwarded-For':self.config.authorized_origin[0]} request_good.headers = {'X-Forwarded-For': self.config.authorized_origin[0]}
request_bad = Mock() request_bad = Mock()
request_bad.client_address = ["wrong_address"] request_bad.client_address = ["wrong_address"]
...@@ -189,8 +200,8 @@ class TestRegistryServer(unittest.TestCase): ...@@ -189,8 +200,8 @@ class TestRegistryServer(unittest.TestCase):
self.server.handle_request(request_bad, method, params) self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params) func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN) request_bad.send_error.assert_called_once_with(http.client.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT) request_good.send_response.assert_called_once_with(http.client.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry # will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self): def test_getPeerProtocol(self):
...@@ -213,7 +224,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -213,7 +224,7 @@ class TestRegistryServer(unittest.TestCase):
res = self.server.hello(prefix, protocol=protocol) res = self.server.hello(prefix, protocol=protocol)
# decrypt # decrypt
length = len(res)/2 length = len(res) // 2
key, sign = res[:length], res[length:] key, sign = res[:length], res[length:]
key = decrypt(pkey, key) key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key, self.assertEqual(self.server.sessions[prefix][-1][0], key,
...@@ -282,7 +293,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -282,7 +293,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less = 0 nb_less = 0
for cert in self.server.iterCert(): for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial): if s and int(s) <= serial:
nb_less += 1 nb_less += 1
self.assertEqual(nb_less, serial) self.assertEqual(nb_less, serial)
...@@ -378,7 +389,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -378,7 +389,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs = get_hmac() hmacs = get_hmac()
key_1 = hmacs[1] key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, '']) self.assertEqual(hmacs, [None, key_1, b''])
# step 2 # step 2
self.server.updateHMAC() self.server.updateHMAC()
...@@ -397,12 +408,11 @@ class TestRegistryServer(unittest.TestCase): ...@@ -397,12 +408,11 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(get_hmac(), [None, key_2, key_1]) self.assertEqual(get_hmac(), [None, key_2, key_1])
#setp 5 # step 5
self.server.updateHMAC() self.server.updateHMAC()
self.assertEqual(get_hmac(), [key_2, None, None]) self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self): def test_getNodePrefix(self):
# prefix in short format # prefix in short format
prefix = "0000000101" prefix = "0000000101"
...@@ -426,19 +436,33 @@ class TestRegistryServer(unittest.TestCase): ...@@ -426,19 +436,33 @@ class TestRegistryServer(unittest.TestCase):
('0000000000000001', '2 0/16 6/16') ('0000000000000001', '2 0/16 6/16')
] ]
recv.side_effect = recv_case recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout): def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call # rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0 side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None] return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1 side_effct.i = len(recv_case) + 1
select.side_effect = side_effct select.side_effect = side_effct
res = self.server.topology() res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \ class CustomDecoder(json.JSONDecoder):
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \ def __init__(self, **kwargs):
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\ json.JSONDecoder.__init__(self, **kwargs)
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}''' self.parse_array = self.JSONArray
self.scan_once = json.scanner.py_make_scanner(self)
def JSONArray(self, s_and_end, scan_once, **kwargs):
values, end = json.decoder.JSONArray(s_and_end, scan_once, **kwargs)
return set(values), end
res = json.loads(res, cls=CustomDecoder)
expect_res = {"36893488147419103232/80": {"0/16", "7/16"},
"": {"36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"}, "4/16": {"0/16"},
"3/16": {"0/16", "7/16"}, "0/16": {"6/16", "7/16"}, "1/16": {"6/16", "0/16"},
"7/16": {"6/16", "4/16"}}
self.assertEqual(res, expect_res) self.assertEqual(res, expect_res)
......
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
import os import os
import unittest import unittest
import hmac import hmac
import httplib import http.client
import base64 import base64
import hashlib import hashlib
from mock import Mock, patch from mock import Mock, patch
...@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase): ...@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase):
self.assertEqual(client1._path, "/example") self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost") self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection) self.assertIsInstance(client1._conn, http.client.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection) self.assertIsInstance(client2._conn, http.client.HTTPConnection)
def test_rpc_hello(self): def test_rpc_hello(self):
prefix = "0000000011111111" prefix = "0000000011111111"
protocol = "7" protocol = "7"
body = "a_hmac_key" body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7" query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK) response = fakeResponse(body, http.client.OK)
self.client._conn.getresponse.return_value = response self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol) res = self.client.hello(prefix, protocol)
...@@ -52,14 +52,14 @@ class TestRegistryClient(unittest.TestCase): ...@@ -52,14 +52,14 @@ class TestRegistryClient(unittest.TestCase):
self.client._hmac = None self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb") self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock() self.client.cert = Mock()
key = "this_is_a_key" key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest() h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest() key = hashlib.sha1(key).digest()
# response part # response part
body = None body = b'this is a body'
response = fakeResponse(body, httplib.NO_CONTENT) response = fakeResponse(body, http.client.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest()) response.msg = dict(Re6stHMAC=base64.b64encode(hmac.HMAC(key, body, hashlib.sha1).digest()))
self.client._conn.getresponse.return_value = response self.client._conn.getresponse.return_value = response
res = self.client.getNetworkConfig(cn) res = self.client.getNetworkConfig(cn)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import os import os
import sys import sys
import unittest import unittest
...@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True) # @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel): # def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """ # """code is 1, peer and msg not none """
# c = chr(1) # c = b"\x01"
# msg = "address" # msg = "address"
# peer = x509.Peer("000001") # peer = x509.Peer("000001")
# self.tunnel._connecting = {peer} # self.tunnel._connecting = {peer}
...@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def test_processPacket_address(self): def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none""" """code is 1, for address. And peer or msg are none"""
c = chr(1) c = b"\x01"
self.tunnel._address = {1: "1,1", 2: "2,2"} self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c) res = self.tunnel._processPacket(c)
...@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ; and each address join by ;
it will truncate address which has more than 3 element it will truncate address which has more than 3 element
""" """
c = chr(1) c = b"\x01"
peer = x509.Peer("000001") peer = x509.Peer("000001")
peer.protocol = 1 peer.protocol = 1
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
...@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none """code is 0, for network version, peer is not none
2 case, one modify the version, one not 2 case, one modify the version, one not
""" """
c = chr(0) c = b"\x00"
peer = x509.Peer("000001") peer = x509.Peer("000001")
version1 = "00003" version1 = b"00003"
version2 = "00007" version2 = b"00007"
self.tunnel._version = version3 = "00005" self.tunnel._version = version3 = b"00005"
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer) res = self.tunnel._processPacket(c + version1, peer)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import os import os
import sys import sys
import unittest import unittest
......
...@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return return
crypto.X509Cert in pem format crypto.X509Cert in pem format
""" """
if type(ca) is str: if type(ca) is bytes:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca) ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str: if type(ca_key) is bytes:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key) ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
...@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
if not_after: if not_after:
cert.set_notAfter( cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after))) time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)).encode())
else: else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration) cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject() subject = req.get_subject()
...@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial): def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr() pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial) cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f: with open(pkey_file, 'wb') as f:
f.write(pkey) f.write(pkey)
with open(cert_file, 'w') as f: with open(cert_file, 'wb') as f:
f.write(cert) f.write(cert)
return pkey, cert return pkey, cert
...@@ -84,9 +84,9 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042): ...@@ -84,9 +84,9 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert.set_pubkey(key) cert.set_pubkey(key)
cert.sign(key, "sha512") cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file: with open(pkey_file, 'wb') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file: with open(cert_file, 'wb') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert return key, cert
...@@ -101,7 +101,7 @@ def serial2prefix(serial): ...@@ -101,7 +101,7 @@ def serial2prefix(serial):
# pkey: private key # pkey: private key
def decrypt(pkey, incontent): def decrypt(pkey, incontent):
with open("node.key", 'w') as f: with open("node.key", 'w') as f:
f.write(pkey) f.write(pkey.decode())
args = "openssl rsautl -decrypt -inkey node.key".split() args = "openssl rsautl -decrypt -inkey node.key".split()
p = subprocess.Popen( p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
......
This diff is collapsed.
...@@ -7,7 +7,7 @@ class UPnPException(Exception): ...@@ -7,7 +7,7 @@ class UPnPException(Exception):
pass pass
class Forwarder(object): class Forwarder:
""" """
External port is chosen randomly between 32768 & 49151 included. External port is chosen randomly between 32768 & 49151 included.
""" """
...@@ -40,7 +40,7 @@ class Forwarder(object): ...@@ -40,7 +40,7 @@ class Forwarder(object):
def wrapper(*args, **kw): def wrapper(*args, **kw):
try: try:
return wrapped(*args, **kw) return wrapped(*args, **kw)
except Exception, e: except Exception as e:
raise UPnPException(str(e)) raise UPnPException(str(e))
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
...@@ -68,14 +68,14 @@ class Forwarder(object): ...@@ -68,14 +68,14 @@ class Forwarder(object):
else: else:
try: try:
return self._refresh() return self._refresh()
except UPnPException, e: except UPnPException as e:
logging.debug("UPnP failure", exc_info=1) logging.debug("UPnP failure", exc_info=True)
self.clear() self.clear()
try: try:
self.discover() self.discover()
self.selectigd() self.selectigd()
return self._refresh() return self._refresh()
except UPnPException, e: except UPnPException as e:
self.next_refresh = self._next_retry = time.time() + 60 self.next_refresh = self._next_retry = time.time() + 60
logging.info(str(e)) logging.info(str(e))
self.clear() self.clear()
...@@ -109,7 +109,7 @@ class Forwarder(object): ...@@ -109,7 +109,7 @@ class Forwarder(object):
try: try:
self.addportmapping(port, *args) self.addportmapping(port, *args)
break break
except UPnPException, e: except UPnPException as e:
if str(e) != 'ConflictInMappingEntry': if str(e) != 'ConflictInMappingEntry':
raise raise
port = None port = None
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback import sys, textwrap, threading, time, traceback
from collections.abc import Iterator, Mapping
# PY3: It will be even better to use Popen(pass_fds=...), HMAC_LEN = len(hashlib.sha1(b'').digest())
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1('').digest())
class ReexecException(Exception): class ReexecException(Exception):
pass pass
...@@ -37,12 +32,12 @@ class FileHandler(logging.FileHandler): ...@@ -37,12 +32,12 @@ class FileHandler(logging.FileHandler):
finally: finally:
self.lock.release() self.lock.release()
# In the rare case _reopen is set just before the lock was released # In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0): if self._reopen and self.lock.acquire(False):
self.release() self.release()
def async_reopen(self, *_): def async_reopen(self, *_):
self._reopen = True self._reopen = True
if self.lock.acquire(0): if self.lock.acquire(False):
self.release() self.release()
def setupLog(log_level, filename=None, **kw): def setupLog(log_level, filename=None, **kw):
...@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser): ...@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser):
ca /etc/re6stnet/ca.crt""", **kw) ca /etc/re6stnet/ca.crt""", **kw)
class exit(object): class exit:
status = None status = None
...@@ -150,7 +145,7 @@ class exit(object): ...@@ -150,7 +145,7 @@ class exit(object):
def handler(*args): def handler(*args):
if self.status is None: if self.status is None:
self.status = status self.status = status
if self.acquire(0): if self.acquire(False):
self.release() self.release()
for sig in sigs: for sig in sigs:
signal.signal(sig, handler) signal.signal(sig, handler)
...@@ -164,7 +159,7 @@ class Popen(subprocess.Popen): ...@@ -164,7 +159,7 @@ class Popen(subprocess.Popen):
self._args = tuple(args[0] if args else kw['args']) self._args = tuple(args[0] if args else kw['args'])
try: try:
super(Popen, self).__init__(*args, **kw) super(Popen, self).__init__(*args, **kw)
except OSError, e: except OSError as e:
if e.errno != errno.ENOMEM: if e.errno != errno.ENOMEM:
raise raise
self.returncode = -1 self.returncode = -1
...@@ -179,9 +174,9 @@ class Popen(subprocess.Popen): ...@@ -179,9 +174,9 @@ class Popen(subprocess.Popen):
self.terminate() self.terminate()
t = threading.Timer(5, self.kill) t = threading.Timer(5, self.kill)
t.start() t.start()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel() r = os.waitid(os.P_PID, self.pid, os.WNOWAIT)
r = self.wait()
t.cancel() t.cancel()
self.poll()
return r return r
...@@ -209,7 +204,7 @@ def select(R, W, T): ...@@ -209,7 +204,7 @@ def select(R, W, T):
def makedirs(*args): def makedirs(*args):
try: try:
os.makedirs(*args) os.makedirs(*args)
except OSError, e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
...@@ -240,7 +235,7 @@ def parse_address(address_list): ...@@ -240,7 +235,7 @@ def parse_address(address_list):
a = address.split(',') a = address.split(',')
int(a[1]) # Check if port is an int int(a[1]) # Check if port is an int
yield tuple(a[:4]) yield tuple(a[:4])
except ValueError, e: except ValueError as e:
logging.warning("Failed to parse node address %r (%s)", logging.warning("Failed to parse node address %r (%s)",
address, e) address, e)
...@@ -261,21 +256,21 @@ newHmacSecret = newHmacSecret() ...@@ -261,21 +256,21 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value # - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes # - the 3 first bits code the number of bytes
def packInteger(i): def packInteger(i: int) -> bytes:
for n in xrange(8): for n in range(8):
x = 32 << 8 * n x = 32 << 8 * n
if i < x: if i < x:
return struct.pack("!Q", i + n * x)[7-n:] return struct.pack("!Q", i + n * x)[7-n:]
i -= x i -= x
raise OverflowError raise OverflowError
def unpackInteger(x): def unpackInteger(x: bytes) -> tuple[int, int] | None:
n = ord(x[0]) >> 5 n = x[0] >> 5
try: try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1]) i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1])
except struct.error: except struct.error:
return return
return sum((32 << 8 * i for i in xrange(n)), return sum((32 << 8 * i for i in range(n)),
i - (n * 32 << 8 * n)), n + 1 i - (n * 32 << 8 * n)), n + 1
### ###
......
...@@ -40,4 +40,4 @@ protocol = 8 ...@@ -40,4 +40,4 @@ protocol = 8
min_protocol = 1 min_protocol = 1
if __name__ == "__main__": if __name__ == "__main__":
print version print(version)
...@@ -14,23 +14,23 @@ def subnetFromCert(cert): ...@@ -14,23 +14,23 @@ def subnetFromCert(cert):
return cert.get_subject().CN return cert.get_subject().CN
def notBefore(cert): def notBefore(cert):
return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert): def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args): def openssl(*args, fds=[]):
return utils.Popen(('openssl',) + args, return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data): def encrypt(cert, data):
r, w = os.pipe() r, w = os.pipe()
try: try:
threading.Thread(target=os.write, args=(w, cert)).start() threading.Thread(target=os.write, args=(w, cert)).start()
p = openssl('rsautl', '-encrypt', '-certin', p = openssl('rsautl', '-encrypt', '-certin',
'-inkey', '/proc/self/fd/%u' % r) '-inkey', '/proc/self/fd/%u' % r, fds=[r])
out, err = p.communicate(data) out, err = p.communicate(data)
finally: finally:
os.close(r) os.close(r)
...@@ -52,7 +52,7 @@ def maybe_renew(path, cert, info, renew, force=False): ...@@ -52,7 +52,7 @@ def maybe_renew(path, cert, info, renew, force=False):
if time.time() < next_renew: if time.time() < next_renew:
return cert, next_renew return cert, next_renew
try: try:
pem = renew() pem: bytes = renew()
if not pem or pem == crypto.dump_certificate( if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert): crypto.FILETYPE_PEM, cert):
exc_info = 0 exc_info = 0
...@@ -62,7 +62,7 @@ def maybe_renew(path, cert, info, renew, force=False): ...@@ -62,7 +62,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info = 1 exc_info = 1
break break
new_path = path + '.new' new_path = path + '.new'
with open(new_path, 'w') as f: with open(new_path, 'wb') as f:
f.write(pem) f.write(pem)
try: try:
s = os.stat(path) s = os.stat(path)
...@@ -84,19 +84,19 @@ class NewSessionError(Exception): ...@@ -84,19 +84,19 @@ class NewSessionError(Exception):
pass pass
class Cert(object): class Cert:
def __init__(self, ca, key, cert=None): def __init__(self, ca, key, cert=None):
self.ca_path = ca self.ca_path = ca
self.cert_path = cert self.cert_path = cert
self.key_path = key self.key_path = key
with open(ca) as f: with open(ca, "rb") as f:
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
with open(key) as f: with open(key, "rb") as f:
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
if cert: if cert:
with open(cert) as f: with open(cert) as f:
self.cert = self.loadVerify(f.read()) self.cert = self.loadVerify(f.read().encode())
@property @property
def prefix(self): def prefix(self):
...@@ -143,21 +143,21 @@ class Cert(object): ...@@ -143,21 +143,21 @@ class Cert(object):
"error running openssl, assuming cert is invalid") "error running openssl, assuming cert is invalid")
# BBB: With old versions of openssl, detailed # BBB: With old versions of openssl, detailed
# error is printed to standard output. # error is printed to standard output.
for err in err, out: for stream in err, out:
for x in err.splitlines(): for x in stream.decode(errors='replace').splitlines():
if x.startswith('error '): if x.startswith('error '):
x, msg = x.split(':', 1) x, msg = x.split(':', 1)
_, code, _, depth, _ = x.split(None, 4) _, code, _, depth, _ = x.split(None, 4)
raise VerifyError(int(code), int(depth), msg.strip()) raise VerifyError(int(code), int(depth), msg.strip())
return r return r
def verify(self, sign, data): def verify(self, sign: bytes, data):
crypto.verify(self.ca, sign, data, 'sha512') crypto.verify(self.ca, sign, data, 'sha512')
def sign(self, data): def sign(self, data) -> bytes:
return crypto.sign(self.key, data, 'sha512') return crypto.sign(self.key, data, 'sha512')
def decrypt(self, data): def decrypt(self, data: bytes) -> bytes:
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path) p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data) out, err = p.communicate(data)
if p.returncode: if p.returncode:
...@@ -166,7 +166,7 @@ class Cert(object): ...@@ -166,7 +166,7 @@ class Cert(object):
def verifyVersion(self, version): def verifyVersion(self, version):
try: try:
n = 1 + (ord(version[0]) >> 5) n = 1 + (version[0] >> 5)
self.verify(version[n:], version[:n]) self.verify(version[n:], version[:n])
except (IndexError, crypto.Error): except (IndexError, crypto.Error):
raise VerifyError(None, None, 'invalid network version') raise VerifyError(None, None, 'invalid network version')
...@@ -175,7 +175,7 @@ class Cert(object): ...@@ -175,7 +175,7 @@ class Cert(object):
PACKED_PROTOCOL = utils.packInteger(protocol) PACKED_PROTOCOL = utils.packInteger(protocol)
class Peer(object): class Peer:
""" """
UDP: A ─────────────────────────────────────────────> B UDP: A ─────────────────────────────────────────────> B
...@@ -206,9 +206,9 @@ class Peer(object): ...@@ -206,9 +206,9 @@ class Peer(object):
_key = newHmacSecret() _key = newHmacSecret()
serial = None serial = None
stop_date = float('inf') stop_date = float('inf')
version = '' version = b''
def __init__(self, prefix): def __init__(self, prefix: str):
self.prefix = prefix self.prefix = prefix
@property @property
...@@ -229,11 +229,11 @@ class Peer(object): ...@@ -229,11 +229,11 @@ class Peer(object):
try: try:
# Always assume peer is not old, in case it has just upgraded, # Always assume peer is not old, in case it has just upgraded,
# else we would be stuck with the old protocol. # else we would be stuck with the old protocol.
msg = ('\0\0\0\1' msg = (b'\0\0\0\1'
+ PACKED_PROTOCOL + PACKED_PROTOCOL
+ fingerprint(self.cert).digest()) + fingerprint(self.cert).digest())
except AttributeError: except AttributeError:
msg = '\0\0\0\0' msg = b'\0\0\0\0'
return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
def hello0Sent(self): def hello0Sent(self):
...@@ -246,13 +246,13 @@ class Peer(object): ...@@ -246,13 +246,13 @@ class Peer(object):
self._i = self._j = 2 self._i = self._j = 2
self._last = 0 self._last = 0
self.protocol = protocol self.protocol = protocol
return ''.join(('\0\0\0\2', PACKED_PROTOCOL if protocol else '', return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h))) h, cert.sign(h)))
def _hmac(self, msg): def _hmac(self, msg):
return hmac.HMAC(self._key, msg, hashlib.sha1).digest() return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key, protocol): def newSession(self, key: bytes, protocol):
if key <= self._key: if key <= self._key:
raise NewSessionError(self._key, key) raise NewSessionError(self._key, key)
self._key = key self._key = key
...@@ -279,10 +279,12 @@ class Peer(object): ...@@ -279,10 +279,12 @@ class Peer(object):
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno: if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None self._last = None
self._i = seqno self._i = seqno
return msg[4:i] return msg[4:i].decode()
def encode(self, msg, _pack=seqno_struct.pack): def encode(self, msg, _pack=seqno_struct.pack):
self._j += 1 self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg msg = _pack(self._j) + msg
return msg + self._hmac(msg) return msg + self._hmac(msg)
......
...@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py ...@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py
from distutils import log from distutils import log
version = {"__file__": "re6st/version.py"} version = {"__file__": "re6st/version.py"}
execfile(version["__file__"], version) with open(version["__file__"]) as f:
code = compile(f.read(), version["__file__"], 'exec')
exec(code, version)
def copy_file(self, infile, outfile, *args, **kw): def copy_file(self, infile, outfile, *args, **kw):
if infile == version["__file__"]: if infile == version["__file__"]:
if not self.dry_run: if not self.dry_run:
log.info("generating %s -> %s", infile, outfile) log.info("generating %s -> %s", infile, outfile)
with open(outfile, "wb") as f: with open(outfile, "w") as f:
for x in sorted(version.iteritems()): for x in sorted(version.items()):
if not x[0].startswith("_"): if not x[0].startswith("_"):
f.write("%s = %r\n" % x) f.write("%s = %r\n" % x)
return outfile, 1 return outfile, 1
elif isinstance(self, build_py) and \ elif isinstance(self, build_py) and \
os.stat(infile).st_mode & stat.S_IEXEC: os.stat(infile).st_mode & stat.S_IEXEC:
if os.path.isdir(infile) and os.path.isdir(outfile): if os.path.isdir(infile) and os.path.isdir(outfile):
return (outfile, 0) return outfile, 0
# Adjust interpreter of OpenVPN hooks. # Adjust interpreter of OpenVPN hooks.
with open(infile) as src: with open(infile) as src:
first_line = src.readline() first_line = src.readline()
...@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw): ...@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw):
patched += src.read() patched += src.read()
dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC) dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
try: try:
os.write(dst, patched) os.write(dst, patched.encode())
finally: finally:
os.close(dst) os.close(dst)
return outfile, 1 return outfile, 1
...@@ -51,7 +53,8 @@ Environment :: Console ...@@ -51,7 +53,8 @@ Environment :: Console
License :: OSI Approved :: GNU General Public License (GPL) License :: OSI Approved :: GNU General Public License (GPL)
Natural Language :: English Natural Language :: English
Operating System :: POSIX :: Linux Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7 Programming Language :: Python :: 3
Programming Language :: Python :: 3.11
Topic :: Internet Topic :: Internet
Topic :: System :: Networking Topic :: System :: Networking
""" """
...@@ -73,6 +76,7 @@ setup( ...@@ -73,6 +76,7 @@ setup(
license = 'GPL 2+', license = 'GPL 2+',
platforms = ["any"], platforms = ["any"],
classifiers=classifiers.splitlines(), classifiers=classifiers.splitlines(),
python_requires = '>=3.11',
long_description = ".. contents::\n\n" + open('README.rst').read() long_description = ".. contents::\n\n" + open('README.rst').read()
+ "\n" + open('CHANGES.rst').read() + git_rev, + "\n" + open('CHANGES.rst').read() + git_rev,
packages = find_packages(), packages = find_packages(),
...@@ -95,7 +99,7 @@ setup( ...@@ -95,7 +99,7 @@ setup(
extras_require = { extras_require = {
'geoip': ['geoip2'], 'geoip': ['geoip2'],
'multicast': ['PyYAML'], 'multicast': ['PyYAML'],
'test': ['mock', 'pathlib2', 'nemu', 'python-unshare', 'python-passfd', 'multiping'] 'test': ['mock', 'nemu3', 'unshare', 'multiping']
}, },
#dependency_links = [ #dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7", # "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment