Commit 97f68c6b authored by NeilBrown's avatar NeilBrown Committed by Anna Schumaker

SUNRPC: add 'struct cred *' to auth_cred and rpc_cred

The SUNRPC credential framework was put together before
Linux has 'struct cred'.  Now that we have it, it makes sense to
use it.
This first step just includes a suitable 'struct cred *' pointer
in every 'struct auth_cred' and almost every 'struct rpc_cred'.

The rpc_cred used for auth_null has a NULL 'struct cred *' as nothing
else really makes sense.

For rpc_cred, the pointer is reference counted.
For auth_cred it isn't.  struct auth_cred are either allocated on
the stack, in which case the thread owns a reference to the auth,
or are part of 'struct generic_cred' in which case gc_base owns the
reference, and "acred" shares it.
Signed-off-by: default avatarNeilBrown <neilb@suse.com>
Signed-off-by: default avatarAnna Schumaker <Anna.Schumaker@Netapp.com>
parent f06bc033
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <linux/nfs_fs.h> #include <linux/nfs_fs.h>
#include <linux/nfs_page.h> #include <linux/nfs_page.h>
#include <linux/module.h> #include <linux/module.h>
#include <linux/sched/mm.h>
#include <linux/sunrpc/metrics.h> #include <linux/sunrpc/metrics.h>
...@@ -415,6 +416,7 @@ ff_layout_alloc_lseg(struct pnfs_layout_hdr *lh, ...@@ -415,6 +416,7 @@ ff_layout_alloc_lseg(struct pnfs_layout_hdr *lh,
struct nfs4_ff_layout_mirror *mirror; struct nfs4_ff_layout_mirror *mirror;
struct auth_cred acred = { .group_info = ff_zero_group }; struct auth_cred acred = { .group_info = ff_zero_group };
struct rpc_cred __rcu *cred; struct rpc_cred __rcu *cred;
struct cred *kcred;
u32 ds_count, fh_count, id; u32 ds_count, fh_count, id;
int j; int j;
...@@ -491,8 +493,23 @@ ff_layout_alloc_lseg(struct pnfs_layout_hdr *lh, ...@@ -491,8 +493,23 @@ ff_layout_alloc_lseg(struct pnfs_layout_hdr *lh,
acred.gid = make_kgid(&init_user_ns, id); acred.gid = make_kgid(&init_user_ns, id);
if (gfp_flags & __GFP_FS)
kcred = prepare_kernel_cred(NULL);
else {
unsigned int nofs_flags = memalloc_nofs_save();
kcred = prepare_kernel_cred(NULL);
memalloc_nofs_restore(nofs_flags);
}
rc = -ENOMEM;
if (!kcred)
goto out_err_free;
kcred->fsuid = acred.uid;
kcred->fsgid = acred.gid;
acred.cred = kcred;
/* find the cred for it */ /* find the cred for it */
rcu_assign_pointer(cred, rpc_lookup_generic_cred(&acred, 0, gfp_flags)); rcu_assign_pointer(cred, rpc_lookup_generic_cred(&acred, 0, gfp_flags));
put_cred(kcred);
if (IS_ERR(cred)) { if (IS_ERR(cred)) {
rc = PTR_ERR(cred); rc = PTR_ERR(cred);
goto out_err_free; goto out_err_free;
......
...@@ -858,10 +858,21 @@ static struct rpc_cred *get_backchannel_cred(struct nfs4_client *clp, struct rpc ...@@ -858,10 +858,21 @@ static struct rpc_cred *get_backchannel_cred(struct nfs4_client *clp, struct rpc
} else { } else {
struct rpc_auth *auth = client->cl_auth; struct rpc_auth *auth = client->cl_auth;
struct auth_cred acred = {}; struct auth_cred acred = {};
struct cred *kcred;
struct rpc_cred *ret;
kcred = prepare_kernel_cred(NULL);
if (!kcred)
return NULL;
acred.uid = ses->se_cb_sec.uid; acred.uid = ses->se_cb_sec.uid;
acred.gid = ses->se_cb_sec.gid; acred.gid = ses->se_cb_sec.gid;
return auth->au_ops->lookup_cred(client->cl_auth, &acred, 0); kcred->uid = acred.uid;
kcred->gid = acred.gid;
acred.cred = kcred;
ret = auth->au_ops->lookup_cred(client->cl_auth, &acred, 0);
put_cred(kcred);
return ret;
} }
} }
......
...@@ -46,6 +46,7 @@ enum { ...@@ -46,6 +46,7 @@ enum {
/* Work around the lack of a VFS credential */ /* Work around the lack of a VFS credential */
struct auth_cred { struct auth_cred {
const struct cred *cred;
kuid_t uid; kuid_t uid;
kgid_t gid; kgid_t gid;
struct group_info *group_info; struct group_info *group_info;
...@@ -68,6 +69,7 @@ struct rpc_cred { ...@@ -68,6 +69,7 @@ struct rpc_cred {
unsigned long cr_expire; /* when to gc */ unsigned long cr_expire; /* when to gc */
unsigned long cr_flags; /* various flags */ unsigned long cr_flags; /* various flags */
refcount_t cr_count; /* ref count */ refcount_t cr_count; /* ref count */
const struct cred *cr_cred;
kuid_t cr_uid; kuid_t cr_uid;
......
...@@ -659,6 +659,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags) ...@@ -659,6 +659,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags)
acred.uid = cred->fsuid; acred.uid = cred->fsuid;
acred.gid = cred->fsgid; acred.gid = cred->fsgid;
acred.group_info = cred->group_info; acred.group_info = cred->group_info;
acred.cred = cred;
ret = auth->au_ops->lookup_cred(auth, &acred, flags); ret = auth->au_ops->lookup_cred(auth, &acred, flags);
return ret; return ret;
} }
...@@ -674,6 +675,7 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, ...@@ -674,6 +675,7 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
cred->cr_auth = auth; cred->cr_auth = auth;
cred->cr_ops = ops; cred->cr_ops = ops;
cred->cr_expire = jiffies; cred->cr_expire = jiffies;
cred->cr_cred = get_cred(acred->cred);
cred->cr_uid = acred->uid; cred->cr_uid = acred->uid;
} }
EXPORT_SYMBOL_GPL(rpcauth_init_cred); EXPORT_SYMBOL_GPL(rpcauth_init_cred);
...@@ -694,11 +696,15 @@ rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) ...@@ -694,11 +696,15 @@ rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
struct auth_cred acred = { struct auth_cred acred = {
.uid = GLOBAL_ROOT_UID, .uid = GLOBAL_ROOT_UID,
.gid = GLOBAL_ROOT_GID, .gid = GLOBAL_ROOT_GID,
.cred = get_task_cred(&init_task),
}; };
struct rpc_cred *ret;
dprintk("RPC: %5u looking up %s cred\n", dprintk("RPC: %5u looking up %s cred\n",
task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
return auth->au_ops->lookup_cred(auth, &acred, lookupflags); ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
put_cred(acred.cred);
return ret;
} }
static struct rpc_cred * static struct rpc_cred *
......
...@@ -61,11 +61,15 @@ struct rpc_cred *rpc_lookup_machine_cred(const char *service_name) ...@@ -61,11 +61,15 @@ struct rpc_cred *rpc_lookup_machine_cred(const char *service_name)
.gid = RPC_MACHINE_CRED_GROUPID, .gid = RPC_MACHINE_CRED_GROUPID,
.principal = service_name, .principal = service_name,
.machine_cred = 1, .machine_cred = 1,
.cred = get_task_cred(&init_task),
}; };
struct rpc_cred *ret;
dprintk("RPC: looking up machine cred for service %s\n", dprintk("RPC: looking up machine cred for service %s\n",
service_name); service_name);
return generic_auth.au_ops->lookup_cred(&generic_auth, &acred, 0); ret = generic_auth.au_ops->lookup_cred(&generic_auth, &acred, 0);
put_cred(acred.cred);
return ret;
} }
EXPORT_SYMBOL_GPL(rpc_lookup_machine_cred); EXPORT_SYMBOL_GPL(rpc_lookup_machine_cred);
...@@ -110,6 +114,7 @@ generic_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, g ...@@ -110,6 +114,7 @@ generic_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, g
gcred->acred.uid = acred->uid; gcred->acred.uid = acred->uid;
gcred->acred.gid = acred->gid; gcred->acred.gid = acred->gid;
gcred->acred.group_info = acred->group_info; gcred->acred.group_info = acred->group_info;
gcred->acred.cred = gcred->gc_base.cr_cred;
gcred->acred.ac_flags = 0; gcred->acred.ac_flags = 0;
if (gcred->acred.group_info != NULL) if (gcred->acred.group_info != NULL)
get_group_info(gcred->acred.group_info); get_group_info(gcred->acred.group_info);
...@@ -132,6 +137,7 @@ generic_free_cred(struct rpc_cred *cred) ...@@ -132,6 +137,7 @@ generic_free_cred(struct rpc_cred *cred)
dprintk("RPC: generic_free_cred %p\n", gcred); dprintk("RPC: generic_free_cred %p\n", gcred);
if (gcred->acred.group_info != NULL) if (gcred->acred.group_info != NULL)
put_group_info(gcred->acred.group_info); put_group_info(gcred->acred.group_info);
put_cred(cred->cr_cred);
kfree(gcred); kfree(gcred);
} }
......
...@@ -1343,6 +1343,7 @@ gss_destroy_nullcred(struct rpc_cred *cred) ...@@ -1343,6 +1343,7 @@ gss_destroy_nullcred(struct rpc_cred *cred)
struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1);
RCU_INIT_POINTER(gss_cred->gc_ctx, NULL); RCU_INIT_POINTER(gss_cred->gc_ctx, NULL);
put_cred(cred->cr_cred);
call_rcu(&cred->cr_rcu, gss_free_cred_callback); call_rcu(&cred->cr_rcu, gss_free_cred_callback);
if (ctx) if (ctx)
gss_put_ctx(ctx); gss_put_ctx(ctx);
...@@ -1608,6 +1609,7 @@ static int gss_renew_cred(struct rpc_task *task) ...@@ -1608,6 +1609,7 @@ static int gss_renew_cred(struct rpc_task *task)
struct rpc_auth *auth = oldcred->cr_auth; struct rpc_auth *auth = oldcred->cr_auth;
struct auth_cred acred = { struct auth_cred acred = {
.uid = oldcred->cr_uid, .uid = oldcred->cr_uid,
.cred = oldcred->cr_cred,
.principal = gss_cred->gc_principal, .principal = gss_cred->gc_principal,
.machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0), .machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0),
}; };
......
...@@ -97,6 +97,7 @@ static void ...@@ -97,6 +97,7 @@ static void
unx_free_cred(struct unx_cred *unx_cred) unx_free_cred(struct unx_cred *unx_cred)
{ {
dprintk("RPC: unx_free_cred %p\n", unx_cred); dprintk("RPC: unx_free_cred %p\n", unx_cred);
put_cred(unx_cred->uc_base.cr_cred);
kfree(unx_cred); kfree(unx_cred);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment