Commit 05ce8bd4 authored by David S. Miller's avatar David S. Miller

Merge branch 'l2tp-register-sessions-atomically'

Guillaume Nault says:

====================
l2tp: register sessions atomically

Currently l2tp_session_create() allocates a session, partially
initialises it and finally registers it. It therefore exposes sessions
that aren't fully initialised to the rest of the system, because
pseudo-wire specific initialisation can only happen after
l2tp_session_create() returns.
This leads to several crashes when these sessions are used or deleted.

This series starts by splitting session registration out of
l2tp_session_create() (patch #1). Thus allowing pseudo-wires code to
terminate the initialisation phase before registration.

Then patch #2 fixes the eth pseudo-wire code. This requires protecting
the session's netdevice pointer with RCU, because it still needs to be
updated concurrently after the session got registered.

Remaining patches take care of ppp pseudo-wires. RCU protection is
needed there too, for the same reasons. This time it's the pppol2tp
socket pointer that gets protected. For clarity, and since the
conversion requires more modifications, introducing RCU is done in
its own patch (#3). Then patch #4 only has to take care of fixing
sessions initialisation and registration (and adapting part of the
deletion process).
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 949cf8b1 f98be6c6
......@@ -322,8 +322,8 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
}
EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
struct l2tp_session *session)
int l2tp_session_register(struct l2tp_session *session,
struct l2tp_tunnel *tunnel)
{
struct l2tp_session *session_walk;
struct hlist_head *g_head;
......@@ -371,6 +371,10 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
hlist_add_head(&session->hlist, head);
write_unlock_bh(&tunnel->hlist_lock);
/* Ignore management session in session count value */
if (session->session_id != 0)
atomic_inc(&l2tp_session_count);
return 0;
err_tlock_pnlock:
......@@ -380,6 +384,7 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
return err;
}
EXPORT_SYMBOL_GPL(l2tp_session_register);
/* Lookup a tunnel by id
*/
......@@ -1788,7 +1793,6 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len);
struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg)
{
struct l2tp_session *session;
int err;
session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
if (session != NULL) {
......@@ -1846,17 +1850,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
refcount_set(&session->ref_count, 1);
err = l2tp_session_add_to_tunnel(tunnel, session);
if (err) {
kfree(session);
return ERR_PTR(err);
}
/* Ignore management session in session count value */
if (session->session_id != 0)
atomic_inc(&l2tp_session_count);
return session;
}
......
......@@ -263,6 +263,9 @@ struct l2tp_session *l2tp_session_create(int priv_size,
struct l2tp_tunnel *tunnel,
u32 session_id, u32 peer_session_id,
struct l2tp_session_cfg *cfg);
int l2tp_session_register(struct l2tp_session *session,
struct l2tp_tunnel *tunnel);
void __l2tp_session_unhash(struct l2tp_session *session);
int l2tp_session_delete(struct l2tp_session *session);
void l2tp_session_free(struct l2tp_session *session);
......
......@@ -54,7 +54,7 @@ struct l2tp_eth {
/* via l2tp_session_priv() */
struct l2tp_eth_sess {
struct net_device *dev;
struct net_device __rcu *dev;
};
......@@ -72,7 +72,14 @@ static int l2tp_eth_dev_init(struct net_device *dev)
static void l2tp_eth_dev_uninit(struct net_device *dev)
{
dev_put(dev);
struct l2tp_eth *priv = netdev_priv(dev);
struct l2tp_eth_sess *spriv;
spriv = l2tp_session_priv(priv->session);
RCU_INIT_POINTER(spriv->dev, NULL);
/* No need for synchronize_net() here. We're called by
* unregister_netdev*(), which does the synchronisation for us.
*/
}
static int l2tp_eth_dev_xmit(struct sk_buff *skb, struct net_device *dev)
......@@ -130,8 +137,8 @@ static void l2tp_eth_dev_setup(struct net_device *dev)
static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb, int data_len)
{
struct l2tp_eth_sess *spriv = l2tp_session_priv(session);
struct net_device *dev = spriv->dev;
struct l2tp_eth *priv = netdev_priv(dev);
struct net_device *dev;
struct l2tp_eth *priv;
if (session->debug & L2TP_MSG_DATA) {
unsigned int length;
......@@ -155,16 +162,25 @@ static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb,
skb_dst_drop(skb);
nf_reset(skb);
rcu_read_lock();
dev = rcu_dereference(spriv->dev);
if (!dev)
goto error_rcu;
priv = netdev_priv(dev);
if (dev_forward_skb(dev, skb) == NET_RX_SUCCESS) {
atomic_long_inc(&priv->rx_packets);
atomic_long_add(data_len, &priv->rx_bytes);
} else {
atomic_long_inc(&priv->rx_errors);
}
rcu_read_unlock();
return;
error_rcu:
rcu_read_unlock();
error:
atomic_long_inc(&priv->rx_errors);
kfree_skb(skb);
}
......@@ -175,11 +191,15 @@ static void l2tp_eth_delete(struct l2tp_session *session)
if (session) {
spriv = l2tp_session_priv(session);
dev = spriv->dev;
rtnl_lock();
dev = rtnl_dereference(spriv->dev);
if (dev) {
unregister_netdev(dev);
spriv->dev = NULL;
unregister_netdevice(dev);
rtnl_unlock();
module_put(THIS_MODULE);
} else {
rtnl_unlock();
}
}
}
......@@ -189,9 +209,20 @@ static void l2tp_eth_show(struct seq_file *m, void *arg)
{
struct l2tp_session *session = arg;
struct l2tp_eth_sess *spriv = l2tp_session_priv(session);
struct net_device *dev = spriv->dev;
struct net_device *dev;
rcu_read_lock();
dev = rcu_dereference(spriv->dev);
if (!dev) {
rcu_read_unlock();
return;
}
dev_hold(dev);
rcu_read_unlock();
seq_printf(m, " interface %s\n", dev->name);
dev_put(dev);
}
#endif
......@@ -268,14 +299,14 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
peer_session_id, cfg);
if (IS_ERR(session)) {
rc = PTR_ERR(session);
goto out;
goto err;
}
dev = alloc_netdev(sizeof(*priv), name, name_assign_type,
l2tp_eth_dev_setup);
if (!dev) {
rc = -ENOMEM;
goto out_del_session;
goto err_sess;
}
dev_net_set(dev, net);
......@@ -295,26 +326,48 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
#endif
spriv = l2tp_session_priv(session);
spriv->dev = dev;
rc = register_netdev(dev);
if (rc < 0)
goto out_del_dev;
l2tp_session_inc_refcount(session);
rtnl_lock();
/* Register both device and session while holding the rtnl lock. This
* ensures that l2tp_eth_delete() will see that there's a device to
* unregister, even if it happened to run before we assign spriv->dev.
*/
rc = l2tp_session_register(session, tunnel);
if (rc < 0) {
rtnl_unlock();
goto err_sess_dev;
}
rc = register_netdevice(dev);
if (rc < 0) {
rtnl_unlock();
l2tp_session_delete(session);
l2tp_session_dec_refcount(session);
free_netdev(dev);
return rc;
}
__module_get(THIS_MODULE);
/* Must be done after register_netdev() */
strlcpy(session->ifname, dev->name, IFNAMSIZ);
rcu_assign_pointer(spriv->dev, dev);
dev_hold(dev);
rtnl_unlock();
l2tp_session_dec_refcount(session);
__module_get(THIS_MODULE);
return 0;
out_del_dev:
err_sess_dev:
l2tp_session_dec_refcount(session);
free_netdev(dev);
spriv->dev = NULL;
out_del_session:
l2tp_session_delete(session);
out:
err_sess:
kfree(session);
err:
return rc;
}
......
......@@ -122,8 +122,11 @@
struct pppol2tp_session {
int owner; /* pid that opened the socket */
struct sock *sock; /* Pointer to the session
struct mutex sk_lock; /* Protects .sk */
struct sock __rcu *sk; /* Pointer to the session
* PPPoX socket */
struct sock *__sk; /* Copy of .sk, for cleanup */
struct rcu_head rcu; /* For asynchronous release */
struct sock *tunnel_sock; /* Pointer to the tunnel UDP
* socket */
int flags; /* accessed by PPPIOCGFLAGS.
......@@ -138,6 +141,24 @@ static const struct ppp_channel_ops pppol2tp_chan_ops = {
static const struct proto_ops pppol2tp_ops;
/* Retrieves the pppol2tp socket associated to a session.
* A reference is held on the returned socket, so this function must be paired
* with sock_put().
*/
static struct sock *pppol2tp_session_get_sock(struct l2tp_session *session)
{
struct pppol2tp_session *ps = l2tp_session_priv(session);
struct sock *sk;
rcu_read_lock();
sk = rcu_dereference(ps->sk);
if (sk)
sock_hold(sk);
rcu_read_unlock();
return sk;
}
/* Helpers to obtain tunnel/session contexts from sockets.
*/
static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
......@@ -224,7 +245,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
/* If the socket is bound, send it in to PPP's input queue. Otherwise
* queue it on the session socket.
*/
sk = ps->sock;
rcu_read_lock();
sk = rcu_dereference(ps->sk);
if (sk == NULL)
goto no_sock;
......@@ -247,30 +269,16 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
kfree_skb(skb);
}
}
rcu_read_unlock();
return;
no_sock:
rcu_read_unlock();
l2tp_info(session, L2TP_MSG_DATA, "%s: no socket\n", session->name);
kfree_skb(skb);
}
static void pppol2tp_session_sock_hold(struct l2tp_session *session)
{
struct pppol2tp_session *ps = l2tp_session_priv(session);
if (ps->sock)
sock_hold(ps->sock);
}
static void pppol2tp_session_sock_put(struct l2tp_session *session)
{
struct pppol2tp_session *ps = l2tp_session_priv(session);
if (ps->sock)
sock_put(ps->sock);
}
/************************************************************************
* Transmit handling
***********************************************************************/
......@@ -431,17 +439,16 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
*/
static void pppol2tp_session_close(struct l2tp_session *session)
{
struct pppol2tp_session *ps = l2tp_session_priv(session);
struct sock *sk = ps->sock;
struct socket *sock = sk->sk_socket;
struct sock *sk;
BUG_ON(session->magic != L2TP_SESSION_MAGIC);
if (sock)
inet_shutdown(sock, SEND_SHUTDOWN);
/* Don't let the session go away before our socket does */
l2tp_session_inc_refcount(session);
sk = pppol2tp_session_get_sock(session);
if (sk) {
if (sk->sk_socket)
inet_shutdown(sk->sk_socket, SEND_SHUTDOWN);
sock_put(sk);
}
}
/* Really kill the session socket. (Called from sock_put() if
......@@ -461,6 +468,14 @@ static void pppol2tp_session_destruct(struct sock *sk)
}
}
static void pppol2tp_put_sk(struct rcu_head *head)
{
struct pppol2tp_session *ps;
ps = container_of(head, typeof(*ps), rcu);
sock_put(ps->__sk);
}
/* Called when the PPPoX socket (session) is closed.
*/
static int pppol2tp_release(struct socket *sock)
......@@ -486,11 +501,23 @@ static int pppol2tp_release(struct socket *sock)
session = pppol2tp_sock_to_session(sk);
/* Purge any queued data */
if (session != NULL) {
__l2tp_session_unhash(session);
l2tp_session_queue_purge(session);
sock_put(sk);
struct pppol2tp_session *ps;
l2tp_session_delete(session);
ps = l2tp_session_priv(session);
mutex_lock(&ps->sk_lock);
ps->__sk = rcu_dereference_protected(ps->sk,
lockdep_is_held(&ps->sk_lock));
RCU_INIT_POINTER(ps->sk, NULL);
mutex_unlock(&ps->sk_lock);
call_rcu(&ps->rcu, pppol2tp_put_sk);
/* Rely on the sock_put() call at the end of the function for
* dropping the reference held by pppol2tp_sock_to_session().
* The last reference will be dropped by pppol2tp_put_sk().
*/
}
release_sock(sk);
......@@ -557,16 +584,47 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
static void pppol2tp_show(struct seq_file *m, void *arg)
{
struct l2tp_session *session = arg;
struct pppol2tp_session *ps = l2tp_session_priv(session);
struct sock *sk;
sk = pppol2tp_session_get_sock(session);
if (sk) {
struct pppox_sock *po = pppox_sk(sk);
if (ps) {
struct pppox_sock *po = pppox_sk(ps->sock);
if (po)
seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan));
seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan));
sock_put(sk);
}
}
#endif
static void pppol2tp_session_init(struct l2tp_session *session)
{
struct pppol2tp_session *ps;
struct dst_entry *dst;
session->recv_skb = pppol2tp_recv;
session->session_close = pppol2tp_session_close;
#if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
session->show = pppol2tp_show;
#endif
ps = l2tp_session_priv(session);
mutex_init(&ps->sk_lock);
ps->tunnel_sock = session->tunnel->sock;
ps->owner = current->pid;
/* If PMTU discovery was enabled, use the MTU that was discovered */
dst = sk_dst_get(session->tunnel->sock);
if (dst) {
u32 pmtu = dst_mtu(dst);
if (pmtu) {
session->mtu = pmtu - PPPOL2TP_HEADER_OVERHEAD;
session->mru = pmtu - PPPOL2TP_HEADER_OVERHEAD;
}
dst_release(dst);
}
}
/* connect() handler. Attach a PPPoX socket to a tunnel UDP socket
*/
static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
......@@ -578,7 +636,6 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
struct l2tp_session *session = NULL;
struct l2tp_tunnel *tunnel;
struct pppol2tp_session *ps;
struct dst_entry *dst;
struct l2tp_session_cfg cfg = { 0, };
int error = 0;
u32 tunnel_id, peer_tunnel_id;
......@@ -693,13 +750,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
/* Using a pre-existing session is fine as long as it hasn't
* been connected yet.
*/
if (ps->sock) {
mutex_lock(&ps->sk_lock);
if (rcu_dereference_protected(ps->sk,
lockdep_is_held(&ps->sk_lock))) {
mutex_unlock(&ps->sk_lock);
error = -EEXIST;
goto end;
}
/* consistency checks */
if (ps->tunnel_sock != tunnel->sock) {
mutex_unlock(&ps->sk_lock);
error = -EEXIST;
goto end;
}
......@@ -715,35 +776,19 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
error = PTR_ERR(session);
goto end;
}
}
/* Associate session with its PPPoL2TP socket */
ps = l2tp_session_priv(session);
ps->owner = current->pid;
ps->sock = sk;
ps->tunnel_sock = tunnel->sock;
session->recv_skb = pppol2tp_recv;
session->session_close = pppol2tp_session_close;
#if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
session->show = pppol2tp_show;
#endif
/* We need to know each time a skb is dropped from the reorder
* queue.
*/
session->ref = pppol2tp_session_sock_hold;
session->deref = pppol2tp_session_sock_put;
/* If PMTU discovery was enabled, use the MTU that was discovered */
dst = sk_dst_get(tunnel->sock);
if (dst != NULL) {
u32 pmtu = dst_mtu(dst);
pppol2tp_session_init(session);
ps = l2tp_session_priv(session);
l2tp_session_inc_refcount(session);
if (pmtu != 0)
session->mtu = session->mru = pmtu -
PPPOL2TP_HEADER_OVERHEAD;
dst_release(dst);
mutex_lock(&ps->sk_lock);
error = l2tp_session_register(session, tunnel);
if (error < 0) {
mutex_unlock(&ps->sk_lock);
kfree(session);
goto end;
}
drop_refcnt = true;
}
/* Special case: if source & dest session_id == 0x0000, this
......@@ -768,12 +813,23 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
po->chan.mtu = session->mtu;
error = ppp_register_net_channel(sock_net(sk), &po->chan);
if (error)
if (error) {
mutex_unlock(&ps->sk_lock);
goto end;
}
out_no_ppp:
/* This is how we get the session context from the socket. */
sk->sk_user_data = session;
rcu_assign_pointer(ps->sk, sk);
mutex_unlock(&ps->sk_lock);
/* Keep the reference we've grabbed on the session: sk doesn't expect
* the session to disappear. pppol2tp_session_destruct() is responsible
* for dropping it.
*/
drop_refcnt = false;
sk->sk_state = PPPOX_CONNECTED;
l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n",
session->name);
......@@ -795,12 +851,11 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
{
int error;
struct l2tp_session *session;
struct pppol2tp_session *ps;
/* Error if tunnel socket is not prepped */
if (!tunnel->sock) {
error = -ENOENT;
goto out;
goto err;
}
/* Default MTU values. */
......@@ -815,18 +870,20 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
peer_session_id, cfg);
if (IS_ERR(session)) {
error = PTR_ERR(session);
goto out;
goto err;
}
ps = l2tp_session_priv(session);
ps->tunnel_sock = tunnel->sock;
pppol2tp_session_init(session);
l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n",
session->name);
error = l2tp_session_register(session, tunnel);
if (error < 0)
goto err_sess;
error = 0;
return 0;
out:
err_sess:
kfree(session);
err:
return error;
}
......@@ -987,12 +1044,10 @@ static int pppol2tp_session_ioctl(struct l2tp_session *session,
"%s: pppol2tp_session_ioctl(cmd=%#x, arg=%#lx)\n",
session->name, cmd, arg);
sk = ps->sock;
sk = pppol2tp_session_get_sock(session);
if (!sk)
return -EBADR;
sock_hold(sk);
switch (cmd) {
case SIOCGIFMTU:
err = -ENXIO;
......@@ -1268,7 +1323,6 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
int optname, int val)
{
int err = 0;
struct pppol2tp_session *ps = l2tp_session_priv(session);
switch (optname) {
case PPPOL2TP_SO_RECVSEQ:
......@@ -1289,8 +1343,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
}
session->send_seq = !!val;
{
struct sock *ssk = ps->sock;
struct pppox_sock *po = pppox_sk(ssk);
struct pppox_sock *po = pppox_sk(sk);
po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
}
......@@ -1629,8 +1683,9 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
{
struct l2tp_session *session = v;
struct l2tp_tunnel *tunnel = session->tunnel;
struct pppol2tp_session *ps = l2tp_session_priv(session);
struct pppox_sock *po = pppox_sk(ps->sock);
unsigned char state;
char user_data_ok;
struct sock *sk;
u32 ip = 0;
u16 port = 0;
......@@ -1640,6 +1695,15 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
port = ntohs(inet->inet_sport);
}
sk = pppol2tp_session_get_sock(session);
if (sk) {
state = sk->sk_state;
user_data_ok = (session == sk->sk_user_data) ? 'Y' : 'N';
} else {
state = 0;
user_data_ok = 'N';
}
seq_printf(m, " SESSION '%s' %08X/%d %04X/%04X -> "
"%04X/%04X %d %c\n",
session->name, ip, port,
......@@ -1647,9 +1711,7 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
session->session_id,
tunnel->peer_tunnel_id,
session->peer_session_id,
ps->sock->sk_state,
(session == ps->sock->sk_user_data) ?
'Y' : 'N');
state, user_data_ok);
seq_printf(m, " %d/%d/%c/%c/%s %08x %u\n",
session->mtu, session->mru,
session->recv_seq ? 'R' : '-',
......@@ -1666,8 +1728,12 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
atomic_long_read(&session->stats.rx_bytes),
atomic_long_read(&session->stats.rx_errors));
if (po)
if (sk) {
struct pppox_sock *po = pppox_sk(sk);
seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan));
sock_put(sk);
}
}
static int pppol2tp_seq_show(struct seq_file *m, void *v)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment