Commit 585f5a62 authored by Daniel Borkmann's avatar Daniel Borkmann Committed by Alexei Starovoitov

bpf, sockmap: fix sock_map_ctx_update_elem race with exist/noexist

The current code in sock_map_ctx_update_elem() allows for BPF_EXIST
and BPF_NOEXIST map update flags. While on array-like maps this approach
is rather uncommon, e.g. bpf_fd_array_map_update_elem() and others
enforce map update flags to be BPF_ANY such that xchg() can be used
directly, the current implementation in sock map does not guarantee
that such operation with BPF_EXIST / BPF_NOEXIST is atomic.

The initial test does a READ_ONCE(stab->sock_map[i]) to fetch the
socket from the slot which is then tested for NULL / non-NULL. However
later after __sock_map_ctx_update_elem(), the actual update is done
through osock = xchg(&stab->sock_map[i], sock). Problem is that in
the meantime a different CPU could have updated / deleted a socket
on that specific slot and thus flag contraints won't hold anymore.

I've been thinking whether best would be to just break UAPI and do
an enforcement of BPF_ANY to check if someone actually complains,
however trouble is that already in BPF kselftest we use BPF_NOEXIST
for the map update, and therefore it might have been copied into
applications already. The fix to keep the current behavior intact
would be to add a map lock similar to the sock hash bucket lock only
for covering the whole map.

Fixes: 174a79ff ("bpf: sockmap with sk redirect support")
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Acked-by: default avatarSong Liu <songliubraving@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 166ab6f0
...@@ -58,6 +58,7 @@ struct bpf_stab { ...@@ -58,6 +58,7 @@ struct bpf_stab {
struct bpf_map map; struct bpf_map map;
struct sock **sock_map; struct sock **sock_map;
struct bpf_sock_progs progs; struct bpf_sock_progs progs;
raw_spinlock_t lock;
}; };
struct bucket { struct bucket {
...@@ -89,9 +90,9 @@ enum smap_psock_state { ...@@ -89,9 +90,9 @@ enum smap_psock_state {
struct smap_psock_map_entry { struct smap_psock_map_entry {
struct list_head list; struct list_head list;
struct bpf_map *map;
struct sock **entry; struct sock **entry;
struct htab_elem __rcu *hash_link; struct htab_elem __rcu *hash_link;
struct bpf_htab __rcu *htab;
}; };
struct smap_psock { struct smap_psock {
...@@ -343,13 +344,18 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -343,13 +344,18 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
e = psock_map_pop(sk, psock); e = psock_map_pop(sk, psock);
while (e) { while (e) {
if (e->entry) { if (e->entry) {
osk = cmpxchg(e->entry, sk, NULL); struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
raw_spin_lock_bh(&stab->lock);
osk = *e->entry;
if (osk == sk) { if (osk == sk) {
*e->entry = NULL;
smap_release_sock(psock, sk); smap_release_sock(psock, sk);
} }
raw_spin_unlock_bh(&stab->lock);
} else { } else {
struct htab_elem *link = rcu_dereference(e->hash_link); struct htab_elem *link = rcu_dereference(e->hash_link);
struct bpf_htab *htab = rcu_dereference(e->htab); struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
struct hlist_head *head; struct hlist_head *head;
struct htab_elem *l; struct htab_elem *l;
struct bucket *b; struct bucket *b;
...@@ -1642,6 +1648,7 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr) ...@@ -1642,6 +1648,7 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
bpf_map_init_from_attr(&stab->map, attr); bpf_map_init_from_attr(&stab->map, attr);
raw_spin_lock_init(&stab->lock);
/* make sure page count doesn't overflow */ /* make sure page count doesn't overflow */
cost = (u64) stab->map.max_entries * sizeof(struct sock *); cost = (u64) stab->map.max_entries * sizeof(struct sock *);
...@@ -1716,14 +1723,15 @@ static void sock_map_free(struct bpf_map *map) ...@@ -1716,14 +1723,15 @@ static void sock_map_free(struct bpf_map *map)
* and a grace period expire to ensure psock is really safe to remove. * and a grace period expire to ensure psock is really safe to remove.
*/ */
rcu_read_lock(); rcu_read_lock();
raw_spin_lock_bh(&stab->lock);
for (i = 0; i < stab->map.max_entries; i++) { for (i = 0; i < stab->map.max_entries; i++) {
struct smap_psock *psock; struct smap_psock *psock;
struct sock *sock; struct sock *sock;
sock = xchg(&stab->sock_map[i], NULL); sock = stab->sock_map[i];
if (!sock) if (!sock)
continue; continue;
stab->sock_map[i] = NULL;
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
/* This check handles a racing sock event that can get the /* This check handles a racing sock event that can get the
* sk_callback_lock before this case but after xchg happens * sk_callback_lock before this case but after xchg happens
...@@ -1735,6 +1743,7 @@ static void sock_map_free(struct bpf_map *map) ...@@ -1735,6 +1743,7 @@ static void sock_map_free(struct bpf_map *map)
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
} }
} }
raw_spin_unlock_bh(&stab->lock);
rcu_read_unlock(); rcu_read_unlock();
sock_map_remove_complete(stab); sock_map_remove_complete(stab);
...@@ -1778,14 +1787,16 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) ...@@ -1778,14 +1787,16 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
if (k >= map->max_entries) if (k >= map->max_entries)
return -EINVAL; return -EINVAL;
sock = xchg(&stab->sock_map[k], NULL); raw_spin_lock_bh(&stab->lock);
sock = stab->sock_map[k];
stab->sock_map[k] = NULL;
raw_spin_unlock_bh(&stab->lock);
if (!sock) if (!sock)
return -EINVAL; return -EINVAL;
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
if (!psock) if (!psock)
goto out; return 0;
if (psock->bpf_parse) { if (psock->bpf_parse) {
write_lock_bh(&sock->sk_callback_lock); write_lock_bh(&sock->sk_callback_lock);
smap_stop_sock(psock, sock); smap_stop_sock(psock, sock);
...@@ -1793,7 +1804,6 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) ...@@ -1793,7 +1804,6 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
} }
smap_list_map_remove(psock, &stab->sock_map[k]); smap_list_map_remove(psock, &stab->sock_map[k]);
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
out:
return 0; return 0;
} }
...@@ -1829,11 +1839,9 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) ...@@ -1829,11 +1839,9 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
static int __sock_map_ctx_update_elem(struct bpf_map *map, static int __sock_map_ctx_update_elem(struct bpf_map *map,
struct bpf_sock_progs *progs, struct bpf_sock_progs *progs,
struct sock *sock, struct sock *sock,
struct sock **map_link,
void *key) void *key)
{ {
struct bpf_prog *verdict, *parse, *tx_msg; struct bpf_prog *verdict, *parse, *tx_msg;
struct smap_psock_map_entry *e = NULL;
struct smap_psock *psock; struct smap_psock *psock;
bool new = false; bool new = false;
int err = 0; int err = 0;
...@@ -1906,14 +1914,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1906,14 +1914,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
new = true; new = true;
} }
if (map_link) {
e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
if (!e) {
err = -ENOMEM;
goto out_free;
}
}
/* 3. At this point we have a reference to a valid psock that is /* 3. At this point we have a reference to a valid psock that is
* running. Attach any BPF programs needed. * running. Attach any BPF programs needed.
*/ */
...@@ -1935,17 +1935,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1935,17 +1935,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
write_unlock_bh(&sock->sk_callback_lock); write_unlock_bh(&sock->sk_callback_lock);
} }
/* 4. Place psock in sockmap for use and stop any programs on
* the old sock assuming its not the same sock we are replacing
* it with. Because we can only have a single set of programs if
* old_sock has a strp we can stop it.
*/
if (map_link) {
e->entry = map_link;
spin_lock_bh(&psock->maps_lock);
list_add_tail(&e->list, &psock->maps);
spin_unlock_bh(&psock->maps_lock);
}
return err; return err;
out_free: out_free:
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
...@@ -1956,7 +1945,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1956,7 +1945,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
} }
if (tx_msg) if (tx_msg)
bpf_prog_put(tx_msg); bpf_prog_put(tx_msg);
kfree(e);
return err; return err;
} }
...@@ -1966,36 +1954,57 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1966,36 +1954,57 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
{ {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map); struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct bpf_sock_progs *progs = &stab->progs; struct bpf_sock_progs *progs = &stab->progs;
struct sock *osock, *sock; struct sock *osock, *sock = skops->sk;
struct smap_psock_map_entry *e;
struct smap_psock *psock;
u32 i = *(u32 *)key; u32 i = *(u32 *)key;
int err; int err;
if (unlikely(flags > BPF_EXIST)) if (unlikely(flags > BPF_EXIST))
return -EINVAL; return -EINVAL;
if (unlikely(i >= stab->map.max_entries)) if (unlikely(i >= stab->map.max_entries))
return -E2BIG; return -E2BIG;
sock = READ_ONCE(stab->sock_map[i]); e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
if (flags == BPF_EXIST && !sock) if (!e)
return -ENOENT; return -ENOMEM;
else if (flags == BPF_NOEXIST && sock)
return -EEXIST;
sock = skops->sk; err = __sock_map_ctx_update_elem(map, progs, sock, key);
err = __sock_map_ctx_update_elem(map, progs, sock, &stab->sock_map[i],
key);
if (err) if (err)
goto out; goto out;
osock = xchg(&stab->sock_map[i], sock); /* psock guaranteed to be present. */
if (osock) { psock = smap_psock_sk(sock);
struct smap_psock *opsock = smap_psock_sk(osock); raw_spin_lock_bh(&stab->lock);
osock = stab->sock_map[i];
if (osock && flags == BPF_NOEXIST) {
err = -EEXIST;
goto out_unlock;
}
if (!osock && flags == BPF_EXIST) {
err = -ENOENT;
goto out_unlock;
}
e->entry = &stab->sock_map[i];
e->map = map;
spin_lock_bh(&psock->maps_lock);
list_add_tail(&e->list, &psock->maps);
spin_unlock_bh(&psock->maps_lock);
smap_list_map_remove(opsock, &stab->sock_map[i]); stab->sock_map[i] = sock;
smap_release_sock(opsock, osock); if (osock) {
psock = smap_psock_sk(osock);
smap_list_map_remove(psock, &stab->sock_map[i]);
smap_release_sock(psock, osock);
} }
raw_spin_unlock_bh(&stab->lock);
return 0;
out_unlock:
smap_release_sock(psock, sock);
raw_spin_unlock_bh(&stab->lock);
out: out:
kfree(e);
return err; return err;
} }
...@@ -2358,7 +2367,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -2358,7 +2367,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
b = __select_bucket(htab, hash); b = __select_bucket(htab, hash);
head = &b->head; head = &b->head;
err = __sock_map_ctx_update_elem(map, progs, sock, NULL, key); err = __sock_map_ctx_update_elem(map, progs, sock, key);
if (err) if (err)
goto err; goto err;
...@@ -2384,8 +2393,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -2384,8 +2393,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
} }
rcu_assign_pointer(e->hash_link, l_new); rcu_assign_pointer(e->hash_link, l_new);
rcu_assign_pointer(e->htab, e->map = map;
container_of(map, struct bpf_htab, map));
spin_lock_bh(&psock->maps_lock); spin_lock_bh(&psock->maps_lock);
list_add_tail(&e->list, &psock->maps); list_add_tail(&e->list, &psock->maps);
spin_unlock_bh(&psock->maps_lock); spin_unlock_bh(&psock->maps_lock);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment