Commit 26abe143 authored by Eric W. Biederman's avatar Eric W. Biederman Committed by David S. Miller

net: Modify sk_alloc to not reference count the netns of kernel sockets.

Now that sk_alloc knows when a kernel socket is being allocated modify
it to not reference count the network namespace of kernel sockets.

Keep track of if a socket needs reference counting by adding a flag to
struct sock called sk_net_refcnt.

Update all of the callers of sock_create_kern to stop using
sk_change_net and sk_release_kernel as those hacks are no longer
needed, to avoid reference counting a kernel socket.
Signed-off-by: default avatar"Eric W. Biederman" <ebiederm@xmission.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 11aa9c28
...@@ -41,7 +41,7 @@ int inet_recv_error(struct sock *sk, struct msghdr *msg, int len, ...@@ -41,7 +41,7 @@ int inet_recv_error(struct sock *sk, struct msghdr *msg, int len,
static inline void inet_ctl_sock_destroy(struct sock *sk) static inline void inet_ctl_sock_destroy(struct sock *sk)
{ {
sk_release_kernel(sk); sock_release(sk->sk_socket);
} }
#endif #endif
...@@ -184,6 +184,7 @@ struct sock_common { ...@@ -184,6 +184,7 @@ struct sock_common {
unsigned char skc_reuse:4; unsigned char skc_reuse:4;
unsigned char skc_reuseport:1; unsigned char skc_reuseport:1;
unsigned char skc_ipv6only:1; unsigned char skc_ipv6only:1;
unsigned char skc_net_refcnt:1;
int skc_bound_dev_if; int skc_bound_dev_if;
union { union {
struct hlist_node skc_bind_node; struct hlist_node skc_bind_node;
...@@ -323,6 +324,7 @@ struct sock { ...@@ -323,6 +324,7 @@ struct sock {
#define sk_reuse __sk_common.skc_reuse #define sk_reuse __sk_common.skc_reuse
#define sk_reuseport __sk_common.skc_reuseport #define sk_reuseport __sk_common.skc_reuseport
#define sk_ipv6only __sk_common.skc_ipv6only #define sk_ipv6only __sk_common.skc_ipv6only
#define sk_net_refcnt __sk_common.skc_net_refcnt
#define sk_bound_dev_if __sk_common.skc_bound_dev_if #define sk_bound_dev_if __sk_common.skc_bound_dev_if
#define sk_bind_node __sk_common.skc_bind_node #define sk_bind_node __sk_common.skc_bind_node
#define sk_prot __sk_common.skc_prot #define sk_prot __sk_common.skc_prot
......
...@@ -1412,7 +1412,10 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority, ...@@ -1412,7 +1412,10 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
*/ */
sk->sk_prot = sk->sk_prot_creator = prot; sk->sk_prot = sk->sk_prot_creator = prot;
sock_lock_init(sk); sock_lock_init(sk);
sock_net_set(sk, get_net(net)); sk->sk_net_refcnt = kern ? 0 : 1;
if (likely(sk->sk_net_refcnt))
get_net(net);
sock_net_set(sk, net);
atomic_set(&sk->sk_wmem_alloc, 1); atomic_set(&sk->sk_wmem_alloc, 1);
sock_update_classid(sk); sock_update_classid(sk);
...@@ -1446,7 +1449,8 @@ static void __sk_free(struct sock *sk) ...@@ -1446,7 +1449,8 @@ static void __sk_free(struct sock *sk)
if (sk->sk_peer_cred) if (sk->sk_peer_cred)
put_cred(sk->sk_peer_cred); put_cred(sk->sk_peer_cred);
put_pid(sk->sk_peer_pid); put_pid(sk->sk_peer_pid);
put_net(sock_net(sk)); if (likely(sk->sk_net_refcnt))
put_net(sock_net(sk));
sk_prot_free(sk->sk_prot_creator, sk); sk_prot_free(sk->sk_prot_creator, sk);
} }
......
...@@ -1430,7 +1430,7 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family, ...@@ -1430,7 +1430,7 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family,
struct net *net) struct net *net)
{ {
struct socket *sock; struct socket *sock;
int rc = sock_create_kern(&init_net, family, type, protocol, &sock); int rc = sock_create_kern(net, family, type, protocol, &sock);
if (rc == 0) { if (rc == 0) {
*sk = sock->sk; *sk = sock->sk;
...@@ -1440,8 +1440,6 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family, ...@@ -1440,8 +1440,6 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family,
* we do not wish this socket to see incoming packets. * we do not wish this socket to see incoming packets.
*/ */
(*sk)->sk_prot->unhash(*sk); (*sk)->sk_prot->unhash(*sk);
sk_change_net(*sk, net);
} }
return rc; return rc;
} }
......
...@@ -15,12 +15,10 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, ...@@ -15,12 +15,10 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,
struct socket *sock = NULL; struct socket *sock = NULL;
struct sockaddr_in udp_addr; struct sockaddr_in udp_addr;
err = sock_create_kern(&init_net, AF_INET, SOCK_DGRAM, 0, &sock); err = sock_create_kern(net, AF_INET, SOCK_DGRAM, 0, &sock);
if (err < 0) if (err < 0)
goto error; goto error;
sk_change_net(sock->sk, net);
udp_addr.sin_family = AF_INET; udp_addr.sin_family = AF_INET;
udp_addr.sin_addr = cfg->local_ip; udp_addr.sin_addr = cfg->local_ip;
udp_addr.sin_port = cfg->local_udp_port; udp_addr.sin_port = cfg->local_udp_port;
...@@ -47,7 +45,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, ...@@ -47,7 +45,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,
error: error:
if (sock) { if (sock) {
kernel_sock_shutdown(sock, SHUT_RDWR); kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk); sock_release(sock);
} }
*sockp = NULL; *sockp = NULL;
return err; return err;
...@@ -101,7 +99,7 @@ void udp_tunnel_sock_release(struct socket *sock) ...@@ -101,7 +99,7 @@ void udp_tunnel_sock_release(struct socket *sock)
{ {
rcu_assign_sk_user_data(sock->sk, NULL); rcu_assign_sk_user_data(sock->sk, NULL);
kernel_sock_shutdown(sock, SHUT_RDWR); kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk); sock_release(sock);
} }
EXPORT_SYMBOL_GPL(udp_tunnel_sock_release); EXPORT_SYMBOL_GPL(udp_tunnel_sock_release);
......
...@@ -19,12 +19,10 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg, ...@@ -19,12 +19,10 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,
int err; int err;
struct socket *sock = NULL; struct socket *sock = NULL;
err = sock_create_kern(&init_net, AF_INET6, SOCK_DGRAM, 0, &sock); err = sock_create_kern(net, AF_INET6, SOCK_DGRAM, 0, &sock);
if (err < 0) if (err < 0)
goto error; goto error;
sk_change_net(sock->sk, net);
udp6_addr.sin6_family = AF_INET6; udp6_addr.sin6_family = AF_INET6;
memcpy(&udp6_addr.sin6_addr, &cfg->local_ip6, memcpy(&udp6_addr.sin6_addr, &cfg->local_ip6,
sizeof(udp6_addr.sin6_addr)); sizeof(udp6_addr.sin6_addr));
...@@ -55,7 +53,7 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg, ...@@ -55,7 +53,7 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,
error: error:
if (sock) { if (sock) {
kernel_sock_shutdown(sock, SHUT_RDWR); kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk); sock_release(sock);
} }
*sockp = NULL; *sockp = NULL;
return err; return err;
......
...@@ -1334,9 +1334,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work) ...@@ -1334,9 +1334,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
if (sock) if (sock)
inet_shutdown(sock, 2); inet_shutdown(sock, 2);
} else { } else {
if (sock) if (sock) {
kernel_sock_shutdown(sock, SHUT_RDWR); kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sk); sock_release(sock);
}
} }
l2tp_tunnel_sock_put(sk); l2tp_tunnel_sock_put(sk);
...@@ -1399,13 +1400,11 @@ static int l2tp_tunnel_sock_create(struct net *net, ...@@ -1399,13 +1400,11 @@ static int l2tp_tunnel_sock_create(struct net *net,
if (cfg->local_ip6 && cfg->peer_ip6) { if (cfg->local_ip6 && cfg->peer_ip6) {
struct sockaddr_l2tpip6 ip6_addr = {0}; struct sockaddr_l2tpip6 ip6_addr = {0};
err = sock_create_kern(&init_net, AF_INET6, SOCK_DGRAM, err = sock_create_kern(net, AF_INET6, SOCK_DGRAM,
IPPROTO_L2TP, &sock); IPPROTO_L2TP, &sock);
if (err < 0) if (err < 0)
goto out; goto out;
sk_change_net(sock->sk, net);
ip6_addr.l2tp_family = AF_INET6; ip6_addr.l2tp_family = AF_INET6;
memcpy(&ip6_addr.l2tp_addr, cfg->local_ip6, memcpy(&ip6_addr.l2tp_addr, cfg->local_ip6,
sizeof(ip6_addr.l2tp_addr)); sizeof(ip6_addr.l2tp_addr));
...@@ -1429,13 +1428,11 @@ static int l2tp_tunnel_sock_create(struct net *net, ...@@ -1429,13 +1428,11 @@ static int l2tp_tunnel_sock_create(struct net *net,
{ {
struct sockaddr_l2tpip ip_addr = {0}; struct sockaddr_l2tpip ip_addr = {0};
err = sock_create_kern(&init_net, AF_INET, SOCK_DGRAM, err = sock_create_kern(net, AF_INET, SOCK_DGRAM,
IPPROTO_L2TP, &sock); IPPROTO_L2TP, &sock);
if (err < 0) if (err < 0)
goto out; goto out;
sk_change_net(sock->sk, net);
ip_addr.l2tp_family = AF_INET; ip_addr.l2tp_family = AF_INET;
ip_addr.l2tp_addr = cfg->local_ip; ip_addr.l2tp_addr = cfg->local_ip;
ip_addr.l2tp_conn_id = tunnel_id; ip_addr.l2tp_conn_id = tunnel_id;
...@@ -1462,7 +1459,7 @@ static int l2tp_tunnel_sock_create(struct net *net, ...@@ -1462,7 +1459,7 @@ static int l2tp_tunnel_sock_create(struct net *net,
*sockp = sock; *sockp = sock;
if ((err < 0) && sock) { if ((err < 0) && sock) {
kernel_sock_shutdown(sock, SHUT_RDWR); kernel_sock_shutdown(sock, SHUT_RDWR);
sk_release_kernel(sock->sk); sock_release(sock);
*sockp = NULL; *sockp = NULL;
} }
......
...@@ -1457,18 +1457,12 @@ static struct socket *make_send_sock(struct net *net, int id) ...@@ -1457,18 +1457,12 @@ static struct socket *make_send_sock(struct net *net, int id)
struct socket *sock; struct socket *sock;
int result; int result;
/* First create a socket move it to right name space later */ /* First create a socket */
result = sock_create_kern(&init_net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock); result = sock_create_kern(net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
if (result < 0) { if (result < 0) {
pr_err("Error during creation of socket; terminating\n"); pr_err("Error during creation of socket; terminating\n");
return ERR_PTR(result); return ERR_PTR(result);
} }
/*
* Kernel sockets that are a part of a namespace, should not
* hold a reference to a namespace in order to allow to stop it.
* After sk_change_net should be released using sk_release_kernel.
*/
sk_change_net(sock->sk, net);
result = set_mcast_if(sock->sk, ipvs->master_mcast_ifn); result = set_mcast_if(sock->sk, ipvs->master_mcast_ifn);
if (result < 0) { if (result < 0) {
pr_err("Error setting outbound mcast interface\n"); pr_err("Error setting outbound mcast interface\n");
...@@ -1497,7 +1491,7 @@ static struct socket *make_send_sock(struct net *net, int id) ...@@ -1497,7 +1491,7 @@ static struct socket *make_send_sock(struct net *net, int id)
return sock; return sock;
error: error:
sk_release_kernel(sock->sk); sock_release(sock);
return ERR_PTR(result); return ERR_PTR(result);
} }
...@@ -1518,17 +1512,11 @@ static struct socket *make_receive_sock(struct net *net, int id) ...@@ -1518,17 +1512,11 @@ static struct socket *make_receive_sock(struct net *net, int id)
int result; int result;
/* First create a socket */ /* First create a socket */
result = sock_create_kern(&init_net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock); result = sock_create_kern(net, PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
if (result < 0) { if (result < 0) {
pr_err("Error during creation of socket; terminating\n"); pr_err("Error during creation of socket; terminating\n");
return ERR_PTR(result); return ERR_PTR(result);
} }
/*
* Kernel sockets that are a part of a namespace, should not
* hold a reference to a namespace in order to allow to stop it.
* After sk_change_net should be released using sk_release_kernel.
*/
sk_change_net(sock->sk, net);
/* it is equivalent to the REUSEADDR option in user-space */ /* it is equivalent to the REUSEADDR option in user-space */
sock->sk->sk_reuse = SK_CAN_REUSE; sock->sk->sk_reuse = SK_CAN_REUSE;
result = sysctl_sync_sock_size(ipvs); result = sysctl_sync_sock_size(ipvs);
...@@ -1554,7 +1542,7 @@ static struct socket *make_receive_sock(struct net *net, int id) ...@@ -1554,7 +1542,7 @@ static struct socket *make_receive_sock(struct net *net, int id)
return sock; return sock;
error: error:
sk_release_kernel(sock->sk); sock_release(sock);
return ERR_PTR(result); return ERR_PTR(result);
} }
...@@ -1692,7 +1680,7 @@ static int sync_thread_master(void *data) ...@@ -1692,7 +1680,7 @@ static int sync_thread_master(void *data)
ip_vs_sync_buff_release(sb); ip_vs_sync_buff_release(sb);
/* release the sending multicast socket */ /* release the sending multicast socket */
sk_release_kernel(tinfo->sock->sk); sock_release(tinfo->sock);
kfree(tinfo); kfree(tinfo);
return 0; return 0;
...@@ -1729,7 +1717,7 @@ static int sync_thread_backup(void *data) ...@@ -1729,7 +1717,7 @@ static int sync_thread_backup(void *data)
} }
/* release the sending multicast socket */ /* release the sending multicast socket */
sk_release_kernel(tinfo->sock->sk); sock_release(tinfo->sock);
kfree(tinfo->buf); kfree(tinfo->buf);
kfree(tinfo); kfree(tinfo);
...@@ -1854,11 +1842,11 @@ int start_sync_thread(struct net *net, int state, char *mcast_ifn, __u8 syncid) ...@@ -1854,11 +1842,11 @@ int start_sync_thread(struct net *net, int state, char *mcast_ifn, __u8 syncid)
return 0; return 0;
outsocket: outsocket:
sk_release_kernel(sock->sk); sock_release(sock);
outtinfo: outtinfo:
if (tinfo) { if (tinfo) {
sk_release_kernel(tinfo->sock->sk); sock_release(tinfo->sock);
kfree(tinfo->buf); kfree(tinfo->buf);
kfree(tinfo); kfree(tinfo);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment