Commit 26abe143 authored by Eric W. Biederman's avatar Eric W. Biederman Committed by David S. Miller

net: Modify sk_alloc to not reference count the netns of kernel sockets.

Now that sk_alloc knows when a kernel socket is being allocated modify
it to not reference count the network namespace of kernel sockets.

Keep track of if a socket needs reference counting by adding a flag to
struct sock called sk_net_refcnt.

Update all of the callers of sock_create_kern to stop using
sk_change_net and sk_release_kernel as those hacks are no longer
needed, to avoid reference counting a kernel socket.
Signed-off-by: default avatar"Eric W. Biederman" <ebiederm@xmission.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 11aa9c28
......@@ -41,7 +41,7 @@ int inet_recv_error(struct sock *sk, struct msghdr *msg, int len,
static inline void inet_ctl_sock_destroy(struct sock *sk)
{
sk_release_kernel(sk);
sock_release(sk->sk_socket);
}
#endif
......@@ -184,6 +184,7 @@ struct sock_common {
unsigned char skc_reuse:4;
unsigned char skc_reuseport:1;
unsigned char skc_ipv6only:1;
unsigned char skc_net_refcnt:1;
int skc_bound_dev_if;
union {
struct hlist_node skc_bind_node;
......@@ -323,6 +324,7 @@ struct sock {
#define sk_reuse __sk_common.skc_reuse
#define sk_reuseport __sk_common.skc_reuseport
#define sk_ipv6only __sk_common.skc_ipv6only
#define sk_net_refcnt __sk_common.skc_net_refcnt
#define sk_bound_dev_if __sk_common.skc_bound_dev_if
#define sk_bind_node __sk_common.skc_bind_node
#define sk_prot __sk_common.skc_prot
......
......@@ -1412,7 +1412,10 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
*/
sk->sk_prot = sk->sk_prot_creator = prot;
sock_lock_init(sk);
sock_net_set(sk, get_net(net));
sk->sk_net_refcnt = kern ? 0 : 1;
if (likely(sk->sk_net_refcnt))
get_net(net);
sock_net_set(sk, net);
atomic_set(&sk->sk_wmem_alloc, 1);
sock_update_classid(sk);
......@@ -1446,7 +1449,8 @@ static void __sk_free(struct sock *sk)
if (sk->sk_peer_cred)
put_cred(sk->sk_peer_cred);
put_pid(sk->sk_peer_pid);
put_net(sock_net(sk));
if (likely(sk->sk_net_refcnt))
put_net(sock_net(sk));
sk_prot_free(sk->sk_prot_creator, sk);
}
......
......@@ -1430,7 +1430,7 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family,
struct net *net)
{
struct socket *sock;
int rc = sock_create_kern(&init_net, family, type, protocol, &sock);
int rc = sock_create_kern(net, family, type, protocol, &sock);
if (rc == 0) {
*sk = sock->sk;
......@@ -1440,8 +1440,6 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family,
* we do not wish this socket to see incoming packets.
*/
(*sk)->sk_prot->unhash(*sk);
sk_change_net(*sk, net);
}
return rc;
}
......
......@@ -15,12 +15,10 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,
struct socket *sock = NULL;
struct sockaddr_in udp_addr;
err = sock_create_kern(&init_net, AF_INET, SOCK_DGRAM, 0, &sock);
err = sock_create_kern(net, AF_INET, SOCK_DGRAM, 0, &sock);
if (err < 0)
goto error;
sk_change_net(sock->sk, net);
udp_addr.sin_family = AF_INET;
udp_addr.sin_addr = cfg->local_ip;
udp_addr.sin_port = cfg->local_udp_port;
......@@ -47,7 +45,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,
error:
if (sock) {
kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk);
sock_release(sock);
}
*sockp = NULL;
return err;
......@@ -101,7 +99,7 @@ void udp_tunnel_sock_release(struct socket *sock)
{
rcu_assign_sk_user_data(sock->sk, NULL);
kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk);
sock_release(sock);
}
EXPORT_SYMBOL_GPL(udp_tunnel_sock_release);
......
......@@ -19,12 +19,10 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,
int err;
struct socket *sock = NULL;
err = sock_create_kern(&init_net, AF_INET6, SOCK_DGRAM, 0, &sock);
err = sock_create_kern(net, AF_INET6, SOCK_DGRAM, 0, &sock);
if (err < 0)
goto error;
sk_change_net(sock->sk, net);
udp6_addr.sin6_family = AF_INET6;
memcpy(&udp6_addr.sin6_addr, &cfg->local_ip6,
sizeof(udp6_addr.sin6_addr));
......@@ -55,7 +53,7 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,
error:
if (sock) {
kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk);
sock_release(sock);
}
*sockp = NULL;
return err;
......
......@@ -1334,9 +1334,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
if (sock)
inet_shutdown(sock, 2);
} else {
if (sock)
if (sock) {
kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sk);
sock_release(sock);
}
}
l2tp_tunnel_sock_put(sk);
......@@ -1399,13 +1400,11 @@ static int l2tp_tunnel_sock_create(struct net *net,
if (cfg->local_ip6 && cfg->peer_ip6) {
struct sockaddr_l2tpip6 ip6_addr = {0};
err = sock_create_kern(&init_net, AF_INET6, SOCK_DGRAM,
err = sock_create_kern(net, AF_INET6, SOCK_DGRAM,
IPPROTO_L2TP, &sock);
if (err < 0)
goto out;
sk_change_net(sock->sk, net);
ip6_addr.l2tp_family = AF_INET6;
memcpy(&ip6_addr.l2tp_addr, cfg->local_ip6,
sizeof(ip6_addr.l2tp_addr));
......@@ -1429,13 +1428,11 @@ static int l2tp_tunnel_sock_create(struct net *net,
{
struct sockaddr_l2tpip ip_addr = {0};
err = sock_create_kern(&init_net, AF_INET, SOCK_DGRAM,
err = sock_create_kern(net, AF_INET, SOCK_DGRAM,
IPPROTO_L2TP, &sock);
if (err < 0)
goto out;
sk_change_net(sock->sk, net);
ip_addr.l2tp_family = AF_INET;
ip_addr.l2tp_addr = cfg->local_ip;
ip_addr.l2tp_conn_id = tunnel_id;
......@@ -1462,7 +1459,7 @@ static int l2tp_tunnel_sock_create(struct net *net,
*sockp = sock;
if ((err < 0) && sock) {
kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk);
sock_release(sock);
*sockp = NULL;
}
......
......@@ -1457,18 +1457,12 @@ static struct socket *make_send_sock(struct net *net, int id)
struct socket *sock;
int result;
/* First create a socket move it to right name space later */
result = sock_create_kern(&init_net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
/* First create a socket */
result = sock_create_kern(net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
if (result < 0) {
pr_err("Error during creation of socket; terminating\n");
return ERR_PTR(result);
}
/*
* Kernel sockets that are a part of a namespace, should not
* hold a reference to a namespace in order to allow to stop it.
* After sk_change_net should be released using sk_release_kernel.
*/
sk_change_net(sock->sk, net);
result = set_mcast_if(sock->sk, ipvs->master_mcast_ifn);
if (result < 0) {
pr_err("Error setting outbound mcast interface\n");
......@@ -1497,7 +1491,7 @@ static struct socket *make_send_sock(struct net *net, int id)
return sock;
error:
sk_release_kernel(sock->sk);
sock_release(sock);
return ERR_PTR(result);
}
......@@ -1518,17 +1512,11 @@ static struct socket *make_receive_sock(struct net *net, int id)
int result;
/* First create a socket */
result = sock_create_kern(&init_net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
result = sock_create_kern(net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
if (result < 0) {
pr_err("Error during creation of socket; terminating\n");
return ERR_PTR(result);
}
/*
* Kernel sockets that are a part of a namespace, should not
* hold a reference to a namespace in order to allow to stop it.
* After sk_change_net should be released using sk_release_kernel.
*/
sk_change_net(sock->sk, net);
/* it is equivalent to the REUSEADDR option in user-space */
sock->sk->sk_reuse = SK_CAN_REUSE;
result = sysctl_sync_sock_size(ipvs);
......@@ -1554,7 +1542,7 @@ static struct socket *make_receive_sock(struct net *net, int id)
return sock;
error:
sk_release_kernel(sock->sk);
sock_release(sock);
return ERR_PTR(result);
}
......@@ -1692,7 +1680,7 @@ static int sync_thread_master(void *data)
ip_vs_sync_buff_release(sb);
/* release the sending multicast socket */
sk_release_kernel(tinfo->sock->sk);
sock_release(tinfo->sock);
kfree(tinfo);
return 0;
......@@ -1729,7 +1717,7 @@ static int sync_thread_backup(void *data)
}
/* release the sending multicast socket */
sk_release_kernel(tinfo->sock->sk);
sock_release(tinfo->sock);
kfree(tinfo->buf);
kfree(tinfo);
......@@ -1854,11 +1842,11 @@ int start_sync_thread(struct net *net, int state, char *mcast_ifn, __u8 syncid)
return 0;
outsocket:
sk_release_kernel(sock->sk);
sock_release(sock);
outtinfo:
if (tinfo) {
sk_release_kernel(tinfo->sock->sk);
sock_release(tinfo->sock);
kfree(tinfo->buf);
kfree(tinfo);
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment