Commit 0b20a133 authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'bpf: net: Remove duplicated code from bpf_getsockopt()'

Martin KaFai Lau says:

====================

From: Martin KaFai Lau <martin.lau@kernel.org>

The earlier commits [0] removed duplicated code from bpf_setsockopt().
This series is to remove duplicated code from bpf_getsockopt().

Unlike the setsockopt() which had already changed to take
the sockptr_t argument, the same has not been done to
getsockopt().  This is the extra step being done in this
series.

[0]: https://lore.kernel.org/all/20220817061704.4174272-1-kafai@fb.com/

v2:
- The previous v2 did not reach the list. It is a resend.
- Add comments on bpf_getsockopt() should not free
  the saved_syn (Stanislav)
- Explicitly null-terminate the tcp-cc name (Stanislav)
====================
Reviewed-by: default avatarStanislav Fomichev <sdf@google.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents af515a55 f649f992
......@@ -900,8 +900,7 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk);
int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk);
void sk_reuseport_prog_free(struct bpf_prog *prog);
int sk_detach_filter(struct sock *sk);
int sk_get_filter(struct sock *sk, struct sock_filter __user *filter,
unsigned int len);
int sk_get_filter(struct sock *sk, sockptr_t optval, unsigned int len);
bool sk_filter_charge(struct sock *sk, struct sk_filter *fp);
void sk_filter_uncharge(struct sock *sk, struct sk_filter *fp);
......
......@@ -118,9 +118,9 @@ extern int ip_mc_source(int add, int omode, struct sock *sk,
struct ip_mreq_source *mreqs, int ifindex);
extern int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf,int ifindex);
extern int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
struct ip_msfilter __user *optval, int __user *optlen);
sockptr_t optval, sockptr_t optlen);
extern int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p);
sockptr_t optval, size_t offset);
extern int ip_mc_sf_allow(struct sock *sk, __be32 local, __be32 rmt,
int dif, int sdif);
extern void ip_mc_init_dev(struct in_device *);
......
......@@ -17,7 +17,7 @@ static inline int ip_mroute_opt(int opt)
}
int ip_mroute_setsockopt(struct sock *, int, sockptr_t, unsigned int);
int ip_mroute_getsockopt(struct sock *, int, char __user *, int __user *);
int ip_mroute_getsockopt(struct sock *, int, sockptr_t, sockptr_t);
int ipmr_ioctl(struct sock *sk, int cmd, void __user *arg);
int ipmr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg);
int ip_mr_init(void);
......@@ -29,8 +29,8 @@ static inline int ip_mroute_setsockopt(struct sock *sock, int optname,
return -ENOPROTOOPT;
}
static inline int ip_mroute_getsockopt(struct sock *sock, int optname,
char __user *optval, int __user *optlen)
static inline int ip_mroute_getsockopt(struct sock *sk, int optname,
sockptr_t optval, sockptr_t optlen)
{
return -ENOPROTOOPT;
}
......
......@@ -27,7 +27,7 @@ struct sock;
#ifdef CONFIG_IPV6_MROUTE
extern int ip6_mroute_setsockopt(struct sock *, int, sockptr_t, unsigned int);
extern int ip6_mroute_getsockopt(struct sock *, int, char __user *, int __user *);
extern int ip6_mroute_getsockopt(struct sock *, int, sockptr_t, sockptr_t);
extern int ip6_mr_input(struct sk_buff *skb);
extern int ip6mr_ioctl(struct sock *sk, int cmd, void __user *arg);
extern int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg);
......@@ -42,7 +42,7 @@ static inline int ip6_mroute_setsockopt(struct sock *sock, int optname,
static inline
int ip6_mroute_getsockopt(struct sock *sock,
int optname, char __user *optval, int __user *optlen)
int optname, sockptr_t optval, sockptr_t optlen)
{
return -ENOPROTOOPT;
}
......
......@@ -64,6 +64,11 @@ static inline int copy_to_sockptr_offset(sockptr_t dst, size_t offset,
return 0;
}
static inline int copy_to_sockptr(sockptr_t dst, const void *src, size_t size)
{
return copy_to_sockptr_offset(dst, 0, src, size);
}
static inline void *memdup_sockptr(sockptr_t src, size_t len)
{
void *p = kmalloc_track_caller(len, GFP_USER | __GFP_NOWARN);
......
......@@ -747,6 +747,8 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int do_ip_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen);
int ip_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
int __user *optlen);
int ip_ra_control(struct sock *sk, unsigned char on,
......
......@@ -1160,6 +1160,8 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval
unsigned int optlen);
int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen);
int ipv6_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen);
......@@ -1209,7 +1211,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage *list);
int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p);
sockptr_t optval, size_t ss_offset);
#ifdef CONFIG_PROC_FS
int ac6_proc_init(struct net *net);
......
......@@ -83,6 +83,8 @@ struct ipv6_bpf_stub {
struct sk_buff *skb);
int (*ipv6_setsockopt)(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen);
int (*ipv6_getsockopt)(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen);
};
extern const struct ipv6_bpf_stub *ipv6_bpf_stub __read_mostly;
......
......@@ -1833,6 +1833,8 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
int sock_setsockopt(struct socket *sock, int level, int op,
sockptr_t optval, unsigned int optlen);
int sk_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen);
int sock_getsockopt(struct socket *sock, int level, int op,
char __user *optval, int __user *optlen);
int sock_gettstamp(struct socket *sock, void __user *userstamp,
......
......@@ -402,6 +402,8 @@ void tcp_init_sock(struct sock *sk);
void tcp_init_transfer(struct sock *sk, int bpf_op, struct sk_buff *skb);
__poll_t tcp_poll(struct file *file, struct socket *sock,
struct poll_table_struct *wait);
int do_tcp_getsockopt(struct sock *sk, int level,
int optname, sockptr_t optval, sockptr_t optlen);
int tcp_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen);
bool tcp_bpf_bypass_getsockopt(int level, int optname);
......
......@@ -5017,8 +5017,9 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
.arg1_type = ARG_PTR_TO_CTX,
};
static int sol_socket_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
static int sol_socket_sockopt(struct sock *sk, int optname,
char *optval, int *optlen,
bool getopt)
{
switch (optname) {
case SO_REUSEADDR:
......@@ -5032,7 +5033,7 @@ static int sol_socket_setsockopt(struct sock *sk, int optname,
case SO_MAX_PACING_RATE:
case SO_BINDTOIFINDEX:
case SO_TXREHASH:
if (optlen != sizeof(int))
if (*optlen != sizeof(int))
return -EINVAL;
break;
case SO_BINDTODEVICE:
......@@ -5041,8 +5042,16 @@ static int sol_socket_setsockopt(struct sock *sk, int optname,
return -EINVAL;
}
if (getopt) {
if (optname == SO_BINDTODEVICE)
return -EINVAL;
return sk_getsockopt(sk, SOL_SOCKET, optname,
KERNEL_SOCKPTR(optval),
KERNEL_SOCKPTR(optlen));
}
return sk_setsockopt(sk, SOL_SOCKET, optname,
KERNEL_SOCKPTR(optval), optlen);
KERNEL_SOCKPTR(optval), *optlen);
}
static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname,
......@@ -5091,8 +5100,9 @@ static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname,
return 0;
}
static int sol_tcp_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
static int sol_tcp_sockopt(struct sock *sk, int optname,
char *optval, int *optlen,
bool getopt)
{
if (sk->sk_prot->setsockopt != tcp_setsockopt)
return -EINVAL;
......@@ -5109,40 +5119,81 @@ static int sol_tcp_setsockopt(struct sock *sk, int optname,
case TCP_USER_TIMEOUT:
case TCP_NOTSENT_LOWAT:
case TCP_SAVE_SYN:
if (optlen != sizeof(int))
if (*optlen != sizeof(int))
return -EINVAL;
break;
case TCP_CONGESTION:
if (*optlen < 2)
return -EINVAL;
break;
case TCP_SAVED_SYN:
if (*optlen < 1)
return -EINVAL;
break;
default:
return bpf_sol_tcp_setsockopt(sk, optname, optval, optlen);
if (getopt)
return -EINVAL;
return bpf_sol_tcp_setsockopt(sk, optname, optval, *optlen);
}
if (getopt) {
if (optname == TCP_SAVED_SYN) {
struct tcp_sock *tp = tcp_sk(sk);
if (!tp->saved_syn ||
*optlen > tcp_saved_syn_len(tp->saved_syn))
return -EINVAL;
memcpy(optval, tp->saved_syn->data, *optlen);
/* It cannot free tp->saved_syn here because it
* does not know if the user space still needs it.
*/
return 0;
}
if (optname == TCP_CONGESTION) {
if (!inet_csk(sk)->icsk_ca_ops)
return -EINVAL;
/* BPF expects NULL-terminated tcp-cc string */
optval[--(*optlen)] = '\0';
}
return do_tcp_getsockopt(sk, SOL_TCP, optname,
KERNEL_SOCKPTR(optval),
KERNEL_SOCKPTR(optlen));
}
return do_tcp_setsockopt(sk, SOL_TCP, optname,
KERNEL_SOCKPTR(optval), optlen);
KERNEL_SOCKPTR(optval), *optlen);
}
static int sol_ip_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
static int sol_ip_sockopt(struct sock *sk, int optname,
char *optval, int *optlen,
bool getopt)
{
if (sk->sk_family != AF_INET)
return -EINVAL;
switch (optname) {
case IP_TOS:
if (optlen != sizeof(int))
if (*optlen != sizeof(int))
return -EINVAL;
break;
default:
return -EINVAL;
}
if (getopt)
return do_ip_getsockopt(sk, SOL_IP, optname,
KERNEL_SOCKPTR(optval),
KERNEL_SOCKPTR(optlen));
return do_ip_setsockopt(sk, SOL_IP, optname,
KERNEL_SOCKPTR(optval), optlen);
KERNEL_SOCKPTR(optval), *optlen);
}
static int sol_ipv6_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
static int sol_ipv6_sockopt(struct sock *sk, int optname,
char *optval, int *optlen,
bool getopt)
{
if (sk->sk_family != AF_INET6)
return -EINVAL;
......@@ -5150,15 +5201,20 @@ static int sol_ipv6_setsockopt(struct sock *sk, int optname,
switch (optname) {
case IPV6_TCLASS:
case IPV6_AUTOFLOWLABEL:
if (optlen != sizeof(int))
if (*optlen != sizeof(int))
return -EINVAL;
break;
default:
return -EINVAL;
}
if (getopt)
return ipv6_bpf_stub->ipv6_getsockopt(sk, SOL_IPV6, optname,
KERNEL_SOCKPTR(optval),
KERNEL_SOCKPTR(optlen));
return ipv6_bpf_stub->ipv6_setsockopt(sk, SOL_IPV6, optname,
KERNEL_SOCKPTR(optval), optlen);
KERNEL_SOCKPTR(optval), *optlen);
}
static int __bpf_setsockopt(struct sock *sk, int level, int optname,
......@@ -5168,13 +5224,13 @@ static int __bpf_setsockopt(struct sock *sk, int level, int optname,
return -EINVAL;
if (level == SOL_SOCKET)
return sol_socket_setsockopt(sk, optname, optval, optlen);
return sol_socket_sockopt(sk, optname, optval, &optlen, false);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP)
return sol_ip_setsockopt(sk, optname, optval, optlen);
return sol_ip_sockopt(sk, optname, optval, &optlen, false);
else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6)
return sol_ipv6_setsockopt(sk, optname, optval, optlen);
return sol_ipv6_sockopt(sk, optname, optval, &optlen, false);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_TCP)
return sol_tcp_setsockopt(sk, optname, optval, optlen);
return sol_tcp_sockopt(sk, optname, optval, &optlen, false);
return -EINVAL;
}
......@@ -5190,101 +5246,30 @@ static int _bpf_setsockopt(struct sock *sk, int level, int optname,
static int __bpf_getsockopt(struct sock *sk, int level, int optname,
char *optval, int optlen)
{
if (!sk_fullsock(sk))
goto err_clear;
int err, saved_optlen = optlen;
if (level == SOL_SOCKET) {
if (optlen != sizeof(int))
goto err_clear;
switch (optname) {
case SO_RCVBUF:
*((int *)optval) = sk->sk_rcvbuf;
break;
case SO_SNDBUF:
*((int *)optval) = sk->sk_sndbuf;
break;
case SO_MARK:
*((int *)optval) = sk->sk_mark;
break;
case SO_PRIORITY:
*((int *)optval) = sk->sk_priority;
break;
case SO_BINDTOIFINDEX:
*((int *)optval) = sk->sk_bound_dev_if;
break;
case SO_REUSEPORT:
*((int *)optval) = sk->sk_reuseport;
break;
case SO_TXREHASH:
*((int *)optval) = sk->sk_txrehash;
break;
default:
goto err_clear;
}
#ifdef CONFIG_INET
} else if (level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) {
struct inet_connection_sock *icsk;
struct tcp_sock *tp;
switch (optname) {
case TCP_CONGESTION:
icsk = inet_csk(sk);
if (!icsk->icsk_ca_ops || optlen <= 1)
goto err_clear;
strncpy(optval, icsk->icsk_ca_ops->name, optlen);
optval[optlen - 1] = 0;
break;
case TCP_SAVED_SYN:
tp = tcp_sk(sk);
if (optlen <= 0 || !tp->saved_syn ||
optlen > tcp_saved_syn_len(tp->saved_syn))
goto err_clear;
memcpy(optval, tp->saved_syn->data, optlen);
break;
default:
goto err_clear;
}
} else if (level == SOL_IP) {
struct inet_sock *inet = inet_sk(sk);
if (optlen != sizeof(int) || sk->sk_family != AF_INET)
goto err_clear;
/* Only some options are supported */
switch (optname) {
case IP_TOS:
*((int *)optval) = (int)inet->tos;
break;
default:
goto err_clear;
}
#if IS_ENABLED(CONFIG_IPV6)
} else if (level == SOL_IPV6) {
struct ipv6_pinfo *np = inet6_sk(sk);
if (!sk_fullsock(sk)) {
err = -EINVAL;
goto done;
}
if (optlen != sizeof(int) || sk->sk_family != AF_INET6)
goto err_clear;
if (level == SOL_SOCKET)
err = sol_socket_sockopt(sk, optname, optval, &optlen, true);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_TCP)
err = sol_tcp_sockopt(sk, optname, optval, &optlen, true);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP)
err = sol_ip_sockopt(sk, optname, optval, &optlen, true);
else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6)
err = sol_ipv6_sockopt(sk, optname, optval, &optlen, true);
else
err = -EINVAL;
/* Only some options are supported */
switch (optname) {
case IPV6_TCLASS:
*((int *)optval) = (int)np->tclass;
break;
default:
goto err_clear;
}
#endif
#endif
} else {
goto err_clear;
}
return 0;
err_clear:
memset(optval, 0, optlen);
return -EINVAL;
done:
if (err)
optlen = 0;
if (optlen < saved_optlen)
memset(optval + optlen, 0, saved_optlen - optlen);
return err;
}
static int _bpf_getsockopt(struct sock *sk, int level, int optname,
......@@ -10716,14 +10701,13 @@ int sk_detach_filter(struct sock *sk)
}
EXPORT_SYMBOL_GPL(sk_detach_filter);
int sk_get_filter(struct sock *sk, struct sock_filter __user *ubuf,
unsigned int len)
int sk_get_filter(struct sock *sk, sockptr_t optval, unsigned int len)
{
struct sock_fprog_kern *fprog;
struct sk_filter *filter;
int ret = 0;
lock_sock(sk);
sockopt_lock_sock(sk);
filter = rcu_dereference_protected(sk->sk_filter,
lockdep_sock_is_held(sk));
if (!filter)
......@@ -10748,7 +10732,7 @@ int sk_get_filter(struct sock *sk, struct sock_filter __user *ubuf,
goto out;
ret = -EFAULT;
if (copy_to_user(ubuf, fprog->filter, bpf_classic_proglen(fprog)))
if (copy_to_sockptr(optval, fprog->filter, bpf_classic_proglen(fprog)))
goto out;
/* Instead of bytes, the API requests to return the number
......@@ -10756,7 +10740,7 @@ int sk_get_filter(struct sock *sk, struct sock_filter __user *ubuf,
*/
ret = fprog->len;
out:
release_sock(sk);
sockopt_release_sock(sk);
return ret;
}
......
......@@ -712,8 +712,8 @@ static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen)
return ret;
}
static int sock_getbindtodevice(struct sock *sk, char __user *optval,
int __user *optlen, int len)
static int sock_getbindtodevice(struct sock *sk, sockptr_t optval,
sockptr_t optlen, int len)
{
int ret = -ENOPROTOOPT;
#ifdef CONFIG_NETDEVICES
......@@ -737,12 +737,12 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval,
len = strlen(devname) + 1;
ret = -EFAULT;
if (copy_to_user(optval, devname, len))
if (copy_to_sockptr(optval, devname, len))
goto out;
zero:
ret = -EFAULT;
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
goto out;
ret = 0;
......@@ -1568,22 +1568,25 @@ static void cred_to_ucred(struct pid *pid, const struct cred *cred,
}
}
static int groups_to_user(gid_t __user *dst, const struct group_info *src)
static int groups_to_user(sockptr_t dst, const struct group_info *src)
{
struct user_namespace *user_ns = current_user_ns();
int i;
for (i = 0; i < src->ngroups; i++)
if (put_user(from_kgid_munged(user_ns, src->gid[i]), dst + i))
for (i = 0; i < src->ngroups; i++) {
gid_t gid = from_kgid_munged(user_ns, src->gid[i]);
if (copy_to_sockptr_offset(dst, i * sizeof(gid), &gid, sizeof(gid)))
return -EFAULT;
}
return 0;
}
int sock_getsockopt(struct socket *sock, int level, int optname,
char __user *optval, int __user *optlen)
int sk_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen)
{
struct sock *sk = sock->sk;
struct socket *sock = sk->sk_socket;
union {
int val;
......@@ -1600,7 +1603,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
int lv = sizeof(int);
int len;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
if (len < 0)
return -EINVAL;
......@@ -1735,7 +1738,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
cred_to_ucred(sk->sk_peer_pid, sk->sk_peer_cred, &peercred);
spin_unlock(&sk->sk_peer_lock);
if (copy_to_user(optval, &peercred, len))
if (copy_to_sockptr(optval, &peercred, len))
return -EFAULT;
goto lenout;
}
......@@ -1753,11 +1756,11 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
if (len < n * sizeof(gid_t)) {
len = n * sizeof(gid_t);
put_cred(cred);
return put_user(len, optlen) ? -EFAULT : -ERANGE;
return copy_to_sockptr(optlen, &len, sizeof(int)) ? -EFAULT : -ERANGE;
}
len = n * sizeof(gid_t);
ret = groups_to_user((gid_t __user *)optval, cred->group_info);
ret = groups_to_user(optval, cred->group_info);
put_cred(cred);
if (ret)
return ret;
......@@ -1773,7 +1776,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
return -ENOTCONN;
if (lv < len)
return -EINVAL;
if (copy_to_user(optval, address, len))
if (copy_to_sockptr(optval, address, len))
return -EFAULT;
goto lenout;
}
......@@ -1790,7 +1793,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
break;
case SO_PEERSEC:
return security_socket_getpeersec_stream(sock, optval, optlen, len);
return security_socket_getpeersec_stream(sock, optval.user, optlen.user, len);
case SO_MARK:
v.val = sk->sk_mark;
......@@ -1822,7 +1825,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
return sock_getbindtodevice(sk, optval, optlen, len);
case SO_GET_FILTER:
len = sk_get_filter(sk, (struct sock_filter __user *)optval, len);
len = sk_get_filter(sk, optval, len);
if (len < 0)
return len;
......@@ -1870,7 +1873,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
sk_get_meminfo(sk, meminfo);
len = min_t(unsigned int, len, sizeof(meminfo));
if (copy_to_user(optval, &meminfo, len))
if (copy_to_sockptr(optval, &meminfo, len))
return -EFAULT;
goto lenout;
......@@ -1939,14 +1942,22 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
if (len > lv)
len = lv;
if (copy_to_user(optval, &v, len))
if (copy_to_sockptr(optval, &v, len))
return -EFAULT;
lenout:
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
return 0;
}
int sock_getsockopt(struct socket *sock, int level, int optname,
char __user *optval, int __user *optlen)
{
return sk_getsockopt(sock->sk, level, optname,
USER_SOCKPTR(optval),
USER_SOCKPTR(optlen));
}
/*
* Initialize an sk_lock.
*
......
......@@ -2529,11 +2529,10 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
err = ip_mc_leave_group(sk, &imr);
return err;
}
int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
struct ip_msfilter __user *optval, int __user *optlen)
sockptr_t optval, sockptr_t optlen)
{
int err, len, count, copycount;
int err, len, count, copycount, msf_size;
struct ip_mreqn imr;
__be32 addr = msf->imsf_multiaddr;
struct ip_mc_socklist *pmc;
......@@ -2575,12 +2574,15 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
copycount = count < msf->imsf_numsrc ? count : msf->imsf_numsrc;
len = flex_array_size(psl, sl_addr, copycount);
msf->imsf_numsrc = count;
if (put_user(IP_MSFILTER_SIZE(copycount), optlen) ||
copy_to_user(optval, msf, IP_MSFILTER_SIZE(0))) {
msf_size = IP_MSFILTER_SIZE(copycount);
if (copy_to_sockptr(optlen, &msf_size, sizeof(int)) ||
copy_to_sockptr(optval, msf, IP_MSFILTER_SIZE(0))) {
return -EFAULT;
}
if (len &&
copy_to_user(&optval->imsf_slist_flex[0], psl->sl_addr, len))
copy_to_sockptr_offset(optval,
offsetof(struct ip_msfilter, imsf_slist_flex),
psl->sl_addr, len))
return -EFAULT;
return 0;
done:
......@@ -2588,7 +2590,7 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
}
int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p)
sockptr_t optval, size_t ss_offset)
{
int i, count, copycount;
struct sockaddr_in *psin;
......@@ -2618,15 +2620,17 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
count = psl ? psl->sl_count : 0;
copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc;
gsf->gf_numsrc = count;
for (i = 0; i < copycount; i++, p++) {
for (i = 0; i < copycount; i++) {
struct sockaddr_storage ss;
psin = (struct sockaddr_in *)&ss;
memset(&ss, 0, sizeof(ss));
psin->sin_family = AF_INET;
psin->sin_addr.s_addr = psl->sl_addr[i];
if (copy_to_user(p, &ss, sizeof(ss)))
if (copy_to_sockptr_offset(optval, ss_offset,
&ss, sizeof(ss)))
return -EFAULT;
ss_offset += sizeof(ss);
}
return 0;
}
......
......@@ -1462,37 +1462,37 @@ static bool getsockopt_needs_rtnl(int optname)
return false;
}
static int ip_get_mcast_msfilter(struct sock *sk, void __user *optval,
int __user *optlen, int len)
static int ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval,
sockptr_t optlen, int len)
{
const int size0 = offsetof(struct group_filter, gf_slist_flex);
struct group_filter __user *p = optval;
struct group_filter gsf;
int num;
int num, gsf_size;
int err;
if (len < size0)
return -EINVAL;
if (copy_from_user(&gsf, p, size0))
if (copy_from_sockptr(&gsf, optval, size0))
return -EFAULT;
num = gsf.gf_numsrc;
err = ip_mc_gsfget(sk, &gsf, p->gf_slist_flex);
err = ip_mc_gsfget(sk, &gsf, optval,
offsetof(struct group_filter, gf_slist_flex));
if (err)
return err;
if (gsf.gf_numsrc < num)
num = gsf.gf_numsrc;
if (put_user(GROUP_FILTER_SIZE(num), optlen) ||
copy_to_user(p, &gsf, size0))
gsf_size = GROUP_FILTER_SIZE(num);
if (copy_to_sockptr(optlen, &gsf_size, sizeof(int)) ||
copy_to_sockptr(optval, &gsf, size0))
return -EFAULT;
return 0;
}
static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval,
int __user *optlen, int len)
static int compat_ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval,
sockptr_t optlen, int len)
{
const int size0 = offsetof(struct compat_group_filter, gf_slist_flex);
struct compat_group_filter __user *p = optval;
struct compat_group_filter gf32;
struct group_filter gf;
int num;
......@@ -1500,7 +1500,7 @@ static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval,
if (len < size0)
return -EINVAL;
if (copy_from_user(&gf32, p, size0))
if (copy_from_sockptr(&gf32, optval, size0))
return -EFAULT;
gf.gf_interface = gf32.gf_interface;
......@@ -1508,21 +1508,24 @@ static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval,
num = gf.gf_numsrc = gf32.gf_numsrc;
gf.gf_group = gf32.gf_group;
err = ip_mc_gsfget(sk, &gf, p->gf_slist_flex);
err = ip_mc_gsfget(sk, &gf, optval,
offsetof(struct compat_group_filter, gf_slist_flex));
if (err)
return err;
if (gf.gf_numsrc < num)
num = gf.gf_numsrc;
len = GROUP_FILTER_SIZE(num) - (sizeof(gf) - sizeof(gf32));
if (put_user(len, optlen) ||
put_user(gf.gf_fmode, &p->gf_fmode) ||
put_user(gf.gf_numsrc, &p->gf_numsrc))
if (copy_to_sockptr(optlen, &len, sizeof(int)) ||
copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode),
&gf.gf_fmode, sizeof(gf.gf_fmode)) ||
copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc),
&gf.gf_numsrc, sizeof(gf.gf_numsrc)))
return -EFAULT;
return 0;
}
static int do_ip_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen)
int do_ip_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen)
{
struct inet_sock *inet = inet_sk(sk);
bool needs_rtnl = getsockopt_needs_rtnl(optname);
......@@ -1535,14 +1538,14 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
if (ip_mroute_opt(optname))
return ip_mroute_getsockopt(sk, optname, optval, optlen);
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
if (len < 0)
return -EINVAL;
if (needs_rtnl)
rtnl_lock();
lock_sock(sk);
sockopt_lock_sock(sk);
switch (optname) {
case IP_OPTIONS:
......@@ -1558,17 +1561,19 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
memcpy(optbuf, &inet_opt->opt,
sizeof(struct ip_options) +
inet_opt->opt.optlen);
release_sock(sk);
sockopt_release_sock(sk);
if (opt->optlen == 0)
return put_user(0, optlen);
if (opt->optlen == 0) {
len = 0;
return copy_to_sockptr(optlen, &len, sizeof(int));
}
ip_options_undo(opt);
len = min_t(unsigned int, len, opt->optlen);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, opt->__data, len))
if (copy_to_sockptr(optval, opt->__data, len))
return -EFAULT;
return 0;
}
......@@ -1632,7 +1637,7 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
dst_release(dst);
}
if (!val) {
release_sock(sk);
sockopt_release_sock(sk);
return -ENOTCONN;
}
break;
......@@ -1657,11 +1662,11 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
struct in_addr addr;
len = min_t(unsigned int, len, sizeof(struct in_addr));
addr.s_addr = inet->mc_addr;
release_sock(sk);
sockopt_release_sock(sk);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &addr, len))
if (copy_to_sockptr(optval, &addr, len))
return -EFAULT;
return 0;
}
......@@ -1673,12 +1678,11 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
err = -EINVAL;
goto out;
}
if (copy_from_user(&msf, optval, IP_MSFILTER_SIZE(0))) {
if (copy_from_sockptr(&msf, optval, IP_MSFILTER_SIZE(0))) {
err = -EFAULT;
goto out;
}
err = ip_mc_msfget(sk, &msf,
(struct ip_msfilter __user *)optval, optlen);
err = ip_mc_msfget(sk, &msf, optval, optlen);
goto out;
}
case MCAST_MSFILTER:
......@@ -1695,13 +1699,18 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
{
struct msghdr msg;
release_sock(sk);
sockopt_release_sock(sk);
if (sk->sk_type != SOCK_STREAM)
return -ENOPROTOOPT;
msg.msg_control_is_user = true;
msg.msg_control_user = optval;
if (optval.is_kernel) {
msg.msg_control_is_user = false;
msg.msg_control = optval.kernel;
} else {
msg.msg_control_is_user = true;
msg.msg_control_user = optval.user;
}
msg.msg_controllen = len;
msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0;
......@@ -1722,7 +1731,7 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos);
}
len -= msg.msg_controllen;
return put_user(len, optlen);
return copy_to_sockptr(optlen, &len, sizeof(int));
}
case IP_FREEBIND:
val = inet->freebind;
......@@ -1734,29 +1743,29 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
val = inet->min_ttl;
break;
default:
release_sock(sk);
sockopt_release_sock(sk);
return -ENOPROTOOPT;
}
release_sock(sk);
sockopt_release_sock(sk);
if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) {
unsigned char ucval = (unsigned char)val;
len = 1;
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &ucval, 1))
if (copy_to_sockptr(optval, &ucval, 1))
return -EFAULT;
} else {
len = min_t(unsigned int, sizeof(int), len);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &val, len))
if (copy_to_sockptr(optval, &val, len))
return -EFAULT;
}
return 0;
out:
release_sock(sk);
sockopt_release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return err;
......@@ -1767,7 +1776,8 @@ int ip_getsockopt(struct sock *sk, int level,
{
int err;
err = do_ip_getsockopt(sk, level, optname, optval, optlen);
err = do_ip_getsockopt(sk, level, optname,
USER_SOCKPTR(optval), USER_SOCKPTR(optlen));
#if IS_ENABLED(CONFIG_BPFILTER_UMH)
if (optname >= BPFILTER_IPT_SO_GET_INFO &&
......
......@@ -1546,7 +1546,8 @@ int ip_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval,
}
/* Getsock opt support for the multicast routing system. */
int ip_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int __user *optlen)
int ip_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval,
sockptr_t optlen)
{
int olr;
int val;
......@@ -1577,14 +1578,14 @@ int ip_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int
return -ENOPROTOOPT;
}
if (get_user(olr, optlen))
if (copy_from_sockptr(&olr, optlen, sizeof(int)))
return -EFAULT;
olr = min_t(unsigned int, olr, sizeof(int));
if (olr < 0)
return -EINVAL;
if (put_user(olr, optlen))
if (copy_to_sockptr(optlen, &olr, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &val, olr))
if (copy_to_sockptr(optval, &val, olr))
return -EFAULT;
return 0;
}
......
......@@ -4043,15 +4043,15 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk,
return stats;
}
static int do_tcp_getsockopt(struct sock *sk, int level,
int optname, char __user *optval, int __user *optlen)
int do_tcp_getsockopt(struct sock *sk, int level,
int optname, sockptr_t optval, sockptr_t optlen)
{
struct inet_connection_sock *icsk = inet_csk(sk);
struct tcp_sock *tp = tcp_sk(sk);
struct net *net = sock_net(sk);
int val, len;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
len = min_t(unsigned int, len, sizeof(int));
......@@ -4101,15 +4101,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
case TCP_INFO: {
struct tcp_info info;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
tcp_get_info(sk, &info);
len = min_t(unsigned int, len, sizeof(info));
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &info, len))
if (copy_to_sockptr(optval, &info, len))
return -EFAULT;
return 0;
}
......@@ -4119,7 +4119,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
size_t sz = 0;
int attr;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
ca_ops = icsk->icsk_ca_ops;
......@@ -4127,9 +4127,9 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
sz = ca_ops->get_info(sk, ~0U, &attr, &info);
len = min_t(unsigned int, len, sz);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &info, len))
if (copy_to_sockptr(optval, &info, len))
return -EFAULT;
return 0;
}
......@@ -4138,27 +4138,28 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
break;
case TCP_CONGESTION:
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
len = min_t(unsigned int, len, TCP_CA_NAME_MAX);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, icsk->icsk_ca_ops->name, len))
if (copy_to_sockptr(optval, icsk->icsk_ca_ops->name, len))
return -EFAULT;
return 0;
case TCP_ULP:
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
len = min_t(unsigned int, len, TCP_ULP_NAME_MAX);
if (!icsk->icsk_ulp_ops) {
if (put_user(0, optlen))
len = 0;
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
return 0;
}
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, icsk->icsk_ulp_ops->name, len))
if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len))
return -EFAULT;
return 0;
......@@ -4166,15 +4167,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
u64 key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(u64)];
unsigned int key_len;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
key_len = tcp_fastopen_get_cipher(net, icsk, key) *
TCP_FASTOPEN_KEY_LENGTH;
len = min_t(unsigned int, len, key_len);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, key, len))
if (copy_to_sockptr(optval, key, len))
return -EFAULT;
return 0;
}
......@@ -4200,7 +4201,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
case TCP_REPAIR_WINDOW: {
struct tcp_repair_window opt;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
if (len != sizeof(opt))
......@@ -4215,7 +4216,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
opt.rcv_wnd = tp->rcv_wnd;
opt.rcv_wup = tp->rcv_wup;
if (copy_to_user(optval, &opt, len))
if (copy_to_sockptr(optval, &opt, len))
return -EFAULT;
return 0;
}
......@@ -4261,35 +4262,35 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
val = tp->save_syn;
break;
case TCP_SAVED_SYN: {
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
lock_sock(sk);
sockopt_lock_sock(sk);
if (tp->saved_syn) {
if (len < tcp_saved_syn_len(tp->saved_syn)) {
if (put_user(tcp_saved_syn_len(tp->saved_syn),
optlen)) {
release_sock(sk);
len = tcp_saved_syn_len(tp->saved_syn);
if (copy_to_sockptr(optlen, &len, sizeof(int))) {
sockopt_release_sock(sk);
return -EFAULT;
}
release_sock(sk);
sockopt_release_sock(sk);
return -EINVAL;
}
len = tcp_saved_syn_len(tp->saved_syn);
if (put_user(len, optlen)) {
release_sock(sk);
if (copy_to_sockptr(optlen, &len, sizeof(int))) {
sockopt_release_sock(sk);
return -EFAULT;
}
if (copy_to_user(optval, tp->saved_syn->data, len)) {
release_sock(sk);
if (copy_to_sockptr(optval, tp->saved_syn->data, len)) {
sockopt_release_sock(sk);
return -EFAULT;
}
tcp_saved_syn_free(tp);
release_sock(sk);
sockopt_release_sock(sk);
} else {
release_sock(sk);
sockopt_release_sock(sk);
len = 0;
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
}
return 0;
......@@ -4300,31 +4301,31 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
struct tcp_zerocopy_receive zc = {};
int err;
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
if (len < 0 ||
len < offsetofend(struct tcp_zerocopy_receive, length))
return -EINVAL;
if (unlikely(len > sizeof(zc))) {
err = check_zeroed_user(optval + sizeof(zc),
len - sizeof(zc));
err = check_zeroed_sockptr(optval, sizeof(zc),
len - sizeof(zc));
if (err < 1)
return err == 0 ? -EINVAL : err;
len = sizeof(zc);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
}
if (copy_from_user(&zc, optval, len))
if (copy_from_sockptr(&zc, optval, len))
return -EFAULT;
if (zc.reserved)
return -EINVAL;
if (zc.msg_flags & ~(TCP_VALID_ZC_MSG_FLAGS))
return -EINVAL;
lock_sock(sk);
sockopt_lock_sock(sk);
err = tcp_zerocopy_receive(sk, &zc, &tss);
err = BPF_CGROUP_RUN_PROG_GETSOCKOPT_KERN(sk, level, optname,
&zc, &len, err);
release_sock(sk);
sockopt_release_sock(sk);
if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags))
goto zerocopy_rcv_cmsg;
switch (len) {
......@@ -4354,7 +4355,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
zerocopy_rcv_inq:
zc.inq = tcp_inq_hint(sk);
zerocopy_rcv_out:
if (!err && copy_to_user(optval, &zc, len))
if (!err && copy_to_sockptr(optval, &zc, len))
err = -EFAULT;
return err;
}
......@@ -4363,9 +4364,9 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
return -ENOPROTOOPT;
}
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &val, len))
if (copy_to_sockptr(optval, &val, len))
return -EFAULT;
return 0;
}
......@@ -4390,7 +4391,8 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
if (level != SOL_TCP)
return icsk->icsk_af_ops->getsockopt(sk, level, optname,
optval, optlen);
return do_tcp_getsockopt(sk, level, optname, optval, optlen);
return do_tcp_getsockopt(sk, level, optname, USER_SOCKPTR(optval),
USER_SOCKPTR(optlen));
}
EXPORT_SYMBOL(tcp_getsockopt);
......
......@@ -1058,6 +1058,7 @@ static const struct ipv6_bpf_stub ipv6_bpf_stub_impl = {
.inet6_bind = __inet6_bind,
.udp6_lib_lookup = __udp6_lib_lookup,
.ipv6_setsockopt = do_ipv6_setsockopt,
.ipv6_getsockopt = do_ipv6_getsockopt,
};
static int __init inet6_init(void)
......
......@@ -1827,8 +1827,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval,
* Getsock opt support for the multicast routing system.
*/
int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval,
int __user *optlen)
int ip6_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval,
sockptr_t optlen)
{
int olr;
int val;
......@@ -1859,16 +1859,16 @@ int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval,
return -ENOPROTOOPT;
}
if (get_user(olr, optlen))
if (copy_from_sockptr(&olr, optlen, sizeof(int)))
return -EFAULT;
olr = min_t(int, olr, sizeof(int));
if (olr < 0)
return -EINVAL;
if (put_user(olr, optlen))
if (copy_to_sockptr(optlen, &olr, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &val, olr))
if (copy_to_sockptr(optval, &val, olr))
return -EFAULT;
return 0;
}
......
......@@ -1030,7 +1030,7 @@ int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
EXPORT_SYMBOL(ipv6_setsockopt);
static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt,
int optname, char __user *optval, int len)
int optname, sockptr_t optval, int len)
{
struct ipv6_opt_hdr *hdr;
......@@ -1058,56 +1058,53 @@ static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt,
return 0;
len = min_t(unsigned int, len, ipv6_optlen(hdr));
if (copy_to_user(optval, hdr, len))
if (copy_to_sockptr(optval, hdr, len))
return -EFAULT;
return len;
}
static int ipv6_get_msfilter(struct sock *sk, void __user *optval,
int __user *optlen, int len)
static int ipv6_get_msfilter(struct sock *sk, sockptr_t optval,
sockptr_t optlen, int len)
{
const int size0 = offsetof(struct group_filter, gf_slist_flex);
struct group_filter __user *p = optval;
struct group_filter gsf;
int num;
int err;
if (len < size0)
return -EINVAL;
if (copy_from_user(&gsf, p, size0))
if (copy_from_sockptr(&gsf, optval, size0))
return -EFAULT;
if (gsf.gf_group.ss_family != AF_INET6)
return -EADDRNOTAVAIL;
num = gsf.gf_numsrc;
lock_sock(sk);
err = ip6_mc_msfget(sk, &gsf, p->gf_slist_flex);
sockopt_lock_sock(sk);
err = ip6_mc_msfget(sk, &gsf, optval, size0);
if (!err) {
if (num > gsf.gf_numsrc)
num = gsf.gf_numsrc;
if (put_user(GROUP_FILTER_SIZE(num), optlen) ||
copy_to_user(p, &gsf, size0))
len = GROUP_FILTER_SIZE(num);
if (copy_to_sockptr(optlen, &len, sizeof(int)) ||
copy_to_sockptr(optval, &gsf, size0))
err = -EFAULT;
}
release_sock(sk);
sockopt_release_sock(sk);
return err;
}
static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval,
int __user *optlen)
static int compat_ipv6_get_msfilter(struct sock *sk, sockptr_t optval,
sockptr_t optlen, int len)
{
const int size0 = offsetof(struct compat_group_filter, gf_slist_flex);
struct compat_group_filter __user *p = optval;
struct compat_group_filter gf32;
struct group_filter gf;
int len, err;
int err;
int num;
if (get_user(len, optlen))
return -EFAULT;
if (len < size0)
return -EINVAL;
if (copy_from_user(&gf32, p, size0))
if (copy_from_sockptr(&gf32, optval, size0))
return -EFAULT;
gf.gf_interface = gf32.gf_interface;
gf.gf_fmode = gf32.gf_fmode;
......@@ -1117,23 +1114,25 @@ static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval,
if (gf.gf_group.ss_family != AF_INET6)
return -EADDRNOTAVAIL;
lock_sock(sk);
err = ip6_mc_msfget(sk, &gf, p->gf_slist_flex);
release_sock(sk);
sockopt_lock_sock(sk);
err = ip6_mc_msfget(sk, &gf, optval, size0);
sockopt_release_sock(sk);
if (err)
return err;
if (num > gf.gf_numsrc)
num = gf.gf_numsrc;
len = GROUP_FILTER_SIZE(num) - (sizeof(gf)-sizeof(gf32));
if (put_user(len, optlen) ||
put_user(gf.gf_fmode, &p->gf_fmode) ||
put_user(gf.gf_numsrc, &p->gf_numsrc))
if (copy_to_sockptr(optlen, &len, sizeof(int)) ||
copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode),
&gf.gf_fmode, sizeof(gf32.gf_fmode)) ||
copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc),
&gf.gf_numsrc, sizeof(gf32.gf_numsrc)))
return -EFAULT;
return 0;
}
static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen, unsigned int flags)
int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen)
{
struct ipv6_pinfo *np = inet6_sk(sk);
int len;
......@@ -1142,7 +1141,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (ip6_mroute_opt(optname))
return ip6_mroute_getsockopt(sk, optname, optval, optlen);
if (get_user(len, optlen))
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
switch (optname) {
case IPV6_ADDRFORM:
......@@ -1156,7 +1155,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
break;
case MCAST_MSFILTER:
if (in_compat_syscall())
return compat_ipv6_get_msfilter(sk, optval, optlen);
return compat_ipv6_get_msfilter(sk, optval, optlen, len);
return ipv6_get_msfilter(sk, optval, optlen, len);
case IPV6_2292PKTOPTIONS:
{
......@@ -1166,16 +1165,21 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (sk->sk_type != SOCK_STREAM)
return -ENOPROTOOPT;
msg.msg_control_user = optval;
if (optval.is_kernel) {
msg.msg_control_is_user = false;
msg.msg_control = optval.kernel;
} else {
msg.msg_control_is_user = true;
msg.msg_control_user = optval.user;
}
msg.msg_controllen = len;
msg.msg_flags = flags;
msg.msg_control_is_user = true;
msg.msg_flags = 0;
lock_sock(sk);
sockopt_lock_sock(sk);
skb = np->pktoptions;
if (skb)
ip6_datagram_recv_ctl(sk, &msg, skb);
release_sock(sk);
sockopt_release_sock(sk);
if (!skb) {
if (np->rxopt.bits.rxinfo) {
struct in6_pktinfo src_info;
......@@ -1212,7 +1216,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
}
}
len -= msg.msg_controllen;
return put_user(len, optlen);
return copy_to_sockptr(optlen, &len, sizeof(int));
}
case IPV6_MTU:
{
......@@ -1264,15 +1268,15 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
{
struct ipv6_txoptions *opt;
lock_sock(sk);
sockopt_lock_sock(sk);
opt = rcu_dereference_protected(np->opt,
lockdep_sock_is_held(sk));
len = ipv6_getsockopt_sticky(sk, opt, optname, optval, len);
release_sock(sk);
sockopt_release_sock(sk);
/* check if ipv6_getsockopt_sticky() returns err code */
if (len < 0)
return len;
return put_user(len, optlen);
return copy_to_sockptr(optlen, &len, sizeof(int));
}
case IPV6_RECVHOPOPTS:
......@@ -1326,9 +1330,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (!mtuinfo.ip6m_mtu)
return -ENOTCONN;
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &mtuinfo, len))
if (copy_to_sockptr(optval, &mtuinfo, len))
return -EFAULT;
return 0;
......@@ -1405,7 +1409,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (len < sizeof(freq))
return -EINVAL;
if (copy_from_user(&freq, optval, sizeof(freq)))
if (copy_from_sockptr(&freq, optval, sizeof(freq)))
return -EFAULT;
if (freq.flr_action != IPV6_FL_A_GET)
......@@ -1420,9 +1424,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (val < 0)
return val;
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &freq, len))
if (copy_to_sockptr(optval, &freq, len))
return -EFAULT;
return 0;
......@@ -1474,9 +1478,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
return -ENOPROTOOPT;
}
len = min_t(unsigned int, sizeof(int), len);
if (put_user(len, optlen))
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_user(optval, &val, len))
if (copy_to_sockptr(optval, &val, len))
return -EFAULT;
return 0;
}
......@@ -1492,7 +1496,8 @@ int ipv6_getsockopt(struct sock *sk, int level, int optname,
if (level != SOL_IPV6)
return -ENOPROTOOPT;
err = do_ipv6_getsockopt(sk, level, optname, optval, optlen, 0);
err = do_ipv6_getsockopt(sk, level, optname,
USER_SOCKPTR(optval), USER_SOCKPTR(optlen));
#ifdef CONFIG_NETFILTER
/* we need to exclude all possible ENOPROTOOPTs except default case */
if (err == -ENOPROTOOPT && optname != IPV6_2292PKTOPTIONS) {
......
......@@ -580,7 +580,7 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
}
int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p)
sockptr_t optval, size_t ss_offset)
{
struct ipv6_pinfo *inet6 = inet6_sk(sk);
const struct in6_addr *group;
......@@ -612,8 +612,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc;
gsf->gf_numsrc = count;
for (i = 0; i < copycount; i++, p++) {
for (i = 0; i < copycount; i++) {
struct sockaddr_in6 *psin6;
struct sockaddr_storage ss;
......@@ -621,8 +620,9 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
memset(&ss, 0, sizeof(ss));
psin6->sin6_family = AF_INET6;
psin6->sin6_addr = psl->sl_addr[i];
if (copy_to_user(p, &ss, sizeof(ss)))
if (copy_to_sockptr_offset(optval, ss_offset, &ss, sizeof(ss)))
return -EFAULT;
ss_offset += sizeof(ss);
}
return 0;
}
......
......@@ -38,6 +38,7 @@
#define TCP_USER_TIMEOUT 18
#define TCP_NOTSENT_LOWAT 25
#define TCP_SAVE_SYN 27
#define TCP_SAVED_SYN 28
#define TCP_CA_NAME_MAX 16
#define TCP_NAGLE_OFF 1
......
......@@ -52,7 +52,6 @@ static const struct sockopt_test sol_socket_tests[] = {
static const struct sockopt_test sol_tcp_tests[] = {
{ .opt = TCP_NODELAY, .flip = 1, },
{ .opt = TCP_MAXSEG, .new = 1314, .expected = 1314, },
{ .opt = TCP_KEEPIDLE, .new = 123, .expected = 123, .restore = 321, },
{ .opt = TCP_KEEPINTVL, .new = 123, .expected = 123, .restore = 321, },
{ .opt = TCP_KEEPCNT, .new = 123, .expected = 123, .restore = 124, },
......@@ -62,7 +61,6 @@ static const struct sockopt_test sol_tcp_tests[] = {
{ .opt = TCP_THIN_LINEAR_TIMEOUTS, .flip = 1, },
{ .opt = TCP_USER_TIMEOUT, .new = 123400, .expected = 123400, },
{ .opt = TCP_NOTSENT_LOWAT, .new = 1314, .expected = 1314, },
{ .opt = TCP_SAVE_SYN, .new = 1, .expected = 1, },
{ .opt = 0, },
};
......@@ -82,102 +80,6 @@ struct loop_ctx {
struct sock *sk;
};
static int __bpf_getsockopt(void *ctx, struct sock *sk,
int level, int opt, int *optval,
int optlen)
{
if (level == SOL_SOCKET) {
switch (opt) {
case SO_REUSEADDR:
*optval = !!BPF_CORE_READ_BITFIELD(sk, sk_reuse);
break;
case SO_KEEPALIVE:
*optval = !!(sk->sk_flags & (1UL << 3));
break;
case SO_RCVLOWAT:
*optval = sk->sk_rcvlowat;
break;
case SO_MAX_PACING_RATE:
*optval = sk->sk_max_pacing_rate;
break;
default:
return bpf_getsockopt(ctx, level, opt, optval, optlen);
}
return 0;
}
if (level == IPPROTO_TCP) {
struct tcp_sock *tp = bpf_skc_to_tcp_sock(sk);
if (!tp)
return -1;
switch (opt) {
case TCP_NODELAY:
*optval = !!(BPF_CORE_READ_BITFIELD(tp, nonagle) & TCP_NAGLE_OFF);
break;
case TCP_MAXSEG:
*optval = tp->rx_opt.user_mss;
break;
case TCP_KEEPIDLE:
*optval = tp->keepalive_time / CONFIG_HZ;
break;
case TCP_SYNCNT:
*optval = tp->inet_conn.icsk_syn_retries;
break;
case TCP_KEEPINTVL:
*optval = tp->keepalive_intvl / CONFIG_HZ;
break;
case TCP_KEEPCNT:
*optval = tp->keepalive_probes;
break;
case TCP_WINDOW_CLAMP:
*optval = tp->window_clamp;
break;
case TCP_THIN_LINEAR_TIMEOUTS:
*optval = !!BPF_CORE_READ_BITFIELD(tp, thin_lto);
break;
case TCP_USER_TIMEOUT:
*optval = tp->inet_conn.icsk_user_timeout;
break;
case TCP_NOTSENT_LOWAT:
*optval = tp->notsent_lowat;
break;
case TCP_SAVE_SYN:
*optval = BPF_CORE_READ_BITFIELD(tp, save_syn);
break;
default:
return bpf_getsockopt(ctx, level, opt, optval, optlen);
}
return 0;
}
if (level == IPPROTO_IPV6) {
switch (opt) {
case IPV6_AUTOFLOWLABEL: {
__u16 proto = sk->sk_protocol;
struct inet_sock *inet_sk;
if (proto == IPPROTO_TCP)
inet_sk = (struct inet_sock *)bpf_skc_to_tcp_sock(sk);
else
inet_sk = (struct inet_sock *)bpf_skc_to_udp6_sock(sk);
if (!inet_sk)
return -1;
*optval = !!inet_sk->pinet6->autoflowlabel;
break;
}
default:
return bpf_getsockopt(ctx, level, opt, optval, optlen);
}
return 0;
}
return bpf_getsockopt(ctx, level, opt, optval, optlen);
}
static int bpf_test_sockopt_flip(void *ctx, struct sock *sk,
const struct sockopt_test *t,
int level)
......@@ -186,7 +88,7 @@ static int bpf_test_sockopt_flip(void *ctx, struct sock *sk,
opt = t->opt;
if (__bpf_getsockopt(ctx, sk, level, opt, &old, sizeof(old)))
if (bpf_getsockopt(ctx, level, opt, &old, sizeof(old)))
return 1;
/* kernel initialized txrehash to 255 */
if (level == SOL_SOCKET && opt == SO_TXREHASH && old != 0 && old != 1)
......@@ -195,7 +97,7 @@ static int bpf_test_sockopt_flip(void *ctx, struct sock *sk,
new = !old;
if (bpf_setsockopt(ctx, level, opt, &new, sizeof(new)))
return 1;
if (__bpf_getsockopt(ctx, sk, level, opt, &tmp, sizeof(tmp)) ||
if (bpf_getsockopt(ctx, level, opt, &tmp, sizeof(tmp)) ||
tmp != new)
return 1;
......@@ -218,13 +120,13 @@ static int bpf_test_sockopt_int(void *ctx, struct sock *sk,
else
expected = t->expected;
if (__bpf_getsockopt(ctx, sk, level, opt, &old, sizeof(old)) ||
if (bpf_getsockopt(ctx, level, opt, &old, sizeof(old)) ||
old == new)
return 1;
if (bpf_setsockopt(ctx, level, opt, &new, sizeof(new)))
return 1;
if (__bpf_getsockopt(ctx, sk, level, opt, &tmp, sizeof(tmp)) ||
if (bpf_getsockopt(ctx, level, opt, &tmp, sizeof(tmp)) ||
tmp != expected)
return 1;
......@@ -410,6 +312,34 @@ static int binddev_test(void *ctx)
return 0;
}
static int test_tcp_maxseg(void *ctx, struct sock *sk)
{
int val = 1314, tmp;
if (sk->sk_state != TCP_ESTABLISHED)
return bpf_setsockopt(ctx, IPPROTO_TCP, TCP_MAXSEG,
&val, sizeof(val));
if (bpf_getsockopt(ctx, IPPROTO_TCP, TCP_MAXSEG, &tmp, sizeof(tmp)) ||
tmp > val)
return -1;
return 0;
}
static int test_tcp_saved_syn(void *ctx, struct sock *sk)
{
__u8 saved_syn[20];
int one = 1;
if (sk->sk_state == TCP_LISTEN)
return bpf_setsockopt(ctx, IPPROTO_TCP, TCP_SAVE_SYN,
&one, sizeof(one));
return bpf_getsockopt(ctx, IPPROTO_TCP, TCP_SAVED_SYN,
saved_syn, sizeof(saved_syn));
}
SEC("lsm_cgroup/socket_post_create")
int BPF_PROG(socket_post_create, struct socket *sock, int family,
int type, int protocol, int kern)
......@@ -440,16 +370,22 @@ int skops_sockopt(struct bpf_sock_ops *skops)
switch (skops->op) {
case BPF_SOCK_OPS_TCP_LISTEN_CB:
nr_listen += !bpf_test_sockopt(skops, sk);
nr_listen += !(bpf_test_sockopt(skops, sk) ||
test_tcp_maxseg(skops, sk) ||
test_tcp_saved_syn(skops, sk));
break;
case BPF_SOCK_OPS_TCP_CONNECT_CB:
nr_connect += !bpf_test_sockopt(skops, sk);
nr_connect += !(bpf_test_sockopt(skops, sk) ||
test_tcp_maxseg(skops, sk));
break;
case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB:
nr_active += !bpf_test_sockopt(skops, sk);
nr_active += !(bpf_test_sockopt(skops, sk) ||
test_tcp_maxseg(skops, sk));
break;
case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB:
nr_passive += !bpf_test_sockopt(skops, sk);
nr_passive += !(bpf_test_sockopt(skops, sk) ||
test_tcp_maxseg(skops, sk) ||
test_tcp_saved_syn(skops, sk));
break;
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment