o net: improve the current module infrastructure

As per discussions in netdev we'll probably be moving to a brand new scheme, but this
set of changesets have been discussed and are an improvement to the current situation
and were already done prior to this thread happening.
parent f48b1d88
...@@ -89,9 +89,11 @@ struct page; ...@@ -89,9 +89,11 @@ struct page;
struct kiocb; struct kiocb;
struct sockaddr; struct sockaddr;
struct msghdr; struct msghdr;
struct module;
struct proto_ops { struct proto_ops {
int family; int family;
struct module *owner;
int (*release) (struct socket *sock); int (*release) (struct socket *sock);
int (*bind) (struct socket *sock, int (*bind) (struct socket *sock,
struct sockaddr *umyaddr, struct sockaddr *umyaddr,
...@@ -127,8 +129,6 @@ struct proto_ops { ...@@ -127,8 +129,6 @@ struct proto_ops {
int offset, size_t size, int flags); int offset, size_t size, int flags);
}; };
struct module;
struct net_proto_family { struct net_proto_family {
int family; int family;
int (*create)(struct socket *sock, int protocol); int (*create)(struct socket *sock, int protocol);
...@@ -140,9 +140,6 @@ struct net_proto_family { ...@@ -140,9 +140,6 @@ struct net_proto_family {
struct module *owner; struct module *owner;
}; };
extern int net_family_get(int family);
extern void net_family_put(int family);
struct iovec; struct iovec;
extern int sock_wake_async(struct socket *sk, int how, int band); extern int sock_wake_async(struct socket *sk, int how, int band);
...@@ -227,7 +224,7 @@ SOCKCALL_WRAP(name, mmap, (struct file *file, struct socket *sock, struct vm_are ...@@ -227,7 +224,7 @@ SOCKCALL_WRAP(name, mmap, (struct file *file, struct socket *sock, struct vm_are
\ \
static struct proto_ops name##_ops = { \ static struct proto_ops name##_ops = { \
.family = fam, \ .family = fam, \
\ .owner = THIS_MODULE, \
.release = __lock_##name##_release, \ .release = __lock_##name##_release, \
.bind = __lock_##name##_bind, \ .bind = __lock_##name##_bind, \
.connect = __lock_##name##_connect, \ .connect = __lock_##name##_connect, \
......
...@@ -43,7 +43,7 @@ ...@@ -43,7 +43,7 @@
#include <linux/config.h> #include <linux/config.h>
#include <linux/timer.h> #include <linux/timer.h>
#include <linux/cache.h> #include <linux/cache.h>
#include <linux/module.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/skbuff.h> /* struct sk_buff */ #include <linux/skbuff.h> /* struct sk_buff */
#include <linux/security.h> #include <linux/security.h>
...@@ -197,6 +197,7 @@ struct sock { ...@@ -197,6 +197,7 @@ struct sock {
void *user_data; void *user_data;
/* Callbacks */ /* Callbacks */
struct module *owner;
void (*state_change)(struct sock *sk); void (*state_change)(struct sock *sk);
void (*data_ready)(struct sock *sk,int bytes); void (*data_ready)(struct sock *sk,int bytes);
void (*write_space)(struct sock *sk); void (*write_space)(struct sock *sk);
...@@ -270,6 +271,23 @@ struct proto { ...@@ -270,6 +271,23 @@ struct proto {
} stats[NR_CPUS]; } stats[NR_CPUS];
}; };
static __inline__ void sk_set_owner(struct sock *sk, struct module *owner)
{
/*
* One should use sk_set_owner just once, after struct sock creation,
* be it shortly after sk_alloc or after a function that returns a new
* struct sock (and that down the call chain called sk_alloc), e.g. the
* IPv4 and IPv6 modules share tcp_create_openreq_child, so if
* tcp_create_openreq_child called sk_set_owner IPv6 would have to
* change the ownership of this struct sock, with one not needed
* transient sk_set_owner call.
*/
if (unlikely(sk->owner != NULL))
BUG();
sk->owner = owner;
__module_get(owner);
}
/* Called with local bh disabled */ /* Called with local bh disabled */
static __inline__ void sock_prot_inc_use(struct proto *prot) static __inline__ void sock_prot_inc_use(struct proto *prot)
{ {
......
...@@ -591,8 +591,6 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab) ...@@ -591,8 +591,6 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab)
{ {
struct sock *sk = NULL; struct sock *sk = NULL;
if (!net_family_get(family))
goto out;
if (!slab) if (!slab)
slab = sk_cachep; slab = sk_cachep;
sk = kmem_cache_alloc(slab, priority); sk = kmem_cache_alloc(slab, priority);
...@@ -604,16 +602,14 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab) ...@@ -604,16 +602,14 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab)
sock_lock_init(sk); sock_lock_init(sk);
} }
sk->slab = slab; sk->slab = slab;
} else }
net_family_put(family);
out:
return sk; return sk;
} }
void sk_free(struct sock *sk) void sk_free(struct sock *sk)
{ {
struct sk_filter *filter; struct sk_filter *filter;
const int family = sk->family; struct module *owner = sk->owner;
if (sk->destruct) if (sk->destruct)
sk->destruct(sk); sk->destruct(sk);
...@@ -628,7 +624,7 @@ void sk_free(struct sock *sk) ...@@ -628,7 +624,7 @@ void sk_free(struct sock *sk)
printk(KERN_DEBUG "sk_free: optmem leakage (%d bytes) detected.\n", atomic_read(&sk->omem_alloc)); printk(KERN_DEBUG "sk_free: optmem leakage (%d bytes) detected.\n", atomic_read(&sk->omem_alloc));
kmem_cache_free(sk->slab, sk); kmem_cache_free(sk->slab, sk);
net_family_put(family); module_put(owner);
} }
void __init sk_init(void) void __init sk_init(void)
...@@ -1112,6 +1108,7 @@ void sock_init_data(struct socket *sock, struct sock *sk) ...@@ -1112,6 +1108,7 @@ void sock_init_data(struct socket *sock, struct sock *sk)
sk->rcvlowat = 1; sk->rcvlowat = 1;
sk->rcvtimeo = MAX_SCHEDULE_TIMEOUT; sk->rcvtimeo = MAX_SCHEDULE_TIMEOUT;
sk->sndtimeo = MAX_SCHEDULE_TIMEOUT; sk->sndtimeo = MAX_SCHEDULE_TIMEOUT;
sk->owner = NULL;
atomic_set(&sk->refcnt, 1); atomic_set(&sk->refcnt, 1);
} }
...@@ -141,36 +141,6 @@ static struct file_operations socket_file_ops = { ...@@ -141,36 +141,6 @@ static struct file_operations socket_file_ops = {
static struct net_proto_family *net_families[NPROTO]; static struct net_proto_family *net_families[NPROTO];
static __inline__ void net_family_bug(int family)
{
printk(KERN_ERR "%d is not yet sock_registered!\n", family);
BUG();
}
int net_family_get(int family)
{
struct net_proto_family *prot = net_families[family];
int rc = 1;
barrier();
if (likely(prot != NULL))
rc = try_module_get(prot->owner);
else
net_family_bug(family);
return rc;
}
void net_family_put(int family)
{
struct net_proto_family *prot = net_families[family];
barrier();
if (likely(prot != NULL))
module_put(prot->owner);
else
net_family_bug(family);
}
#if defined(CONFIG_SMP) || defined(CONFIG_PREEMPT) #if defined(CONFIG_SMP) || defined(CONFIG_PREEMPT)
static atomic_t net_family_lockct = ATOMIC_INIT(0); static atomic_t net_family_lockct = ATOMIC_INIT(0);
static spinlock_t net_family_lock = SPIN_LOCK_UNLOCKED; static spinlock_t net_family_lock = SPIN_LOCK_UNLOCKED;
...@@ -535,11 +505,11 @@ struct file_operations bad_sock_fops = { ...@@ -535,11 +505,11 @@ struct file_operations bad_sock_fops = {
void sock_release(struct socket *sock) void sock_release(struct socket *sock)
{ {
if (sock->ops) { if (sock->ops) {
const int family = sock->ops->family; struct module *owner = sock->ops->owner;
sock->ops->release(sock); sock->ops->release(sock);
sock->ops = NULL; sock->ops = NULL;
net_family_put(family); module_put(owner);
} }
if (sock->fasync_list) if (sock->fasync_list)
...@@ -1091,19 +1061,37 @@ int sock_create(int family, int type, int protocol, struct socket **res) ...@@ -1091,19 +1061,37 @@ int sock_create(int family, int type, int protocol, struct socket **res)
sock->type = type; sock->type = type;
/*
* We will call the ->create function, that possibly is in a loadable
* module, so we have to bump that loadable module refcnt first.
*/
i = -EAFNOSUPPORT; i = -EAFNOSUPPORT;
if (!net_family_get(family)) if (!try_module_get(net_families[family]->owner))
goto out_release; goto out_release;
if ((i = net_families[family]->create(sock, protocol)) < 0) if ((i = net_families[family]->create(sock, protocol)) < 0)
goto out_release; goto out_module_put;
/*
* Now to bump the refcnt of the [loadable] module that owns this
* socket at sock_release time we decrement its refcnt.
*/
if (!try_module_get(sock->ops->owner)) {
sock->ops = NULL;
goto out_module_put;
}
/*
* Now that we're done with the ->create function, the [loadable]
* module can have its refcnt decremented
*/
module_put(net_families[family]->owner);
*res = sock; *res = sock;
security_socket_post_create(sock, family, type, protocol); security_socket_post_create(sock, family, type, protocol);
out: out:
net_family_read_unlock(); net_family_read_unlock();
return i; return i;
out_module_put:
module_put(net_families[family]->owner);
out_release: out_release:
sock_release(sock); sock_release(sock);
goto out; goto out;
...@@ -1288,28 +1276,30 @@ asmlinkage long sys_accept(int fd, struct sockaddr *upeer_sockaddr, int *upeer_a ...@@ -1288,28 +1276,30 @@ asmlinkage long sys_accept(int fd, struct sockaddr *upeer_sockaddr, int *upeer_a
if (err) if (err)
goto out_release; goto out_release;
err = -EAFNOSUPPORT; /*
if (!net_family_get(sock->ops->family)) * We don't need try_module_get here, as the listening socket (sock)
goto out_release; * has the protocol module (sock->ops->owner) held.
*/
__module_get(sock->ops->owner);
err = sock->ops->accept(sock, newsock, sock->file->f_flags); err = sock->ops->accept(sock, newsock, sock->file->f_flags);
if (err < 0) if (err < 0)
goto out_family_put; goto out_module_put;
if (upeer_sockaddr) { if (upeer_sockaddr) {
if(newsock->ops->getname(newsock, (struct sockaddr *)address, &len, 2)<0) { if(newsock->ops->getname(newsock, (struct sockaddr *)address, &len, 2)<0) {
err = -ECONNABORTED; err = -ECONNABORTED;
goto out_family_put; goto out_module_put;
} }
err = move_addr_to_user(address, len, upeer_sockaddr, upeer_addrlen); err = move_addr_to_user(address, len, upeer_sockaddr, upeer_addrlen);
if (err < 0) if (err < 0)
goto out_family_put; goto out_module_put;
} }
/* File flags are not inherited via accept() unlike another OSes. */ /* File flags are not inherited via accept() unlike another OSes. */
if ((err = sock_map_fd(newsock)) < 0) if ((err = sock_map_fd(newsock)) < 0)
goto out_family_put; goto out_module_put;
security_socket_post_accept(sock, newsock); security_socket_post_accept(sock, newsock);
...@@ -1317,8 +1307,8 @@ asmlinkage long sys_accept(int fd, struct sockaddr *upeer_sockaddr, int *upeer_a ...@@ -1317,8 +1307,8 @@ asmlinkage long sys_accept(int fd, struct sockaddr *upeer_sockaddr, int *upeer_a
sockfd_put(sock); sockfd_put(sock);
out: out:
return err; return err;
out_family_put: out_module_put:
net_family_put(sock->ops->family); module_put(sock->ops->owner);
out_release: out_release:
sock_release(newsock); sock_release(newsock);
goto out_put; goto out_put;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment