Commit f484fc65 authored by Tom Niget's avatar Tom Niget Committed by Tom Niget

wip: fix various places mishandling strings and bytes

parent 2b63c9a7
...@@ -252,15 +252,18 @@ class Babel: ...@@ -252,15 +252,18 @@ class Babel:
unidentified = set(n) unidentified = set(n)
self.neighbours = neighbours = {} self.neighbours = neighbours = {}
a = len(self.network) a = len(self.network)
logging.info("Routes: %r", routes)
for route in routes: for route in routes:
assert route.flags & 1, route # installed assert route.flags & 1, route # installed
if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'): if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
logging.warning("Ignoring IPv4 route: %r", route)
continue continue
assert route.neigh_address == route.nexthop, route assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex address = route.neigh_address, route.ifindex
neigh_routes = n[address] neigh_routes = n[address]
ip = utils.binFromRawIp(route.prefix) ip = utils.binFromRawIp(route.prefix)
if ip[:a] == self.network: if ip[:a] == self.network:
logging.debug("Route is on the network: %r", route)
prefix = ip[a:route.plen] prefix = ip[a:route.plen]
if prefix and not route.refmetric: if prefix and not route.refmetric:
neighbours[prefix] = neigh_routes neighbours[prefix] = neigh_routes
...@@ -275,7 +278,9 @@ class Babel: ...@@ -275,7 +278,9 @@ class Babel:
socket.inet_ntop(socket.AF_INET6, route.prefix), socket.inet_ntop(socket.AF_INET6, route.prefix),
route.plen) route.plen)
else: else:
logging.debug("Route is not on the network: %r", route)
prefix = None prefix = None
logging.debug("Adding route %r to %r", route, neigh_routes)
neigh_routes[1][prefix] = route neigh_routes[1][prefix] = route
self.locked.clear() self.locked.clear()
if unidentified: if unidentified:
......
...@@ -206,13 +206,20 @@ class RegistryServer: ...@@ -206,13 +206,20 @@ class RegistryServer:
def recv(self, code): def recv(self, code):
try: try:
prefix, msg = self.sock.recv(1<<16).split(b'\x00', 1) data = self.sock.recv(1<<16)
logging.info("recv raw: %r", data)
prefix, msg = data.split(b'\x00', 1)
int(prefix, 2) int(prefix, 2)
except ValueError: except ValueError:
pass pass
else: else:
if msg and msg[0:1] == code: if msg:
return prefix, msg[1:] if msg[0:1] == bytes([code]):
return prefix.decode(), msg[1:]
else:
logging.error("Unexpected code: %r", msg)
else:
logging.error("Empty message")
return None, None return None, None
def select(self, r, w, t): def select(self, r, w, t):
...@@ -609,7 +616,7 @@ class RegistryServer: ...@@ -609,7 +616,7 @@ class RegistryServer:
return zlib.compress(json.dumps(config).encode("utf-8")) return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer): def _queryAddress(self, peer):
logging.info("Querying address for %s/%s", int(peer, 2), len(peer)) logging.info("Querying address for %s/%s %r", int(peer, 2), len(peer), peer)
self.sendto(peer, 1) self.sendto(peer, 1)
s = self.sock, s = self.sock,
timeout = 3 timeout = 3
...@@ -617,9 +624,9 @@ class RegistryServer: ...@@ -617,9 +624,9 @@ class RegistryServer:
# Loop because there may be answers from previous requests. # Loop because there may be answers from previous requests.
while select.select(s, (), (), timeout)[0]: while select.select(s, (), (), timeout)[0]:
prefix, msg = self.recv(1) prefix, msg = self.recv(1)
logging.info("* received: %s - %s", prefix, msg) logging.info("* received: %r - %r", prefix, msg)
if prefix == peer: if prefix == peer:
return msg return msg.decode()
timeout = max(0, end - time.time()) timeout = max(0, end - time.time())
logging.info("Timeout while querying address for %s/%s", logging.info("Timeout while querying address for %s/%s",
int(peer, 2), len(peer)) int(peer, 2), len(peer))
...@@ -662,7 +669,7 @@ class RegistryServer: ...@@ -662,7 +669,7 @@ class RegistryServer:
cert = self.getCert(cn) cert = self.getCert(cn)
msg = "%s %s" % (peer, msg) msg = "%s %s" % (peer, msg)
logging.info("Sending bootstrap peer: %s", msg) logging.info("Sending bootstrap peer: %s", msg)
return x509.encrypt(cert, msg) return x509.encrypt(cert, msg.encode())
@rpc_private @rpc_private
def revoke(self, cn_or_serial): def revoke(self, cn_or_serial):
......
...@@ -252,6 +252,7 @@ def binFromSubnet(subnet): ...@@ -252,6 +252,7 @@ def binFromSubnet(subnet):
return bin(int(p))[2:].rjust(int(l), '0') return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret(): def newHmacSecret():
"""returns bytes"""
from random import getrandbits as g from random import getrandbits as g
pack = struct.Struct(">QQI").pack pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN assert len(pack(0,0,0)) == HMAC_LEN
......
...@@ -31,7 +31,8 @@ def openssl(*args, fds=[]): ...@@ -31,7 +31,8 @@ def openssl(*args, fds=[]):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, pass_fds=fds) stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data): def encrypt(cert, data: bytes) -> bytes:
assert isinstance(data, bytes)
r, w = os.pipe() r, w = os.pipe()
try: try:
threading.Thread(target=os.write, args=(w, cert)).start() threading.Thread(target=os.write, args=(w, cert)).start()
...@@ -182,6 +183,7 @@ class Cert: ...@@ -182,6 +183,7 @@ class Cert:
) )
def decrypt(self, data: bytes) -> bytes: def decrypt(self, data: bytes) -> bytes:
assert isinstance(data, bytes)
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path) p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data) out, err = p.communicate(data)
if p.returncode: if p.returncode:
...@@ -289,7 +291,8 @@ class Peer: ...@@ -289,7 +291,8 @@ class Peer:
seqno_struct = struct.Struct("!L") seqno_struct = struct.Struct("!L")
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> str: def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> bytes:
assert isinstance(msg, bytes)
seqno, = _unpack(msg[:4]) seqno, = _unpack(msg[:4])
if seqno <= 2: if seqno <= 2:
msg = msg[4:] msg = msg[4:]
...@@ -303,11 +306,7 @@ class Peer: ...@@ -303,11 +306,7 @@ class Peer:
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno: if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None self._last = None
self._i = seqno self._i = seqno
try: return msg[4:i]
return msg[4:i].decode()
except UnicodeDecodeError:
logging.error("Invalid message from %s: %r", self.prefix, msg)
raise
def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes: def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1 self._j += 1
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment