Commit eb07b358 authored by Tom Niget's avatar Tom Niget Committed by Tom Niget

python3: add type annotations

parent 389a9a01
......@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
from re6st import utils, x509
from OpenSSL import crypto
with open("/etc/re6stnet/ca.crt") as f:
with open("/etc/re6stnet/ca.crt", "rb") as f:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
network = x509.networkFromCa(ca)
......
......@@ -5,7 +5,7 @@ from . import utils, version, x509
class Cache:
def __init__(self, db_path, registry, cert: x509.Cert, db_size=200):
def __init__(self, db_path: str, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix
self._db_size = db_size
self._decrypt = cert.decrypt
......@@ -50,7 +50,7 @@ class Cache:
self.warnProtocol()
logging.info("Cache initialized.")
def _open(self, path):
def _open(self, path: str) -> sqlite3.Connection:
db = sqlite3.connect(path, isolation_level=None)
db.text_factory = str
db.execute("PRAGMA synchronous = OFF")
......@@ -147,7 +147,7 @@ class Cache:
logging.warning("There's a new version of re6stnet:"
" you should update.")
def getDh(self, path):
def getDh(self, path: str):
# We'd like to do a full check here but
# from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
......@@ -179,11 +179,11 @@ class Cache:
logging.trace("- %s: %s%s", prefix, address,
' (blacklisted)' if _try else '')
def cacheMinimize(self, size):
def cacheMinimize(self, size: int):
with self._db:
self._cacheMinimize(size)
def _cacheMinimize(self, size):
def _cacheMinimize(self, size: int):
a = self._db.execute(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
(size,)).fetchall()
......@@ -192,26 +192,26 @@ class Cache:
q("DELETE FROM peer WHERE prefix IN (?)", a)
q("DELETE FROM volatile.stat WHERE peer IN (?)", a)
def connecting(self, prefix, connecting):
def connecting(self, prefix: str, connecting: bool):
self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
(connecting, prefix))
def resetConnecting(self):
self._db.execute("UPDATE volatile.stat SET try=0")
def getAddress(self, prefix):
def getAddress(self, prefix: str) -> bool:
r = self._db.execute("SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0",
(prefix,)).fetchone()
return r and r[0]
@property
def my_address(self):
def my_address(self) -> str:
for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
return x
@my_address.setter
def my_address(self, value):
def my_address(self, value: str):
if value:
with self._db as db:
db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
......@@ -229,14 +229,14 @@ class Cache:
# IOW, one should probably always put our own address there.
_get_peer_sql = "SELECT %s FROM peer, volatile.stat" \
" WHERE prefix=peer AND prefix!=? AND try=?"
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
def getPeerList(self, failed=False, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"):
return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int:
def getPeerCount(self, failed=False, __sql=_get_peer_sql % "COUNT(*)") -> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self):
def getBootstrapPeer(self) -> tuple[str, str]:
logging.info('Getting Boot peer...')
try:
bootpeer = self._registry.getBootstrapPeer(self._prefix)
......@@ -250,7 +250,7 @@ class Cache:
return prefix, address
logging.warning('Buggy registry sent us our own address')
def addPeer(self, prefix, address, set_preferred=False):
def addPeer(self, prefix: str, address: str, set_preferred=False):
logging.debug('Adding peer %s: %s', prefix, address)
with self._db:
q = self._db.execute
......@@ -274,7 +274,7 @@ class Cache:
q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
def getCountry(self, ip):
def getCountry(self, ip: str) -> str:
try:
return self._registry.getCountry(self._prefix, ip).decode()
except socket.error as e:
......
......@@ -272,7 +272,7 @@ def main():
call(args)
args[3] = 'del'
cleanup.append(lambda: subprocess.call(args))
def ip(object, *args):
def ip(object: str, *args):
args = ['ip', '-6', object, 'add'] + list(args)
call(args)
args[3] = 'del'
......
......@@ -171,7 +171,7 @@ class Babel:
_decode = None
def __init__(self, socket_path, handler, network):
def __init__(self, socket_path: str, handler, network: str):
self.socket_path = socket_path
self.handler = handler
self.network = network
......@@ -252,15 +252,18 @@ class Babel:
unidentified = set(n)
self.neighbours = neighbours = {}
a = len(self.network)
logging.info("Routes: %r", routes)
for route in routes:
assert route.flags & 1, route # installed
if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
logging.warning("Ignoring IPv4 route: %r", route)
continue
assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex
neigh_routes = n[address]
ip = utils.binFromRawIp(route.prefix)
if ip[:a] == self.network:
logging.debug("Route is on the network: %r", route)
prefix = ip[a:route.plen]
if prefix and not route.refmetric:
neighbours[prefix] = neigh_routes
......@@ -275,7 +278,9 @@ class Babel:
socket.inet_ntop(socket.AF_INET6, route.prefix),
route.plen)
else:
logging.debug("Route is not on the network: %r", route)
prefix = None
logging.debug("Adding route %r to %r", route, neigh_routes)
neigh_routes[1][prefix] = route
self.locked.clear()
if unidentified:
......@@ -299,7 +304,7 @@ class iterRoutes:
_waiting = True
def __new__(cls, control_socket, network):
def __new__(cls, control_socket: str, network: str):
self = object.__new__(cls)
c = Babel(control_socket, self, network)
c.request_dump()
......
......@@ -3,30 +3,30 @@ import errno, os, socket, stat, threading
class Socket:
def __init__(self, socket):
def __init__(self, socket: socket.socket):
# In case that the default timeout is not None.
socket.settimeout(None)
self._socket = socket
self._buf = ''
self._buf = b''
def close(self):
self._socket.close()
def write(self, data):
def write(self, data: bytes):
self._socket.send(data)
def readline(self):
def readline(self) -> bytes:
recv = self._socket.recv
data = self._buf
while True:
i = 1 + data.find('\n')
i = 1 + data.find(b'\n')
if i:
self._buf = data[i:]
return data[:i]
d = recv(4096)
data += d
if not d:
self._buf = ''
self._buf = b''
return data
def flush(self):
......
......@@ -8,7 +8,7 @@ ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw):
def openvpn(iface: str, encrypt, *args, **kw) -> utils.Popen:
args = ['openvpn',
'--dev-type', 'tap',
'--dev', iface,
......@@ -28,7 +28,7 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450}
def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
def server(iface: str, max_clients: int, dh_path: str, fd: int, port: int, proto: str, encrypt: bool, *args, **kw) -> utils.Popen:
if proto == 'udp':
proto = 'udp4'
client_script = '%s %s' % (ovpn_server, fd)
......@@ -49,7 +49,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
*args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw):
def client(iface: str, address_list: list[tuple[str, int, str]], encrypt: bool, *args, **kw) -> utils.Popen:
remote = ['--nobind', '--client']
# XXX: We'd like to pass <connection> sections at command-line.
link_mtu = set()
......@@ -65,8 +65,8 @@ def client(iface, address_list, encrypt, *args, **kw):
return openvpn(iface, encrypt, *remote, **kw)
def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
control_socket, default, hmac, *args, **kw):
def router(ip: tuple[str, int], ip4, rt6: tuple[str, bool, bool], hello_interval: int, log_path: str, state_path: str, pidfile: str,
control_socket: str, default: str, hmac: tuple[bytes | None, bytes | None], *args, **kw) -> utils.Popen:
network, gateway, has_ipv6_subtrees = rt6
network_mask = int(network[network.index('/')+1:])
ip, n = ip
......@@ -83,7 +83,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign:
def key(cmd, id: str, value):
def key(cmd: list[str], id: str, value: bytes):
cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign)
......
This diff is collapsed.
......@@ -11,11 +11,13 @@ import hashlib
import time
import tempfile
from argparse import Namespace
from sqlite3 import Cursor
from OpenSSL import crypto
from mock import Mock, patch
from pathlib import Path
from re6st import registry
from re6st import registry, x509
from re6st.tests.tools import *
from re6st.tests import DEMO_PATH
......@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
def load_config(filename="registry.json"):
def load_config(filename: str="registry.json") -> Namespace:
with open(filename) as f:
config = json.load(f)
config["dh"] = DEMO_PATH / "dh2048.pem"
......@@ -37,13 +39,13 @@ def load_config(filename="registry.json"):
return Namespace(**config)
def get_cert(cur, prefix):
def get_cert(cur: Cursor, prefix: str):
res = cur.execute(
"SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone()
return res[0]
def insert_cert(cur, ca, prefix, not_after=None, email=None):
def insert_cert(cur: Cursor, ca: x509.Cert, prefix: str, not_after=None, email=None):
key, csr = generate_csr()
cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after)
cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert))
......@@ -54,7 +56,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial = 0
def delete_cert(cur, prefix):
def delete_cert(cur: Cursor, prefix: str):
cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,))
......
......@@ -92,18 +92,15 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
return key, cert
def prefix2cn(prefix):
def prefix2cn(prefix: str) -> str:
return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial):
def serial2prefix(serial: int) -> str:
return bin(serial)[2:].rjust(16, '0')
# pkey: private key
def decrypt(pkey, incontent):
with open("node.key", 'w') as f:
f.write(pkey.decode())
def decrypt(pkey: bytes, incontent: bytes) -> bytes:
with open("node.key", 'wb') as f:
f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split()
with subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
outcontent, err = p.communicate(incontent)
return outcontent
return subprocess.run(args, input=incontent, stdout=subprocess.PIPE).stdout
......@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
import subprocess, struct, sys, time, weakref
from collections import defaultdict, deque
from bisect import bisect, insort
from collections.abc import Iterator, Sequence
from typing import Callable, TYPE_CHECKING
from OpenSSL import crypto
from . import ctl, plib, utils, version, x509
if TYPE_CHECKING:
from . import cache
PORT = 326
......@@ -21,7 +26,7 @@ proto_dict = {
proto_dict['tcp'] = proto_dict['tcp4']
proto_dict['udp'] = proto_dict['udp4']
def resolve(ip, port, proto):
def resolve(ip, port, proto: str) -> tuple[socket.AddressFamily | None, Iterator[str]]:
try:
family, proto = proto_dict[proto]
except KeyError:
......@@ -31,16 +36,16 @@ def resolve(ip, port, proto):
class MultiGatewayManager(dict):
def __init__(self, gateway):
def __init__(self, gateway: Callable[[str], str]):
self._gw = gateway
def _route(self, cmd, dest, gw):
def _route(self, cmd: str, dest: str, gw: str):
if gw:
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
logging.trace('%r', cmd)
subprocess.check_call(cmd)
def add(self, dest, route):
def add(self, dest: str, route: bool):
try:
self[dest][1] += 1
except KeyError:
......@@ -48,7 +53,7 @@ class MultiGatewayManager(dict):
self[dest] = [gw, 0]
self._route('add', dest, gw)
def remove(self, dest):
def remove(self, dest: str):
gw, count = self[dest]
if count:
self[dest][1] = count - 1
......@@ -65,7 +70,7 @@ class Connection:
serial = None
time = float('inf')
def __init__(self, tunnel_manager, address_list, iface, prefix):
def __init__(self, tunnel_manager: "TunnelManager", address_list, iface, prefix):
self.tunnel_manager = tunnel_manager
self.address_list = address_list
self.iface = iface
......@@ -109,7 +114,7 @@ class Connection:
if i:
cache.addPeer(self._prefix, ','.join(self.address_list[i]), True)
else:
cache.connecting(self._prefix, 0)
cache.connecting(self._prefix, False)
def close(self):
try:
......@@ -198,7 +203,7 @@ class BaseTunnelManager:
_geoiplookup = None
_forward = None
def __init__(self, control_socket, cache, cert, conf_country, address=()):
def __init__(self, control_socket, cache: "cache.Cache", cert: x509.Cert, conf_country, address=()):
self.cert = cert
self._network = cert.network
self._prefix = cert.prefix
......@@ -450,7 +455,7 @@ class BaseTunnelManager:
self._sendto(to, msg[0:1] + answer.encode() if answer else b'', peer)
def _processPacket(self, msg, peer=None):
def _processPacket(self, msg: bytes, peer: x509.Peer|str=None):
c = msg[0]
msg = msg[1:]
code = c & 0x7f
......@@ -564,12 +569,12 @@ class BaseTunnelManager:
self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart)
def handleServerEvent(self, sock):
def handleServerEvent(self, sock: socket.socket):
event, args = eval(sock.recv(65536))
logging.debug("%s%r", event, args)
r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args)
if r is not None:
sock.send(chr(r))
sock.send(bytes([r]))
def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip):
if serial in self.cache.crl:
......@@ -581,7 +586,7 @@ class BaseTunnelManager:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix)
self.cache.connecting(prefix, 0)
self.cache.connecting(prefix, False)
return True
def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip):
......@@ -665,7 +670,7 @@ class TunnelManager(BaseTunnelManager):
def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, conf_country, address,
ip_changed, remote_gateway, disable_proto, neighbour_list=()):
ip_changed, remote_gateway: Callable[[str], str], disable_proto: Sequence[str], neighbour_list=()):
super(TunnelManager, self).__init__(control_socket,
cache, cert, conf_country, address)
self.ovpn_args = openvpn_args
......@@ -877,7 +882,7 @@ class TunnelManager(BaseTunnelManager):
address_list.append((ip, x[1], x[2]))
continue
address_list.append(x[:3])
self.cache.connecting(prefix, 1)
self.cache.connecting(prefix, True)
if not address_list:
return False
logging.info('Establishing a connection with %u/%u',
......
......@@ -17,7 +17,7 @@ class Forwarder:
_lcg_n = 0
@classmethod
def _getExternalPort(cls):
def _getExternalPort(cls) -> int:
# Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without
......@@ -35,7 +35,7 @@ class Forwarder:
self._u.discoverdelay = 200
self._rules = []
def __getattr__(self, name):
def __getattr__(self, name: str):
wrapped = getattr(self._u, name)
def wrapper(*args, **kw):
try:
......
......@@ -40,7 +40,7 @@ class FileHandler(logging.FileHandler):
if self.lock.acquire(False):
self.release()
def setupLog(log_level, filename=None, **kw):
def setupLog(log_level: int, filename: str | None=None, **kw):
if log_level and filename:
makedirs(os.path.dirname(filename))
handler = FileHandler(filename)
......@@ -184,7 +184,7 @@ def setCloexec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def select(R, W, T):
def select(R: Mapping, W: Mapping, T):
try:
r, w, _ = _select.select(R, W, (),
max(0, min(T)[0] - time.time()) if T else None)
......@@ -208,15 +208,15 @@ def makedirs(*args):
if e.errno != errno.EEXIST:
raise
def binFromIp(ip):
def binFromIp(ip: str) -> str:
return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip))
def binFromRawIp(ip):
def binFromRawIp(ip: bytes) -> str:
ip1, ip2 = struct.unpack('>QQ', ip)
return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
def ipFromBin(ip, suffix=''):
def ipFromBin(ip: str, suffix='') -> str:
suffix_len = 128 - len(ip)
if suffix_len > 0:
ip += suffix.rjust(suffix_len, '0')
......@@ -225,11 +225,11 @@ def ipFromBin(ip, suffix=''):
return socket.inet_ntop(socket.AF_INET6,
struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2)))
def dump_address(address):
def dump_address(address: str) -> str:
return ';'.join(map(','.join, address))
# Yield ip, port, protocol, and country if it is in the address
def parse_address(address_list):
def parse_address(address_list: str) -> Iterator[tuple[str, str, str, str]]:
for address in address_list.split(';'):
try:
a = address.split(',')
......@@ -239,16 +239,18 @@ def parse_address(address_list):
logging.warning("Failed to parse node address %r (%s)",
address, e)
def binFromSubnet(subnet):
def binFromSubnet(subnet: str) -> str:
p, l = subnet.split('/')
return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret():
def _newHmacSecret():
from random import getrandbits as g
pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret()
newHmacSecret = _newHmacSecret() # https://github.com/python/mypy/issues/1174
### Integer serialization
# - supports values from 0 to 0x202020202020201f
......
# -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from typing import Callable, Any
from OpenSSL import crypto
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
......@@ -9,29 +11,29 @@ from cryptography.x509 import load_pem_x509_certificate
from . import utils
from .version import protocol
def newHmacSecret():
def newHmacSecret() -> bytes:
return utils.newHmacSecret(int(time.time() * 1000000))
def networkFromCa(ca):
def networkFromCa(ca: crypto.X509) -> str:
# TODO: will be ca.serial_number after migration to cryptography
return bin(ca.get_serial_number())[3:]
def subnetFromCert(cert):
def subnetFromCert(cert: crypto.X509) -> str:
return cert.get_subject().CN
def notBefore(cert):
def notBefore(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert):
def notAfter(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args, fds=[]):
def openssl(*args: str, fds=[]) -> utils.Popen:
return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data):
def encrypt(cert: bytes, data: bytes) -> bytes:
r, w = os.pipe()
try:
threading.Thread(target=os.write, args=(w, cert)).start()
......@@ -45,10 +47,10 @@ def encrypt(cert, data):
raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out
def fingerprint(cert, alg='sha1'):
def fingerprint(cert: crypto.X509, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew, force=False):
def maybe_renew(path: str, cert: crypto.X509, info: str, renew: Callable[[], bytes], force=False) -> tuple[crypto.X509, int]:
from .registry import RENEW_PERIOD
while True:
if force:
......@@ -58,7 +60,7 @@ def maybe_renew(path, cert, info, renew, force=False):
if time.time() < next_renew:
return cert, next_renew
try:
pem: bytes = renew()
pem = renew()
if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert):
exc_info = 0
......@@ -92,7 +94,7 @@ class NewSessionError(Exception):
class Cert:
def __init__(self, ca, key, cert=None):
def __init__(self, ca: str, key: str, cert: str | None=None):
self.ca_path = ca
self.cert_path = cert
self.key_path = key
......@@ -110,24 +112,24 @@ class Cert:
self.cert = self.loadVerify(f.read().encode())
@property
def prefix(self):
def prefix(self) -> str:
return utils.binFromSubnet(subnetFromCert(self.cert))
@property
def network(self):
def network(self) -> str:
return networkFromCa(self.ca)
@property
def subject_serial(self):
def subject_serial(self) -> int:
return int(self.cert.get_subject().serialNumber)
@property
def openvpn_args(self):
def openvpn_args(self) -> tuple[str, ...]:
return ('--ca', self.ca_path,
'--cert', self.cert_path,
'--key', self.key_path)
def maybeRenew(self, registry, crl):
def maybeRenew(self, registry, crl) -> int:
self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl)
......@@ -163,7 +165,6 @@ class Cert:
return r
def verify(self, sign: bytes, data: bytes):
assert isinstance(data, bytes)
pub_key = self.ca_crypto.public_key()
pub_key.verify(
sign,
......@@ -173,7 +174,6 @@ class Cert:
)
def sign(self, data: bytes) -> bytes:
assert isinstance(data, bytes)
return self.key_crypto.sign(
data,
padding.PKCS1v15(),
......@@ -230,6 +230,7 @@ class Peer:
serial = None
stop_date = float('inf')
version = b''
cert: crypto.X509
def __init__(self, prefix: str):
self.prefix = prefix
......@@ -247,7 +248,7 @@ class Peer:
def __lt__(self, other):
return self.prefix < (other if type(other) is str else other.prefix)
def hello0(self, cert):
def hello0(self, cert: crypto.X509) -> bytes:
if self._hello < time.time():
try:
# Always assume peer is not old, in case it has just upgraded,
......@@ -262,7 +263,7 @@ class Peer:
def hello0Sent(self):
self._hello = time.time() + 60
def hello(self, cert, protocol):
def hello(self, cert: Cert, protocol: int) -> bytes:
key = self._key = newHmacSecret()
h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
key)
......@@ -272,10 +273,10 @@ class Peer:
return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h)))
def _hmac(self, msg):
def _hmac(self, msg: bytes) -> bytes:
return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key: bytes, protocol):
def newSession(self, key: bytes, protocol: int):
if key <= self._key:
raise NewSessionError(self._key, key)
self._key = key
......@@ -283,12 +284,12 @@ class Peer:
self._last = None
self.protocol = protocol
def verify(self, sign, data):
def verify(self, sign: bytes, data: bytes):
crypto.verify(self.cert, sign, data, 'sha512')
seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack):
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> tuple[int, bytes, int | None] | bytes:
seqno, = _unpack(msg[:4])
if seqno <= 2:
msg = msg[4:]
......@@ -302,9 +303,9 @@ class Peer:
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None
self._i = seqno
return msg[4:i].decode()
return msg[4:i]
def encode(self, msg, _pack=seqno_struct.pack):
def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1
if type(msg) is str:
msg = msg.encode()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment