Commit 29003875 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov

bpf: Change bpf_setsockopt(SOL_SOCKET) to reuse sk_setsockopt()

After the prep work in the previous patches,
this patch removes most of the dup code from bpf_setsockopt(SOL_SOCKET)
and reuses them from sk_setsockopt().

The sock ptr test is added to the SO_RCVLOWAT because
the sk->sk_socket could be NULL in some of the bpf hooks.

The existing optname white-list is refactored into a new
function sol_socket_setsockopt().
Reviewed-by: default avatarStanislav Fomichev <sdf@google.com>
Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/r/20220817061804.4178920-1-kafai@fb.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent ebf9e8e6
...@@ -1828,6 +1828,8 @@ void sock_pfree(struct sk_buff *skb); ...@@ -1828,6 +1828,8 @@ void sock_pfree(struct sk_buff *skb);
#define sock_edemux sock_efree #define sock_edemux sock_efree
#endif #endif
int sk_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen);
int sock_setsockopt(struct socket *sock, int level, int op, int sock_setsockopt(struct socket *sock, int level, int op,
sockptr_t optval, unsigned int optlen); sockptr_t optval, unsigned int optlen);
......
...@@ -5013,109 +5013,43 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = { ...@@ -5013,109 +5013,43 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
.arg1_type = ARG_PTR_TO_CTX, .arg1_type = ARG_PTR_TO_CTX,
}; };
static int __bpf_setsockopt(struct sock *sk, int level, int optname, static int sol_socket_setsockopt(struct sock *sk, int optname,
char *optval, int optlen) char *optval, int optlen)
{ {
char devname[IFNAMSIZ];
int val, valbool;
struct net *net;
int ifindex;
int ret = 0;
if (!sk_fullsock(sk))
return -EINVAL;
if (level == SOL_SOCKET) {
if (optlen != sizeof(int) && optname != SO_BINDTODEVICE)
return -EINVAL;
val = *((int *)optval);
valbool = val ? 1 : 0;
/* Only some socketops are supported */
switch (optname) { switch (optname) {
case SO_RCVBUF:
val = min_t(u32, val, sysctl_rmem_max);
val = min_t(int, val, INT_MAX / 2);
sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
WRITE_ONCE(sk->sk_rcvbuf,
max_t(int, val * 2, SOCK_MIN_RCVBUF));
break;
case SO_SNDBUF: case SO_SNDBUF:
val = min_t(u32, val, sysctl_wmem_max); case SO_RCVBUF:
val = min_t(int, val, INT_MAX / 2); case SO_KEEPALIVE:
sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
WRITE_ONCE(sk->sk_sndbuf,
max_t(int, val * 2, SOCK_MIN_SNDBUF));
break;
case SO_MAX_PACING_RATE: /* 32bit version */
if (val != ~0U)
cmpxchg(&sk->sk_pacing_status,
SK_PACING_NONE,
SK_PACING_NEEDED);
sk->sk_max_pacing_rate = (val == ~0U) ?
~0UL : (unsigned int)val;
sk->sk_pacing_rate = min(sk->sk_pacing_rate,
sk->sk_max_pacing_rate);
break;
case SO_PRIORITY: case SO_PRIORITY:
sk->sk_priority = val; case SO_REUSEPORT:
break;
case SO_RCVLOWAT: case SO_RCVLOWAT:
if (val < 0)
val = INT_MAX;
if (sk->sk_socket && sk->sk_socket->ops->set_rcvlowat)
ret = sk->sk_socket->ops->set_rcvlowat(sk, val);
else
WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
break;
case SO_MARK: case SO_MARK:
if (sk->sk_mark != val) { case SO_MAX_PACING_RATE:
sk->sk_mark = val;
sk_dst_reset(sk);
}
break;
case SO_BINDTODEVICE:
optlen = min_t(long, optlen, IFNAMSIZ - 1);
strncpy(devname, optval, optlen);
devname[optlen] = 0;
ifindex = 0;
if (devname[0] != '\0') {
struct net_device *dev;
ret = -ENODEV;
net = sock_net(sk);
dev = dev_get_by_name(net, devname);
if (!dev)
break;
ifindex = dev->ifindex;
dev_put(dev);
}
fallthrough;
case SO_BINDTOIFINDEX: case SO_BINDTOIFINDEX:
if (optname == SO_BINDTOIFINDEX)
ifindex = val;
ret = sock_bindtoindex(sk, ifindex, false);
break;
case SO_KEEPALIVE:
if (sk->sk_prot->keepalive)
sk->sk_prot->keepalive(sk, valbool);
sock_valbool_flag(sk, SOCK_KEEPOPEN, valbool);
break;
case SO_REUSEPORT:
sk->sk_reuseport = valbool;
break;
case SO_TXREHASH: case SO_TXREHASH:
if (val < -1 || val > 1) { if (optlen != sizeof(int))
ret = -EINVAL; return -EINVAL;
break; break;
} case SO_BINDTODEVICE:
sk->sk_txrehash = (u8)val;
break; break;
default: default:
ret = -EINVAL; return -EINVAL;
} }
return sk_setsockopt(sk, SOL_SOCKET, optname,
KERNEL_SOCKPTR(optval), optlen);
}
static int __bpf_setsockopt(struct sock *sk, int level, int optname,
char *optval, int optlen)
{
int val, ret = 0;
if (!sk_fullsock(sk))
return -EINVAL;
if (level == SOL_SOCKET) {
return sol_socket_setsockopt(sk, optname, optval, optlen);
} else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) { } else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) {
if (optlen != sizeof(int) || sk->sk_family != AF_INET) if (optlen != sizeof(int) || sk->sk_family != AF_INET)
return -EINVAL; return -EINVAL;
......
...@@ -1077,7 +1077,7 @@ EXPORT_SYMBOL(sockopt_capable); ...@@ -1077,7 +1077,7 @@ EXPORT_SYMBOL(sockopt_capable);
* at the socket level. Everything here is generic. * at the socket level. Everything here is generic.
*/ */
static int sk_setsockopt(struct sock *sk, int level, int optname, int sk_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen) sockptr_t optval, unsigned int optlen)
{ {
struct so_timestamping timestamping; struct so_timestamping timestamping;
...@@ -1264,7 +1264,7 @@ static int sk_setsockopt(struct sock *sk, int level, int optname, ...@@ -1264,7 +1264,7 @@ static int sk_setsockopt(struct sock *sk, int level, int optname,
case SO_RCVLOWAT: case SO_RCVLOWAT:
if (val < 0) if (val < 0)
val = INT_MAX; val = INT_MAX;
if (sock->ops->set_rcvlowat) if (sock && sock->ops->set_rcvlowat)
ret = sock->ops->set_rcvlowat(sk, val); ret = sock->ops->set_rcvlowat(sk, val);
else else
WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment