Commit 0fc6e204 authored by Tom Niget's avatar Tom Niget

more migration

parent 66f2c73b
...@@ -4,6 +4,8 @@ import socket, sqlite3, subprocess, sys, time, weakref ...@@ -4,6 +4,8 @@ import socket, sqlite3, subprocess, sys, time, weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from threading import Thread from threading import Thread
from typing import Optional
IPTABLES = 'iptables' IPTABLES = 'iptables'
SCREEN = 'screen' SCREEN = 'screen'
VERBOSE = 4 VERBOSE = 4
...@@ -16,6 +18,8 @@ CA_DAYS = 1000 ...@@ -16,6 +18,8 @@ CA_DAYS = 1000
# Quick check to avoid wasting time if there is an error. # Quick check to avoid wasting time if there is an error.
for x in 're6stnet', 're6st-conf', 're6st-registry': for x in 're6stnet', 're6st-conf', 're6st-registry':
subprocess.check_call(('./py', x, '--help'), stdout=subprocess.DEVNULL) subprocess.check_call(('./py', x, '--help'), stdout=subprocess.DEVNULL)
# #
# Underlying network: # Underlying network:
# #
...@@ -45,59 +49,85 @@ for x in 're6stnet', 're6st-conf', 're6st-registry': ...@@ -45,59 +49,85 @@ for x in 're6stnet', 're6st-conf', 're6st-registry':
def disable_signal_on_children(sig): def disable_signal_on_children(sig):
pid = os.getpid() pid = os.getpid()
sigint = signal.signal(sig, lambda *x: os.getpid() == pid and sigint(*x)) sigint = signal.signal(sig, lambda *x: os.getpid() == pid and sigint(*x))
disable_signal_on_children(signal.SIGINT) disable_signal_on_children(signal.SIGINT)
Node__add_interface = nemu.Node._add_interface Node__add_interface = nemu.Node._add_interface
def _add_interface(node, iface): def _add_interface(node, iface):
iface.__dict__['node'] = weakref.proxy(node) iface.__dict__['node'] = weakref.proxy(node)
return Node__add_interface(node, iface) return Node__add_interface(node, iface)
nemu.Node._add_interface = _add_interface nemu.Node._add_interface = _add_interface
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('port', type = int, parser.add_argument('port', type=int,
help = 'port used to display tunnels') help='port used to display tunnels')
parser.add_argument('-d', '--duration', type = int, parser.add_argument('-d', '--duration', type=int,
help = 'time of the demo execution in seconds') help='time of the demo execution in seconds')
parser.add_argument('-p', '--ping', action = 'store_true', parser.add_argument('-p', '--ping', action='store_true',
help = 'execute ping utility') help='execute ping utility')
parser.add_argument('-m', '--hmac', action = 'store_true', parser.add_argument('-m', '--hmac', action='store_true',
help = 'execute HMAC test') help='execute HMAC test')
args = parser.parse_args() args = parser.parse_args()
def handler(signum, frame): def handler(signum, frame):
sys.exit() sys.exit()
if args.duration: if args.duration:
signal.signal(signal.SIGALRM, handler) signal.signal(signal.SIGALRM, handler)
signal.alarm(args.duration) signal.alarm(args.duration)
exec(compile(open("fixnemu.py", "rb").read(), "fixnemu.py", 'exec')) exec(compile(open("fixnemu.py", "rb").read(), "fixnemu.py", 'exec'))
class Re6stNode(nemu.Node):
name: str
short: str
re6st_cmdline: Optional[str]
def __init__(self, name, short):
super().__init__()
self.name = name
self.short = short
self.Popen(('sysctl', '-q',
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait()
self._screen = self.Popen((SCREEN, '-DmS', name))
self.screen = (lambda name: lambda *cmd:
subprocess.call([SCREEN, '-r', name, '-X', 'eval'] + list(map(
"""screen sh -c 'set %s; "\$@"; echo "\$@"; exec $SHELL'"""
.__mod__, cmd))))(name)
self.re6st_cmdline = None
# create nodes # create nodes
for name in """internet=I registry=R internet = Re6stNode('internet', 'I')
gateway1=g1 machine1=1 machine2=2 registry = Re6stNode('registry', 'R')
gateway2=g2 machine3=3 machine4=4 machine5=5 gateway1 = Re6stNode('gateway1', 'g1')
machine6=6 machine7=7 machine8=8 machine9=9 machine1 = Re6stNode('machine1', '1')
registry2=R2 machine10=10 machine2 = Re6stNode('machine2', '2')
""".split(): gateway2 = Re6stNode('gateway2', 'g2')
name, short = name.split('=') machine3 = Re6stNode('machine3', '3')
globals()[name] = node = nemu.Node() machine4 = Re6stNode('machine4', '4')
node.name = name machine5 = Re6stNode('machine5', '5')
node.short = short machine6 = Re6stNode('machine6', '6')
node.Popen(('sysctl', '-q', machine7 = Re6stNode('machine7', '7')
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait() machine8 = Re6stNode('machine8', '8')
node._screen = node.Popen((SCREEN, '-DmS', name)) machine9 = Re6stNode('machine9', '9')
node.screen = (lambda name: lambda *cmd: registry2 = Re6stNode('registry2', 'R2')
subprocess.call([SCREEN, '-r', name, '-X', 'eval'] + list(map( machine10 = Re6stNode('machine10', '10')
"""screen sh -c 'set %s; "\$@"; echo "\$@"; exec $SHELL'"""
.__mod__, cmd))))(name)
# create switch # create switch
switch1 = nemu.Switch() switch1 = nemu.Switch()
switch2 = nemu.Switch() switch2 = nemu.Switch()
switch3 = nemu.Switch() switch3 = nemu.Switch()
#create interfaces # create interfaces
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet) re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1) in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1)
in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2) in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2)
...@@ -178,12 +208,14 @@ m6_if_0.add_v6_address(address='fc42:6::1', prefix_len=16) ...@@ -178,12 +208,14 @@ m6_if_0.add_v6_address(address='fc42:6::1', prefix_len=16)
m7_if_0.add_v6_address(address='fc42:7::1', prefix_len=16) m7_if_0.add_v6_address(address='fc42:7::1', prefix_len=16)
m8_if_0.add_v6_address(address='fc42:8::1', prefix_len=16) m8_if_0.add_v6_address(address='fc42:8::1', prefix_len=16)
def add_llrtr(iface, peer, dst='default'): def add_llrtr(iface, peer, dst='default'):
for a in peer.get_addresses(): for a in peer.get_addresses():
a = a['address'] a = a['address']
if a.startswith('fe80:'): if a.startswith('fe80:'):
return iface.node.Popen(('ip', 'route', 'add', dst, 'via', a, return iface.node.Popen(('ip', 'route', 'add', dst, 'via', a,
'proto', 'static', 'dev', iface.name)).wait() 'proto', 'static', 'dev', iface.name)).wait()
# setup routes # setup routes
add_llrtr(re_if_0, in_if_0) add_llrtr(re_if_0, in_if_0)
...@@ -211,11 +243,13 @@ for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3': ...@@ -211,11 +243,13 @@ for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3':
else: else:
print("Connectivity IPv4 OK!") print("Connectivity IPv4 OK!")
nodes = [] nodes: list[Re6stNode] = []
gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid' gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid'
' -a %s -i %s' % (g1_if_1.name, g1_if_0_name)) ' -a %s -i %s' % (g1_if_1.name, g1_if_0_name))
@contextmanager @contextmanager
def new_network(registry, reg_addr, serial, ca): def new_network(registry: Re6stNode, reg_addr: str, serial: str, ca: str):
from OpenSSL import crypto from OpenSSL import crypto
import hashlib, sqlite3 import hashlib, sqlite3
os.path.exists(ca) or subprocess.check_call( os.path.exists(ca) or subprocess.check_call(
...@@ -228,11 +262,11 @@ def new_network(registry, reg_addr, serial, ca): ...@@ -228,11 +262,11 @@ def new_network(registry, reg_addr, serial, ca):
fingerprint = "sha256:" + hashlib.sha256( fingerprint = "sha256:" + hashlib.sha256(
crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest() crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest()
db_path = "%s/registry.db" % registry.name db_path = "%s/registry.db" % registry.name
registry.screen("./py re6st-registry @%s/re6st-registry.conf" registry.screen("\"%s\" ./py re6st-registry @%s/re6st-registry.conf"
" --db %s --mailhost %s -v%u" " --db %s --mailhost %s -v%u"
% (registry.name, db_path, os.path.abspath('mbox'), VERBOSE)) % (sys.executable, registry.name, db_path, os.path.abspath('mbox'), VERBOSE))
registry_url = 'http://%s/' % reg_addr registry_url = 'http://%s/' % reg_addr
registry.Popen(('python', '-c', """if 1: registry.Popen((sys.executable, '-c', """if 1:
import socket, time import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True: while True:
...@@ -243,16 +277,17 @@ def new_network(registry, reg_addr, serial, ca): ...@@ -243,16 +277,17 @@ def new_network(registry, reg_addr, serial, ca):
time.sleep(.1) time.sleep(.1)
""")).wait() """)).wait()
db = sqlite3.connect(db_path, isolation_level=None) db = sqlite3.connect(db_path, isolation_level=None)
def new_node(node, folder, args='', prefix_len=None, registry=registry_url):
def new_node(node: Re6stNode, folder: str, args='', prefix_len: Optional[int] = None, registry=registry_url):
nodes.append(node) nodes.append(node)
if not os.path.exists(folder + '/cert.crt'): if not os.path.exists(folder + '/cert.crt'):
dh_path = folder + '/dh2048.pem' dh_path = folder + '/dh2048.pem'
if not os.path.exists(dh_path): if not os.path.exists(dh_path):
os.symlink('../dh2048.pem', dh_path) os.symlink('../dh2048.pem', dh_path)
email = node.name + '@example.com' email = node.name + '@example.com'
p = node.Popen(('../py', 're6st-conf', '--registry', registry, p = node.Popen((sys.executable, '../py', 're6st-conf', '--registry', registry,
'--email', email, '--fingerprint', fingerprint), '--email', email, '--fingerprint', fingerprint),
stdin=subprocess.PIPE, cwd=folder) stdin=subprocess.PIPE, cwd=folder)
token = None token = None
while not token: while not token:
time.sleep(.1) time.sleep(.1)
...@@ -265,14 +300,16 @@ def new_network(registry, reg_addr, serial, ca): ...@@ -265,14 +300,16 @@ def new_network(registry, reg_addr, serial, ca):
os.remove(dh_path) os.remove(dh_path)
os.remove(folder + '/ca.crt') os.remove(folder + '/ca.crt')
node.re6st_cmdline = ( node.re6st_cmdline = (
'./py re6stnet @%s/re6stnet.conf -v%u --registry %s' '"%s" ./py re6stnet @%s/re6stnet.conf -v%u --registry %s'
' --console %s/run/console.sock %s' ' --console %s/run/console.sock %s'
) % (folder, VERBOSE, registry, folder, args) ) % (sys.executable, folder, VERBOSE, registry, folder, args)
node.screen(node.re6st_cmdline) node.screen(node.re6st_cmdline)
new_node(registry, registry.name, '--ip ' + reg_addr, registry='http://localhost/') new_node(registry, registry.name, '--ip ' + reg_addr, registry='http://localhost/')
yield new_node yield new_node
db.close() db.close()
with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node: with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node:
new_node(machine1, 'm1', '-I%s' % m1_if_0.name) new_node(machine1, 'm1', '-I%s' % m1_if_0.name)
new_node(machine2, 'm2', '--remote-gateway 10.1.1.1', prefix_len=77) new_node(machine2, 'm2', '--remote-gateway 10.1.1.1', prefix_len=77)
...@@ -300,15 +337,16 @@ if args.ping: ...@@ -300,15 +337,16 @@ if args.ping:
name = machine.name if machine.short[0] == 'R' else 'm' + machine.short name = machine.name if machine.short[0] == 'R' else 'm' + machine.short
machine.screen('python ping.py {} {}'.format(name, ' '.join(ips))) machine.screen('python ping.py {} {}'.format(name, ' '.join(ips)))
class testHMAC(Thread): class testHMAC(Thread):
def run(self): def run(self):
updateHMAC = ('python', '-c', "import urllib, sys; sys.exit(" updateHMAC = ('python', '-c', "import urllib, sys; sys.exit("
"204 != urllib.urlopen('http://127.0.0.1/updateHMAC').code)") "204 != urllib.urlopen('http://127.0.0.1/updateHMAC').code)")
reg1_db = sqlite3.connect('registry/registry.db', isolation_level=None, reg1_db = sqlite3.connect('registry/registry.db', isolation_level=None,
check_same_thread=False) check_same_thread=False)
reg2_db = sqlite3.connect('registry2/registry.db', isolation_level=None, reg2_db = sqlite3.connect('registry2/registry.db', isolation_level=None,
check_same_thread=False) check_same_thread=False)
reg1_db.text_factory = reg2_db.text_factory = str reg1_db.text_factory = reg2_db.text_factory = str
m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8' m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8'
m_net2 = 'registry2', 'm10' m_net2 = 'registry2', 'm10'
...@@ -340,15 +378,19 @@ class testHMAC(Thread): ...@@ -340,15 +378,19 @@ class testHMAC(Thread):
reg1_db.close() reg1_db.close()
reg2_db.close() reg2_db.close()
if args.hmac: if args.hmac:
import test_hmac import test_hmac
t = testHMAC() t = testHMAC()
t.deamon = 1 t.deamon = 1
t.start() t.start()
del t del t
_ll = {} _ll: dict[str, tuple[Re6stNode, bool]] = {}
def node_by_ll(addr):
def node_by_ll(addr: str) -> tuple[Re6stNode, bool]:
try: try:
return _ll[addr] return _ll[addr]
except KeyError: except KeyError:
...@@ -366,24 +408,26 @@ def node_by_ll(addr): ...@@ -366,24 +408,26 @@ def node_by_ll(addr):
if a.startswith('10.42.'): if a.startswith('10.42.'):
assert not p % 8 assert not p % 8
_ll[socket.inet_ntoa(socket.inet_aton( _ll[socket.inet_ntoa(socket.inet_aton(
a)[:p//8].ljust(4, b'\0'))] = n, t a)[:p // 8].ljust(4, b'\0'))] = n, t
elif a.startswith('2001:db8:'): elif a.startswith('2001:db8:'):
assert not p % 8 assert not p % 8
a = socket.inet_ntop(socket.AF_INET6, a = socket.inet_ntop(socket.AF_INET6,
socket.inet_pton(socket.AF_INET6, socket.inet_pton(socket.AF_INET6,
a)[:p//8].ljust(16, b'\0')) a)[:p // 8].ljust(16, b'\0'))
elif not a.startswith('fe80::'): elif not a.startswith('fe80::'):
continue continue
_ll[a] = n, t _ll[a] = n, t
return _ll[addr] return _ll[addr]
def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
graph = {} def route_svg(ipv4, z=4, default=type('', (), {'short': None})):
graph: dict[Re6stNode, dict[tuple[Re6stNode, bool], list[Re6stNode]]] = {}
for n in nodes: for n in nodes:
g = graph[n] = defaultdict(list) g = graph[n] = defaultdict(list)
g: dict[tuple[Re6stNode, bool], list[Re6stNode]]
for r in n.get_routes(): for r in n.get_routes():
if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else
r.prefix is None or r.prefix.startswith('2001:db8:')): r.prefix is None or r.prefix.startswith('2001:db8:')):
try: try:
g[node_by_ll(r.nexthop)].append( g[node_by_ll(r.nexthop)].append(
node_by_ll(r.prefix)[0] if r.prefix else default) node_by_ll(r.prefix)[0] if r.prefix else default)
...@@ -394,39 +438,45 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})): ...@@ -394,39 +438,45 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
a = 2 * math.pi / N a = 2 * math.pi / N
edges = set() edges = set()
for i, n in enumerate(nodes): for i, n in enumerate(nodes):
i: int
gv.append('%s[pos="%s,%s!"];' gv.append('%s[pos="%s,%s!"];'
% (n.name, z * math.cos(a * i), z * math.sin(a * i))) % (n.name, z * math.cos(a * i), z * math.sin(a * i)))
l = [] l = []
for p, r in graph[n].items(): for p, r in graph[n].items():
j = abs(nodes.index(p[0]) - i) p: tuple[Re6stNode, bool]
r: list[Re6stNode]
j: int = abs(nodes.index(p[0]) - i)
l.append((min(j, N - j), p, r)) l.append((min(j, N - j), p, r))
for j, (l, (p, t), r) in enumerate(sorted(l)): for j, (_, (p2, t), r) in enumerate(sorted(l, key=lambda x: x[0])):
l = [] p2: Re6stNode
l2: list[str] = []
arrowhead = 'none' arrowhead = 'none'
for r in sorted(r.short for r in r): for r2 in sorted(r2.short for r2 in r):
if r: if r2:
if r == p.short: if r2 == p2.short:
r = '<font color="grey">%s</font>' % r r2 = '<font color="grey">%s</font>' % r2
l.append(r) l2.append(r2)
else: else:
arrowhead = 'dot' arrowhead = 'dot'
if (n.name, p.name) in edges: if (n.name, p2.name) in edges:
r = 'penwidth=0' r3 = 'penwidth=0'
else: else:
edges.add((p.name, n.name)) edges.add((p2.name, n.name))
r = 'style=solid' if t else 'style=dashed' r3 = 'style=solid' if t else 'style=dashed'
gv.append( gv.append(
'%s -> %s [labeldistance=%u, headlabel=<%s>, arrowhead=%s, %s];' '%s -> %s [labeldistance=%u, headlabel=<%s>, arrowhead=%s, %s];'
% (p.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l), % (p2.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l2),
arrowhead, r)) arrowhead, r3))
gv.append('}\n') gv.append('}\n')
return subprocess.Popen(('neato', '-Tsvg'), return subprocess.Popen(('neato', '-Tsvg'),
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
).communicate('\n'.join(gv).encode("utf-8"))[0].decode("utf-8") ).communicate('\n'.join(gv).encode("utf-8"))[0].decode("utf-8")
if args.port: if args.port:
import http.server, socketserver import http.server, socketserver
class Handler(http.server.SimpleHTTPRequestHandler): class Handler(http.server.SimpleHTTPRequestHandler):
_path_match = re.compile('/(.+)\.(html|svg)$').match _path_match = re.compile('/(.+)\.(html|svg)$').match
...@@ -474,8 +524,8 @@ if args.port: ...@@ -474,8 +524,8 @@ if args.port:
"""), stdout=subprocess.PIPE, cwd="..").communicate()[0].decode("utf-8") """), stdout=subprocess.PIPE, cwd="..").communicate()[0].decode("utf-8")
if body: if body:
body = subprocess.Popen(('neato', '-Tsvg'), body = subprocess.Popen(('neato', '-Tsvg'),
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
).communicate(body.encode("utf-8"))[0].decode("utf-8") ).communicate(body.encode("utf-8"))[0].decode("utf-8")
if not body: if not body:
self.send_error(500) self.send_error(500)
return return
...@@ -502,9 +552,9 @@ if args.port: ...@@ -502,9 +552,9 @@ if args.port:
%s %s
</body> </body>
</html>""" % (name, ' '.join(x if i == page else </html>""" % (name, ' '.join(x if i == page else
'<a href="%s.html">%s</a>' % (x, x) '<a href="%s.html">%s</a>' % (x, x)
for i, x in enumerate(self.pages)), for i, x in enumerate(self.pages)),
body[body.find('<svg'):]) body[body.find('<svg'):])
self.send_response(200) self.send_response(200)
body = body.encode("utf-8") body = body.encode("utf-8")
self.send_header('Content-Length', str(len(body))) self.send_header('Content-Length', str(len(body)))
...@@ -512,9 +562,13 @@ if args.port: ...@@ -512,9 +562,13 @@ if args.port:
self.end_headers() self.end_headers()
self.wfile.write(body) self.wfile.write(body)
class TCPServer(socketserver.TCPServer): class TCPServer(socketserver.TCPServer):
allow_reuse_address = True allow_reuse_address = True
TCPServer(('', args.port), Handler).serve_forever() TCPServer(('', args.port), Handler).serve_forever()
import pdb; pdb.set_trace() import pdb
pdb.set_trace()
...@@ -85,7 +85,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -85,7 +85,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
if hmac_sign: if hmac_sign:
def key(cmd, id: str, value): def key(cmd, id: str, value):
cmd += '-C', ('key type blake2s128 id %s value %s' % cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, binascii.hexlify(value))) (id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign) key(cmd, 'sign', hmac_sign)
default += ' key sign' default += ' key sign'
if hmac_accept is not None: if hmac_accept is not None:
......
...@@ -814,12 +814,12 @@ class RegistryClient(object): ...@@ -814,12 +814,12 @@ class RegistryClient(object):
def __getattr__(self, name): def __getattr__(self, name):
getcallargs = getattr(RegistryServer, name).getcallargs getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw): def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw) kw = getcallargs(*args, **kw)
query = '/' + name query = '/' + name
if kw: if kw:
if any(type(v) is not str for v in kw.values()): if any(type(v) is not str for v in kw.values()):
raise TypeError raise TypeError(kw)
query += '?' + urlencode(kw) query += '?' + urlencode(kw)
url = self._path + query url = self._path + query
client_prefix = kw.get('cn') client_prefix = kw.get('cn')
......
...@@ -344,6 +344,8 @@ class BaseTunnelManager(object): ...@@ -344,6 +344,8 @@ class BaseTunnelManager(object):
peer.hello0Sent() peer.hello0Sent()
def _sendto(self, to, msg, peer=None): def _sendto(self, to, msg, peer=None):
if type(msg) is str:
msg = msg.encode()
try: try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to) r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error as e: except socket.error as e:
......
...@@ -266,7 +266,7 @@ class Peer(object): ...@@ -266,7 +266,7 @@ class Peer(object):
seqno_struct = struct.Struct("!L") seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack): def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> str:
seqno, = _unpack(msg[:4]) seqno, = _unpack(msg[:4])
if seqno <= 2: if seqno <= 2:
msg = msg[4:] msg = msg[4:]
...@@ -280,10 +280,12 @@ class Peer(object): ...@@ -280,10 +280,12 @@ class Peer(object):
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno: if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None self._last = None
self._i = seqno self._i = seqno
return msg[4:i] return msg[4:i].decode()
def encode(self, msg, _pack=seqno_struct.pack): def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1 self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg msg = _pack(self._j) + msg
return msg + self._hmac(msg) return msg + self._hmac(msg)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment