Commit 8e46c353 authored by Daniel Borkmann's avatar Daniel Borkmann

Merge branch 'bpf-sk-storage-clone'

Stanislav Fomichev says:

====================
Currently there is no way to propagate sk storage from the listener
socket to a newly accepted one. Consider the following use case:

        fd = socket();
        setsockopt(fd, SOL_IP, IP_TOS,...);
        /* ^^^ setsockopt BPF program triggers here and saves something
         * into sk storage of the listener.
         */
        listen(fd, ...);
        while (client = accept(fd)) {
                /* At this point all association between listener
                 * socket and newly accepted one is gone. New
                 * socket will not have any sk storage attached.
                 */
        }

Let's add new BPF_F_CLONE flag that can be specified when creating
a socket storage map. This new flag indicates that map contents
should be cloned when the socket is cloned.

v4:
* drop 'goto err' in bpf_sk_storage_clone (Yonghong Song)
* add comment about race with bpf_sk_storage_map_free to the
  bpf_sk_storage_clone side as well (Daniel Borkmann)

v3:
* make sure BPF_F_NO_PREALLOC is always present when creating
  a map (Martin KaFai Lau)
* don't call bpf_sk_storage_free explicitly, rely on
  sk_free_unlock_clone to do the cleanup (Martin KaFai Lau)

v2:
* remove spinlocks around selem_link_map/sk (Martin KaFai Lau)
* BPF_F_CLONE on a map, not selem (Martin KaFai Lau)
* hold a map while cloning (Martin KaFai Lau)
* use BTF maps in selftests (Yonghong Song)
* do proper cleanup selftests; don't call close(-1) (Yonghong Song)
* export bpf_map_inc_not_zero
====================
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parents fae55527 c3bbf176
......@@ -647,6 +647,8 @@ void bpf_map_free_id(struct bpf_map *map, bool do_idr_lock);
struct bpf_map *bpf_map_get_with_uref(u32 ufd);
struct bpf_map *__bpf_map_get(struct fd f);
struct bpf_map * __must_check bpf_map_inc(struct bpf_map *map, bool uref);
struct bpf_map * __must_check bpf_map_inc_not_zero(struct bpf_map *map,
bool uref);
void bpf_map_put_with_uref(struct bpf_map *map);
void bpf_map_put(struct bpf_map *map);
int bpf_map_charge_memlock(struct bpf_map *map, u32 pages);
......
......@@ -10,4 +10,14 @@ void bpf_sk_storage_free(struct sock *sk);
extern const struct bpf_func_proto bpf_sk_storage_get_proto;
extern const struct bpf_func_proto bpf_sk_storage_delete_proto;
#ifdef CONFIG_BPF_SYSCALL
int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk);
#else
static inline int bpf_sk_storage_clone(const struct sock *sk,
struct sock *newsk)
{
return 0;
}
#endif
#endif /* _BPF_SK_STORAGE_H */
......@@ -337,6 +337,9 @@ enum bpf_attach_type {
#define BPF_F_RDONLY_PROG (1U << 7)
#define BPF_F_WRONLY_PROG (1U << 8)
/* Clone map from listener for newly accepted socket */
#define BPF_F_CLONE (1U << 9)
/* flags for BPF_PROG_QUERY */
#define BPF_F_QUERY_EFFECTIVE (1U << 0)
......
......@@ -683,8 +683,8 @@ struct bpf_map *bpf_map_get_with_uref(u32 ufd)
}
/* map_idr_lock should have been held */
static struct bpf_map *bpf_map_inc_not_zero(struct bpf_map *map,
bool uref)
static struct bpf_map *__bpf_map_inc_not_zero(struct bpf_map *map,
bool uref)
{
int refold;
......@@ -704,6 +704,16 @@ static struct bpf_map *bpf_map_inc_not_zero(struct bpf_map *map,
return map;
}
struct bpf_map *bpf_map_inc_not_zero(struct bpf_map *map, bool uref)
{
spin_lock_bh(&map_idr_lock);
map = __bpf_map_inc_not_zero(map, uref);
spin_unlock_bh(&map_idr_lock);
return map;
}
EXPORT_SYMBOL_GPL(bpf_map_inc_not_zero);
int __weak bpf_stackmap_copy(struct bpf_map *map, void *key, void *value)
{
return -ENOTSUPP;
......@@ -2177,7 +2187,7 @@ static int bpf_map_get_fd_by_id(const union bpf_attr *attr)
spin_lock_bh(&map_idr_lock);
map = idr_find(&map_idr, id);
if (map)
map = bpf_map_inc_not_zero(map, true);
map = __bpf_map_inc_not_zero(map, true);
else
map = ERR_PTR(-ENOENT);
spin_unlock_bh(&map_idr_lock);
......
......@@ -12,6 +12,9 @@
static atomic_t cache_idx;
#define SK_STORAGE_CREATE_FLAG_MASK \
(BPF_F_NO_PREALLOC | BPF_F_CLONE)
struct bucket {
struct hlist_head list;
raw_spinlock_t lock;
......@@ -209,7 +212,6 @@ static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
kfree_rcu(sk_storage, rcu);
}
/* sk_storage->lock must be held and sk_storage->list cannot be empty */
static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
struct bpf_sk_storage_elem *selem)
{
......@@ -509,7 +511,7 @@ static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
return 0;
}
/* Called by __sk_destruct() */
/* Called by __sk_destruct() & bpf_sk_storage_clone() */
void bpf_sk_storage_free(struct sock *sk)
{
struct bpf_sk_storage_elem *selem;
......@@ -557,6 +559,11 @@ static void bpf_sk_storage_map_free(struct bpf_map *map)
smap = (struct bpf_sk_storage_map *)map;
/* Note that this map might be concurrently cloned from
* bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
* RCU read section to finish before proceeding. New RCU
* read sections should be prevented via bpf_map_inc_not_zero.
*/
synchronize_rcu();
/* bpf prog and the userspace can no longer access this map
......@@ -601,7 +608,9 @@ static void bpf_sk_storage_map_free(struct bpf_map *map)
static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
{
if (attr->map_flags != BPF_F_NO_PREALLOC || attr->max_entries ||
if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
!(attr->map_flags & BPF_F_NO_PREALLOC) ||
attr->max_entries ||
attr->key_size != sizeof(int) || !attr->value_size ||
/* Enforce BTF for userspace sk dumping */
!attr->btf_key_type_id || !attr->btf_value_type_id)
......@@ -739,6 +748,95 @@ static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
return err;
}
static struct bpf_sk_storage_elem *
bpf_sk_storage_clone_elem(struct sock *newsk,
struct bpf_sk_storage_map *smap,
struct bpf_sk_storage_elem *selem)
{
struct bpf_sk_storage_elem *copy_selem;
copy_selem = selem_alloc(smap, newsk, NULL, true);
if (!copy_selem)
return NULL;
if (map_value_has_spin_lock(&smap->map))
copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
SDATA(selem)->data, true);
else
copy_map_value(&smap->map, SDATA(copy_selem)->data,
SDATA(selem)->data);
return copy_selem;
}
int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
{
struct bpf_sk_storage *new_sk_storage = NULL;
struct bpf_sk_storage *sk_storage;
struct bpf_sk_storage_elem *selem;
int ret = 0;
RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
rcu_read_lock();
sk_storage = rcu_dereference(sk->sk_bpf_storage);
if (!sk_storage || hlist_empty(&sk_storage->list))
goto out;
hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
struct bpf_sk_storage_elem *copy_selem;
struct bpf_sk_storage_map *smap;
struct bpf_map *map;
smap = rcu_dereference(SDATA(selem)->smap);
if (!(smap->map.map_flags & BPF_F_CLONE))
continue;
/* Note that for lockless listeners adding new element
* here can race with cleanup in bpf_sk_storage_map_free.
* Try to grab map refcnt to make sure that it's still
* alive and prevent concurrent removal.
*/
map = bpf_map_inc_not_zero(&smap->map, false);
if (IS_ERR(map))
continue;
copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
if (!copy_selem) {
ret = -ENOMEM;
bpf_map_put(map);
goto out;
}
if (new_sk_storage) {
selem_link_map(smap, copy_selem);
__selem_link_sk(new_sk_storage, copy_selem);
} else {
ret = sk_storage_alloc(newsk, smap, copy_selem);
if (ret) {
kfree(copy_selem);
atomic_sub(smap->elem_size,
&newsk->sk_omem_alloc);
bpf_map_put(map);
goto out;
}
new_sk_storage = rcu_dereference(copy_selem->sk_storage);
}
bpf_map_put(map);
}
out:
rcu_read_unlock();
/* In case of an error, don't free anything explicitly here, the
* caller is responsible to call bpf_sk_storage_free.
*/
return ret;
}
BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
void *, value, u64, flags)
{
......
......@@ -1851,9 +1851,12 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
goto out;
}
RCU_INIT_POINTER(newsk->sk_reuseport_cb, NULL);
#ifdef CONFIG_BPF_SYSCALL
RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
#endif
if (bpf_sk_storage_clone(sk, newsk)) {
sk_free_unlock_clone(newsk);
newsk = NULL;
goto out;
}
newsk->sk_err = 0;
newsk->sk_err_soft = 0;
......
......@@ -337,6 +337,9 @@ enum bpf_attach_type {
#define BPF_F_RDONLY_PROG (1U << 7)
#define BPF_F_WRONLY_PROG (1U << 8)
/* Clone map from listener for newly accepted socket */
#define BPF_F_CLONE (1U << 9)
/* flags for BPF_PROG_QUERY */
#define BPF_F_QUERY_EFFECTIVE (1U << 0)
......
......@@ -42,4 +42,5 @@ xdping
test_sockopt
test_sockopt_sk
test_sockopt_multi
test_sockopt_inherit
test_tcp_rtt
......@@ -29,7 +29,7 @@ TEST_GEN_PROGS = test_verifier test_tag test_maps test_lru_map test_lpm_map test
test_cgroup_storage test_select_reuseport test_section_names \
test_netcnt test_tcpnotify_user test_sock_fields test_sysctl test_hashmap \
test_btf_dump test_cgroup_attach xdping test_sockopt test_sockopt_sk \
test_sockopt_multi test_tcp_rtt
test_sockopt_multi test_sockopt_inherit test_tcp_rtt
BPF_OBJ_FILES = $(patsubst %.c,%.o, $(notdir $(wildcard progs/*.c)))
TEST_GEN_FILES = $(BPF_OBJ_FILES)
......@@ -111,6 +111,7 @@ $(OUTPUT)/test_cgroup_attach: cgroup_helpers.c
$(OUTPUT)/test_sockopt: cgroup_helpers.c
$(OUTPUT)/test_sockopt_sk: cgroup_helpers.c
$(OUTPUT)/test_sockopt_multi: cgroup_helpers.c
$(OUTPUT)/test_sockopt_inherit: cgroup_helpers.c
$(OUTPUT)/test_tcp_rtt: cgroup_helpers.c
.PHONY: force
......
// SPDX-License-Identifier: GPL-2.0
#include <linux/bpf.h>
#include "bpf_helpers.h"
char _license[] SEC("license") = "GPL";
__u32 _version SEC("version") = 1;
#define SOL_CUSTOM 0xdeadbeef
#define CUSTOM_INHERIT1 0
#define CUSTOM_INHERIT2 1
#define CUSTOM_LISTENER 2
struct sockopt_inherit {
__u8 val;
};
struct {
__uint(type, BPF_MAP_TYPE_SK_STORAGE);
__uint(map_flags, BPF_F_NO_PREALLOC | BPF_F_CLONE);
__type(key, int);
__type(value, struct sockopt_inherit);
} cloned1_map SEC(".maps");
struct {
__uint(type, BPF_MAP_TYPE_SK_STORAGE);
__uint(map_flags, BPF_F_NO_PREALLOC | BPF_F_CLONE);
__type(key, int);
__type(value, struct sockopt_inherit);
} cloned2_map SEC(".maps");
struct {
__uint(type, BPF_MAP_TYPE_SK_STORAGE);
__uint(map_flags, BPF_F_NO_PREALLOC);
__type(key, int);
__type(value, struct sockopt_inherit);
} listener_only_map SEC(".maps");
static __inline struct sockopt_inherit *get_storage(struct bpf_sockopt *ctx)
{
if (ctx->optname == CUSTOM_INHERIT1)
return bpf_sk_storage_get(&cloned1_map, ctx->sk, 0,
BPF_SK_STORAGE_GET_F_CREATE);
else if (ctx->optname == CUSTOM_INHERIT2)
return bpf_sk_storage_get(&cloned2_map, ctx->sk, 0,
BPF_SK_STORAGE_GET_F_CREATE);
else
return bpf_sk_storage_get(&listener_only_map, ctx->sk, 0,
BPF_SK_STORAGE_GET_F_CREATE);
}
SEC("cgroup/getsockopt")
int _getsockopt(struct bpf_sockopt *ctx)
{
__u8 *optval_end = ctx->optval_end;
struct sockopt_inherit *storage;
__u8 *optval = ctx->optval;
if (ctx->level != SOL_CUSTOM)
return 1; /* only interested in SOL_CUSTOM */
if (optval + 1 > optval_end)
return 0; /* EPERM, bounds check */
storage = get_storage(ctx);
if (!storage)
return 0; /* EPERM, couldn't get sk storage */
ctx->retval = 0; /* Reset system call return value to zero */
optval[0] = storage->val;
ctx->optlen = 1;
return 1;
}
SEC("cgroup/setsockopt")
int _setsockopt(struct bpf_sockopt *ctx)
{
__u8 *optval_end = ctx->optval_end;
struct sockopt_inherit *storage;
__u8 *optval = ctx->optval;
if (ctx->level != SOL_CUSTOM)
return 1; /* only interested in SOL_CUSTOM */
if (optval + 1 > optval_end)
return 0; /* EPERM, bounds check */
storage = get_storage(ctx);
if (!storage)
return 0; /* EPERM, couldn't get sk storage */
storage->val = optval[0];
ctx->optlen = -1;
return 1;
}
// SPDX-License-Identifier: GPL-2.0
#include <error.h>
#include <errno.h>
#include <stdio.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <pthread.h>
#include <linux/filter.h>
#include <bpf/bpf.h>
#include <bpf/libbpf.h>
#include "bpf_rlimit.h"
#include "bpf_util.h"
#include "cgroup_helpers.h"
#define CG_PATH "/sockopt_inherit"
#define SOL_CUSTOM 0xdeadbeef
#define CUSTOM_INHERIT1 0
#define CUSTOM_INHERIT2 1
#define CUSTOM_LISTENER 2
static int connect_to_server(int server_fd)
{
struct sockaddr_storage addr;
socklen_t len = sizeof(addr);
int fd;
fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
log_err("Failed to create client socket");
return -1;
}
if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
log_err("Failed to get server addr");
goto out;
}
if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
log_err("Fail to connect to server");
goto out;
}
return fd;
out:
close(fd);
return -1;
}
static int verify_sockopt(int fd, int optname, const char *msg, char expected)
{
socklen_t optlen = 1;
char buf = 0;
int err;
err = getsockopt(fd, SOL_CUSTOM, optname, &buf, &optlen);
if (err) {
log_err("%s: failed to call getsockopt", msg);
return 1;
}
printf("%s %d: got=0x%x ? expected=0x%x\n", msg, optname, buf, expected);
if (buf != expected) {
log_err("%s: unexpected getsockopt value %d != %d", msg,
buf, expected);
return 1;
}
return 0;
}
static void *server_thread(void *arg)
{
struct sockaddr_storage addr;
socklen_t len = sizeof(addr);
int fd = *(int *)arg;
int client_fd;
int err = 0;
if (listen(fd, 1) < 0)
error(1, errno, "Failed to listed on socket");
err += verify_sockopt(fd, CUSTOM_INHERIT1, "listen", 1);
err += verify_sockopt(fd, CUSTOM_INHERIT2, "listen", 1);
err += verify_sockopt(fd, CUSTOM_LISTENER, "listen", 1);
client_fd = accept(fd, (struct sockaddr *)&addr, &len);
if (client_fd < 0)
error(1, errno, "Failed to accept client");
err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "accept", 1);
err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "accept", 1);
err += verify_sockopt(client_fd, CUSTOM_LISTENER, "accept", 0);
close(client_fd);
return (void *)(long)err;
}
static int start_server(void)
{
struct sockaddr_in addr = {
.sin_family = AF_INET,
.sin_addr.s_addr = htonl(INADDR_LOOPBACK),
};
char buf;
int err;
int fd;
int i;
fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
log_err("Failed to create server socket");
return -1;
}
for (i = CUSTOM_INHERIT1; i <= CUSTOM_LISTENER; i++) {
buf = 0x01;
err = setsockopt(fd, SOL_CUSTOM, i, &buf, 1);
if (err) {
log_err("Failed to call setsockopt(%d)", i);
close(fd);
return -1;
}
}
if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
log_err("Failed to bind socket");
close(fd);
return -1;
}
return fd;
}
static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title)
{
enum bpf_attach_type attach_type;
enum bpf_prog_type prog_type;
struct bpf_program *prog;
int err;
err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
if (err) {
log_err("Failed to deduct types for %s BPF program", title);
return -1;
}
prog = bpf_object__find_program_by_title(obj, title);
if (!prog) {
log_err("Failed to find %s BPF program", title);
return -1;
}
err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
attach_type, 0);
if (err) {
log_err("Failed to attach %s BPF program", title);
return -1;
}
return 0;
}
static int run_test(int cgroup_fd)
{
struct bpf_prog_load_attr attr = {
.file = "./sockopt_inherit.o",
};
int server_fd = -1, client_fd;
struct bpf_object *obj;
void *server_err;
pthread_t tid;
int ignored;
int err;
err = bpf_prog_load_xattr(&attr, &obj, &ignored);
if (err) {
log_err("Failed to load BPF object");
return -1;
}
err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt");
if (err)
goto close_bpf_object;
err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt");
if (err)
goto close_bpf_object;
server_fd = start_server();
if (server_fd < 0) {
err = -1;
goto close_bpf_object;
}
pthread_create(&tid, NULL, server_thread, (void *)&server_fd);
client_fd = connect_to_server(server_fd);
if (client_fd < 0) {
err = -1;
goto close_server_fd;
}
err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "connect", 0);
err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "connect", 0);
err += verify_sockopt(client_fd, CUSTOM_LISTENER, "connect", 0);
pthread_join(tid, &server_err);
err += (int)(long)server_err;
close(client_fd);
close_server_fd:
close(server_fd);
close_bpf_object:
bpf_object__close(obj);
return err;
}
int main(int args, char **argv)
{
int cgroup_fd;
int err = EXIT_SUCCESS;
if (setup_cgroup_environment())
return err;
cgroup_fd = create_and_get_cgroup(CG_PATH);
if (cgroup_fd < 0)
goto cleanup_cgroup_env;
if (join_cgroup(CG_PATH))
goto cleanup_cgroup;
if (run_test(cgroup_fd))
err = EXIT_FAILURE;
printf("test_sockopt_inherit: %s\n",
err == EXIT_SUCCESS ? "PASSED" : "FAILED");
cleanup_cgroup:
close(cgroup_fd);
cleanup_cgroup_env:
cleanup_cgroup_environment();
return err;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment