Commit 113214be authored by Daniel Borkmann's avatar Daniel Borkmann Committed by David S. Miller

bpf: refactor bpf_prog_get and type check into helper

Since bpf_prog_get() and program type check is used in a couple of places,
refactor this into a small helper function that we can make use of. Since
the non RO prog->aux part is not used in performance critical paths and a
program destruction via RCU is rather very unlikley when doing the put, we
shouldn't have an issue just doing the bpf_prog_get() + prog->type != type
check, but actually not taking the ref at all (due to being in fdget() /
fdput() section of the bpf fd) is even cleaner and makes the diff smaller
as well, so just go for that. Callsites are changed to make use of the new
helper where possible.
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Acked-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 1aacde3d
...@@ -218,6 +218,7 @@ void bpf_register_prog_type(struct bpf_prog_type_list *tl); ...@@ -218,6 +218,7 @@ void bpf_register_prog_type(struct bpf_prog_type_list *tl);
void bpf_register_map_type(struct bpf_map_type_list *tl); void bpf_register_map_type(struct bpf_map_type_list *tl);
struct bpf_prog *bpf_prog_get(u32 ufd); struct bpf_prog *bpf_prog_get(u32 ufd);
struct bpf_prog *bpf_prog_get_type(u32 ufd, enum bpf_prog_type type);
struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog); struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog);
void bpf_prog_put(struct bpf_prog *prog); void bpf_prog_put(struct bpf_prog *prog);
...@@ -277,6 +278,12 @@ static inline struct bpf_prog *bpf_prog_get(u32 ufd) ...@@ -277,6 +278,12 @@ static inline struct bpf_prog *bpf_prog_get(u32 ufd)
return ERR_PTR(-EOPNOTSUPP); return ERR_PTR(-EOPNOTSUPP);
} }
static inline struct bpf_prog *bpf_prog_get_type(u32 ufd,
enum bpf_prog_type type)
{
return ERR_PTR(-EOPNOTSUPP);
}
static inline void bpf_prog_put(struct bpf_prog *prog) static inline void bpf_prog_put(struct bpf_prog *prog)
{ {
} }
......
...@@ -657,7 +657,7 @@ int bpf_prog_new_fd(struct bpf_prog *prog) ...@@ -657,7 +657,7 @@ int bpf_prog_new_fd(struct bpf_prog *prog)
O_RDWR | O_CLOEXEC); O_RDWR | O_CLOEXEC);
} }
static struct bpf_prog *__bpf_prog_get(struct fd f) static struct bpf_prog *____bpf_prog_get(struct fd f)
{ {
if (!f.file) if (!f.file)
return ERR_PTR(-EBADF); return ERR_PTR(-EBADF);
...@@ -678,24 +678,35 @@ struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog) ...@@ -678,24 +678,35 @@ struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog)
return prog; return prog;
} }
/* called by sockets/tracing/seccomp before attaching program to an event static struct bpf_prog *__bpf_prog_get(u32 ufd, enum bpf_prog_type *type)
* pairs with bpf_prog_put()
*/
struct bpf_prog *bpf_prog_get(u32 ufd)
{ {
struct fd f = fdget(ufd); struct fd f = fdget(ufd);
struct bpf_prog *prog; struct bpf_prog *prog;
prog = __bpf_prog_get(f); prog = ____bpf_prog_get(f);
if (IS_ERR(prog)) if (IS_ERR(prog))
return prog; return prog;
if (type && prog->type != *type) {
prog = ERR_PTR(-EINVAL);
goto out;
}
prog = bpf_prog_inc(prog); prog = bpf_prog_inc(prog);
out:
fdput(f); fdput(f);
return prog; return prog;
} }
EXPORT_SYMBOL_GPL(bpf_prog_get);
struct bpf_prog *bpf_prog_get(u32 ufd)
{
return __bpf_prog_get(ufd, NULL);
}
struct bpf_prog *bpf_prog_get_type(u32 ufd, enum bpf_prog_type type)
{
return __bpf_prog_get(ufd, &type);
}
EXPORT_SYMBOL_GPL(bpf_prog_get_type);
/* last field in 'union bpf_attr' used by this command */ /* last field in 'union bpf_attr' used by this command */
#define BPF_PROG_LOAD_LAST_FIELD kern_version #define BPF_PROG_LOAD_LAST_FIELD kern_version
......
...@@ -1301,21 +1301,10 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk) ...@@ -1301,21 +1301,10 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk)
static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk) static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk)
{ {
struct bpf_prog *prog;
if (sock_flag(sk, SOCK_FILTER_LOCKED)) if (sock_flag(sk, SOCK_FILTER_LOCKED))
return ERR_PTR(-EPERM); return ERR_PTR(-EPERM);
prog = bpf_prog_get(ufd); return bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER);
if (IS_ERR(prog))
return prog;
if (prog->type != BPF_PROG_TYPE_SOCKET_FILTER) {
bpf_prog_put(prog);
return ERR_PTR(-EINVAL);
}
return prog;
} }
int sk_attach_bpf(u32 ufd, struct sock *sk) int sk_attach_bpf(u32 ufd, struct sock *sk)
......
...@@ -1765,18 +1765,12 @@ static int kcm_attach_ioctl(struct socket *sock, struct kcm_attach *info) ...@@ -1765,18 +1765,12 @@ static int kcm_attach_ioctl(struct socket *sock, struct kcm_attach *info)
if (!csock) if (!csock)
return -ENOENT; return -ENOENT;
prog = bpf_prog_get(info->bpf_fd); prog = bpf_prog_get_type(info->bpf_fd, BPF_PROG_TYPE_SOCKET_FILTER);
if (IS_ERR(prog)) { if (IS_ERR(prog)) {
err = PTR_ERR(prog); err = PTR_ERR(prog);
goto out; goto out;
} }
if (prog->type != BPF_PROG_TYPE_SOCKET_FILTER) {
bpf_prog_put(prog);
err = -EINVAL;
goto out;
}
err = kcm_attach(sock, csock, prog); err = kcm_attach(sock, csock, prog);
if (err) { if (err) {
bpf_prog_put(prog); bpf_prog_put(prog);
......
...@@ -1588,13 +1588,9 @@ static int fanout_set_data_ebpf(struct packet_sock *po, char __user *data, ...@@ -1588,13 +1588,9 @@ static int fanout_set_data_ebpf(struct packet_sock *po, char __user *data,
if (copy_from_user(&fd, data, len)) if (copy_from_user(&fd, data, len))
return -EFAULT; return -EFAULT;
new = bpf_prog_get(fd); new = bpf_prog_get_type(fd, BPF_PROG_TYPE_SOCKET_FILTER);
if (IS_ERR(new)) if (IS_ERR(new))
return PTR_ERR(new); return PTR_ERR(new);
if (new->type != BPF_PROG_TYPE_SOCKET_FILTER) {
bpf_prog_put(new);
return -EINVAL;
}
__fanout_set_data_bpf(po->fanout, new); __fanout_set_data_bpf(po->fanout, new);
return 0; return 0;
......
...@@ -223,15 +223,10 @@ static int tcf_bpf_init_from_efd(struct nlattr **tb, struct tcf_bpf_cfg *cfg) ...@@ -223,15 +223,10 @@ static int tcf_bpf_init_from_efd(struct nlattr **tb, struct tcf_bpf_cfg *cfg)
bpf_fd = nla_get_u32(tb[TCA_ACT_BPF_FD]); bpf_fd = nla_get_u32(tb[TCA_ACT_BPF_FD]);
fp = bpf_prog_get(bpf_fd); fp = bpf_prog_get_type(bpf_fd, BPF_PROG_TYPE_SCHED_ACT);
if (IS_ERR(fp)) if (IS_ERR(fp))
return PTR_ERR(fp); return PTR_ERR(fp);
if (fp->type != BPF_PROG_TYPE_SCHED_ACT) {
bpf_prog_put(fp);
return -EINVAL;
}
if (tb[TCA_ACT_BPF_NAME]) { if (tb[TCA_ACT_BPF_NAME]) {
name = kmemdup(nla_data(tb[TCA_ACT_BPF_NAME]), name = kmemdup(nla_data(tb[TCA_ACT_BPF_NAME]),
nla_len(tb[TCA_ACT_BPF_NAME]), nla_len(tb[TCA_ACT_BPF_NAME]),
......
...@@ -272,15 +272,10 @@ static int cls_bpf_prog_from_efd(struct nlattr **tb, struct cls_bpf_prog *prog, ...@@ -272,15 +272,10 @@ static int cls_bpf_prog_from_efd(struct nlattr **tb, struct cls_bpf_prog *prog,
bpf_fd = nla_get_u32(tb[TCA_BPF_FD]); bpf_fd = nla_get_u32(tb[TCA_BPF_FD]);
fp = bpf_prog_get(bpf_fd); fp = bpf_prog_get_type(bpf_fd, BPF_PROG_TYPE_SCHED_CLS);
if (IS_ERR(fp)) if (IS_ERR(fp))
return PTR_ERR(fp); return PTR_ERR(fp);
if (fp->type != BPF_PROG_TYPE_SCHED_CLS) {
bpf_prog_put(fp);
return -EINVAL;
}
if (tb[TCA_BPF_NAME]) { if (tb[TCA_BPF_NAME]) {
name = kmemdup(nla_data(tb[TCA_BPF_NAME]), name = kmemdup(nla_data(tb[TCA_BPF_NAME]),
nla_len(tb[TCA_BPF_NAME]), nla_len(tb[TCA_BPF_NAME]),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment