Commit 1e1e49df authored by Andrii Nakryiko's avatar Andrii Nakryiko

Merge branch 'sockmap: add sockmap support for unix stream socket'

Jiang Wang says:

====================

This patch series add support for unix stream type
for sockmap. Sockmap already supports TCP, UDP,
unix dgram types. The unix stream support is similar
to unix dgram.

Also add selftests for unix stream type in sockmap tests.
====================
Signed-off-by: default avatarAndrii Nakryiko <andrii@kernel.org>
parents edce1a24 31c50aee
......@@ -87,6 +87,8 @@ long unix_outq_len(struct sock *sk);
int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
int flags);
int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
int flags);
#ifdef CONFIG_SYSCTL
int unix_sysctl_register(struct net *net);
void unix_sysctl_unregister(struct net *net);
......@@ -96,9 +98,11 @@ static inline void unix_sysctl_unregister(struct net *net) {}
#endif
#ifdef CONFIG_BPF_SYSCALL
extern struct proto unix_proto;
extern struct proto unix_dgram_proto;
extern struct proto unix_stream_proto;
int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
void __init unix_bpf_build_proto(void);
#else
static inline void __init unix_bpf_build_proto(void)
......
......@@ -1494,6 +1494,7 @@ void sock_map_unhash(struct sock *sk)
rcu_read_unlock();
saved_unhash(sk);
}
EXPORT_SYMBOL_GPL(sock_map_unhash);
void sock_map_close(struct sock *sk, long timeout)
{
......
......@@ -679,6 +679,8 @@ static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t);
static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int);
static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor);
static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor);
static int unix_dgram_connect(struct socket *, struct sockaddr *,
int, int);
static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t);
......@@ -732,6 +734,7 @@ static const struct proto_ops unix_stream_ops = {
.shutdown = unix_shutdown,
.sendmsg = unix_stream_sendmsg,
.recvmsg = unix_stream_recvmsg,
.read_sock = unix_stream_read_sock,
.mmap = sock_no_mmap,
.sendpage = unix_stream_sendpage,
.splice_read = unix_stream_splice_read,
......@@ -795,17 +798,35 @@ static void unix_close(struct sock *sk, long timeout)
*/
}
struct proto unix_proto = {
.name = "UNIX",
static void unix_unhash(struct sock *sk)
{
/* Nothing to do here, unix socket does not need a ->unhash().
* This is merely for sockmap.
*/
}
struct proto unix_dgram_proto = {
.name = "UNIX-DGRAM",
.owner = THIS_MODULE,
.obj_size = sizeof(struct unix_sock),
.close = unix_close,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = unix_bpf_update_proto,
.psock_update_sk_prot = unix_dgram_bpf_update_proto,
#endif
};
static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
struct proto unix_stream_proto = {
.name = "UNIX-STREAM",
.owner = THIS_MODULE,
.obj_size = sizeof(struct unix_sock),
.close = unix_close,
.unhash = unix_unhash,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = unix_stream_bpf_update_proto,
#endif
};
static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type)
{
struct sock *sk = NULL;
struct unix_sock *u;
......@@ -814,7 +835,11 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files())
goto out;
sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_proto, kern);
if (type == SOCK_STREAM)
sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern);
else /*dgram and seqpacket */
sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern);
if (!sk)
goto out;
......@@ -876,7 +901,7 @@ static int unix_create(struct net *net, struct socket *sock, int protocol,
return -ESOCKTNOSUPPORT;
}
return unix_create1(net, sock, kern) ? 0 : -ENOMEM;
return unix_create1(net, sock, kern, sock->type) ? 0 : -ENOMEM;
}
static int unix_release(struct socket *sock)
......@@ -1290,7 +1315,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
err = -ENOMEM;
/* create new sock for complete connection */
newsk = unix_create1(sock_net(sk), NULL, 0);
newsk = unix_create1(sock_net(sk), NULL, 0, sock->type);
if (newsk == NULL)
goto out;
......@@ -2320,8 +2345,10 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t si
struct sock *sk = sock->sk;
#ifdef CONFIG_BPF_SYSCALL
if (sk->sk_prot != &unix_proto)
return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
const struct proto *prot = READ_ONCE(sk->sk_prot);
if (prot != &unix_dgram_proto)
return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
flags & ~MSG_DONTWAIT, NULL);
#endif
return __unix_dgram_recvmsg(sk, msg, size, flags);
......@@ -2491,6 +2518,15 @@ static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
}
#endif
static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor)
{
if (unlikely(sk->sk_state != TCP_ESTABLISHED))
return -ENOTCONN;
return unix_read_sock(sk, desc, recv_actor);
}
static int unix_stream_read_generic(struct unix_stream_read_state *state,
bool freezable)
{
......@@ -2716,6 +2752,20 @@ static int unix_stream_read_actor(struct sk_buff *skb,
return ret ?: chunk;
}
int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg,
size_t size, int flags)
{
struct unix_stream_read_state state = {
.recv_actor = unix_stream_read_actor,
.socket = sk->sk_socket,
.msg = msg,
.size = size,
.flags = flags
};
return unix_stream_read_generic(&state, true);
}
static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
size_t size, int flags)
{
......@@ -2727,6 +2777,14 @@ static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
.flags = flags
};
#ifdef CONFIG_BPF_SYSCALL
struct sock *sk = sock->sk;
const struct proto *prot = READ_ONCE(sk->sk_prot);
if (prot != &unix_stream_proto)
return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
flags & ~MSG_DONTWAIT, NULL);
#endif
return unix_stream_read_generic(&state, true);
}
......@@ -2787,7 +2845,9 @@ static int unix_shutdown(struct socket *sock, int mode)
(sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)) {
int peer_mode = 0;
const struct proto *prot = READ_ONCE(other->sk_prot);
prot->unhash(other);
if (mode&RCV_SHUTDOWN)
peer_mode |= SEND_SHUTDOWN;
if (mode&SEND_SHUTDOWN)
......@@ -2796,10 +2856,12 @@ static int unix_shutdown(struct socket *sock, int mode)
other->sk_shutdown |= peer_mode;
unix_state_unlock(other);
other->sk_state_change(other);
if (peer_mode == SHUTDOWN_MASK)
if (peer_mode == SHUTDOWN_MASK) {
sk_wake_async(other, SOCK_WAKE_WAITD, POLL_HUP);
else if (peer_mode & RCV_SHUTDOWN)
other->sk_state = TCP_CLOSE;
} else if (peer_mode & RCV_SHUTDOWN) {
sk_wake_async(other, SOCK_WAKE_WAITD, POLL_IN);
}
}
if (other)
sock_put(other);
......@@ -3277,7 +3339,13 @@ static int __init af_unix_init(void)
BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb));
rc = proto_register(&unix_proto, 1);
rc = proto_register(&unix_dgram_proto, 1);
if (rc != 0) {
pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
goto out;
}
rc = proto_register(&unix_stream_proto, 1);
if (rc != 0) {
pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
goto out;
......@@ -3298,7 +3366,8 @@ static int __init af_unix_init(void)
static void __exit af_unix_exit(void)
{
sock_unregister(PF_UNIX);
proto_unregister(&unix_proto);
proto_unregister(&unix_dgram_proto);
proto_unregister(&unix_stream_proto);
unregister_pernet_subsys(&unix_net_ops);
}
......
......@@ -38,9 +38,18 @@ static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock,
return ret;
}
static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
size_t len, int nonblock, int flags,
int *addr_len)
static int __unix_recvmsg(struct sock *sk, struct msghdr *msg,
size_t len, int flags)
{
if (sk->sk_type == SOCK_DGRAM)
return __unix_dgram_recvmsg(sk, msg, len, flags);
else
return __unix_stream_recvmsg(sk, msg, len, flags);
}
static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
size_t len, int nonblock, int flags,
int *addr_len)
{
struct unix_sock *u = unix_sk(sk);
struct sk_psock *psock;
......@@ -48,14 +57,14 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
psock = sk_psock_get(sk);
if (unlikely(!psock))
return __unix_dgram_recvmsg(sk, msg, len, flags);
return __unix_recvmsg(sk, msg, len, flags);
mutex_lock(&u->iolock);
if (!skb_queue_empty(&sk->sk_receive_queue) &&
sk_psock_queue_empty(psock)) {
mutex_unlock(&u->iolock);
sk_psock_put(sk, psock);
return __unix_dgram_recvmsg(sk, msg, len, flags);
return __unix_recvmsg(sk, msg, len, flags);
}
msg_bytes_ready:
......@@ -71,7 +80,7 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
goto msg_bytes_ready;
mutex_unlock(&u->iolock);
sk_psock_put(sk, psock);
return __unix_dgram_recvmsg(sk, msg, len, flags);
return __unix_recvmsg(sk, msg, len, flags);
}
copied = -EAGAIN;
}
......@@ -80,30 +89,55 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
return copied;
}
static struct proto *unix_prot_saved __read_mostly;
static DEFINE_SPINLOCK(unix_prot_lock);
static struct proto unix_bpf_prot;
static struct proto *unix_dgram_prot_saved __read_mostly;
static DEFINE_SPINLOCK(unix_dgram_prot_lock);
static struct proto unix_dgram_bpf_prot;
static struct proto *unix_stream_prot_saved __read_mostly;
static DEFINE_SPINLOCK(unix_stream_prot_lock);
static struct proto unix_stream_bpf_prot;
static void unix_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
{
*prot = *base;
prot->close = sock_map_close;
prot->recvmsg = unix_dgram_bpf_recvmsg;
prot->recvmsg = unix_bpf_recvmsg;
}
static void unix_stream_bpf_rebuild_protos(struct proto *prot,
const struct proto *base)
{
*prot = *base;
prot->close = sock_map_close;
prot->recvmsg = unix_bpf_recvmsg;
prot->unhash = sock_map_unhash;
}
static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops)
{
if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) {
spin_lock_bh(&unix_dgram_prot_lock);
if (likely(ops != unix_dgram_prot_saved)) {
unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops);
smp_store_release(&unix_dgram_prot_saved, ops);
}
spin_unlock_bh(&unix_dgram_prot_lock);
}
}
static void unix_bpf_check_needs_rebuild(struct proto *ops)
static void unix_stream_bpf_check_needs_rebuild(struct proto *ops)
{
if (unlikely(ops != smp_load_acquire(&unix_prot_saved))) {
spin_lock_bh(&unix_prot_lock);
if (likely(ops != unix_prot_saved)) {
unix_bpf_rebuild_protos(&unix_bpf_prot, ops);
smp_store_release(&unix_prot_saved, ops);
if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) {
spin_lock_bh(&unix_stream_prot_lock);
if (likely(ops != unix_stream_prot_saved)) {
unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops);
smp_store_release(&unix_stream_prot_saved, ops);
}
spin_unlock_bh(&unix_prot_lock);
spin_unlock_bh(&unix_stream_prot_lock);
}
}
int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
if (sk->sk_type != SOCK_DGRAM)
return -EOPNOTSUPP;
......@@ -114,12 +148,27 @@ int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
return 0;
}
unix_bpf_check_needs_rebuild(psock->sk_proto);
WRITE_ONCE(sk->sk_prot, &unix_bpf_prot);
unix_dgram_bpf_check_needs_rebuild(psock->sk_proto);
WRITE_ONCE(sk->sk_prot, &unix_dgram_bpf_prot);
return 0;
}
int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
if (restore) {
sk->sk_write_space = psock->saved_write_space;
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
return 0;
}
unix_stream_bpf_check_needs_rebuild(psock->sk_proto);
WRITE_ONCE(sk->sk_prot, &unix_stream_bpf_prot);
return 0;
}
void __init unix_bpf_build_proto(void)
{
unix_bpf_rebuild_protos(&unix_bpf_prot, &unix_proto);
unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto);
unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto);
}
......@@ -1692,14 +1692,14 @@ static void test_reuseport(struct test_sockmap_listen *skel,
}
}
static int udp_socketpair(int family, int *s, int *c)
static int inet_socketpair(int family, int type, int *s, int *c)
{
struct sockaddr_storage addr;
socklen_t len;
int p0, c0;
int err;
p0 = socket_loopback(family, SOCK_DGRAM | SOCK_NONBLOCK);
p0 = socket_loopback(family, type | SOCK_NONBLOCK);
if (p0 < 0)
return p0;
......@@ -1708,7 +1708,7 @@ static int udp_socketpair(int family, int *s, int *c)
if (err)
goto close_peer0;
c0 = xsocket(family, SOCK_DGRAM | SOCK_NONBLOCK, 0);
c0 = xsocket(family, type | SOCK_NONBLOCK, 0);
if (c0 < 0) {
err = c0;
goto close_peer0;
......@@ -1747,10 +1747,10 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd,
zero_verdict_count(verd_mapfd);
err = udp_socketpair(family, &p0, &c0);
err = inet_socketpair(family, SOCK_DGRAM, &p0, &c0);
if (err)
return;
err = udp_socketpair(family, &p1, &c1);
err = inet_socketpair(family, SOCK_DGRAM, &p1, &c1);
if (err)
goto close_cli0;
......@@ -1825,7 +1825,7 @@ static void test_udp_redir(struct test_sockmap_listen *skel, struct bpf_map *map
udp_skb_redir_to_connected(skel, map, family);
}
static void udp_unix_redir_to_connected(int family, int sock_mapfd,
static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd,
int verd_mapfd, enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
......@@ -1843,7 +1843,7 @@ static void udp_unix_redir_to_connected(int family, int sock_mapfd,
return;
c0 = sfd[0], p0 = sfd[1];
err = udp_socketpair(family, &p1, &c1);
err = inet_socketpair(family, SOCK_DGRAM, &p1, &c1);
if (err)
goto close;
......@@ -1884,7 +1884,7 @@ static void udp_unix_redir_to_connected(int family, int sock_mapfd,
xclose(p0);
}
static void udp_unix_skb_redir_to_connected(struct test_sockmap_listen *skel,
static void inet_unix_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int family)
{
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
......@@ -1897,14 +1897,20 @@ static void udp_unix_skb_redir_to_connected(struct test_sockmap_listen *skel,
return;
skel->bss->test_ingress = false;
udp_unix_redir_to_connected(family, sock_map, verdict_map, REDIR_EGRESS);
inet_unix_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map,
REDIR_EGRESS);
inet_unix_redir_to_connected(family, SOCK_STREAM, sock_map, verdict_map,
REDIR_EGRESS);
skel->bss->test_ingress = true;
udp_unix_redir_to_connected(family, sock_map, verdict_map, REDIR_INGRESS);
inet_unix_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map,
REDIR_INGRESS);
inet_unix_redir_to_connected(family, SOCK_STREAM, sock_map, verdict_map,
REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
}
static void unix_udp_redir_to_connected(int family, int sock_mapfd,
static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd,
int verd_mapfd, enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
......@@ -1917,7 +1923,7 @@ static void unix_udp_redir_to_connected(int family, int sock_mapfd,
zero_verdict_count(verd_mapfd);
err = udp_socketpair(family, &p0, &c0);
err = inet_socketpair(family, SOCK_DGRAM, &p0, &c0);
if (err)
return;
......@@ -1959,7 +1965,7 @@ static void unix_udp_redir_to_connected(int family, int sock_mapfd,
}
static void unix_udp_skb_redir_to_connected(struct test_sockmap_listen *skel,
static void unix_inet_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int family)
{
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
......@@ -1972,9 +1978,15 @@ static void unix_udp_skb_redir_to_connected(struct test_sockmap_listen *skel,
return;
skel->bss->test_ingress = false;
unix_udp_redir_to_connected(family, sock_map, verdict_map, REDIR_EGRESS);
unix_inet_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map,
REDIR_EGRESS);
unix_inet_redir_to_connected(family, SOCK_STREAM, sock_map, verdict_map,
REDIR_EGRESS);
skel->bss->test_ingress = true;
unix_udp_redir_to_connected(family, sock_map, verdict_map, REDIR_INGRESS);
unix_inet_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map,
REDIR_INGRESS);
unix_inet_redir_to_connected(family, SOCK_STREAM, sock_map, verdict_map,
REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
}
......@@ -1990,8 +2002,8 @@ static void test_udp_unix_redir(struct test_sockmap_listen *skel, struct bpf_map
snprintf(s, sizeof(s), "%s %s %s", map_name, family_name, __func__);
if (!test__start_subtest(s))
return;
udp_unix_skb_redir_to_connected(skel, map, family);
unix_udp_skb_redir_to_connected(skel, map, family);
inet_unix_skb_redir_to_connected(skel, map, family);
unix_inet_skb_redir_to_connected(skel, map, family);
}
static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map,
......@@ -2020,11 +2032,13 @@ void test_sockmap_listen(void)
run_tests(skel, skel->maps.sock_map, AF_INET);
run_tests(skel, skel->maps.sock_map, AF_INET6);
test_unix_redir(skel, skel->maps.sock_map, SOCK_DGRAM);
test_unix_redir(skel, skel->maps.sock_map, SOCK_STREAM);
skel->bss->test_sockmap = false;
run_tests(skel, skel->maps.sock_hash, AF_INET);
run_tests(skel, skel->maps.sock_hash, AF_INET6);
test_unix_redir(skel, skel->maps.sock_hash, SOCK_DGRAM);
test_unix_redir(skel, skel->maps.sock_hash, SOCK_STREAM);
test_sockmap_listen__destroy(skel);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment