Commit 97cf0ef9 authored by David S. Miller's avatar David S. Miller

Merge branch 'improve-msg_control-kernel-vs-user-pointer-handling'

Christoph Hellwig says:

====================
improve msg_control kernel vs user pointer handling

this series replace the msg_control in the kernel msghdr structure
with an anonymous union and separate fields for kernel vs user
pointers.  In addition to helping a bit with type safety and reducing
sparse warnings, this also allows to remove the set_fs() in
kernel_recvmsg, helping with an eventual entire removal of set_fs().
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 3242956b 1f466e1f
...@@ -50,7 +50,17 @@ struct msghdr { ...@@ -50,7 +50,17 @@ struct msghdr {
void *msg_name; /* ptr to socket address structure */ void *msg_name; /* ptr to socket address structure */
int msg_namelen; /* size of socket address structure */ int msg_namelen; /* size of socket address structure */
struct iov_iter msg_iter; /* data */ struct iov_iter msg_iter; /* data */
void *msg_control; /* ancillary data */
/*
* Ancillary data. msg_control_user is the user buffer used for the
* recv* side when msg_control_is_user is set, msg_control is the kernel
* buffer used for all other cases.
*/
union {
void *msg_control;
void __user *msg_control_user;
};
bool msg_control_is_user : 1;
__kernel_size_t msg_controllen; /* ancillary data buffer length */ __kernel_size_t msg_controllen; /* ancillary data buffer length */
unsigned int msg_flags; /* flags on received message */ unsigned int msg_flags; /* flags on received message */
struct kiocb *msg_iocb; /* ptr to iocb for async requests */ struct kiocb *msg_iocb; /* ptr to iocb for async requests */
...@@ -94,7 +104,10 @@ struct cmsghdr { ...@@ -94,7 +104,10 @@ struct cmsghdr {
#define CMSG_ALIGN(len) ( ((len)+sizeof(long)-1) & ~(sizeof(long)-1) ) #define CMSG_ALIGN(len) ( ((len)+sizeof(long)-1) & ~(sizeof(long)-1) )
#define CMSG_DATA(cmsg) ((void *)((char *)(cmsg) + sizeof(struct cmsghdr))) #define CMSG_DATA(cmsg) \
((void *)(cmsg) + sizeof(struct cmsghdr))
#define CMSG_USER_DATA(cmsg) \
((void __user *)(cmsg) + sizeof(struct cmsghdr))
#define CMSG_SPACE(len) (sizeof(struct cmsghdr) + CMSG_ALIGN(len)) #define CMSG_SPACE(len) (sizeof(struct cmsghdr) + CMSG_ALIGN(len))
#define CMSG_LEN(len) (sizeof(struct cmsghdr) + (len)) #define CMSG_LEN(len) (sizeof(struct cmsghdr) + (len))
......
...@@ -56,7 +56,8 @@ int __get_compat_msghdr(struct msghdr *kmsg, ...@@ -56,7 +56,8 @@ int __get_compat_msghdr(struct msghdr *kmsg,
if (kmsg->msg_namelen > sizeof(struct sockaddr_storage)) if (kmsg->msg_namelen > sizeof(struct sockaddr_storage))
kmsg->msg_namelen = sizeof(struct sockaddr_storage); kmsg->msg_namelen = sizeof(struct sockaddr_storage);
kmsg->msg_control = compat_ptr(msg.msg_control); kmsg->msg_control_is_user = true;
kmsg->msg_control_user = compat_ptr(msg.msg_control);
kmsg->msg_controllen = msg.msg_controllen; kmsg->msg_controllen = msg.msg_controllen;
if (save_addr) if (save_addr)
...@@ -121,7 +122,7 @@ int get_compat_msghdr(struct msghdr *kmsg, ...@@ -121,7 +122,7 @@ int get_compat_msghdr(struct msghdr *kmsg,
((ucmlen) >= sizeof(struct compat_cmsghdr) && \ ((ucmlen) >= sizeof(struct compat_cmsghdr) && \
(ucmlen) <= (unsigned long) \ (ucmlen) <= (unsigned long) \
((mhdr)->msg_controllen - \ ((mhdr)->msg_controllen - \
((char *)(ucmsg) - (char *)(mhdr)->msg_control))) ((char __user *)(ucmsg) - (char __user *)(mhdr)->msg_control_user)))
static inline struct compat_cmsghdr __user *cmsg_compat_nxthdr(struct msghdr *msg, static inline struct compat_cmsghdr __user *cmsg_compat_nxthdr(struct msghdr *msg,
struct compat_cmsghdr __user *cmsg, int cmsg_len) struct compat_cmsghdr __user *cmsg, int cmsg_len)
......
...@@ -212,16 +212,12 @@ EXPORT_SYMBOL(__scm_send); ...@@ -212,16 +212,12 @@ EXPORT_SYMBOL(__scm_send);
int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data) int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
{ {
struct cmsghdr __user *cm
= (__force struct cmsghdr __user *)msg->msg_control;
struct cmsghdr cmhdr;
int cmlen = CMSG_LEN(len); int cmlen = CMSG_LEN(len);
int err;
if (MSG_CMSG_COMPAT & msg->msg_flags) if (msg->msg_flags & MSG_CMSG_COMPAT)
return put_cmsg_compat(msg, level, type, len, data); return put_cmsg_compat(msg, level, type, len, data);
if (cm==NULL || msg->msg_controllen < sizeof(*cm)) { if (!msg->msg_control || msg->msg_controllen < sizeof(struct cmsghdr)) {
msg->msg_flags |= MSG_CTRUNC; msg->msg_flags |= MSG_CTRUNC;
return 0; /* XXX: return error? check spec. */ return 0; /* XXX: return error? check spec. */
} }
...@@ -229,23 +225,30 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data) ...@@ -229,23 +225,30 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
msg->msg_flags |= MSG_CTRUNC; msg->msg_flags |= MSG_CTRUNC;
cmlen = msg->msg_controllen; cmlen = msg->msg_controllen;
} }
if (msg->msg_control_is_user) {
struct cmsghdr __user *cm = msg->msg_control_user;
struct cmsghdr cmhdr;
cmhdr.cmsg_level = level; cmhdr.cmsg_level = level;
cmhdr.cmsg_type = type; cmhdr.cmsg_type = type;
cmhdr.cmsg_len = cmlen; cmhdr.cmsg_len = cmlen;
if (copy_to_user(cm, &cmhdr, sizeof cmhdr) ||
copy_to_user(CMSG_USER_DATA(cm), data, cmlen - sizeof(*cm)))
return -EFAULT;
} else {
struct cmsghdr *cm = msg->msg_control;
cm->cmsg_level = level;
cm->cmsg_type = type;
cm->cmsg_len = cmlen;
memcpy(CMSG_DATA(cm), data, cmlen - sizeof(*cm));
}
err = -EFAULT; cmlen = min(CMSG_SPACE(len), msg->msg_controllen);
if (copy_to_user(cm, &cmhdr, sizeof cmhdr))
goto out;
if (copy_to_user(CMSG_DATA(cm), data, cmlen - sizeof(struct cmsghdr)))
goto out;
cmlen = CMSG_SPACE(len);
if (msg->msg_controllen < cmlen)
cmlen = msg->msg_controllen;
msg->msg_control += cmlen; msg->msg_control += cmlen;
msg->msg_controllen -= cmlen; msg->msg_controllen -= cmlen;
err = 0; return 0;
out:
return err;
} }
EXPORT_SYMBOL(put_cmsg); EXPORT_SYMBOL(put_cmsg);
...@@ -277,78 +280,90 @@ void put_cmsg_scm_timestamping(struct msghdr *msg, struct scm_timestamping_inter ...@@ -277,78 +280,90 @@ void put_cmsg_scm_timestamping(struct msghdr *msg, struct scm_timestamping_inter
} }
EXPORT_SYMBOL(put_cmsg_scm_timestamping); EXPORT_SYMBOL(put_cmsg_scm_timestamping);
static int __scm_install_fd(struct file *file, int __user *ufd, int o_flags)
{
struct socket *sock;
int new_fd;
int error;
error = security_file_receive(file);
if (error)
return error;
new_fd = get_unused_fd_flags(o_flags);
if (new_fd < 0)
return new_fd;
error = put_user(new_fd, ufd);
if (error) {
put_unused_fd(new_fd);
return error;
}
/* Bump the usage count and install the file. */
sock = sock_from_file(file, &error);
if (sock) {
sock_update_netprioidx(&sock->sk->sk_cgrp_data);
sock_update_classid(&sock->sk->sk_cgrp_data);
}
fd_install(new_fd, get_file(file));
return error;
}
static int scm_max_fds(struct msghdr *msg)
{
if (msg->msg_controllen <= sizeof(struct cmsghdr))
return 0;
return (msg->msg_controllen - sizeof(struct cmsghdr)) / sizeof(int);
}
void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm) void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm)
{ {
struct cmsghdr __user *cm struct cmsghdr __user *cm
= (__force struct cmsghdr __user*)msg->msg_control; = (__force struct cmsghdr __user*)msg->msg_control;
int o_flags = (msg->msg_flags & MSG_CMSG_CLOEXEC) ? O_CLOEXEC : 0;
int fdmax = 0; int fdmax = min_t(int, scm_max_fds(msg), scm->fp->count);
int fdnum = scm->fp->count; int __user *cmsg_data = CMSG_USER_DATA(cm);
struct file **fp = scm->fp->fp;
int __user *cmfptr;
int err = 0, i; int err = 0, i;
if (MSG_CMSG_COMPAT & msg->msg_flags) { if (msg->msg_flags & MSG_CMSG_COMPAT) {
scm_detach_fds_compat(msg, scm); scm_detach_fds_compat(msg, scm);
return; return;
} }
if (msg->msg_controllen > sizeof(struct cmsghdr)) /* no use for FD passing from kernel space callers */
fdmax = ((msg->msg_controllen - sizeof(struct cmsghdr)) if (WARN_ON_ONCE(!msg->msg_control_is_user))
/ sizeof(int)); return;
if (fdnum < fdmax)
fdmax = fdnum;
for (i=0, cmfptr=(__force int __user *)CMSG_DATA(cm); i<fdmax; for (i = 0; i < fdmax; i++) {
i++, cmfptr++) err = __scm_install_fd(scm->fp->fp[i], cmsg_data + i, o_flags);
{
struct socket *sock;
int new_fd;
err = security_file_receive(fp[i]);
if (err) if (err)
break; break;
err = get_unused_fd_flags(MSG_CMSG_CLOEXEC & msg->msg_flags
? O_CLOEXEC : 0);
if (err < 0)
break;
new_fd = err;
err = put_user(new_fd, cmfptr);
if (err) {
put_unused_fd(new_fd);
break;
}
/* Bump the usage count and install the file. */
sock = sock_from_file(fp[i], &err);
if (sock) {
sock_update_netprioidx(&sock->sk->sk_cgrp_data);
sock_update_classid(&sock->sk->sk_cgrp_data);
}
fd_install(new_fd, get_file(fp[i]));
} }
if (i > 0) if (i > 0) {
{ int cmlen = CMSG_LEN(i * sizeof(int));
int cmlen = CMSG_LEN(i*sizeof(int));
err = put_user(SOL_SOCKET, &cm->cmsg_level); err = put_user(SOL_SOCKET, &cm->cmsg_level);
if (!err) if (!err)
err = put_user(SCM_RIGHTS, &cm->cmsg_type); err = put_user(SCM_RIGHTS, &cm->cmsg_type);
if (!err) if (!err)
err = put_user(cmlen, &cm->cmsg_len); err = put_user(cmlen, &cm->cmsg_len);
if (!err) { if (!err) {
cmlen = CMSG_SPACE(i*sizeof(int)); cmlen = CMSG_SPACE(i * sizeof(int));
if (msg->msg_controllen < cmlen) if (msg->msg_controllen < cmlen)
cmlen = msg->msg_controllen; cmlen = msg->msg_controllen;
msg->msg_control += cmlen; msg->msg_control += cmlen;
msg->msg_controllen -= cmlen; msg->msg_controllen -= cmlen;
} }
} }
if (i < fdnum || (fdnum && fdmax <= 0))
if (i < scm->fp->count || (scm->fp->count && fdmax <= 0))
msg->msg_flags |= MSG_CTRUNC; msg->msg_flags |= MSG_CTRUNC;
/* /*
* All of the files that fit in the message have had their * All of the files that fit in the message have had their usage counts
* usage counts incremented, so we just free the list. * incremented, so we just free the list.
*/ */
__scm_destroy(scm); __scm_destroy(scm);
} }
......
...@@ -1492,7 +1492,8 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1492,7 +1492,8 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
if (sk->sk_type != SOCK_STREAM) if (sk->sk_type != SOCK_STREAM)
return -ENOPROTOOPT; return -ENOPROTOOPT;
msg.msg_control = (__force void *) optval; msg.msg_control_is_user = true;
msg.msg_control_user = optval;
msg.msg_controllen = len; msg.msg_controllen = len;
msg.msg_flags = flags; msg.msg_flags = flags;
......
...@@ -924,14 +924,9 @@ EXPORT_SYMBOL(sock_recvmsg); ...@@ -924,14 +924,9 @@ EXPORT_SYMBOL(sock_recvmsg);
int kernel_recvmsg(struct socket *sock, struct msghdr *msg, int kernel_recvmsg(struct socket *sock, struct msghdr *msg,
struct kvec *vec, size_t num, size_t size, int flags) struct kvec *vec, size_t num, size_t size, int flags)
{ {
mm_segment_t oldfs = get_fs(); msg->msg_control_is_user = false;
int result;
iov_iter_kvec(&msg->msg_iter, READ, vec, num, size); iov_iter_kvec(&msg->msg_iter, READ, vec, num, size);
set_fs(KERNEL_DS); return sock_recvmsg(sock, msg, flags);
result = sock_recvmsg(sock, msg, flags);
set_fs(oldfs);
return result;
} }
EXPORT_SYMBOL(kernel_recvmsg); EXPORT_SYMBOL(kernel_recvmsg);
...@@ -2239,7 +2234,8 @@ int __copy_msghdr_from_user(struct msghdr *kmsg, ...@@ -2239,7 +2234,8 @@ int __copy_msghdr_from_user(struct msghdr *kmsg,
if (copy_from_user(&msg, umsg, sizeof(*umsg))) if (copy_from_user(&msg, umsg, sizeof(*umsg)))
return -EFAULT; return -EFAULT;
kmsg->msg_control = (void __force *)msg.msg_control; kmsg->msg_control_is_user = true;
kmsg->msg_control_user = msg.msg_control;
kmsg->msg_controllen = msg.msg_controllen; kmsg->msg_controllen = msg.msg_controllen;
kmsg->msg_flags = msg.msg_flags; kmsg->msg_flags = msg.msg_flags;
...@@ -2331,16 +2327,10 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys, ...@@ -2331,16 +2327,10 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
goto out; goto out;
} }
err = -EFAULT; err = -EFAULT;
/* if (copy_from_user(ctl_buf, msg_sys->msg_control_user, ctl_len))
* Careful! Before this, msg_sys->msg_control contains a user pointer.
* Afterwards, it will be a kernel pointer. Thus the compiler-assisted
* checking falls down on this.
*/
if (copy_from_user(ctl_buf,
(void __user __force *)msg_sys->msg_control,
ctl_len))
goto out_freectl; goto out_freectl;
msg_sys->msg_control = ctl_buf; msg_sys->msg_control = ctl_buf;
msg_sys->msg_control_is_user = false;
} }
msg_sys->msg_flags = flags; msg_sys->msg_flags = flags;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment