o net: module refcounting for sk_alloc/sk_free

I had to move the rtnetlink_init and init_netlink calls to af_netlink init time, so that
the sk_alloc called down the rtnetlink_init callchain is done after the PF_NETLINK
net_proto_family is sock_registered, and because of that the af_netlink init function
call had to be moved to earlier by means of subsys_initcall (DaveM's suggestion).
parent 5ac8451a
...@@ -140,6 +140,9 @@ struct net_proto_family { ...@@ -140,6 +140,9 @@ struct net_proto_family {
struct module *owner; struct module *owner;
}; };
extern int net_family_get(int family);
extern void net_family_put(int family);
struct iovec; struct iovec;
extern int sock_wake_async(struct socket *sk, int how, int band); extern int sock_wake_async(struct socket *sk, int how, int band);
......
...@@ -589,8 +589,10 @@ static kmem_cache_t *sk_cachep; ...@@ -589,8 +589,10 @@ static kmem_cache_t *sk_cachep;
*/ */
struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab) struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab)
{ {
struct sock *sk; struct sock *sk = NULL;
if (!net_family_get(family))
goto out;
if (!slab) if (!slab)
slab = sk_cachep; slab = sk_cachep;
sk = kmem_cache_alloc(slab, priority); sk = kmem_cache_alloc(slab, priority);
...@@ -602,14 +604,16 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab) ...@@ -602,14 +604,16 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab)
sock_lock_init(sk); sock_lock_init(sk);
} }
sk->slab = slab; sk->slab = slab;
} } else
net_family_put(family);
out:
return sk; return sk;
} }
void sk_free(struct sock *sk) void sk_free(struct sock *sk)
{ {
struct sk_filter *filter; struct sk_filter *filter;
const int family = sk->family;
if (sk->destruct) if (sk->destruct)
sk->destruct(sk); sk->destruct(sk);
...@@ -624,6 +628,7 @@ void sk_free(struct sock *sk) ...@@ -624,6 +628,7 @@ void sk_free(struct sock *sk)
printk(KERN_DEBUG "sk_free: optmem leakage (%d bytes) detected.\n", atomic_read(&sk->omem_alloc)); printk(KERN_DEBUG "sk_free: optmem leakage (%d bytes) detected.\n", atomic_read(&sk->omem_alloc));
kmem_cache_free(sk->slab, sk); kmem_cache_free(sk->slab, sk);
net_family_put(family);
} }
void __init sk_init(void) void __init sk_init(void)
......
...@@ -1052,6 +1052,7 @@ struct proto_ops netlink_ops = { ...@@ -1052,6 +1052,7 @@ struct proto_ops netlink_ops = {
struct net_proto_family netlink_family_ops = { struct net_proto_family netlink_family_ops = {
.family = PF_NETLINK, .family = PF_NETLINK,
.create = netlink_create, .create = netlink_create,
.owner = THIS_MODULE, /* for consistency 8) */
}; };
static int __init netlink_proto_init(void) static int __init netlink_proto_init(void)
...@@ -1065,6 +1066,11 @@ static int __init netlink_proto_init(void) ...@@ -1065,6 +1066,11 @@ static int __init netlink_proto_init(void)
sock_register(&netlink_family_ops); sock_register(&netlink_family_ops);
#ifdef CONFIG_PROC_FS #ifdef CONFIG_PROC_FS
create_proc_read_entry("net/netlink", 0, 0, netlink_read_proc, NULL); create_proc_read_entry("net/netlink", 0, 0, netlink_read_proc, NULL);
#endif
/* The netlink device handler may be needed early. */
rtnetlink_init();
#ifdef CONFIG_NETLINK_DEV
init_netlink();
#endif #endif
return 0; return 0;
} }
...@@ -1075,7 +1081,7 @@ static void __exit netlink_proto_exit(void) ...@@ -1075,7 +1081,7 @@ static void __exit netlink_proto_exit(void)
remove_proc_entry("net/netlink", NULL); remove_proc_entry("net/netlink", NULL);
} }
module_init(netlink_proto_init); subsys_initcall(netlink_proto_init);
module_exit(netlink_proto_exit); module_exit(netlink_proto_exit);
MODULE_LICENSE("GPL"); MODULE_LICENSE("GPL");
...@@ -69,8 +69,6 @@ ...@@ -69,8 +69,6 @@
#include <linux/proc_fs.h> #include <linux/proc_fs.h>
#include <linux/seq_file.h> #include <linux/seq_file.h>
#include <linux/wanrouter.h> #include <linux/wanrouter.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <linux/if_bridge.h> #include <linux/if_bridge.h>
#include <linux/init.h> #include <linux/init.h>
#include <linux/poll.h> #include <linux/poll.h>
...@@ -143,6 +141,31 @@ static struct file_operations socket_file_ops = { ...@@ -143,6 +141,31 @@ static struct file_operations socket_file_ops = {
static struct net_proto_family *net_families[NPROTO]; static struct net_proto_family *net_families[NPROTO];
static __inline__ void net_family_bug(int family)
{
printk(KERN_ERR "%d is not yet sock_registered!\n", family);
BUG();
}
int net_family_get(int family)
{
int rc = 1;
if (likely(net_families[family] != NULL))
rc = try_module_get(net_families[family]->owner);
else
net_family_bug(family);
return rc;
}
void net_family_put(int family)
{
if (likely(net_families[family] != NULL))
module_put(net_families[family]->owner);
else
net_family_bug(family);
}
#if defined(CONFIG_SMP) || defined(CONFIG_PREEMPT) #if defined(CONFIG_SMP) || defined(CONFIG_PREEMPT)
static atomic_t net_family_lockct = ATOMIC_INIT(0); static atomic_t net_family_lockct = ATOMIC_INIT(0);
static spinlock_t net_family_lock = SPIN_LOCK_UNLOCKED; static spinlock_t net_family_lock = SPIN_LOCK_UNLOCKED;
...@@ -511,7 +534,7 @@ void sock_release(struct socket *sock) ...@@ -511,7 +534,7 @@ void sock_release(struct socket *sock)
sock->ops->release(sock); sock->ops->release(sock);
sock->ops = NULL; sock->ops = NULL;
module_put(net_families[family]->owner); net_family_put(family);
} }
if (sock->fasync_list) if (sock->fasync_list)
...@@ -1064,7 +1087,7 @@ int sock_create(int family, int type, int protocol, struct socket **res) ...@@ -1064,7 +1087,7 @@ int sock_create(int family, int type, int protocol, struct socket **res)
sock->type = type; sock->type = type;
i = -EBUSY; i = -EBUSY;
if (!try_module_get(net_families[family]->owner)) if (!net_family_get(family))
goto out_release; goto out_release;
if ((i = net_families[family]->create(sock, protocol)) < 0) if ((i = net_families[family]->create(sock, protocol)) < 0)
...@@ -1953,17 +1976,6 @@ void __init sock_init(void) ...@@ -1953,17 +1976,6 @@ void __init sock_init(void)
* do_initcalls is run. * do_initcalls is run.
*/ */
/*
* The netlink device handler may be needed early.
*/
#ifdef CONFIG_NET
rtnetlink_init();
#endif
#ifdef CONFIG_NETLINK_DEV
init_netlink();
#endif
#ifdef CONFIG_NETFILTER #ifdef CONFIG_NETFILTER
netfilter_init(); netfilter_init();
#endif #endif
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment