Commit 6dd1ec6c authored by Yonghong Song's avatar Yonghong Song Committed by Alexei Starovoitov

bpf: fix kernel page fault in lpm map trie_get_next_key

Commit b471f2f1 ("bpf: implement MAP_GET_NEXT_KEY command
for LPM_TRIE map") introduces a bug likes below:

    if (!rcu_dereference(trie->root))
        return -ENOENT;
    if (!key || key->prefixlen > trie->max_prefixlen) {
        root = &trie->root;
        goto find_leftmost;
    }
    ......
  find_leftmost:
    for (node = rcu_dereference(*root); node;) {

In the code after label find_leftmost, it is assumed
that *root should not be NULL, but it is not true as
it is possbile trie->root is changed to NULL by an
asynchronous delete operation.

The issue is reported by syzbot and Eric Dumazet with the
below error log:
  ......
  kasan: CONFIG_KASAN_INLINE enabled
  kasan: GPF could be caused by NULL-ptr deref or user memory access
  general protection fault: 0000 [#1] SMP KASAN
  Dumping ftrace buffer:
     (ftrace buffer empty)
  Modules linked in:
  CPU: 1 PID: 8033 Comm: syz-executor3 Not tainted 4.15.0-rc8+ #4
  Hardware name: Google Google Compute Engine/Google Compute Engine, BIOS Google 01/01/2011
  RIP: 0010:trie_get_next_key+0x3c2/0xf10 kernel/bpf/lpm_trie.c:682
  ......

This patch fixed the issue by use local rcu_dereferenced
pointer instead of *(&trie->root) later on.

Fixes: b471f2f1 ("bpf: implement MAP_GET_NEXT_KEY command or LPM_TRIE map")
Reported-by: default avatarsyzbot <syzkaller@googlegroups.com>
Reported-by: default avatarEric Dumazet <edumazet@google.com>
Signed-off-by: default avatarYonghong Song <yhs@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 1651e39e
...@@ -593,11 +593,10 @@ static void trie_free(struct bpf_map *map) ...@@ -593,11 +593,10 @@ static void trie_free(struct bpf_map *map)
static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
{ {
struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
struct lpm_trie *trie = container_of(map, struct lpm_trie, map); struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
struct bpf_lpm_trie_key *key = _key, *next_key = _next_key; struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
struct lpm_trie_node *node, *next_node = NULL, *parent;
struct lpm_trie_node **node_stack = NULL; struct lpm_trie_node **node_stack = NULL;
struct lpm_trie_node __rcu **root;
int err = 0, stack_ptr = -1; int err = 0, stack_ptr = -1;
unsigned int next_bit; unsigned int next_bit;
size_t matchlen; size_t matchlen;
...@@ -614,14 +613,13 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) ...@@ -614,14 +613,13 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
*/ */
/* Empty trie */ /* Empty trie */
if (!rcu_dereference(trie->root)) search_root = rcu_dereference(trie->root);
if (!search_root)
return -ENOENT; return -ENOENT;
/* For invalid key, find the leftmost node in the trie */ /* For invalid key, find the leftmost node in the trie */
if (!key || key->prefixlen > trie->max_prefixlen) { if (!key || key->prefixlen > trie->max_prefixlen)
root = &trie->root;
goto find_leftmost; goto find_leftmost;
}
node_stack = kmalloc(trie->max_prefixlen * sizeof(struct lpm_trie_node *), node_stack = kmalloc(trie->max_prefixlen * sizeof(struct lpm_trie_node *),
GFP_ATOMIC | __GFP_NOWARN); GFP_ATOMIC | __GFP_NOWARN);
...@@ -629,7 +627,7 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) ...@@ -629,7 +627,7 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
return -ENOMEM; return -ENOMEM;
/* Try to find the exact node for the given key */ /* Try to find the exact node for the given key */
for (node = rcu_dereference(trie->root); node;) { for (node = search_root; node;) {
node_stack[++stack_ptr] = node; node_stack[++stack_ptr] = node;
matchlen = longest_prefix_match(trie, node, key); matchlen = longest_prefix_match(trie, node, key);
if (node->prefixlen != matchlen || if (node->prefixlen != matchlen ||
...@@ -640,10 +638,8 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) ...@@ -640,10 +638,8 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
node = rcu_dereference(node->child[next_bit]); node = rcu_dereference(node->child[next_bit]);
} }
if (!node || node->prefixlen != key->prefixlen || if (!node || node->prefixlen != key->prefixlen ||
(node->flags & LPM_TREE_NODE_FLAG_IM)) { (node->flags & LPM_TREE_NODE_FLAG_IM))
root = &trie->root;
goto find_leftmost; goto find_leftmost;
}
/* The node with the exactly-matching key has been found, /* The node with the exactly-matching key has been found,
* find the first node in postorder after the matched node. * find the first node in postorder after the matched node.
...@@ -651,9 +647,9 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) ...@@ -651,9 +647,9 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
node = node_stack[stack_ptr]; node = node_stack[stack_ptr];
while (stack_ptr > 0) { while (stack_ptr > 0) {
parent = node_stack[stack_ptr - 1]; parent = node_stack[stack_ptr - 1];
if (rcu_dereference(parent->child[0]) == node && if (rcu_dereference(parent->child[0]) == node) {
rcu_dereference(parent->child[1])) { search_root = rcu_dereference(parent->child[1]);
root = &parent->child[1]; if (search_root)
goto find_leftmost; goto find_leftmost;
} }
if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) { if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
...@@ -673,7 +669,7 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) ...@@ -673,7 +669,7 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
/* Find the leftmost non-intermediate node, all intermediate nodes /* Find the leftmost non-intermediate node, all intermediate nodes
* have exact two children, so this function will never return NULL. * have exact two children, so this function will never return NULL.
*/ */
for (node = rcu_dereference(*root); node;) { for (node = search_root; node;) {
if (!(node->flags & LPM_TREE_NODE_FLAG_IM)) if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
next_node = node; next_node = node;
node = rcu_dereference(node->child[0]); node = rcu_dereference(node->child[0]);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment