Commit 4257c8ca authored by Jens Axboe's avatar Jens Axboe

net: separate out the msghdr copy from ___sys_{send,recv}msg()

This is in preparation for enabling the io_uring helpers for sendmsg
and recvmsg to first copy the header for validation before continuing
with the operation.

There should be no functional changes in this patch.
Acked-by: default avatarDavid S. Miller <davem@davemloft.net>
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 8042d6ce
...@@ -2264,15 +2264,10 @@ static int copy_msghdr_from_user(struct msghdr *kmsg, ...@@ -2264,15 +2264,10 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
return err < 0 ? err : 0; return err < 0 ? err : 0;
} }
static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
struct msghdr *msg_sys, unsigned int flags, unsigned int flags, struct used_address *used_address,
struct used_address *used_address,
unsigned int allowed_msghdr_flags) unsigned int allowed_msghdr_flags)
{ {
struct compat_msghdr __user *msg_compat =
(struct compat_msghdr __user *)msg;
struct sockaddr_storage address;
struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
unsigned char ctl[sizeof(struct cmsghdr) + 20] unsigned char ctl[sizeof(struct cmsghdr) + 20]
__aligned(sizeof(__kernel_size_t)); __aligned(sizeof(__kernel_size_t));
/* 20 is size of ipv6_pktinfo */ /* 20 is size of ipv6_pktinfo */
...@@ -2280,19 +2275,10 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2280,19 +2275,10 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
int ctl_len; int ctl_len;
ssize_t err; ssize_t err;
msg_sys->msg_name = &address;
if (MSG_CMSG_COMPAT & flags)
err = get_compat_msghdr(msg_sys, msg_compat, NULL, &iov);
else
err = copy_msghdr_from_user(msg_sys, msg, NULL, &iov);
if (err < 0)
return err;
err = -ENOBUFS; err = -ENOBUFS;
if (msg_sys->msg_controllen > INT_MAX) if (msg_sys->msg_controllen > INT_MAX)
goto out_freeiov; goto out;
flags |= (msg_sys->msg_flags & allowed_msghdr_flags); flags |= (msg_sys->msg_flags & allowed_msghdr_flags);
ctl_len = msg_sys->msg_controllen; ctl_len = msg_sys->msg_controllen;
if ((MSG_CMSG_COMPAT & flags) && ctl_len) { if ((MSG_CMSG_COMPAT & flags) && ctl_len) {
...@@ -2300,7 +2286,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2300,7 +2286,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl, cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl,
sizeof(ctl)); sizeof(ctl));
if (err) if (err)
goto out_freeiov; goto out;
ctl_buf = msg_sys->msg_control; ctl_buf = msg_sys->msg_control;
ctl_len = msg_sys->msg_controllen; ctl_len = msg_sys->msg_controllen;
} else if (ctl_len) { } else if (ctl_len) {
...@@ -2309,7 +2295,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2309,7 +2295,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
if (ctl_len > sizeof(ctl)) { if (ctl_len > sizeof(ctl)) {
ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL); ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL);
if (ctl_buf == NULL) if (ctl_buf == NULL)
goto out_freeiov; goto out;
} }
err = -EFAULT; err = -EFAULT;
/* /*
...@@ -2355,7 +2341,47 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2355,7 +2341,47 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
out_freectl: out_freectl:
if (ctl_buf != ctl) if (ctl_buf != ctl)
sock_kfree_s(sock->sk, ctl_buf, ctl_len); sock_kfree_s(sock->sk, ctl_buf, ctl_len);
out_freeiov: out:
return err;
}
static int sendmsg_copy_msghdr(struct msghdr *msg,
struct user_msghdr __user *umsg, unsigned flags,
struct iovec **iov)
{
int err;
if (flags & MSG_CMSG_COMPAT) {
struct compat_msghdr __user *msg_compat;
msg_compat = (struct compat_msghdr __user *) umsg;
err = get_compat_msghdr(msg, msg_compat, NULL, iov);
} else {
err = copy_msghdr_from_user(msg, umsg, NULL, iov);
}
if (err < 0)
return err;
return 0;
}
static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
struct msghdr *msg_sys, unsigned int flags,
struct used_address *used_address,
unsigned int allowed_msghdr_flags)
{
struct sockaddr_storage address;
struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
ssize_t err;
msg_sys->msg_name = &address;
err = sendmsg_copy_msghdr(msg_sys, msg, flags, &iov);
if (err < 0)
return err;
err = ____sys_sendmsg(sock, msg_sys, flags, used_address,
allowed_msghdr_flags);
kfree(iov); kfree(iov);
return err; return err;
} }
...@@ -2474,33 +2500,41 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg, ...@@ -2474,33 +2500,41 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
return __sys_sendmmsg(fd, mmsg, vlen, flags, true); return __sys_sendmmsg(fd, mmsg, vlen, flags, true);
} }
static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, static int recvmsg_copy_msghdr(struct msghdr *msg,
struct msghdr *msg_sys, unsigned int flags, int nosec) struct user_msghdr __user *umsg, unsigned flags,
struct sockaddr __user **uaddr,
struct iovec **iov)
{ {
struct compat_msghdr __user *msg_compat =
(struct compat_msghdr __user *)msg;
struct iovec iovstack[UIO_FASTIOV];
struct iovec *iov = iovstack;
unsigned long cmsg_ptr;
int len;
ssize_t err; ssize_t err;
/* kernel mode address */ if (MSG_CMSG_COMPAT & flags) {
struct sockaddr_storage addr; struct compat_msghdr __user *msg_compat;
/* user mode address pointers */
struct sockaddr __user *uaddr;
int __user *uaddr_len = COMPAT_NAMELEN(msg);
msg_sys->msg_name = &addr;
if (MSG_CMSG_COMPAT & flags) msg_compat = (struct compat_msghdr __user *) umsg;
err = get_compat_msghdr(msg_sys, msg_compat, &uaddr, &iov); err = get_compat_msghdr(msg, msg_compat, uaddr, iov);
else } else {
err = copy_msghdr_from_user(msg_sys, msg, &uaddr, &iov); err = copy_msghdr_from_user(msg, umsg, uaddr, iov);
}
if (err < 0) if (err < 0)
return err; return err;
return 0;
}
static int ____sys_recvmsg(struct socket *sock, struct msghdr *msg_sys,
struct user_msghdr __user *msg,
struct sockaddr __user *uaddr,
unsigned int flags, int nosec)
{
struct compat_msghdr __user *msg_compat =
(struct compat_msghdr __user *) msg;
int __user *uaddr_len = COMPAT_NAMELEN(msg);
struct sockaddr_storage addr;
unsigned long cmsg_ptr;
int len;
ssize_t err;
msg_sys->msg_name = &addr;
cmsg_ptr = (unsigned long)msg_sys->msg_control; cmsg_ptr = (unsigned long)msg_sys->msg_control;
msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT); msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
...@@ -2511,7 +2545,7 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2511,7 +2545,7 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
flags |= MSG_DONTWAIT; flags |= MSG_DONTWAIT;
err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags); err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags);
if (err < 0) if (err < 0)
goto out_freeiov; goto out;
len = err; len = err;
if (uaddr != NULL) { if (uaddr != NULL) {
...@@ -2519,12 +2553,12 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2519,12 +2553,12 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
msg_sys->msg_namelen, uaddr, msg_sys->msg_namelen, uaddr,
uaddr_len); uaddr_len);
if (err < 0) if (err < 0)
goto out_freeiov; goto out;
} }
err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT), err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
COMPAT_FLAGS(msg)); COMPAT_FLAGS(msg));
if (err) if (err)
goto out_freeiov; goto out;
if (MSG_CMSG_COMPAT & flags) if (MSG_CMSG_COMPAT & flags)
err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr, err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
&msg_compat->msg_controllen); &msg_compat->msg_controllen);
...@@ -2532,10 +2566,25 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2532,10 +2566,25 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr, err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
&msg->msg_controllen); &msg->msg_controllen);
if (err) if (err)
goto out_freeiov; goto out;
err = len; err = len;
out:
return err;
}
static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
struct msghdr *msg_sys, unsigned int flags, int nosec)
{
struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
/* user mode address pointers */
struct sockaddr __user *uaddr;
ssize_t err;
err = recvmsg_copy_msghdr(msg_sys, msg, flags, &uaddr, &iov);
if (err < 0)
return err;
out_freeiov: err = ____sys_recvmsg(sock, msg_sys, msg, uaddr, flags, nosec);
kfree(iov); kfree(iov);
return err; return err;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment