Commit a5fa25ad authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov

bpf: Change bpf_sk_release and bpf_sk_*cgroup_id to accept ARG_PTR_TO_BTF_ID_SOCK_COMMON

The previous patch allows the networking bpf prog to use the
bpf_skc_to_*() helpers to get a PTR_TO_BTF_ID socket pointer,
e.g. "struct tcp_sock *".  It allows the bpf prog to read all the
fields of the tcp_sock.

This patch changes the bpf_sk_release() and bpf_sk_*cgroup_id()
to take ARG_PTR_TO_BTF_ID_SOCK_COMMON such that they will
work with the pointer returned by the bpf_skc_to_*() helpers
also.  For example, the following will work:

	sk = bpf_skc_lookup_tcp(skb, tuple, tuplen, BPF_F_CURRENT_NETNS, 0);
	if (!sk)
		return;
	tp = bpf_skc_to_tcp_sock(sk);
	if (!tp) {
		bpf_sk_release(sk);
		return;
	}
	lsndtime = tp->lsndtime;
	/* Pass tp to bpf_sk_release() will also work */
	bpf_sk_release(tp);

Since PTR_TO_BTF_ID could be NULL, the helper taking
ARG_PTR_TO_BTF_ID_SOCK_COMMON has to check for NULL at runtime.

A btf_id of "struct sock" may not always mean a fullsock.  Regardless
the helper's running context may get a non-fullsock or not,
considering fullsock check/handling is pretty cheap, it is better to
keep the same verifier expectation on helper that takes ARG_PTR_TO_BTF_ID*
will be able to handle the minisock situation.  In the bpf_sk_*cgroup_id()
case,  it will try to get a fullsock by using sk_to_full_sk() as its
skb variant bpf_sk"b"_*cgroup_id() has already been doing.

bpf_sk_release can already handle minisock, so nothing special has to
be done.
Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20200925000356.3856047-1-kafai@fb.com
parent 1df8f55a
...@@ -2512,7 +2512,7 @@ union bpf_attr { ...@@ -2512,7 +2512,7 @@ union bpf_attr {
* result is from *reuse*\ **->socks**\ [] using the hash of the * result is from *reuse*\ **->socks**\ [] using the hash of the
* tuple. * tuple.
* *
* long bpf_sk_release(struct bpf_sock *sock) * long bpf_sk_release(void *sock)
* Description * Description
* Release the reference held by *sock*. *sock* must be a * Release the reference held by *sock*. *sock* must be a
* non-**NULL** pointer that was returned from * non-**NULL** pointer that was returned from
...@@ -3234,11 +3234,11 @@ union bpf_attr { ...@@ -3234,11 +3234,11 @@ union bpf_attr {
* *
* **-EOVERFLOW** if an overflow happened: The same object will be tried again. * **-EOVERFLOW** if an overflow happened: The same object will be tried again.
* *
* u64 bpf_sk_cgroup_id(struct bpf_sock *sk) * u64 bpf_sk_cgroup_id(void *sk)
* Description * Description
* Return the cgroup v2 id of the socket *sk*. * Return the cgroup v2 id of the socket *sk*.
* *
* *sk* must be a non-**NULL** pointer to a full socket, e.g. one * *sk* must be a non-**NULL** pointer to a socket, e.g. one
* returned from **bpf_sk_lookup_xxx**\ (), * returned from **bpf_sk_lookup_xxx**\ (),
* **bpf_sk_fullsock**\ (), etc. The format of returned id is * **bpf_sk_fullsock**\ (), etc. The format of returned id is
* same as in **bpf_skb_cgroup_id**\ (). * same as in **bpf_skb_cgroup_id**\ ().
...@@ -3248,7 +3248,7 @@ union bpf_attr { ...@@ -3248,7 +3248,7 @@ union bpf_attr {
* Return * Return
* The id is returned or 0 in case the id could not be retrieved. * The id is returned or 0 in case the id could not be retrieved.
* *
* u64 bpf_sk_ancestor_cgroup_id(struct bpf_sock *sk, int ancestor_level) * u64 bpf_sk_ancestor_cgroup_id(void *sk, int ancestor_level)
* Description * Description
* Return id of cgroup v2 that is ancestor of cgroup associated * Return id of cgroup v2 that is ancestor of cgroup associated
* with the *sk* at the *ancestor_level*. The root cgroup is at * with the *sk* at the *ancestor_level*. The root cgroup is at
......
...@@ -4088,18 +4088,17 @@ static inline u64 __bpf_sk_cgroup_id(struct sock *sk) ...@@ -4088,18 +4088,17 @@ static inline u64 __bpf_sk_cgroup_id(struct sock *sk)
{ {
struct cgroup *cgrp; struct cgroup *cgrp;
sk = sk_to_full_sk(sk);
if (!sk || !sk_fullsock(sk))
return 0;
cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data); cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data);
return cgroup_id(cgrp); return cgroup_id(cgrp);
} }
BPF_CALL_1(bpf_skb_cgroup_id, const struct sk_buff *, skb) BPF_CALL_1(bpf_skb_cgroup_id, const struct sk_buff *, skb)
{ {
struct sock *sk = skb_to_full_sk(skb); return __bpf_sk_cgroup_id(skb->sk);
if (!sk || !sk_fullsock(sk))
return 0;
return __bpf_sk_cgroup_id(sk);
} }
static const struct bpf_func_proto bpf_skb_cgroup_id_proto = { static const struct bpf_func_proto bpf_skb_cgroup_id_proto = {
...@@ -4115,6 +4114,10 @@ static inline u64 __bpf_sk_ancestor_cgroup_id(struct sock *sk, ...@@ -4115,6 +4114,10 @@ static inline u64 __bpf_sk_ancestor_cgroup_id(struct sock *sk,
struct cgroup *ancestor; struct cgroup *ancestor;
struct cgroup *cgrp; struct cgroup *cgrp;
sk = sk_to_full_sk(sk);
if (!sk || !sk_fullsock(sk))
return 0;
cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data); cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data);
ancestor = cgroup_ancestor(cgrp, ancestor_level); ancestor = cgroup_ancestor(cgrp, ancestor_level);
if (!ancestor) if (!ancestor)
...@@ -4126,12 +4129,7 @@ static inline u64 __bpf_sk_ancestor_cgroup_id(struct sock *sk, ...@@ -4126,12 +4129,7 @@ static inline u64 __bpf_sk_ancestor_cgroup_id(struct sock *sk,
BPF_CALL_2(bpf_skb_ancestor_cgroup_id, const struct sk_buff *, skb, int, BPF_CALL_2(bpf_skb_ancestor_cgroup_id, const struct sk_buff *, skb, int,
ancestor_level) ancestor_level)
{ {
struct sock *sk = skb_to_full_sk(skb); return __bpf_sk_ancestor_cgroup_id(skb->sk, ancestor_level);
if (!sk || !sk_fullsock(sk))
return 0;
return __bpf_sk_ancestor_cgroup_id(sk, ancestor_level);
} }
static const struct bpf_func_proto bpf_skb_ancestor_cgroup_id_proto = { static const struct bpf_func_proto bpf_skb_ancestor_cgroup_id_proto = {
...@@ -4151,7 +4149,7 @@ static const struct bpf_func_proto bpf_sk_cgroup_id_proto = { ...@@ -4151,7 +4149,7 @@ static const struct bpf_func_proto bpf_sk_cgroup_id_proto = {
.func = bpf_sk_cgroup_id, .func = bpf_sk_cgroup_id,
.gpl_only = false, .gpl_only = false,
.ret_type = RET_INTEGER, .ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_SOCKET, .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
}; };
BPF_CALL_2(bpf_sk_ancestor_cgroup_id, struct sock *, sk, int, ancestor_level) BPF_CALL_2(bpf_sk_ancestor_cgroup_id, struct sock *, sk, int, ancestor_level)
...@@ -4163,7 +4161,7 @@ static const struct bpf_func_proto bpf_sk_ancestor_cgroup_id_proto = { ...@@ -4163,7 +4161,7 @@ static const struct bpf_func_proto bpf_sk_ancestor_cgroup_id_proto = {
.func = bpf_sk_ancestor_cgroup_id, .func = bpf_sk_ancestor_cgroup_id,
.gpl_only = false, .gpl_only = false,
.ret_type = RET_INTEGER, .ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_SOCKET, .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
.arg2_type = ARG_ANYTHING, .arg2_type = ARG_ANYTHING,
}; };
#endif #endif
...@@ -5697,7 +5695,7 @@ static const struct bpf_func_proto bpf_sk_lookup_udp_proto = { ...@@ -5697,7 +5695,7 @@ static const struct bpf_func_proto bpf_sk_lookup_udp_proto = {
BPF_CALL_1(bpf_sk_release, struct sock *, sk) BPF_CALL_1(bpf_sk_release, struct sock *, sk)
{ {
if (sk_is_refcounted(sk)) if (sk && sk_is_refcounted(sk))
sock_gen_put(sk); sock_gen_put(sk);
return 0; return 0;
} }
...@@ -5706,7 +5704,7 @@ static const struct bpf_func_proto bpf_sk_release_proto = { ...@@ -5706,7 +5704,7 @@ static const struct bpf_func_proto bpf_sk_release_proto = {
.func = bpf_sk_release, .func = bpf_sk_release,
.gpl_only = false, .gpl_only = false,
.ret_type = RET_INTEGER, .ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_SOCK_COMMON, .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
}; };
BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx, BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx,
......
...@@ -2512,7 +2512,7 @@ union bpf_attr { ...@@ -2512,7 +2512,7 @@ union bpf_attr {
* result is from *reuse*\ **->socks**\ [] using the hash of the * result is from *reuse*\ **->socks**\ [] using the hash of the
* tuple. * tuple.
* *
* long bpf_sk_release(struct bpf_sock *sock) * long bpf_sk_release(void *sock)
* Description * Description
* Release the reference held by *sock*. *sock* must be a * Release the reference held by *sock*. *sock* must be a
* non-**NULL** pointer that was returned from * non-**NULL** pointer that was returned from
...@@ -3234,11 +3234,11 @@ union bpf_attr { ...@@ -3234,11 +3234,11 @@ union bpf_attr {
* *
* **-EOVERFLOW** if an overflow happened: The same object will be tried again. * **-EOVERFLOW** if an overflow happened: The same object will be tried again.
* *
* u64 bpf_sk_cgroup_id(struct bpf_sock *sk) * u64 bpf_sk_cgroup_id(void *sk)
* Description * Description
* Return the cgroup v2 id of the socket *sk*. * Return the cgroup v2 id of the socket *sk*.
* *
* *sk* must be a non-**NULL** pointer to a full socket, e.g. one * *sk* must be a non-**NULL** pointer to a socket, e.g. one
* returned from **bpf_sk_lookup_xxx**\ (), * returned from **bpf_sk_lookup_xxx**\ (),
* **bpf_sk_fullsock**\ (), etc. The format of returned id is * **bpf_sk_fullsock**\ (), etc. The format of returned id is
* same as in **bpf_skb_cgroup_id**\ (). * same as in **bpf_skb_cgroup_id**\ ().
...@@ -3248,7 +3248,7 @@ union bpf_attr { ...@@ -3248,7 +3248,7 @@ union bpf_attr {
* Return * Return
* The id is returned or 0 in case the id could not be retrieved. * The id is returned or 0 in case the id could not be retrieved.
* *
* u64 bpf_sk_ancestor_cgroup_id(struct bpf_sock *sk, int ancestor_level) * u64 bpf_sk_ancestor_cgroup_id(void *sk, int ancestor_level)
* Description * Description
* Return id of cgroup v2 that is ancestor of cgroup associated * Return id of cgroup v2 that is ancestor of cgroup associated
* with the *sk* at the *ancestor_level*. The root cgroup is at * with the *sk* at the *ancestor_level*. The root cgroup is at
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment