Commit 4df493a2 authored by Trond Myklebust's avatar Trond Myklebust Committed by J. Bruce Fields

SUNRPC: Cache the process user cred in the RPC server listener

In order to be able to interpret uids and gids correctly in knfsd, we
should cache the user namespace of the process that created the RPC
server's listener. To do so, we refcount the credential of that process.
Signed-off-by: default avatarTrond Myklebust <trond.myklebust@hammerspace.com>
Signed-off-by: default avatarJ. Bruce Fields <bfields@redhat.com>
parent e333f3bb
...@@ -190,12 +190,13 @@ static int create_lockd_listener(struct svc_serv *serv, const char *name, ...@@ -190,12 +190,13 @@ static int create_lockd_listener(struct svc_serv *serv, const char *name,
struct net *net, const int family, struct net *net, const int family,
const unsigned short port) const unsigned short port)
{ {
const struct cred *cred = current_cred();
struct svc_xprt *xprt; struct svc_xprt *xprt;
xprt = svc_find_xprt(serv, name, net, family, 0); xprt = svc_find_xprt(serv, name, net, family, 0);
if (xprt == NULL) if (xprt == NULL)
return svc_create_xprt(serv, name, net, family, port, return svc_create_xprt(serv, name, net, family, port,
SVC_SOCK_DEFAULTS); SVC_SOCK_DEFAULTS, cred);
svc_xprt_put(xprt); svc_xprt_put(xprt);
return 0; return 0;
} }
......
...@@ -41,11 +41,13 @@ static struct svc_program nfs4_callback_program; ...@@ -41,11 +41,13 @@ static struct svc_program nfs4_callback_program;
static int nfs4_callback_up_net(struct svc_serv *serv, struct net *net) static int nfs4_callback_up_net(struct svc_serv *serv, struct net *net)
{ {
const struct cred *cred = current_cred();
int ret; int ret;
struct nfs_net *nn = net_generic(net, nfs_net_id); struct nfs_net *nn = net_generic(net, nfs_net_id);
ret = svc_create_xprt(serv, "tcp", net, PF_INET, ret = svc_create_xprt(serv, "tcp", net, PF_INET,
nfs_callback_set_tcpport, SVC_SOCK_ANONYMOUS); nfs_callback_set_tcpport, SVC_SOCK_ANONYMOUS,
cred);
if (ret <= 0) if (ret <= 0)
goto out_err; goto out_err;
nn->nfs_callback_tcpport = ret; nn->nfs_callback_tcpport = ret;
...@@ -53,7 +55,8 @@ static int nfs4_callback_up_net(struct svc_serv *serv, struct net *net) ...@@ -53,7 +55,8 @@ static int nfs4_callback_up_net(struct svc_serv *serv, struct net *net)
nn->nfs_callback_tcpport, PF_INET, net->ns.inum); nn->nfs_callback_tcpport, PF_INET, net->ns.inum);
ret = svc_create_xprt(serv, "tcp", net, PF_INET6, ret = svc_create_xprt(serv, "tcp", net, PF_INET6,
nfs_callback_set_tcpport, SVC_SOCK_ANONYMOUS); nfs_callback_set_tcpport, SVC_SOCK_ANONYMOUS,
cred);
if (ret > 0) { if (ret > 0) {
nn->nfs_callback_tcpport6 = ret; nn->nfs_callback_tcpport6 = ret;
dprintk("NFS: Callback listener port = %u (af %u, net %x)\n", dprintk("NFS: Callback listener port = %u (af %u, net %x)\n",
......
...@@ -439,7 +439,7 @@ static ssize_t write_threads(struct file *file, char *buf, size_t size) ...@@ -439,7 +439,7 @@ static ssize_t write_threads(struct file *file, char *buf, size_t size)
return rv; return rv;
if (newthreads < 0) if (newthreads < 0)
return -EINVAL; return -EINVAL;
rv = nfsd_svc(newthreads, net); rv = nfsd_svc(newthreads, net, file->f_cred);
if (rv < 0) if (rv < 0)
return rv; return rv;
} else } else
...@@ -717,7 +717,7 @@ static ssize_t __write_ports_names(char *buf, struct net *net) ...@@ -717,7 +717,7 @@ static ssize_t __write_ports_names(char *buf, struct net *net)
* a socket of a supported family/protocol, and we use it as an * a socket of a supported family/protocol, and we use it as an
* nfsd listener. * nfsd listener.
*/ */
static ssize_t __write_ports_addfd(char *buf, struct net *net) static ssize_t __write_ports_addfd(char *buf, struct net *net, const struct cred *cred)
{ {
char *mesg = buf; char *mesg = buf;
int fd, err; int fd, err;
...@@ -736,7 +736,7 @@ static ssize_t __write_ports_addfd(char *buf, struct net *net) ...@@ -736,7 +736,7 @@ static ssize_t __write_ports_addfd(char *buf, struct net *net)
if (err != 0) if (err != 0)
return err; return err;
err = svc_addsock(nn->nfsd_serv, fd, buf, SIMPLE_TRANSACTION_LIMIT); err = svc_addsock(nn->nfsd_serv, fd, buf, SIMPLE_TRANSACTION_LIMIT, cred);
if (err < 0) { if (err < 0) {
nfsd_destroy(net); nfsd_destroy(net);
return err; return err;
...@@ -751,7 +751,7 @@ static ssize_t __write_ports_addfd(char *buf, struct net *net) ...@@ -751,7 +751,7 @@ static ssize_t __write_ports_addfd(char *buf, struct net *net)
* A transport listener is added by writing it's transport name and * A transport listener is added by writing it's transport name and
* a port number. * a port number.
*/ */
static ssize_t __write_ports_addxprt(char *buf, struct net *net) static ssize_t __write_ports_addxprt(char *buf, struct net *net, const struct cred *cred)
{ {
char transport[16]; char transport[16];
struct svc_xprt *xprt; struct svc_xprt *xprt;
...@@ -769,12 +769,12 @@ static ssize_t __write_ports_addxprt(char *buf, struct net *net) ...@@ -769,12 +769,12 @@ static ssize_t __write_ports_addxprt(char *buf, struct net *net)
return err; return err;
err = svc_create_xprt(nn->nfsd_serv, transport, net, err = svc_create_xprt(nn->nfsd_serv, transport, net,
PF_INET, port, SVC_SOCK_ANONYMOUS); PF_INET, port, SVC_SOCK_ANONYMOUS, cred);
if (err < 0) if (err < 0)
goto out_err; goto out_err;
err = svc_create_xprt(nn->nfsd_serv, transport, net, err = svc_create_xprt(nn->nfsd_serv, transport, net,
PF_INET6, port, SVC_SOCK_ANONYMOUS); PF_INET6, port, SVC_SOCK_ANONYMOUS, cred);
if (err < 0 && err != -EAFNOSUPPORT) if (err < 0 && err != -EAFNOSUPPORT)
goto out_close; goto out_close;
...@@ -799,10 +799,10 @@ static ssize_t __write_ports(struct file *file, char *buf, size_t size, ...@@ -799,10 +799,10 @@ static ssize_t __write_ports(struct file *file, char *buf, size_t size,
return __write_ports_names(buf, net); return __write_ports_names(buf, net);
if (isdigit(buf[0])) if (isdigit(buf[0]))
return __write_ports_addfd(buf, net); return __write_ports_addfd(buf, net, file->f_cred);
if (isalpha(buf[0])) if (isalpha(buf[0]))
return __write_ports_addxprt(buf, net); return __write_ports_addxprt(buf, net, file->f_cred);
return -EINVAL; return -EINVAL;
} }
......
...@@ -73,7 +73,7 @@ extern const struct seq_operations nfs_exports_op; ...@@ -73,7 +73,7 @@ extern const struct seq_operations nfs_exports_op;
/* /*
* Function prototypes. * Function prototypes.
*/ */
int nfsd_svc(int nrservs, struct net *net); int nfsd_svc(int nrservs, struct net *net, const struct cred *cred);
int nfsd_dispatch(struct svc_rqst *rqstp, __be32 *statp); int nfsd_dispatch(struct svc_rqst *rqstp, __be32 *statp);
int nfsd_nrthreads(struct net *); int nfsd_nrthreads(struct net *);
......
...@@ -283,7 +283,7 @@ int nfsd_nrthreads(struct net *net) ...@@ -283,7 +283,7 @@ int nfsd_nrthreads(struct net *net)
return rv; return rv;
} }
static int nfsd_init_socks(struct net *net) static int nfsd_init_socks(struct net *net, const struct cred *cred)
{ {
int error; int error;
struct nfsd_net *nn = net_generic(net, nfsd_net_id); struct nfsd_net *nn = net_generic(net, nfsd_net_id);
...@@ -292,12 +292,12 @@ static int nfsd_init_socks(struct net *net) ...@@ -292,12 +292,12 @@ static int nfsd_init_socks(struct net *net)
return 0; return 0;
error = svc_create_xprt(nn->nfsd_serv, "udp", net, PF_INET, NFS_PORT, error = svc_create_xprt(nn->nfsd_serv, "udp", net, PF_INET, NFS_PORT,
SVC_SOCK_DEFAULTS); SVC_SOCK_DEFAULTS, cred);
if (error < 0) if (error < 0)
return error; return error;
error = svc_create_xprt(nn->nfsd_serv, "tcp", net, PF_INET, NFS_PORT, error = svc_create_xprt(nn->nfsd_serv, "tcp", net, PF_INET, NFS_PORT,
SVC_SOCK_DEFAULTS); SVC_SOCK_DEFAULTS, cred);
if (error < 0) if (error < 0)
return error; return error;
...@@ -348,7 +348,7 @@ static bool nfsd_needs_lockd(struct nfsd_net *nn) ...@@ -348,7 +348,7 @@ static bool nfsd_needs_lockd(struct nfsd_net *nn)
return nfsd_vers(nn, 2, NFSD_TEST) || nfsd_vers(nn, 3, NFSD_TEST); return nfsd_vers(nn, 2, NFSD_TEST) || nfsd_vers(nn, 3, NFSD_TEST);
} }
static int nfsd_startup_net(int nrservs, struct net *net) static int nfsd_startup_net(int nrservs, struct net *net, const struct cred *cred)
{ {
struct nfsd_net *nn = net_generic(net, nfsd_net_id); struct nfsd_net *nn = net_generic(net, nfsd_net_id);
int ret; int ret;
...@@ -359,7 +359,7 @@ static int nfsd_startup_net(int nrservs, struct net *net) ...@@ -359,7 +359,7 @@ static int nfsd_startup_net(int nrservs, struct net *net)
ret = nfsd_startup_generic(nrservs); ret = nfsd_startup_generic(nrservs);
if (ret) if (ret)
return ret; return ret;
ret = nfsd_init_socks(net); ret = nfsd_init_socks(net, cred);
if (ret) if (ret)
goto out_socks; goto out_socks;
...@@ -697,7 +697,7 @@ int nfsd_set_nrthreads(int n, int *nthreads, struct net *net) ...@@ -697,7 +697,7 @@ int nfsd_set_nrthreads(int n, int *nthreads, struct net *net)
* this is the first time nrservs is nonzero. * this is the first time nrservs is nonzero.
*/ */
int int
nfsd_svc(int nrservs, struct net *net) nfsd_svc(int nrservs, struct net *net, const struct cred *cred)
{ {
int error; int error;
bool nfsd_up_before; bool nfsd_up_before;
...@@ -719,7 +719,7 @@ nfsd_svc(int nrservs, struct net *net) ...@@ -719,7 +719,7 @@ nfsd_svc(int nrservs, struct net *net)
nfsd_up_before = nn->nfsd_net_up; nfsd_up_before = nn->nfsd_net_up;
error = nfsd_startup_net(nrservs, net); error = nfsd_startup_net(nrservs, net, cred);
if (error) if (error)
goto out_destroy; goto out_destroy;
error = nn->nfsd_serv->sv_ops->svo_setup(nn->nfsd_serv, error = nn->nfsd_serv->sv_ops->svo_setup(nn->nfsd_serv,
......
...@@ -86,6 +86,7 @@ struct svc_xprt { ...@@ -86,6 +86,7 @@ struct svc_xprt {
struct list_head xpt_users; /* callbacks on free */ struct list_head xpt_users; /* callbacks on free */
struct net *xpt_net; struct net *xpt_net;
const struct cred *xpt_cred;
struct rpc_xprt *xpt_bc_xprt; /* NFSv4.1 backchannel */ struct rpc_xprt *xpt_bc_xprt; /* NFSv4.1 backchannel */
struct rpc_xprt_switch *xpt_bc_xps; /* NFSv4.1 backchannel */ struct rpc_xprt_switch *xpt_bc_xps; /* NFSv4.1 backchannel */
}; };
...@@ -119,7 +120,8 @@ void svc_unreg_xprt_class(struct svc_xprt_class *); ...@@ -119,7 +120,8 @@ void svc_unreg_xprt_class(struct svc_xprt_class *);
void svc_xprt_init(struct net *, struct svc_xprt_class *, struct svc_xprt *, void svc_xprt_init(struct net *, struct svc_xprt_class *, struct svc_xprt *,
struct svc_serv *); struct svc_serv *);
int svc_create_xprt(struct svc_serv *, const char *, struct net *, int svc_create_xprt(struct svc_serv *, const char *, struct net *,
const int, const unsigned short, int); const int, const unsigned short, int,
const struct cred *);
void svc_xprt_do_enqueue(struct svc_xprt *xprt); void svc_xprt_do_enqueue(struct svc_xprt *xprt);
void svc_xprt_enqueue(struct svc_xprt *xprt); void svc_xprt_enqueue(struct svc_xprt *xprt);
void svc_xprt_put(struct svc_xprt *xprt); void svc_xprt_put(struct svc_xprt *xprt);
......
...@@ -59,7 +59,8 @@ void svc_drop(struct svc_rqst *); ...@@ -59,7 +59,8 @@ void svc_drop(struct svc_rqst *);
void svc_sock_update_bufs(struct svc_serv *serv); void svc_sock_update_bufs(struct svc_serv *serv);
bool svc_alien_sock(struct net *net, int fd); bool svc_alien_sock(struct net *net, int fd);
int svc_addsock(struct svc_serv *serv, const int fd, int svc_addsock(struct svc_serv *serv, const int fd,
char *name_return, const size_t len); char *name_return, const size_t len,
const struct cred *cred);
void svc_init_xprt_sock(void); void svc_init_xprt_sock(void);
void svc_cleanup_xprt_sock(void); void svc_cleanup_xprt_sock(void);
struct svc_xprt *svc_sock_create(struct svc_serv *serv, int prot); struct svc_xprt *svc_sock_create(struct svc_serv *serv, int prot);
......
...@@ -136,6 +136,7 @@ static void svc_xprt_free(struct kref *kref) ...@@ -136,6 +136,7 @@ static void svc_xprt_free(struct kref *kref)
struct module *owner = xprt->xpt_class->xcl_owner; struct module *owner = xprt->xpt_class->xcl_owner;
if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags))
svcauth_unix_info_release(xprt); svcauth_unix_info_release(xprt);
put_cred(xprt->xpt_cred);
put_net(xprt->xpt_net); put_net(xprt->xpt_net);
/* See comment on corresponding get in xs_setup_bc_tcp(): */ /* See comment on corresponding get in xs_setup_bc_tcp(): */
if (xprt->xpt_bc_xprt) if (xprt->xpt_bc_xprt)
...@@ -252,7 +253,8 @@ void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) ...@@ -252,7 +253,8 @@ void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new)
static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name, static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name,
struct net *net, const int family, struct net *net, const int family,
const unsigned short port, int flags) const unsigned short port, int flags,
const struct cred *cred)
{ {
struct svc_xprt_class *xcl; struct svc_xprt_class *xcl;
...@@ -273,6 +275,7 @@ static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name, ...@@ -273,6 +275,7 @@ static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name,
module_put(xcl->xcl_owner); module_put(xcl->xcl_owner);
return PTR_ERR(newxprt); return PTR_ERR(newxprt);
} }
newxprt->xpt_cred = get_cred(cred);
svc_add_new_perm_xprt(serv, newxprt); svc_add_new_perm_xprt(serv, newxprt);
newport = svc_xprt_local_port(newxprt); newport = svc_xprt_local_port(newxprt);
return newport; return newport;
...@@ -286,15 +289,16 @@ static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name, ...@@ -286,15 +289,16 @@ static int _svc_create_xprt(struct svc_serv *serv, const char *xprt_name,
int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, int svc_create_xprt(struct svc_serv *serv, const char *xprt_name,
struct net *net, const int family, struct net *net, const int family,
const unsigned short port, int flags) const unsigned short port, int flags,
const struct cred *cred)
{ {
int err; int err;
dprintk("svc: creating transport %s[%d]\n", xprt_name, port); dprintk("svc: creating transport %s[%d]\n", xprt_name, port);
err = _svc_create_xprt(serv, xprt_name, net, family, port, flags); err = _svc_create_xprt(serv, xprt_name, net, family, port, flags, cred);
if (err == -EPROTONOSUPPORT) { if (err == -EPROTONOSUPPORT) {
request_module("svc%s", xprt_name); request_module("svc%s", xprt_name);
err = _svc_create_xprt(serv, xprt_name, net, family, port, flags); err = _svc_create_xprt(serv, xprt_name, net, family, port, flags, cred);
} }
if (err < 0) if (err < 0)
dprintk("svc: transport %s not found, err %d\n", dprintk("svc: transport %s not found, err %d\n",
......
...@@ -1332,13 +1332,14 @@ EXPORT_SYMBOL_GPL(svc_alien_sock); ...@@ -1332,13 +1332,14 @@ EXPORT_SYMBOL_GPL(svc_alien_sock);
* @fd: file descriptor of the new listener * @fd: file descriptor of the new listener
* @name_return: pointer to buffer to fill in with name of listener * @name_return: pointer to buffer to fill in with name of listener
* @len: size of the buffer * @len: size of the buffer
* @cred: credential
* *
* Fills in socket name and returns positive length of name if successful. * Fills in socket name and returns positive length of name if successful.
* Name is terminated with '\n'. On error, returns a negative errno * Name is terminated with '\n'. On error, returns a negative errno
* value. * value.
*/ */
int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, int svc_addsock(struct svc_serv *serv, const int fd, char *name_return,
const size_t len) const size_t len, const struct cred *cred)
{ {
int err = 0; int err = 0;
struct socket *so = sockfd_lookup(fd, &err); struct socket *so = sockfd_lookup(fd, &err);
...@@ -1371,6 +1372,7 @@ int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, ...@@ -1371,6 +1372,7 @@ int svc_addsock(struct svc_serv *serv, const int fd, char *name_return,
salen = kernel_getsockname(svsk->sk_sock, sin); salen = kernel_getsockname(svsk->sk_sock, sin);
if (salen >= 0) if (salen >= 0)
svc_xprt_set_local(&svsk->sk_xprt, sin, salen); svc_xprt_set_local(&svsk->sk_xprt, sin, salen);
svsk->sk_xprt.xpt_cred = get_cred(cred);
svc_add_new_perm_xprt(serv, &svsk->sk_xprt); svc_add_new_perm_xprt(serv, &svsk->sk_xprt);
return svc_one_sock_name(svsk, name_return, len); return svc_one_sock_name(svsk, name_return, len);
out: out:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment