Commit c4d48a58 authored by Cong Wang's avatar Cong Wang Committed by David S. Miller

l2tp: convert l2tp_tunnel_list to idr

l2tp uses l2tp_tunnel_list to track all registered tunnels and
to allocate tunnel ID's. IDR can do the same job.

More importantly, with IDR we can hold the ID before a successful
registration so that we don't need to worry about late error
handling, it is not easy to rollback socket changes.

This is a preparation for the following fix.

Cc: Tetsuo Handa <penguin-kernel@I-love.SAKURA.ne.jp>
Cc: Guillaume Nault <gnault@redhat.com>
Cc: Jakub Sitnicki <jakub@cloudflare.com>
Cc: Eric Dumazet <edumazet@google.com>
Cc: Tom Parkin <tparkin@katalix.com>
Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Reviewed-by: default avatarGuillaume Nault <gnault@redhat.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 3a415d59
...@@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq; ...@@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq;
/* per-net private data for this module */ /* per-net private data for this module */
static unsigned int l2tp_net_id; static unsigned int l2tp_net_id;
struct l2tp_net { struct l2tp_net {
struct list_head l2tp_tunnel_list; /* Lock for write access to l2tp_tunnel_idr */
/* Lock for write access to l2tp_tunnel_list */ spinlock_t l2tp_tunnel_idr_lock;
spinlock_t l2tp_tunnel_list_lock; struct idr l2tp_tunnel_idr;
struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2]; struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2];
/* Lock for write access to l2tp_session_hlist */ /* Lock for write access to l2tp_session_hlist */
spinlock_t l2tp_session_hlist_lock; spinlock_t l2tp_session_hlist_lock;
...@@ -208,13 +208,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) ...@@ -208,13 +208,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
struct l2tp_tunnel *tunnel; struct l2tp_tunnel *tunnel;
rcu_read_lock_bh(); rcu_read_lock_bh();
list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id);
if (tunnel->tunnel_id == tunnel_id && if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) {
refcount_inc_not_zero(&tunnel->ref_count)) { rcu_read_unlock_bh();
rcu_read_unlock_bh(); return tunnel;
return tunnel;
}
} }
rcu_read_unlock_bh(); rcu_read_unlock_bh();
...@@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get); ...@@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get);
struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth) struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
{ {
const struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_net *pn = l2tp_pernet(net);
unsigned long tunnel_id, tmp;
struct l2tp_tunnel *tunnel; struct l2tp_tunnel *tunnel;
int count = 0; int count = 0;
rcu_read_lock_bh(); rcu_read_lock_bh();
list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
if (++count > nth && if (tunnel && ++count > nth &&
refcount_inc_not_zero(&tunnel->ref_count)) { refcount_inc_not_zero(&tunnel->ref_count)) {
rcu_read_unlock_bh(); rcu_read_unlock_bh();
return tunnel; return tunnel;
...@@ -1227,6 +1225,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk) ...@@ -1227,6 +1225,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk)
l2tp_tunnel_delete(tunnel); l2tp_tunnel_delete(tunnel);
} }
static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel)
{
struct l2tp_net *pn = l2tp_pernet(net);
spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id);
spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
}
/* Workqueue tunnel deletion function */ /* Workqueue tunnel deletion function */
static void l2tp_tunnel_del_work(struct work_struct *work) static void l2tp_tunnel_del_work(struct work_struct *work)
{ {
...@@ -1234,7 +1241,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work) ...@@ -1234,7 +1241,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
del_work); del_work);
struct sock *sk = tunnel->sock; struct sock *sk = tunnel->sock;
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
struct l2tp_net *pn;
l2tp_tunnel_closeall(tunnel); l2tp_tunnel_closeall(tunnel);
...@@ -1248,12 +1254,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work) ...@@ -1248,12 +1254,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
} }
} }
/* Remove the tunnel struct from the tunnel list */ l2tp_tunnel_remove(tunnel->l2tp_net, tunnel);
pn = l2tp_pernet(tunnel->l2tp_net);
spin_lock_bh(&pn->l2tp_tunnel_list_lock);
list_del_rcu(&tunnel->list);
spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
/* drop initial ref */ /* drop initial ref */
l2tp_tunnel_dec_refcount(tunnel); l2tp_tunnel_dec_refcount(tunnel);
...@@ -1455,12 +1456,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net, ...@@ -1455,12 +1456,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net,
int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
struct l2tp_tunnel_cfg *cfg) struct l2tp_tunnel_cfg *cfg)
{ {
struct l2tp_tunnel *tunnel_walk; struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_net *pn; u32 tunnel_id = tunnel->tunnel_id;
struct socket *sock; struct socket *sock;
struct sock *sk; struct sock *sk;
int ret; int ret;
spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id,
GFP_ATOMIC);
spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
if (ret)
return ret == -ENOSPC ? -EEXIST : ret;
if (tunnel->fd < 0) { if (tunnel->fd < 0) {
ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id, ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id,
tunnel->peer_tunnel_id, cfg, tunnel->peer_tunnel_id, cfg,
...@@ -1481,23 +1489,13 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, ...@@ -1481,23 +1489,13 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
rcu_assign_sk_user_data(sk, tunnel); rcu_assign_sk_user_data(sk, tunnel);
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
tunnel->l2tp_net = net;
pn = l2tp_pernet(net);
sock_hold(sk); sock_hold(sk);
tunnel->sock = sk; tunnel->sock = sk;
tunnel->l2tp_net = net;
spin_lock_bh(&pn->l2tp_tunnel_list_lock); spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) { idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id);
if (tunnel_walk->tunnel_id == tunnel->tunnel_id) { spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
sock_put(sk);
ret = -EEXIST;
goto err_sock;
}
}
list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { if (tunnel->encap == L2TP_ENCAPTYPE_UDP) {
struct udp_tunnel_sock_cfg udp_cfg = { struct udp_tunnel_sock_cfg udp_cfg = {
...@@ -1523,9 +1521,6 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, ...@@ -1523,9 +1521,6 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
return 0; return 0;
err_sock:
write_lock_bh(&sk->sk_callback_lock);
rcu_assign_sk_user_data(sk, NULL);
err_inval_sock: err_inval_sock:
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
...@@ -1534,6 +1529,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, ...@@ -1534,6 +1529,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
else else
sockfd_put(sock); sockfd_put(sock);
err: err:
l2tp_tunnel_remove(net, tunnel);
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(l2tp_tunnel_register); EXPORT_SYMBOL_GPL(l2tp_tunnel_register);
...@@ -1647,8 +1643,8 @@ static __net_init int l2tp_init_net(struct net *net) ...@@ -1647,8 +1643,8 @@ static __net_init int l2tp_init_net(struct net *net)
struct l2tp_net *pn = net_generic(net, l2tp_net_id); struct l2tp_net *pn = net_generic(net, l2tp_net_id);
int hash; int hash;
INIT_LIST_HEAD(&pn->l2tp_tunnel_list); idr_init(&pn->l2tp_tunnel_idr);
spin_lock_init(&pn->l2tp_tunnel_list_lock); spin_lock_init(&pn->l2tp_tunnel_idr_lock);
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]); INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]);
...@@ -1662,11 +1658,13 @@ static __net_exit void l2tp_exit_net(struct net *net) ...@@ -1662,11 +1658,13 @@ static __net_exit void l2tp_exit_net(struct net *net)
{ {
struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_tunnel *tunnel = NULL; struct l2tp_tunnel *tunnel = NULL;
unsigned long tunnel_id, tmp;
int hash; int hash;
rcu_read_lock_bh(); rcu_read_lock_bh();
list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
l2tp_tunnel_delete(tunnel); if (tunnel)
l2tp_tunnel_delete(tunnel);
} }
rcu_read_unlock_bh(); rcu_read_unlock_bh();
...@@ -1676,6 +1674,7 @@ static __net_exit void l2tp_exit_net(struct net *net) ...@@ -1676,6 +1674,7 @@ static __net_exit void l2tp_exit_net(struct net *net)
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash])); WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
idr_destroy(&pn->l2tp_tunnel_idr);
} }
static struct pernet_operations l2tp_net_ops = { static struct pernet_operations l2tp_net_ops = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment