Commit 4b7761b4 authored by Trond Myklebust's avatar Trond Myklebust

[PATCH] Cleanup for SunRPC auth code

Converts the RPC client auth code to use 'list_head' rather than a
custom pointer scheme.

Fixes a (relatively harmless) race which could cause several cred
entries to be created for the same user.
parent 0c5bb195
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
* Client user credentials * Client user credentials
*/ */
struct rpc_cred { struct rpc_cred {
struct rpc_cred * cr_next; /* linked list */ struct list_head cr_hash; /* hash chain */
struct rpc_auth * cr_auth; struct rpc_auth * cr_auth;
struct rpc_credops * cr_ops; struct rpc_credops * cr_ops;
unsigned long cr_expire; /* when to gc */ unsigned long cr_expire; /* when to gc */
...@@ -49,7 +49,7 @@ struct rpc_cred { ...@@ -49,7 +49,7 @@ struct rpc_cred {
#define RPC_CREDCACHE_NR 8 #define RPC_CREDCACHE_NR 8
#define RPC_CREDCACHE_MASK (RPC_CREDCACHE_NR - 1) #define RPC_CREDCACHE_MASK (RPC_CREDCACHE_NR - 1)
struct rpc_auth { struct rpc_auth {
struct rpc_cred * au_credcache[RPC_CREDCACHE_NR]; struct list_head au_credcache[RPC_CREDCACHE_NR];
unsigned long au_expire; /* cache expiry interval */ unsigned long au_expire; /* cache expiry interval */
unsigned long au_nextgc; /* next garbage collection */ unsigned long au_nextgc; /* next garbage collection */
unsigned int au_cslack; /* call cred size estimate */ unsigned int au_cslack; /* call cred size estimate */
...@@ -101,8 +101,6 @@ struct rpc_cred * rpcauth_bindcred(struct rpc_task *); ...@@ -101,8 +101,6 @@ struct rpc_cred * rpcauth_bindcred(struct rpc_task *);
void rpcauth_holdcred(struct rpc_task *); void rpcauth_holdcred(struct rpc_task *);
void put_rpccred(struct rpc_cred *); void put_rpccred(struct rpc_cred *);
void rpcauth_unbindcred(struct rpc_task *); void rpcauth_unbindcred(struct rpc_task *);
int rpcauth_matchcred(struct rpc_auth *,
struct rpc_cred *, int);
u32 * rpcauth_marshcred(struct rpc_task *, u32 *); u32 * rpcauth_marshcred(struct rpc_task *, u32 *);
u32 * rpcauth_checkverf(struct rpc_task *, u32 *); u32 * rpcauth_checkverf(struct rpc_task *, u32 *);
int rpcauth_refreshcred(struct rpc_task *); int rpcauth_refreshcred(struct rpc_task *);
...@@ -110,8 +108,6 @@ void rpcauth_invalcred(struct rpc_task *); ...@@ -110,8 +108,6 @@ void rpcauth_invalcred(struct rpc_task *);
int rpcauth_uptodatecred(struct rpc_task *); int rpcauth_uptodatecred(struct rpc_task *);
void rpcauth_init_credcache(struct rpc_auth *); void rpcauth_init_credcache(struct rpc_auth *);
void rpcauth_free_credcache(struct rpc_auth *); void rpcauth_free_credcache(struct rpc_auth *);
void rpcauth_insert_credcache(struct rpc_auth *,
struct rpc_cred *);
static inline static inline
struct rpc_cred * get_rpccred(struct rpc_cred *cred) struct rpc_cred * get_rpccred(struct rpc_cred *cred)
......
...@@ -75,7 +75,9 @@ static spinlock_t rpc_credcache_lock = SPIN_LOCK_UNLOCKED; ...@@ -75,7 +75,9 @@ static spinlock_t rpc_credcache_lock = SPIN_LOCK_UNLOCKED;
void void
rpcauth_init_credcache(struct rpc_auth *auth) rpcauth_init_credcache(struct rpc_auth *auth)
{ {
memset(auth->au_credcache, 0, sizeof(auth->au_credcache)); int i;
for (i = 0; i < RPC_CREDCACHE_NR; i++)
INIT_LIST_HEAD(&auth->au_credcache[i]);
auth->au_nextgc = jiffies + (auth->au_expire >> 1); auth->au_nextgc = jiffies + (auth->au_expire >> 1);
} }
...@@ -86,11 +88,10 @@ static inline void ...@@ -86,11 +88,10 @@ static inline void
rpcauth_crdestroy(struct rpc_cred *cred) rpcauth_crdestroy(struct rpc_cred *cred)
{ {
#ifdef RPC_DEBUG #ifdef RPC_DEBUG
if (cred->cr_magic != RPCAUTH_CRED_MAGIC) BUG_ON(cred->cr_magic != RPCAUTH_CRED_MAGIC ||
BUG(); atomic_read(&cred->cr_count) ||
!list_empty(&cred->cr_hash));
cred->cr_magic = 0; cred->cr_magic = 0;
if (atomic_read(&cred->cr_count) || cred->cr_auth)
BUG();
#endif #endif
cred->cr_ops->crdestroy(cred); cred->cr_ops->crdestroy(cred);
} }
...@@ -99,12 +100,13 @@ rpcauth_crdestroy(struct rpc_cred *cred) ...@@ -99,12 +100,13 @@ rpcauth_crdestroy(struct rpc_cred *cred)
* Destroy a list of credentials * Destroy a list of credentials
*/ */
static inline static inline
void rpcauth_destroy_credlist(struct rpc_cred *head) void rpcauth_destroy_credlist(struct list_head *head)
{ {
struct rpc_cred *cred; struct rpc_cred *cred;
while ((cred = head) != NULL) { while (!list_empty(head)) {
head = cred->cr_next; cred = list_entry(head->next, struct rpc_cred, cr_hash);
list_del_init(&cred->cr_hash);
rpcauth_crdestroy(cred); rpcauth_crdestroy(cred);
} }
} }
...@@ -116,137 +118,117 @@ void rpcauth_destroy_credlist(struct rpc_cred *head) ...@@ -116,137 +118,117 @@ void rpcauth_destroy_credlist(struct rpc_cred *head)
void void
rpcauth_free_credcache(struct rpc_auth *auth) rpcauth_free_credcache(struct rpc_auth *auth)
{ {
struct rpc_cred **q, *cred, *free = NULL; LIST_HEAD(free);
struct list_head *pos, *next;
struct rpc_cred *cred;
int i; int i;
spin_lock(&rpc_credcache_lock); spin_lock(&rpc_credcache_lock);
for (i = 0; i < RPC_CREDCACHE_NR; i++) { for (i = 0; i < RPC_CREDCACHE_NR; i++) {
q = &auth->au_credcache[i]; list_for_each_safe(pos, next, &auth->au_credcache[i]) {
while ((cred = *q) != NULL) { cred = list_entry(pos, struct rpc_cred, cr_hash);
*q = cred->cr_next;
cred->cr_auth = NULL; cred->cr_auth = NULL;
if (atomic_read(&cred->cr_count) == 0) { list_del_init(&cred->cr_hash);
cred->cr_next = free; if (atomic_read(&cred->cr_count) == 0)
free = cred; list_add(&cred->cr_hash, &free);
} else
cred->cr_next = NULL;
} }
} }
spin_unlock(&rpc_credcache_lock); spin_unlock(&rpc_credcache_lock);
rpcauth_destroy_credlist(free); rpcauth_destroy_credlist(&free);
}
static inline int
rpcauth_prune_expired(struct rpc_cred *cred, struct list_head *free)
{
if (atomic_read(&cred->cr_count) != 0)
return 0;
if (time_before(jiffies, cred->cr_expire))
return 0;
cred->cr_auth = NULL;
list_del(&cred->cr_hash);
list_add(&cred->cr_hash, free);
return 1;
} }
/* /*
* Remove stale credentials. Avoid sleeping inside the loop. * Remove stale credentials. Avoid sleeping inside the loop.
*/ */
static void static void
rpcauth_gc_credcache(struct rpc_auth *auth) rpcauth_gc_credcache(struct rpc_auth *auth, struct list_head *free)
{ {
struct rpc_cred **q, *cred, *free = NULL; struct list_head *pos, *next;
struct rpc_cred *cred;
int i; int i;
dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth); dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth);
spin_lock(&rpc_credcache_lock);
for (i = 0; i < RPC_CREDCACHE_NR; i++) { for (i = 0; i < RPC_CREDCACHE_NR; i++) {
q = &auth->au_credcache[i]; list_for_each_safe(pos, next, &auth->au_credcache[i]) {
while ((cred = *q) != NULL) { cred = list_entry(pos, struct rpc_cred, cr_hash);
if (!atomic_read(&cred->cr_count) && rpcauth_prune_expired(cred, free);
time_before(cred->cr_expire, jiffies)) {
*q = cred->cr_next;
cred->cr_auth = NULL;
cred->cr_next = free;
free = cred;
continue;
}
q = &cred->cr_next;
} }
} }
spin_unlock(&rpc_credcache_lock);
rpcauth_destroy_credlist(free);
auth->au_nextgc = jiffies + auth->au_expire; auth->au_nextgc = jiffies + auth->au_expire;
} }
/*
* Insert credential into cache
*/
void
rpcauth_insert_credcache(struct rpc_auth *auth, struct rpc_cred *cred)
{
int nr;
nr = (cred->cr_uid & RPC_CREDCACHE_MASK);
spin_lock(&rpc_credcache_lock);
cred->cr_next = auth->au_credcache[nr];
auth->au_credcache[nr] = cred;
cred->cr_auth = auth;
get_rpccred(cred);
spin_unlock(&rpc_credcache_lock);
}
/* /*
* Look up a process' credentials in the authentication cache * Look up a process' credentials in the authentication cache
*/ */
static struct rpc_cred * static struct rpc_cred *
rpcauth_lookup_credcache(struct rpc_auth *auth, int taskflags) rpcauth_lookup_credcache(struct rpc_auth *auth, int taskflags)
{ {
struct rpc_cred **q, *cred = NULL; LIST_HEAD(free);
struct list_head *pos, *next;
struct rpc_cred *new = NULL,
*cred = NULL;
int nr = 0; int nr = 0;
if (!(taskflags & RPC_TASK_ROOTCREDS)) if (!(taskflags & RPC_TASK_ROOTCREDS))
nr = current->uid & RPC_CREDCACHE_MASK; nr = current->uid & RPC_CREDCACHE_MASK;
retry:
if (time_before(auth->au_nextgc, jiffies))
rpcauth_gc_credcache(auth);
spin_lock(&rpc_credcache_lock); spin_lock(&rpc_credcache_lock);
q = &auth->au_credcache[nr]; if (time_before(auth->au_nextgc, jiffies))
while ((cred = *q) != NULL) { rpcauth_gc_credcache(auth, &free);
if (!(cred->cr_flags & RPCAUTH_CRED_DEAD) && list_for_each_safe(pos, next, &auth->au_credcache[nr]) {
cred->cr_ops->crmatch(cred, taskflags)) { struct rpc_cred *entry;
*q = cred->cr_next; entry = list_entry(pos, struct rpc_cred, cr_hash);
if (entry->cr_flags & RPCAUTH_CRED_DEAD)
continue;
if (rpcauth_prune_expired(entry, &free))
continue;
if (entry->cr_ops->crmatch(entry, taskflags)) {
list_del(&entry->cr_hash);
cred = entry;
break; break;
} }
q = &cred->cr_next; }
if (new) {
if (cred)
list_add(&new->cr_hash, &free);
else
cred = new;
}
if (cred) {
list_add(&cred->cr_hash, &auth->au_credcache[nr]);
cred->cr_auth = auth;
get_rpccred(cred);
} }
spin_unlock(&rpc_credcache_lock); spin_unlock(&rpc_credcache_lock);
rpcauth_destroy_credlist(&free);
if (!cred) { if (!cred) {
cred = auth->au_ops->crcreate(taskflags); new = auth->au_ops->crcreate(taskflags);
if (new) {
#ifdef RPC_DEBUG #ifdef RPC_DEBUG
if (cred) new->cr_magic = RPCAUTH_CRED_MAGIC;
cred->cr_magic = RPCAUTH_CRED_MAGIC;
#endif #endif
goto retry;
}
} }
if (cred)
rpcauth_insert_credcache(auth, cred);
return (struct rpc_cred *) cred; return (struct rpc_cred *) cred;
} }
/*
* Remove cred handle from cache
*/
static void
rpcauth_remove_credcache(struct rpc_cred *cred)
{
struct rpc_auth *auth = cred->cr_auth;
struct rpc_cred **q, *cr;
int nr;
nr = (cred->cr_uid & RPC_CREDCACHE_MASK);
q = &auth->au_credcache[nr];
while ((cr = *q) != NULL) {
if (cred == cr) {
*q = cred->cr_next;
cred->cr_next = NULL;
cred->cr_auth = NULL;
break;
}
q = &cred->cr_next;
}
}
struct rpc_cred * struct rpc_cred *
rpcauth_lookupcred(struct rpc_auth *auth, int taskflags) rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
{ {
...@@ -268,14 +250,6 @@ rpcauth_bindcred(struct rpc_task *task) ...@@ -268,14 +250,6 @@ rpcauth_bindcred(struct rpc_task *task)
return task->tk_msg.rpc_cred; return task->tk_msg.rpc_cred;
} }
int
rpcauth_matchcred(struct rpc_auth *auth, struct rpc_cred *cred, int taskflags)
{
dprintk("RPC: matching %s cred %d\n",
auth->au_ops->au_name, taskflags);
return cred->cr_ops->crmatch(cred, taskflags);
}
void void
rpcauth_holdcred(struct rpc_task *task) rpcauth_holdcred(struct rpc_task *task)
{ {
...@@ -291,10 +265,10 @@ put_rpccred(struct rpc_cred *cred) ...@@ -291,10 +265,10 @@ put_rpccred(struct rpc_cred *cred)
if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
return; return;
if (cred->cr_auth && cred->cr_flags & RPCAUTH_CRED_DEAD) if ((cred->cr_flags & RPCAUTH_CRED_DEAD) && !list_empty(&cred->cr_hash))
rpcauth_remove_credcache(cred); list_del_init(&cred->cr_hash);
if (!cred->cr_auth) { if (list_empty(&cred->cr_hash)) {
spin_unlock(&rpc_credcache_lock); spin_unlock(&rpc_credcache_lock);
rpcauth_crdestroy(cred); rpcauth_crdestroy(cred);
return; return;
......
...@@ -60,12 +60,7 @@ EXPORT_SYMBOL(xprt_set_timeout); ...@@ -60,12 +60,7 @@ EXPORT_SYMBOL(xprt_set_timeout);
/* Client credential cache */ /* Client credential cache */
EXPORT_SYMBOL(rpcauth_register); EXPORT_SYMBOL(rpcauth_register);
EXPORT_SYMBOL(rpcauth_unregister); EXPORT_SYMBOL(rpcauth_unregister);
EXPORT_SYMBOL(rpcauth_init_credcache);
EXPORT_SYMBOL(rpcauth_free_credcache);
EXPORT_SYMBOL(rpcauth_insert_credcache);
EXPORT_SYMBOL(rpcauth_lookupcred); EXPORT_SYMBOL(rpcauth_lookupcred);
EXPORT_SYMBOL(rpcauth_bindcred);
EXPORT_SYMBOL(rpcauth_matchcred);
EXPORT_SYMBOL(put_rpccred); EXPORT_SYMBOL(put_rpccred);
/* RPC server stuff */ /* RPC server stuff */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment