Commit e46dd903 authored by Donald Hunter's avatar Donald Hunter Committed by Jakub Kicinski

tools/net/ynl: Add support for netlink-raw families

Refactor the ynl code to encapsulate protocol specifics into
NetlinkProtocol and GenlProtocol.
Signed-off-by: default avatarDonald Hunter <donald.hunter@gmail.com>
Link: https://lore.kernel.org/r/20230825122756.7603-8-donald.hunter@gmail.comSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent fb0a06d4
...@@ -25,6 +25,7 @@ class Netlink: ...@@ -25,6 +25,7 @@ class Netlink:
NETLINK_ADD_MEMBERSHIP = 1 NETLINK_ADD_MEMBERSHIP = 1
NETLINK_CAP_ACK = 10 NETLINK_CAP_ACK = 10
NETLINK_EXT_ACK = 11 NETLINK_EXT_ACK = 11
NETLINK_GET_STRICT_CHK = 12
# Netlink message # Netlink message
NLMSG_ERROR = 2 NLMSG_ERROR = 2
...@@ -228,6 +229,9 @@ class NlMsg: ...@@ -228,6 +229,9 @@ class NlMsg:
desc += f" ({spec['doc']})" desc += f" ({spec['doc']})"
self.extack['miss-type'] = desc self.extack['miss-type'] = desc
def cmd(self):
return self.nl_type
def __repr__(self): def __repr__(self):
msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
if self.error: if self.error:
...@@ -322,6 +326,9 @@ class GenlMsg: ...@@ -322,6 +326,9 @@ class GenlMsg:
self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
self.raw = nl_msg.raw[4:] self.raw = nl_msg.raw[4:]
def cmd(self):
return self.genl_cmd
def __repr__(self): def __repr__(self):
msg = repr(self.nl) msg = repr(self.nl)
msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
...@@ -330,9 +337,41 @@ class GenlMsg: ...@@ -330,9 +337,41 @@ class GenlMsg:
return msg return msg
class GenlFamily: class NetlinkProtocol:
def __init__(self, family_name): def __init__(self, family_name, proto_num):
self.family_name = family_name self.family_name = family_name
self.proto_num = proto_num
def _message(self, nl_type, nl_flags, seq=None):
if seq is None:
seq = random.randint(1, 1024)
nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
return nlmsg
def message(self, flags, command, version, seq=None):
return self._message(command, flags, seq)
def _decode(self, nl_msg):
return nl_msg
def decode(self, ynl, nl_msg):
msg = self._decode(nl_msg)
fixed_header_size = 0
if ynl:
op = ynl.rsp_by_value[msg.cmd()]
fixed_header_size = ynl._fixed_header_size(op)
msg.raw_attrs = NlAttrs(msg.raw[fixed_header_size:])
return msg
def get_mcast_id(self, mcast_name, mcast_groups):
if mcast_name not in mcast_groups:
raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
return mcast_groups[mcast_name].value
class GenlProtocol(NetlinkProtocol):
def __init__(self, family_name):
super().__init__(family_name, Netlink.NETLINK_GENERIC)
global genl_family_name_to_id global genl_family_name_to_id
if genl_family_name_to_id is None: if genl_family_name_to_id is None:
...@@ -341,6 +380,19 @@ class GenlFamily: ...@@ -341,6 +380,19 @@ class GenlFamily:
self.genl_family = genl_family_name_to_id[family_name] self.genl_family = genl_family_name_to_id[family_name]
self.family_id = genl_family_name_to_id[family_name]['id'] self.family_id = genl_family_name_to_id[family_name]['id']
def message(self, flags, command, version, seq=None):
nlmsg = self._message(self.family_id, flags, seq)
genlmsg = struct.pack("BBH", command, version, 0)
return nlmsg + genlmsg
def _decode(self, nl_msg):
return GenlMsg(nl_msg)
def get_mcast_id(self, mcast_name, mcast_groups):
if mcast_name not in self.genl_family['mcast']:
raise Exception(f'Multicast group "{mcast_name}" not present in the family')
return self.genl_family['mcast'][mcast_name]
# #
# YNL implementation details. # YNL implementation details.
...@@ -353,9 +405,19 @@ class YnlFamily(SpecFamily): ...@@ -353,9 +405,19 @@ class YnlFamily(SpecFamily):
self.include_raw = False self.include_raw = False
self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) try:
if self.proto == "netlink-raw":
self.nlproto = NetlinkProtocol(self.yaml['name'],
self.yaml['protonum'])
else:
self.nlproto = GenlProtocol(self.yaml['name'])
except KeyError:
raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
self.async_msg_ids = set() self.async_msg_ids = set()
self.async_msg_queue = [] self.async_msg_queue = []
...@@ -368,18 +430,12 @@ class YnlFamily(SpecFamily): ...@@ -368,18 +430,12 @@ class YnlFamily(SpecFamily):
bound_f = functools.partial(self._op, op_name) bound_f = functools.partial(self._op, op_name)
setattr(self, op.ident_name, bound_f) setattr(self, op.ident_name, bound_f)
try:
self.family = GenlFamily(self.yaml['name'])
except KeyError:
raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
def ntf_subscribe(self, mcast_name): def ntf_subscribe(self, mcast_name):
if mcast_name not in self.family.genl_family['mcast']: mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
raise Exception(f'Multicast group "{mcast_name}" not present in the family')
self.sock.bind((0, 0)) self.sock.bind((0, 0))
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
self.family.genl_family['mcast'][mcast_name]) mcast_id)
def _add_attr(self, space, name, value): def _add_attr(self, space, name, value):
try: try:
...@@ -505,11 +561,9 @@ class YnlFamily(SpecFamily): ...@@ -505,11 +561,9 @@ class YnlFamily(SpecFamily):
if 'bad-attr-offs' not in extack: if 'bad-attr-offs' not in extack:
return return
genl_req = GenlMsg(NlMsg(request, 0, op.attr_set)) msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
fixed_header_size = self._fixed_header_size(op) offset = 20 + self._fixed_header_size(op)
offset = 20 + fixed_header_size path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
path = self._decode_extack_path(NlAttrs(genl_req.raw[fixed_header_size:]),
op.attr_set, offset,
extack['bad-attr-offs']) extack['bad-attr-offs'])
if path: if path:
del extack['bad-attr-offs'] del extack['bad-attr-offs']
...@@ -539,14 +593,17 @@ class YnlFamily(SpecFamily): ...@@ -539,14 +593,17 @@ class YnlFamily(SpecFamily):
fixed_header_attrs[m.name] = value fixed_header_attrs[m.name] = value
return fixed_header_attrs return fixed_header_attrs
def handle_ntf(self, nl_msg, genl_msg): def handle_ntf(self, decoded):
msg = dict() msg = dict()
if self.include_raw: if self.include_raw:
msg['nlmsg'] = nl_msg msg['raw'] = decoded
msg['genlmsg'] = genl_msg op = self.rsp_by_value[decoded.cmd()]
op = self.rsp_by_value[genl_msg.genl_cmd] attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
if op.fixed_header:
attrs.update(self._decode_fixed_header(decoded, op.fixed_header))
msg['name'] = op['name'] msg['name'] = op['name']
msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) msg['msg'] = attrs
self.async_msg_queue.append(msg) self.async_msg_queue.append(msg)
def check_ntf(self): def check_ntf(self):
...@@ -566,12 +623,12 @@ class YnlFamily(SpecFamily): ...@@ -566,12 +623,12 @@ class YnlFamily(SpecFamily):
print("Netlink done while checking for ntf!?") print("Netlink done while checking for ntf!?")
continue continue
gm = GenlMsg(nl_msg) decoded = self.nlproto.decode(self, nl_msg)
if gm.genl_cmd not in self.async_msg_ids: if decoded.cmd() not in self.async_msg_ids:
print("Unexpected msg id done while checking for ntf", gm) print("Unexpected msg id done while checking for ntf", decoded)
continue continue
self.handle_ntf(nl_msg, gm) self.handle_ntf(decoded)
def operation_do_attributes(self, name): def operation_do_attributes(self, name):
""" """
...@@ -592,7 +649,7 @@ class YnlFamily(SpecFamily): ...@@ -592,7 +649,7 @@ class YnlFamily(SpecFamily):
nl_flags |= Netlink.NLM_F_DUMP nl_flags |= Netlink.NLM_F_DUMP
req_seq = random.randint(1024, 65535) req_seq = random.randint(1024, 65535)
msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
fixed_header_members = [] fixed_header_members = []
if op.fixed_header: if op.fixed_header:
fixed_header_members = self.consts[op.fixed_header].members fixed_header_members = self.consts[op.fixed_header].members
...@@ -624,19 +681,20 @@ class YnlFamily(SpecFamily): ...@@ -624,19 +681,20 @@ class YnlFamily(SpecFamily):
done = True done = True
break break
gm = GenlMsg(nl_msg) decoded = self.nlproto.decode(self, nl_msg)
# Check if this is a reply to our request # Check if this is a reply to our request
if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
if gm.genl_cmd in self.async_msg_ids: if decoded.cmd() in self.async_msg_ids:
self.handle_ntf(nl_msg, gm) self.handle_ntf(decoded)
continue continue
else: else:
print('Unexpected message: ' + repr(gm)) print('Unexpected message: ' + repr(decoded))
continue continue
rsp_msg = self._decode(NlAttrs(gm.raw), op.attr_set.name) rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
if op.fixed_header: if op.fixed_header:
rsp_msg.update(self._decode_fixed_header(gm, op.fixed_header)) rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header))
rsp.append(rsp_msg) rsp.append(rsp_msg)
if not rsp: if not rsp:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment