Commit aed8ee7f authored by Kumar Kartikeya Dwivedi's avatar Kumar Kartikeya Dwivedi Committed by Alexei Starovoitov

net: netfilter: Deduplicate code in bpf_{xdp,skb}_ct_lookup

Move common checks inside the common function, and maintain the only
difference the two being how to obtain the struct net * from ctx.
No functional change intended.
Signed-off-by: default avatarKumar Kartikeya Dwivedi <memxor@gmail.com>
Link: https://lore.kernel.org/r/20220721134245.2450-7-memxor@gmail.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 63e564eb
...@@ -57,16 +57,19 @@ enum { ...@@ -57,16 +57,19 @@ enum {
static struct nf_conn *__bpf_nf_ct_lookup(struct net *net, static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
struct bpf_sock_tuple *bpf_tuple, struct bpf_sock_tuple *bpf_tuple,
u32 tuple_len, u8 protonum, u32 tuple_len, struct bpf_ct_opts *opts,
s32 netns_id, u8 *dir) u32 opts_len)
{ {
struct nf_conntrack_tuple_hash *hash; struct nf_conntrack_tuple_hash *hash;
struct nf_conntrack_tuple tuple; struct nf_conntrack_tuple tuple;
struct nf_conn *ct; struct nf_conn *ct;
if (unlikely(protonum != IPPROTO_TCP && protonum != IPPROTO_UDP)) if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
opts_len != NF_BPF_CT_OPTS_SZ)
return ERR_PTR(-EINVAL);
if (unlikely(opts->l4proto != IPPROTO_TCP && opts->l4proto != IPPROTO_UDP))
return ERR_PTR(-EPROTO); return ERR_PTR(-EPROTO);
if (unlikely(netns_id < BPF_F_CURRENT_NETNS)) if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS))
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
memset(&tuple, 0, sizeof(tuple)); memset(&tuple, 0, sizeof(tuple));
...@@ -89,23 +92,22 @@ static struct nf_conn *__bpf_nf_ct_lookup(struct net *net, ...@@ -89,23 +92,22 @@ static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
return ERR_PTR(-EAFNOSUPPORT); return ERR_PTR(-EAFNOSUPPORT);
} }
tuple.dst.protonum = protonum; tuple.dst.protonum = opts->l4proto;
if (netns_id >= 0) { if (opts->netns_id >= 0) {
net = get_net_ns_by_id(net, netns_id); net = get_net_ns_by_id(net, opts->netns_id);
if (unlikely(!net)) if (unlikely(!net))
return ERR_PTR(-ENONET); return ERR_PTR(-ENONET);
} }
hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple); hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple);
if (netns_id >= 0) if (opts->netns_id >= 0)
put_net(net); put_net(net);
if (!hash) if (!hash)
return ERR_PTR(-ENOENT); return ERR_PTR(-ENOENT);
ct = nf_ct_tuplehash_to_ctrack(hash); ct = nf_ct_tuplehash_to_ctrack(hash);
if (dir) opts->dir = NF_CT_DIRECTION(hash);
*dir = NF_CT_DIRECTION(hash);
return ct; return ct;
} }
...@@ -138,19 +140,10 @@ bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple, ...@@ -138,19 +140,10 @@ bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple,
struct net *caller_net; struct net *caller_net;
struct nf_conn *nfct; struct nf_conn *nfct;
BUILD_BUG_ON(sizeof(struct bpf_ct_opts) != NF_BPF_CT_OPTS_SZ);
if (!opts)
return NULL;
if (!bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
opts__sz != NF_BPF_CT_OPTS_SZ) {
opts->error = -EINVAL;
return NULL;
}
caller_net = dev_net(ctx->rxq->dev); caller_net = dev_net(ctx->rxq->dev);
nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts->l4proto, nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
opts->netns_id, &opts->dir);
if (IS_ERR(nfct)) { if (IS_ERR(nfct)) {
if (opts)
opts->error = PTR_ERR(nfct); opts->error = PTR_ERR(nfct);
return NULL; return NULL;
} }
...@@ -181,19 +174,10 @@ bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple, ...@@ -181,19 +174,10 @@ bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple,
struct net *caller_net; struct net *caller_net;
struct nf_conn *nfct; struct nf_conn *nfct;
BUILD_BUG_ON(sizeof(struct bpf_ct_opts) != NF_BPF_CT_OPTS_SZ);
if (!opts)
return NULL;
if (!bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
opts__sz != NF_BPF_CT_OPTS_SZ) {
opts->error = -EINVAL;
return NULL;
}
caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk); caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk);
nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts->l4proto, nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
opts->netns_id, &opts->dir);
if (IS_ERR(nfct)) { if (IS_ERR(nfct)) {
if (opts)
opts->error = PTR_ERR(nfct); opts->error = PTR_ERR(nfct);
return NULL; return NULL;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment