Commit e5cd3abc authored by John Fastabend's avatar John Fastabend Committed by Daniel Borkmann

bpf: sockmap, refactor sockmap routines to work with hashmap

This patch only refactors the existing sockmap code. This will allow
much of the psock initialization code path and bpf helper codes to
work for both sockmap bpf map types that are backed by an array, the
currently supported type, and the new hash backed bpf map type
sockhash.

Most the fallout comes from three changes,

  - Pushing bpf programs into an independent structure so we
    can use it from the htab struct in the next patch.
  - Generalizing helpers to use void *key instead of the hardcoded
    u32.
  - Instead of passing map/key through the metadata we now do
    the lookup inline. This avoids storing the key in the metadata
    which will be useful when keys can be longer than 4 bytes. We
    rename the sk pointers to sk_redir at this point as well to
    avoid any confusion between the current sk pointer and the
    redirect pointer sk_redir.
Signed-off-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Acked-by: default avatarDavid S. Miller <davem@davemloft.net>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parent f2467c2d
...@@ -515,9 +515,8 @@ struct sk_msg_buff { ...@@ -515,9 +515,8 @@ struct sk_msg_buff {
int sg_end; int sg_end;
struct scatterlist sg_data[MAX_SKB_FRAGS]; struct scatterlist sg_data[MAX_SKB_FRAGS];
bool sg_copy[MAX_SKB_FRAGS]; bool sg_copy[MAX_SKB_FRAGS];
__u32 key;
__u32 flags; __u32 flags;
struct bpf_map *map; struct sock *sk_redir;
struct sk_buff *skb; struct sk_buff *skb;
struct list_head list; struct list_head list;
}; };
......
...@@ -814,9 +814,8 @@ struct tcp_skb_cb { ...@@ -814,9 +814,8 @@ struct tcp_skb_cb {
#endif #endif
} header; /* For incoming skbs */ } header; /* For incoming skbs */
struct { struct {
__u32 key;
__u32 flags; __u32 flags;
struct bpf_map *map; struct sock *sk_redir;
void *data_end; void *data_end;
} bpf; } bpf;
}; };
......
...@@ -48,14 +48,18 @@ ...@@ -48,14 +48,18 @@
#define SOCK_CREATE_FLAG_MASK \ #define SOCK_CREATE_FLAG_MASK \
(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
struct bpf_stab { struct bpf_sock_progs {
struct bpf_map map;
struct sock **sock_map;
struct bpf_prog *bpf_tx_msg; struct bpf_prog *bpf_tx_msg;
struct bpf_prog *bpf_parse; struct bpf_prog *bpf_parse;
struct bpf_prog *bpf_verdict; struct bpf_prog *bpf_verdict;
}; };
struct bpf_stab {
struct bpf_map map;
struct sock **sock_map;
struct bpf_sock_progs progs;
};
enum smap_psock_state { enum smap_psock_state {
SMAP_TX_RUNNING, SMAP_TX_RUNNING,
}; };
...@@ -461,7 +465,7 @@ static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md) ...@@ -461,7 +465,7 @@ static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md) static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
{ {
return ((_rc == SK_PASS) ? return ((_rc == SK_PASS) ?
(md->map ? __SK_REDIRECT : __SK_PASS) : (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
__SK_DROP); __SK_DROP);
} }
...@@ -1092,7 +1096,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb) ...@@ -1092,7 +1096,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
* when we orphan the skb so that we don't have the possibility * when we orphan the skb so that we don't have the possibility
* to reference a stale map. * to reference a stale map.
*/ */
TCP_SKB_CB(skb)->bpf.map = NULL; TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
skb->sk = psock->sock; skb->sk = psock->sock;
bpf_compute_data_pointers(skb); bpf_compute_data_pointers(skb);
preempt_disable(); preempt_disable();
...@@ -1102,7 +1106,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb) ...@@ -1102,7 +1106,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
/* Moving return codes from UAPI namespace into internal namespace */ /* Moving return codes from UAPI namespace into internal namespace */
return rc == SK_PASS ? return rc == SK_PASS ?
(TCP_SKB_CB(skb)->bpf.map ? __SK_REDIRECT : __SK_PASS) : (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
__SK_DROP; __SK_DROP;
} }
...@@ -1372,7 +1376,6 @@ static int smap_init_sock(struct smap_psock *psock, ...@@ -1372,7 +1376,6 @@ static int smap_init_sock(struct smap_psock *psock,
} }
static void smap_init_progs(struct smap_psock *psock, static void smap_init_progs(struct smap_psock *psock,
struct bpf_stab *stab,
struct bpf_prog *verdict, struct bpf_prog *verdict,
struct bpf_prog *parse) struct bpf_prog *parse)
{ {
...@@ -1450,14 +1453,13 @@ static void smap_gc_work(struct work_struct *w) ...@@ -1450,14 +1453,13 @@ static void smap_gc_work(struct work_struct *w)
kfree(psock); kfree(psock);
} }
static struct smap_psock *smap_init_psock(struct sock *sock, static struct smap_psock *smap_init_psock(struct sock *sock, int node)
struct bpf_stab *stab)
{ {
struct smap_psock *psock; struct smap_psock *psock;
psock = kzalloc_node(sizeof(struct smap_psock), psock = kzalloc_node(sizeof(struct smap_psock),
GFP_ATOMIC | __GFP_NOWARN, GFP_ATOMIC | __GFP_NOWARN,
stab->map.numa_node); node);
if (!psock) if (!psock)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
...@@ -1662,40 +1664,26 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) ...@@ -1662,40 +1664,26 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
* - sock_map must use READ_ONCE and (cmp)xchg operations * - sock_map must use READ_ONCE and (cmp)xchg operations
* - BPF verdict/parse programs must use READ_ONCE and xchg operations * - BPF verdict/parse programs must use READ_ONCE and xchg operations
*/ */
static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
struct bpf_map *map, static int __sock_map_ctx_update_elem(struct bpf_map *map,
void *key, u64 flags) struct bpf_sock_progs *progs,
struct sock *sock,
struct sock **map_link,
void *key)
{ {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct smap_psock_map_entry *e = NULL;
struct bpf_prog *verdict, *parse, *tx_msg; struct bpf_prog *verdict, *parse, *tx_msg;
struct sock *osock, *sock; struct smap_psock_map_entry *e = NULL;
struct smap_psock *psock; struct smap_psock *psock;
u32 i = *(u32 *)key;
bool new = false; bool new = false;
int err; int err;
if (unlikely(flags > BPF_EXIST))
return -EINVAL;
if (unlikely(i >= stab->map.max_entries))
return -E2BIG;
sock = READ_ONCE(stab->sock_map[i]);
if (flags == BPF_EXIST && !sock)
return -ENOENT;
else if (flags == BPF_NOEXIST && sock)
return -EEXIST;
sock = skops->sk;
/* 1. If sock map has BPF programs those will be inherited by the /* 1. If sock map has BPF programs those will be inherited by the
* sock being added. If the sock is already attached to BPF programs * sock being added. If the sock is already attached to BPF programs
* this results in an error. * this results in an error.
*/ */
verdict = READ_ONCE(stab->bpf_verdict); verdict = READ_ONCE(progs->bpf_verdict);
parse = READ_ONCE(stab->bpf_parse); parse = READ_ONCE(progs->bpf_parse);
tx_msg = READ_ONCE(stab->bpf_tx_msg); tx_msg = READ_ONCE(progs->bpf_tx_msg);
if (parse && verdict) { if (parse && verdict) {
/* bpf prog refcnt may be zero if a concurrent attach operation /* bpf prog refcnt may be zero if a concurrent attach operation
...@@ -1703,11 +1691,11 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1703,11 +1691,11 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
* we increment the refcnt. If this is the case abort with an * we increment the refcnt. If this is the case abort with an
* error. * error.
*/ */
verdict = bpf_prog_inc_not_zero(stab->bpf_verdict); verdict = bpf_prog_inc_not_zero(progs->bpf_verdict);
if (IS_ERR(verdict)) if (IS_ERR(verdict))
return PTR_ERR(verdict); return PTR_ERR(verdict);
parse = bpf_prog_inc_not_zero(stab->bpf_parse); parse = bpf_prog_inc_not_zero(progs->bpf_parse);
if (IS_ERR(parse)) { if (IS_ERR(parse)) {
bpf_prog_put(verdict); bpf_prog_put(verdict);
return PTR_ERR(parse); return PTR_ERR(parse);
...@@ -1715,7 +1703,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1715,7 +1703,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
} }
if (tx_msg) { if (tx_msg) {
tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg); tx_msg = bpf_prog_inc_not_zero(progs->bpf_tx_msg);
if (IS_ERR(tx_msg)) { if (IS_ERR(tx_msg)) {
if (verdict) if (verdict)
bpf_prog_put(verdict); bpf_prog_put(verdict);
...@@ -1748,7 +1736,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1748,7 +1736,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
goto out_progs; goto out_progs;
} }
} else { } else {
psock = smap_init_psock(sock, stab); psock = smap_init_psock(sock, map->numa_node);
if (IS_ERR(psock)) { if (IS_ERR(psock)) {
err = PTR_ERR(psock); err = PTR_ERR(psock);
goto out_progs; goto out_progs;
...@@ -1763,7 +1751,6 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1763,7 +1751,6 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
err = -ENOMEM; err = -ENOMEM;
goto out_progs; goto out_progs;
} }
e->entry = &stab->sock_map[i];
/* 3. At this point we have a reference to a valid psock that is /* 3. At this point we have a reference to a valid psock that is
* running. Attach any BPF programs needed. * running. Attach any BPF programs needed.
...@@ -1780,7 +1767,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1780,7 +1767,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
err = smap_init_sock(psock, sock); err = smap_init_sock(psock, sock);
if (err) if (err)
goto out_free; goto out_free;
smap_init_progs(psock, stab, verdict, parse); smap_init_progs(psock, verdict, parse);
smap_start_sock(psock, sock); smap_start_sock(psock, sock);
} }
...@@ -1789,19 +1776,12 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1789,19 +1776,12 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
* it with. Because we can only have a single set of programs if * it with. Because we can only have a single set of programs if
* old_sock has a strp we can stop it. * old_sock has a strp we can stop it.
*/ */
if (map_link) {
e->entry = map_link;
list_add_tail(&e->list, &psock->maps); list_add_tail(&e->list, &psock->maps);
write_unlock_bh(&sock->sk_callback_lock);
osock = xchg(&stab->sock_map[i], sock);
if (osock) {
struct smap_psock *opsock = smap_psock_sk(osock);
write_lock_bh(&osock->sk_callback_lock);
smap_list_remove(opsock, &stab->sock_map[i]);
smap_release_sock(opsock, osock);
write_unlock_bh(&osock->sk_callback_lock);
} }
return 0; write_unlock_bh(&sock->sk_callback_lock);
return err;
out_free: out_free:
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
out_progs: out_progs:
...@@ -1816,23 +1796,69 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1816,23 +1796,69 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
return err; return err;
} }
int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type) static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
struct bpf_map *map,
void *key, u64 flags)
{ {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map); struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct bpf_sock_progs *progs = &stab->progs;
struct sock *osock, *sock;
u32 i = *(u32 *)key;
int err;
if (unlikely(flags > BPF_EXIST))
return -EINVAL;
if (unlikely(i >= stab->map.max_entries))
return -E2BIG;
sock = READ_ONCE(stab->sock_map[i]);
if (flags == BPF_EXIST && !sock)
return -ENOENT;
else if (flags == BPF_NOEXIST && sock)
return -EEXIST;
sock = skops->sk;
err = __sock_map_ctx_update_elem(map, progs, sock, &stab->sock_map[i],
key);
if (err)
goto out;
osock = xchg(&stab->sock_map[i], sock);
if (osock) {
struct smap_psock *opsock = smap_psock_sk(osock);
write_lock_bh(&osock->sk_callback_lock);
smap_list_remove(opsock, &stab->sock_map[i]);
smap_release_sock(opsock, osock);
write_unlock_bh(&osock->sk_callback_lock);
}
out:
return 0;
}
int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
{
struct bpf_sock_progs *progs;
struct bpf_prog *orig; struct bpf_prog *orig;
if (unlikely(map->map_type != BPF_MAP_TYPE_SOCKMAP)) if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
progs = &stab->progs;
} else {
return -EINVAL; return -EINVAL;
}
switch (type) { switch (type) {
case BPF_SK_MSG_VERDICT: case BPF_SK_MSG_VERDICT:
orig = xchg(&stab->bpf_tx_msg, prog); orig = xchg(&progs->bpf_tx_msg, prog);
break; break;
case BPF_SK_SKB_STREAM_PARSER: case BPF_SK_SKB_STREAM_PARSER:
orig = xchg(&stab->bpf_parse, prog); orig = xchg(&progs->bpf_parse, prog);
break; break;
case BPF_SK_SKB_STREAM_VERDICT: case BPF_SK_SKB_STREAM_VERDICT:
orig = xchg(&stab->bpf_verdict, prog); orig = xchg(&progs->bpf_verdict, prog);
break; break;
default: default:
return -EOPNOTSUPP; return -EOPNOTSUPP;
...@@ -1881,16 +1907,18 @@ static int sock_map_update_elem(struct bpf_map *map, ...@@ -1881,16 +1907,18 @@ static int sock_map_update_elem(struct bpf_map *map,
static void sock_map_release(struct bpf_map *map) static void sock_map_release(struct bpf_map *map)
{ {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map); struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct bpf_sock_progs *progs;
struct bpf_prog *orig; struct bpf_prog *orig;
orig = xchg(&stab->bpf_parse, NULL); progs = &stab->progs;
orig = xchg(&progs->bpf_parse, NULL);
if (orig) if (orig)
bpf_prog_put(orig); bpf_prog_put(orig);
orig = xchg(&stab->bpf_verdict, NULL); orig = xchg(&progs->bpf_verdict, NULL);
if (orig) if (orig)
bpf_prog_put(orig); bpf_prog_put(orig);
orig = xchg(&stab->bpf_tx_msg, NULL); orig = xchg(&progs->bpf_tx_msg, NULL);
if (orig) if (orig)
bpf_prog_put(orig); bpf_prog_put(orig);
} }
......
...@@ -2083,9 +2083,10 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, ...@@ -2083,9 +2083,10 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
if (unlikely(flags & ~(BPF_F_INGRESS))) if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP; return SK_DROP;
tcb->bpf.key = key;
tcb->bpf.flags = flags; tcb->bpf.flags = flags;
tcb->bpf.map = map; tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
if (!tcb->bpf.sk_redir)
return SK_DROP;
return SK_PASS; return SK_PASS;
} }
...@@ -2093,16 +2094,8 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, ...@@ -2093,16 +2094,8 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
struct sock *do_sk_redirect_map(struct sk_buff *skb) struct sock *do_sk_redirect_map(struct sk_buff *skb)
{ {
struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
struct sock *sk = NULL;
if (tcb->bpf.map) {
sk = __sock_map_lookup_elem(tcb->bpf.map, tcb->bpf.key);
tcb->bpf.key = 0; return tcb->bpf.sk_redir;
tcb->bpf.map = NULL;
}
return sk;
} }
static const struct bpf_func_proto bpf_sk_redirect_map_proto = { static const struct bpf_func_proto bpf_sk_redirect_map_proto = {
...@@ -2122,25 +2115,17 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg, ...@@ -2122,25 +2115,17 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
if (unlikely(flags & ~(BPF_F_INGRESS))) if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP; return SK_DROP;
msg->key = key;
msg->flags = flags; msg->flags = flags;
msg->map = map; msg->sk_redir = __sock_map_lookup_elem(map, key);
if (!msg->sk_redir)
return SK_DROP;
return SK_PASS; return SK_PASS;
} }
struct sock *do_msg_redirect_map(struct sk_msg_buff *msg) struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
{ {
struct sock *sk = NULL; return msg->sk_redir;
if (msg->map) {
sk = __sock_map_lookup_elem(msg->map, msg->key);
msg->key = 0;
msg->map = NULL;
}
return sk;
} }
static const struct bpf_func_proto bpf_msg_redirect_map_proto = { static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment