Commit 6dadbe4b authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov

bpf: net: Change do_ipv6_getsockopt() to take the sockptr_t argument

Similar to the earlier patch that changes sk_getsockopt() to
take the sockptr_t argument .  This patch also changes
do_ipv6_getsockopt() to take the sockptr_t argument such that
a latter patch can make bpf_getsockopt(SOL_IPV6) to reuse
do_ipv6_getsockopt().

Note on the change in ip6_mc_msfget().  This function is to
return an array of sockaddr_storage in optval.  This function
is shared between ipv6_get_msfilter() and compat_ipv6_get_msfilter().
However, the sockaddr_storage is stored at different offset of the
optval because of the difference between group_filter and
compat_group_filter.  Thus, a new 'ss_offset' argument is
added to ip6_mc_msfget().
Signed-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
Link: https://lore.kernel.org/r/20220902002853.2892532-1-kafai@fb.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 9c3f9707
...@@ -27,7 +27,7 @@ struct sock; ...@@ -27,7 +27,7 @@ struct sock;
#ifdef CONFIG_IPV6_MROUTE #ifdef CONFIG_IPV6_MROUTE
extern int ip6_mroute_setsockopt(struct sock *, int, sockptr_t, unsigned int); extern int ip6_mroute_setsockopt(struct sock *, int, sockptr_t, unsigned int);
extern int ip6_mroute_getsockopt(struct sock *, int, char __user *, int __user *); extern int ip6_mroute_getsockopt(struct sock *, int, sockptr_t, sockptr_t);
extern int ip6_mr_input(struct sk_buff *skb); extern int ip6_mr_input(struct sk_buff *skb);
extern int ip6mr_ioctl(struct sock *sk, int cmd, void __user *arg); extern int ip6mr_ioctl(struct sock *sk, int cmd, void __user *arg);
extern int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg); extern int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg);
...@@ -42,7 +42,7 @@ static inline int ip6_mroute_setsockopt(struct sock *sock, int optname, ...@@ -42,7 +42,7 @@ static inline int ip6_mroute_setsockopt(struct sock *sock, int optname,
static inline static inline
int ip6_mroute_getsockopt(struct sock *sock, int ip6_mroute_getsockopt(struct sock *sock,
int optname, char __user *optval, int __user *optlen) int optname, sockptr_t optval, sockptr_t optlen)
{ {
return -ENOPROTOOPT; return -ENOPROTOOPT;
} }
......
...@@ -1209,7 +1209,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk, ...@@ -1209,7 +1209,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage *list); struct sockaddr_storage *list);
int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p); sockptr_t optval, size_t ss_offset);
#ifdef CONFIG_PROC_FS #ifdef CONFIG_PROC_FS
int ac6_proc_init(struct net *net); int ac6_proc_init(struct net *net);
......
...@@ -1827,8 +1827,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval, ...@@ -1827,8 +1827,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval,
* Getsock opt support for the multicast routing system. * Getsock opt support for the multicast routing system.
*/ */
int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int ip6_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval,
int __user *optlen) sockptr_t optlen)
{ {
int olr; int olr;
int val; int val;
...@@ -1859,16 +1859,16 @@ int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, ...@@ -1859,16 +1859,16 @@ int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval,
return -ENOPROTOOPT; return -ENOPROTOOPT;
} }
if (get_user(olr, optlen)) if (copy_from_sockptr(&olr, optlen, sizeof(int)))
return -EFAULT; return -EFAULT;
olr = min_t(int, olr, sizeof(int)); olr = min_t(int, olr, sizeof(int));
if (olr < 0) if (olr < 0)
return -EINVAL; return -EINVAL;
if (put_user(olr, optlen)) if (copy_to_sockptr(optlen, &olr, sizeof(int)))
return -EFAULT; return -EFAULT;
if (copy_to_user(optval, &val, olr)) if (copy_to_sockptr(optval, &val, olr))
return -EFAULT; return -EFAULT;
return 0; return 0;
} }
......
...@@ -1030,7 +1030,7 @@ int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, ...@@ -1030,7 +1030,7 @@ int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
EXPORT_SYMBOL(ipv6_setsockopt); EXPORT_SYMBOL(ipv6_setsockopt);
static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt, static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt,
int optname, char __user *optval, int len) int optname, sockptr_t optval, int len)
{ {
struct ipv6_opt_hdr *hdr; struct ipv6_opt_hdr *hdr;
...@@ -1058,45 +1058,44 @@ static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt, ...@@ -1058,45 +1058,44 @@ static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt,
return 0; return 0;
len = min_t(unsigned int, len, ipv6_optlen(hdr)); len = min_t(unsigned int, len, ipv6_optlen(hdr));
if (copy_to_user(optval, hdr, len)) if (copy_to_sockptr(optval, hdr, len))
return -EFAULT; return -EFAULT;
return len; return len;
} }
static int ipv6_get_msfilter(struct sock *sk, void __user *optval, static int ipv6_get_msfilter(struct sock *sk, sockptr_t optval,
int __user *optlen, int len) sockptr_t optlen, int len)
{ {
const int size0 = offsetof(struct group_filter, gf_slist_flex); const int size0 = offsetof(struct group_filter, gf_slist_flex);
struct group_filter __user *p = optval;
struct group_filter gsf; struct group_filter gsf;
int num; int num;
int err; int err;
if (len < size0) if (len < size0)
return -EINVAL; return -EINVAL;
if (copy_from_user(&gsf, p, size0)) if (copy_from_sockptr(&gsf, optval, size0))
return -EFAULT; return -EFAULT;
if (gsf.gf_group.ss_family != AF_INET6) if (gsf.gf_group.ss_family != AF_INET6)
return -EADDRNOTAVAIL; return -EADDRNOTAVAIL;
num = gsf.gf_numsrc; num = gsf.gf_numsrc;
lock_sock(sk); lock_sock(sk);
err = ip6_mc_msfget(sk, &gsf, p->gf_slist_flex); err = ip6_mc_msfget(sk, &gsf, optval, size0);
if (!err) { if (!err) {
if (num > gsf.gf_numsrc) if (num > gsf.gf_numsrc)
num = gsf.gf_numsrc; num = gsf.gf_numsrc;
if (put_user(GROUP_FILTER_SIZE(num), optlen) || len = GROUP_FILTER_SIZE(num);
copy_to_user(p, &gsf, size0)) if (copy_to_sockptr(optlen, &len, sizeof(int)) ||
copy_to_sockptr(optval, &gsf, size0))
err = -EFAULT; err = -EFAULT;
} }
release_sock(sk); release_sock(sk);
return err; return err;
} }
static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval, static int compat_ipv6_get_msfilter(struct sock *sk, sockptr_t optval,
int __user *optlen, int len) sockptr_t optlen, int len)
{ {
const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); const int size0 = offsetof(struct compat_group_filter, gf_slist_flex);
struct compat_group_filter __user *p = optval;
struct compat_group_filter gf32; struct compat_group_filter gf32;
struct group_filter gf; struct group_filter gf;
int err; int err;
...@@ -1105,7 +1104,7 @@ static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval, ...@@ -1105,7 +1104,7 @@ static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval,
if (len < size0) if (len < size0)
return -EINVAL; return -EINVAL;
if (copy_from_user(&gf32, p, size0)) if (copy_from_sockptr(&gf32, optval, size0))
return -EFAULT; return -EFAULT;
gf.gf_interface = gf32.gf_interface; gf.gf_interface = gf32.gf_interface;
gf.gf_fmode = gf32.gf_fmode; gf.gf_fmode = gf32.gf_fmode;
...@@ -1116,22 +1115,24 @@ static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval, ...@@ -1116,22 +1115,24 @@ static int compat_ipv6_get_msfilter(struct sock *sk, void __user *optval,
return -EADDRNOTAVAIL; return -EADDRNOTAVAIL;
lock_sock(sk); lock_sock(sk);
err = ip6_mc_msfget(sk, &gf, p->gf_slist_flex); err = ip6_mc_msfget(sk, &gf, optval, size0);
release_sock(sk); release_sock(sk);
if (err) if (err)
return err; return err;
if (num > gf.gf_numsrc) if (num > gf.gf_numsrc)
num = gf.gf_numsrc; num = gf.gf_numsrc;
len = GROUP_FILTER_SIZE(num) - (sizeof(gf)-sizeof(gf32)); len = GROUP_FILTER_SIZE(num) - (sizeof(gf)-sizeof(gf32));
if (put_user(len, optlen) || if (copy_to_sockptr(optlen, &len, sizeof(int)) ||
put_user(gf.gf_fmode, &p->gf_fmode) || copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode),
put_user(gf.gf_numsrc, &p->gf_numsrc)) &gf.gf_fmode, sizeof(gf32.gf_fmode)) ||
copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc),
&gf.gf_numsrc, sizeof(gf32.gf_numsrc)))
return -EFAULT; return -EFAULT;
return 0; return 0;
} }
static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen) sockptr_t optval, sockptr_t optlen)
{ {
struct ipv6_pinfo *np = inet6_sk(sk); struct ipv6_pinfo *np = inet6_sk(sk);
int len; int len;
...@@ -1140,7 +1141,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1140,7 +1141,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (ip6_mroute_opt(optname)) if (ip6_mroute_opt(optname))
return ip6_mroute_getsockopt(sk, optname, optval, optlen); return ip6_mroute_getsockopt(sk, optname, optval, optlen);
if (get_user(len, optlen)) if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT; return -EFAULT;
switch (optname) { switch (optname) {
case IPV6_ADDRFORM: case IPV6_ADDRFORM:
...@@ -1164,10 +1165,15 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1164,10 +1165,15 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (sk->sk_type != SOCK_STREAM) if (sk->sk_type != SOCK_STREAM)
return -ENOPROTOOPT; return -ENOPROTOOPT;
msg.msg_control_user = optval; if (optval.is_kernel) {
msg.msg_control_is_user = false;
msg.msg_control = optval.kernel;
} else {
msg.msg_control_is_user = true;
msg.msg_control_user = optval.user;
}
msg.msg_controllen = len; msg.msg_controllen = len;
msg.msg_flags = 0; msg.msg_flags = 0;
msg.msg_control_is_user = true;
lock_sock(sk); lock_sock(sk);
skb = np->pktoptions; skb = np->pktoptions;
...@@ -1210,7 +1216,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1210,7 +1216,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
} }
} }
len -= msg.msg_controllen; len -= msg.msg_controllen;
return put_user(len, optlen); return copy_to_sockptr(optlen, &len, sizeof(int));
} }
case IPV6_MTU: case IPV6_MTU:
{ {
...@@ -1270,7 +1276,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1270,7 +1276,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
/* check if ipv6_getsockopt_sticky() returns err code */ /* check if ipv6_getsockopt_sticky() returns err code */
if (len < 0) if (len < 0)
return len; return len;
return put_user(len, optlen); return copy_to_sockptr(optlen, &len, sizeof(int));
} }
case IPV6_RECVHOPOPTS: case IPV6_RECVHOPOPTS:
...@@ -1324,9 +1330,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1324,9 +1330,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (!mtuinfo.ip6m_mtu) if (!mtuinfo.ip6m_mtu)
return -ENOTCONN; return -ENOTCONN;
if (put_user(len, optlen)) if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT; return -EFAULT;
if (copy_to_user(optval, &mtuinfo, len)) if (copy_to_sockptr(optval, &mtuinfo, len))
return -EFAULT; return -EFAULT;
return 0; return 0;
...@@ -1403,7 +1409,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1403,7 +1409,7 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (len < sizeof(freq)) if (len < sizeof(freq))
return -EINVAL; return -EINVAL;
if (copy_from_user(&freq, optval, sizeof(freq))) if (copy_from_sockptr(&freq, optval, sizeof(freq)))
return -EFAULT; return -EFAULT;
if (freq.flr_action != IPV6_FL_A_GET) if (freq.flr_action != IPV6_FL_A_GET)
...@@ -1418,9 +1424,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1418,9 +1424,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
if (val < 0) if (val < 0)
return val; return val;
if (put_user(len, optlen)) if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT; return -EFAULT;
if (copy_to_user(optval, &freq, len)) if (copy_to_sockptr(optval, &freq, len))
return -EFAULT; return -EFAULT;
return 0; return 0;
...@@ -1472,9 +1478,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1472,9 +1478,9 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
return -ENOPROTOOPT; return -ENOPROTOOPT;
} }
len = min_t(unsigned int, sizeof(int), len); len = min_t(unsigned int, sizeof(int), len);
if (put_user(len, optlen)) if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT; return -EFAULT;
if (copy_to_user(optval, &val, len)) if (copy_to_sockptr(optval, &val, len))
return -EFAULT; return -EFAULT;
return 0; return 0;
} }
...@@ -1490,7 +1496,8 @@ int ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1490,7 +1496,8 @@ int ipv6_getsockopt(struct sock *sk, int level, int optname,
if (level != SOL_IPV6) if (level != SOL_IPV6)
return -ENOPROTOOPT; return -ENOPROTOOPT;
err = do_ipv6_getsockopt(sk, level, optname, optval, optlen); err = do_ipv6_getsockopt(sk, level, optname,
USER_SOCKPTR(optval), USER_SOCKPTR(optlen));
#ifdef CONFIG_NETFILTER #ifdef CONFIG_NETFILTER
/* we need to exclude all possible ENOPROTOOPTs except default case */ /* we need to exclude all possible ENOPROTOOPTs except default case */
if (err == -ENOPROTOOPT && optname != IPV6_2292PKTOPTIONS) { if (err == -ENOPROTOOPT && optname != IPV6_2292PKTOPTIONS) {
......
...@@ -580,7 +580,7 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, ...@@ -580,7 +580,7 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
} }
int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p) sockptr_t optval, size_t ss_offset)
{ {
struct ipv6_pinfo *inet6 = inet6_sk(sk); struct ipv6_pinfo *inet6 = inet6_sk(sk);
const struct in6_addr *group; const struct in6_addr *group;
...@@ -612,8 +612,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, ...@@ -612,8 +612,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc; copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc;
gsf->gf_numsrc = count; gsf->gf_numsrc = count;
for (i = 0; i < copycount; i++) {
for (i = 0; i < copycount; i++, p++) {
struct sockaddr_in6 *psin6; struct sockaddr_in6 *psin6;
struct sockaddr_storage ss; struct sockaddr_storage ss;
...@@ -621,8 +620,9 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, ...@@ -621,8 +620,9 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
memset(&ss, 0, sizeof(ss)); memset(&ss, 0, sizeof(ss));
psin6->sin6_family = AF_INET6; psin6->sin6_family = AF_INET6;
psin6->sin6_addr = psl->sl_addr[i]; psin6->sin6_addr = psl->sl_addr[i];
if (copy_to_user(p, &ss, sizeof(ss))) if (copy_to_sockptr_offset(optval, ss_offset, &ss, sizeof(ss)))
return -EFAULT; return -EFAULT;
ss_offset += sizeof(ss);
} }
return 0; return 0;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment