Commit 65ddc82d authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov

bpf: Change bpf_getsockopt(SOL_SOCKET) to reuse sk_getsockopt()

This patch changes bpf_getsockopt(SOL_SOCKET) to reuse
sk_getsockopt().  It removes all duplicated code from
bpf_getsockopt(SOL_SOCKET).

Before this patch, there were some optnames available to
bpf_setsockopt(SOL_SOCKET) but missing in bpf_getsockopt(SOL_SOCKET).
It surprises users from time to time.  For example, SO_REUSEADDR,
SO_KEEPALIVE, SO_RCVLOWAT, and SO_MAX_PACING_RATE.  This patch
automatically closes this gap without duplicating more code.
The only exception is SO_BINDTODEVICE because it needs to acquire a
blocking lock.  Thus, SO_BINDTODEVICE is not supported.
Signed-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
Link: https://lore.kernel.org/r/20220902002912.2894040-1-kafai@fb.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent c2b063ca
...@@ -1833,6 +1833,8 @@ int sk_setsockopt(struct sock *sk, int level, int optname, ...@@ -1833,6 +1833,8 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
int sock_setsockopt(struct socket *sock, int level, int op, int sock_setsockopt(struct socket *sock, int level, int op,
sockptr_t optval, unsigned int optlen); sockptr_t optval, unsigned int optlen);
int sk_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen);
int sock_getsockopt(struct socket *sock, int level, int op, int sock_getsockopt(struct socket *sock, int level, int op,
char __user *optval, int __user *optlen); char __user *optval, int __user *optlen);
int sock_gettstamp(struct socket *sock, void __user *userstamp, int sock_gettstamp(struct socket *sock, void __user *userstamp,
......
...@@ -5017,8 +5017,9 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = { ...@@ -5017,8 +5017,9 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
.arg1_type = ARG_PTR_TO_CTX, .arg1_type = ARG_PTR_TO_CTX,
}; };
static int sol_socket_setsockopt(struct sock *sk, int optname, static int sol_socket_sockopt(struct sock *sk, int optname,
char *optval, int optlen) char *optval, int *optlen,
bool getopt)
{ {
switch (optname) { switch (optname) {
case SO_REUSEADDR: case SO_REUSEADDR:
...@@ -5032,7 +5033,7 @@ static int sol_socket_setsockopt(struct sock *sk, int optname, ...@@ -5032,7 +5033,7 @@ static int sol_socket_setsockopt(struct sock *sk, int optname,
case SO_MAX_PACING_RATE: case SO_MAX_PACING_RATE:
case SO_BINDTOIFINDEX: case SO_BINDTOIFINDEX:
case SO_TXREHASH: case SO_TXREHASH:
if (optlen != sizeof(int)) if (*optlen != sizeof(int))
return -EINVAL; return -EINVAL;
break; break;
case SO_BINDTODEVICE: case SO_BINDTODEVICE:
...@@ -5041,8 +5042,16 @@ static int sol_socket_setsockopt(struct sock *sk, int optname, ...@@ -5041,8 +5042,16 @@ static int sol_socket_setsockopt(struct sock *sk, int optname,
return -EINVAL; return -EINVAL;
} }
if (getopt) {
if (optname == SO_BINDTODEVICE)
return -EINVAL;
return sk_getsockopt(sk, SOL_SOCKET, optname,
KERNEL_SOCKPTR(optval),
KERNEL_SOCKPTR(optlen));
}
return sk_setsockopt(sk, SOL_SOCKET, optname, return sk_setsockopt(sk, SOL_SOCKET, optname,
KERNEL_SOCKPTR(optval), optlen); KERNEL_SOCKPTR(optval), *optlen);
} }
static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname, static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname,
...@@ -5168,7 +5177,7 @@ static int __bpf_setsockopt(struct sock *sk, int level, int optname, ...@@ -5168,7 +5177,7 @@ static int __bpf_setsockopt(struct sock *sk, int level, int optname,
return -EINVAL; return -EINVAL;
if (level == SOL_SOCKET) if (level == SOL_SOCKET)
return sol_socket_setsockopt(sk, optname, optval, optlen); return sol_socket_sockopt(sk, optname, optval, &optlen, false);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP)
return sol_ip_setsockopt(sk, optname, optval, optlen); return sol_ip_setsockopt(sk, optname, optval, optlen);
else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6) else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6)
...@@ -5190,38 +5199,13 @@ static int _bpf_setsockopt(struct sock *sk, int level, int optname, ...@@ -5190,38 +5199,13 @@ static int _bpf_setsockopt(struct sock *sk, int level, int optname,
static int __bpf_getsockopt(struct sock *sk, int level, int optname, static int __bpf_getsockopt(struct sock *sk, int level, int optname,
char *optval, int optlen) char *optval, int optlen)
{ {
int err = 0, saved_optlen = optlen;
if (!sk_fullsock(sk)) if (!sk_fullsock(sk))
goto err_clear; goto err_clear;
if (level == SOL_SOCKET) { if (level == SOL_SOCKET) {
if (optlen != sizeof(int)) err = sol_socket_sockopt(sk, optname, optval, &optlen, true);
goto err_clear;
switch (optname) {
case SO_RCVBUF:
*((int *)optval) = sk->sk_rcvbuf;
break;
case SO_SNDBUF:
*((int *)optval) = sk->sk_sndbuf;
break;
case SO_MARK:
*((int *)optval) = sk->sk_mark;
break;
case SO_PRIORITY:
*((int *)optval) = sk->sk_priority;
break;
case SO_BINDTOIFINDEX:
*((int *)optval) = sk->sk_bound_dev_if;
break;
case SO_REUSEPORT:
*((int *)optval) = sk->sk_reuseport;
break;
case SO_TXREHASH:
*((int *)optval) = sk->sk_txrehash;
break;
default:
goto err_clear;
}
} else if (IS_ENABLED(CONFIG_INET) && } else if (IS_ENABLED(CONFIG_INET) &&
level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) { level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) {
struct inet_connection_sock *icsk; struct inet_connection_sock *icsk;
...@@ -5278,7 +5262,12 @@ static int __bpf_getsockopt(struct sock *sk, int level, int optname, ...@@ -5278,7 +5262,12 @@ static int __bpf_getsockopt(struct sock *sk, int level, int optname,
} else { } else {
goto err_clear; goto err_clear;
} }
return 0;
if (err)
optlen = 0;
if (optlen < saved_optlen)
memset(optval + optlen, 0, saved_optlen - optlen);
return err;
err_clear: err_clear:
memset(optval, 0, optlen); memset(optval, 0, optlen);
return -EINVAL; return -EINVAL;
......
...@@ -1583,8 +1583,8 @@ static int groups_to_user(sockptr_t dst, const struct group_info *src) ...@@ -1583,8 +1583,8 @@ static int groups_to_user(sockptr_t dst, const struct group_info *src)
return 0; return 0;
} }
static int sk_getsockopt(struct sock *sk, int level, int optname, int sk_getsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, sockptr_t optlen) sockptr_t optval, sockptr_t optlen)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment