Commit 7b36e8ee authored by Marek Lindner's avatar Marek Lindner

batman-adv: Correct rcu refcounting for orig_node

It might be possible that 2 threads access the same data in the same
rcu grace period. The first thread calls call_rcu() to decrement the
refcount and free the data while the second thread increases the
refcount to use the data. To avoid this race condition all refcount
operations have to be atomic.
Reported-by: default avatarSven Eckelmann <sven@narfation.org>
Signed-off-by: default avatarMarek Lindner <lindner_marek@yahoo.de>
parent 7aadf889
...@@ -53,9 +53,11 @@ void *gw_get_selected(struct bat_priv *bat_priv) ...@@ -53,9 +53,11 @@ void *gw_get_selected(struct bat_priv *bat_priv)
goto out; goto out;
orig_node = curr_gateway_tmp->orig_node; orig_node = curr_gateway_tmp->orig_node;
if (!orig_node)
goto out;
if (orig_node) if (!atomic_inc_not_zero(&orig_node->refcount))
kref_get(&orig_node->refcount); orig_node = NULL;
out: out:
rcu_read_unlock(); rcu_read_unlock();
......
...@@ -271,7 +271,7 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff, ...@@ -271,7 +271,7 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return len; return len;
} }
......
...@@ -102,13 +102,13 @@ struct neigh_node *create_neighbor(struct orig_node *orig_node, ...@@ -102,13 +102,13 @@ struct neigh_node *create_neighbor(struct orig_node *orig_node,
return neigh_node; return neigh_node;
} }
void orig_node_free_ref(struct kref *refcount) static void orig_node_free_rcu(struct rcu_head *rcu)
{ {
struct hlist_node *node, *node_tmp; struct hlist_node *node, *node_tmp;
struct neigh_node *neigh_node, *tmp_neigh_node; struct neigh_node *neigh_node, *tmp_neigh_node;
struct orig_node *orig_node; struct orig_node *orig_node;
orig_node = container_of(refcount, struct orig_node, refcount); orig_node = container_of(rcu, struct orig_node, rcu);
spin_lock_bh(&orig_node->neigh_list_lock); spin_lock_bh(&orig_node->neigh_list_lock);
...@@ -137,6 +137,12 @@ void orig_node_free_ref(struct kref *refcount) ...@@ -137,6 +137,12 @@ void orig_node_free_ref(struct kref *refcount)
kfree(orig_node); kfree(orig_node);
} }
void orig_node_free_ref(struct orig_node *orig_node)
{
if (atomic_dec_and_test(&orig_node->refcount))
call_rcu(&orig_node->rcu, orig_node_free_rcu);
}
void originator_free(struct bat_priv *bat_priv) void originator_free(struct bat_priv *bat_priv)
{ {
struct hashtable_t *hash = bat_priv->orig_hash; struct hashtable_t *hash = bat_priv->orig_hash;
...@@ -163,7 +169,7 @@ void originator_free(struct bat_priv *bat_priv) ...@@ -163,7 +169,7 @@ void originator_free(struct bat_priv *bat_priv)
head, hash_entry) { head, hash_entry) {
hlist_del_rcu(node); hlist_del_rcu(node);
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
} }
spin_unlock_bh(list_lock); spin_unlock_bh(list_lock);
} }
...@@ -196,7 +202,9 @@ struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr) ...@@ -196,7 +202,9 @@ struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr)
spin_lock_init(&orig_node->ogm_cnt_lock); spin_lock_init(&orig_node->ogm_cnt_lock);
spin_lock_init(&orig_node->bcast_seqno_lock); spin_lock_init(&orig_node->bcast_seqno_lock);
spin_lock_init(&orig_node->neigh_list_lock); spin_lock_init(&orig_node->neigh_list_lock);
kref_init(&orig_node->refcount);
/* extra reference for return */
atomic_set(&orig_node->refcount, 2);
orig_node->bat_priv = bat_priv; orig_node->bat_priv = bat_priv;
memcpy(orig_node->orig, addr, ETH_ALEN); memcpy(orig_node->orig, addr, ETH_ALEN);
...@@ -229,8 +237,6 @@ struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr) ...@@ -229,8 +237,6 @@ struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr)
if (hash_added < 0) if (hash_added < 0)
goto free_bcast_own_sum; goto free_bcast_own_sum;
/* extra reference for return */
kref_get(&orig_node->refcount);
return orig_node; return orig_node;
free_bcast_own_sum: free_bcast_own_sum:
kfree(orig_node->bcast_own_sum); kfree(orig_node->bcast_own_sum);
...@@ -348,8 +354,7 @@ static void _purge_orig(struct bat_priv *bat_priv) ...@@ -348,8 +354,7 @@ static void _purge_orig(struct bat_priv *bat_priv)
if (orig_node->gw_flags) if (orig_node->gw_flags)
gw_node_delete(bat_priv, orig_node); gw_node_delete(bat_priv, orig_node);
hlist_del_rcu(node); hlist_del_rcu(node);
kref_put(&orig_node->refcount, orig_node_free_ref(orig_node);
orig_node_free_ref);
continue; continue;
} }
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
int originator_init(struct bat_priv *bat_priv); int originator_init(struct bat_priv *bat_priv);
void originator_free(struct bat_priv *bat_priv); void originator_free(struct bat_priv *bat_priv);
void purge_orig_ref(struct bat_priv *bat_priv); void purge_orig_ref(struct bat_priv *bat_priv);
void orig_node_free_ref(struct kref *refcount); void orig_node_free_ref(struct orig_node *orig_node);
struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr); struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr);
struct neigh_node *create_neighbor(struct orig_node *orig_node, struct neigh_node *create_neighbor(struct orig_node *orig_node,
struct orig_node *orig_neigh_node, struct orig_node *orig_neigh_node,
...@@ -88,8 +88,10 @@ static inline struct orig_node *orig_hash_find(struct bat_priv *bat_priv, ...@@ -88,8 +88,10 @@ static inline struct orig_node *orig_hash_find(struct bat_priv *bat_priv,
if (!compare_eth(orig_node, data)) if (!compare_eth(orig_node, data))
continue; continue;
if (!atomic_inc_not_zero(&orig_node->refcount))
continue;
orig_node_tmp = orig_node; orig_node_tmp = orig_node;
kref_get(&orig_node_tmp->refcount);
break; break;
} }
rcu_read_unlock(); rcu_read_unlock();
......
...@@ -420,7 +420,7 @@ static void update_orig(struct bat_priv *bat_priv, ...@@ -420,7 +420,7 @@ static void update_orig(struct bat_priv *bat_priv,
neigh_node = create_neighbor(orig_node, orig_tmp, neigh_node = create_neighbor(orig_node, orig_tmp,
ethhdr->h_source, if_incoming); ethhdr->h_source, if_incoming);
kref_put(&orig_tmp->refcount, orig_node_free_ref); orig_node_free_ref(orig_tmp);
if (!neigh_node) if (!neigh_node)
goto unlock; goto unlock;
...@@ -604,7 +604,7 @@ static char count_real_packets(struct ethhdr *ethhdr, ...@@ -604,7 +604,7 @@ static char count_real_packets(struct ethhdr *ethhdr,
out: out:
spin_unlock_bh(&orig_node->ogm_cnt_lock); spin_unlock_bh(&orig_node->ogm_cnt_lock);
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
...@@ -730,7 +730,7 @@ void receive_bat_packet(struct ethhdr *ethhdr, ...@@ -730,7 +730,7 @@ void receive_bat_packet(struct ethhdr *ethhdr,
bat_dbg(DBG_BATMAN, bat_priv, "Drop packet: " bat_dbg(DBG_BATMAN, bat_priv, "Drop packet: "
"originator packet from myself (via neighbor)\n"); "originator packet from myself (via neighbor)\n");
kref_put(&orig_neigh_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_neigh_node);
return; return;
} }
...@@ -835,10 +835,10 @@ void receive_bat_packet(struct ethhdr *ethhdr, ...@@ -835,10 +835,10 @@ void receive_bat_packet(struct ethhdr *ethhdr,
0, hna_buff_len, if_incoming); 0, hna_buff_len, if_incoming);
out_neigh: out_neigh:
if (!is_single_hop_neigh) if ((orig_neigh_node) && (!is_single_hop_neigh))
kref_put(&orig_neigh_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_neigh_node);
out: out:
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
} }
int recv_bat_packet(struct sk_buff *skb, struct batman_if *batman_if) int recv_bat_packet(struct sk_buff *skb, struct batman_if *batman_if)
...@@ -952,7 +952,7 @@ static int recv_my_icmp_packet(struct bat_priv *bat_priv, ...@@ -952,7 +952,7 @@ static int recv_my_icmp_packet(struct bat_priv *bat_priv,
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
...@@ -1028,7 +1028,7 @@ static int recv_icmp_ttl_exceeded(struct bat_priv *bat_priv, ...@@ -1028,7 +1028,7 @@ static int recv_icmp_ttl_exceeded(struct bat_priv *bat_priv,
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
...@@ -1134,7 +1134,7 @@ int recv_icmp_packet(struct sk_buff *skb, struct batman_if *recv_if) ...@@ -1134,7 +1134,7 @@ int recv_icmp_packet(struct sk_buff *skb, struct batman_if *recv_if)
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
...@@ -1189,7 +1189,7 @@ struct neigh_node *find_router(struct bat_priv *bat_priv, ...@@ -1189,7 +1189,7 @@ struct neigh_node *find_router(struct bat_priv *bat_priv,
if (!primary_orig_node) if (!primary_orig_node)
goto return_router; goto return_router;
kref_put(&primary_orig_node->refcount, orig_node_free_ref); orig_node_free_ref(primary_orig_node);
} }
/* with less than 2 candidates, we can't do any /* with less than 2 candidates, we can't do any
...@@ -1401,7 +1401,7 @@ int route_unicast_packet(struct sk_buff *skb, struct batman_if *recv_if, ...@@ -1401,7 +1401,7 @@ int route_unicast_packet(struct sk_buff *skb, struct batman_if *recv_if,
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
...@@ -1543,7 +1543,7 @@ int recv_bcast_packet(struct sk_buff *skb, struct batman_if *recv_if) ...@@ -1543,7 +1543,7 @@ int recv_bcast_packet(struct sk_buff *skb, struct batman_if *recv_if)
spin_unlock_bh(&bat_priv->orig_hash_lock); spin_unlock_bh(&bat_priv->orig_hash_lock);
out: out:
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
......
...@@ -589,17 +589,20 @@ void hna_global_free(struct bat_priv *bat_priv) ...@@ -589,17 +589,20 @@ void hna_global_free(struct bat_priv *bat_priv)
struct orig_node *transtable_search(struct bat_priv *bat_priv, uint8_t *addr) struct orig_node *transtable_search(struct bat_priv *bat_priv, uint8_t *addr)
{ {
struct hna_global_entry *hna_global_entry; struct hna_global_entry *hna_global_entry;
struct orig_node *orig_node = NULL;
spin_lock_bh(&bat_priv->hna_ghash_lock); spin_lock_bh(&bat_priv->hna_ghash_lock);
hna_global_entry = hna_global_hash_find(bat_priv, addr); hna_global_entry = hna_global_hash_find(bat_priv, addr);
if (hna_global_entry) if (!hna_global_entry)
kref_get(&hna_global_entry->orig_node->refcount); goto out;
spin_unlock_bh(&bat_priv->hna_ghash_lock); if (!atomic_inc_not_zero(&hna_global_entry->orig_node->refcount))
goto out;
if (!hna_global_entry) orig_node = hna_global_entry->orig_node;
return NULL;
return hna_global_entry->orig_node; out:
spin_unlock_bh(&bat_priv->hna_ghash_lock);
return orig_node;
} }
...@@ -84,7 +84,8 @@ struct orig_node { ...@@ -84,7 +84,8 @@ struct orig_node {
struct hlist_head neigh_list; struct hlist_head neigh_list;
struct list_head frag_list; struct list_head frag_list;
spinlock_t neigh_list_lock; /* protects neighbor list */ spinlock_t neigh_list_lock; /* protects neighbor list */
struct kref refcount; atomic_t refcount;
struct rcu_head rcu;
struct hlist_node hash_entry; struct hlist_node hash_entry;
struct bat_priv *bat_priv; struct bat_priv *bat_priv;
unsigned long last_frag_packet; unsigned long last_frag_packet;
......
...@@ -211,7 +211,7 @@ int frag_reassemble_skb(struct sk_buff *skb, struct bat_priv *bat_priv, ...@@ -211,7 +211,7 @@ int frag_reassemble_skb(struct sk_buff *skb, struct bat_priv *bat_priv,
spin_unlock_bh(&bat_priv->orig_hash_lock); spin_unlock_bh(&bat_priv->orig_hash_lock);
out: out:
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return ret; return ret;
} }
...@@ -280,7 +280,7 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv) ...@@ -280,7 +280,7 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv)
{ {
struct ethhdr *ethhdr = (struct ethhdr *)skb->data; struct ethhdr *ethhdr = (struct ethhdr *)skb->data;
struct unicast_packet *unicast_packet; struct unicast_packet *unicast_packet;
struct orig_node *orig_node = NULL; struct orig_node *orig_node;
struct batman_if *batman_if; struct batman_if *batman_if;
struct neigh_node *neigh_node; struct neigh_node *neigh_node;
int data_len = skb->len; int data_len = skb->len;
...@@ -347,7 +347,7 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv) ...@@ -347,7 +347,7 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv)
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
if (ret == 1) if (ret == 1)
kfree_skb(skb); kfree_skb(skb);
return ret; return ret;
......
...@@ -826,7 +826,7 @@ static void unicast_vis_packet(struct bat_priv *bat_priv, ...@@ -826,7 +826,7 @@ static void unicast_vis_packet(struct bat_priv *bat_priv,
if (neigh_node) if (neigh_node)
neigh_node_free_ref(neigh_node); neigh_node_free_ref(neigh_node);
if (orig_node) if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref); orig_node_free_ref(orig_node);
return; return;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment