Commit 89d69c5d authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'sockmap: introduce BPF_SK_SKB_VERDICT and support UDP'

Cong Wang says:

====================

From: Cong Wang <cong.wang@bytedance.com>

We have thousands of services connected to a daemon on every host
via AF_UNIX dgram sockets, after they are moved into VM, we have to
add a proxy to forward these communications from VM to host, because
rewriting thousands of them is not practical. This proxy uses an
AF_UNIX socket connected to services and a UDP socket to connect to
the host. It is inefficient because data is copied between kernel
space and user space twice, and we can not use splice() which only
supports TCP. Therefore, we want to use sockmap to do the splicing
without going to user-space at all (after the initial setup).

Currently sockmap only fully supports TCP, UDP is partially supported
as it is only allowed to add into sockmap. This patchset, as the second
part of the original large patchset, extends sockmap with:
1) cross-protocol support with BPF_SK_SKB_VERDICT; 2) full UDP support.

On the high level, ->read_sock() is required for each protocol to support
sockmap redirection, and in order to do sock proto update, a new ops
->psock_update_sk_prot() is introduced, which is also required. And the
BPF ->recvmsg() is also needed to replace the original ->recvmsg() to
retrieve skmsg. To make life easier, we have to get rid of lock_sock()
in sk_psock_handle_skb(), otherwise we would have to implement
->sendmsg_locked() on top of ->sendmsg(), which is ugly.

Please see each patch for more details.

To see the big picture, the original patchset is available here:
https://github.com/congwang/linux/tree/sockmap
this patchset is also available:
https://github.com/congwang/linux/tree/sockmap2
---
v8: get rid of 'offset' in udp_read_sock()
    add checks for skb_verdict/stream_verdict conflict
    add two cleanup patches for sock_map_link()
    add a new test case

v7: use work_mutex to protect psock->work
    return err in udp_read_sock()
    add patch 6/13
    clean up test case

v6: get rid of sk_psock_zap_ingress()
    add rcu work patch

v5: use INDIRECT_CALL_2() for function pointers
    use ingress_lock to fix a race condition found by Jacub
    rename two helper functions

v4: get rid of lock_sock() in sk_psock_handle_skb()
    get rid of udp_sendmsg_locked()
    remove an empty line
    update cover letter

v3: export tcp/udp_update_proto()
    rename sk->sk_prot->psock_update_sk_prot()
    improve changelogs

v2: separate from the original large patchset
    rebase to the latest bpf-next
    split UDP test case
    move inet_csk_has_ulp() check to tcp_bpf.c
    clean up udp_read_sock()
====================
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents e27bfefb 8d7cb74f
......@@ -3626,6 +3626,7 @@ int skb_splice_bits(struct sk_buff *skb, struct sock *sk, unsigned int offset,
unsigned int flags);
int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset,
int len);
int skb_send_sock(struct sock *sk, struct sk_buff *skb, int offset, int len);
void skb_copy_and_csum_dev(const struct sk_buff *skb, u8 *to);
unsigned int skb_zerocopy_headlen(const struct sk_buff *from);
int skb_zerocopy(struct sk_buff *to, struct sk_buff *from,
......
......@@ -58,6 +58,7 @@ struct sk_psock_progs {
struct bpf_prog *msg_parser;
struct bpf_prog *stream_parser;
struct bpf_prog *stream_verdict;
struct bpf_prog *skb_verdict;
};
enum sk_psock_state_bits {
......@@ -89,6 +90,7 @@ struct sk_psock {
#endif
struct sk_buff_head ingress_skb;
struct list_head ingress_msg;
spinlock_t ingress_lock;
unsigned long state;
struct list_head link;
spinlock_t link_lock;
......@@ -97,13 +99,12 @@ struct sk_psock {
void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk);
void (*saved_data_ready)(struct sock *sk);
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
struct proto *sk_proto;
struct mutex work_mutex;
struct sk_psock_work_state work_state;
struct work_struct work;
union {
struct rcu_head rcu;
struct work_struct gc;
};
struct rcu_work rwork;
};
int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
......@@ -124,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes);
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags);
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
{
......@@ -284,7 +289,45 @@ static inline struct sk_psock *sk_psock(const struct sock *sk)
static inline void sk_psock_queue_msg(struct sk_psock *psock,
struct sk_msg *msg)
{
spin_lock_bh(&psock->ingress_lock);
list_add_tail(&msg->list, &psock->ingress_msg);
spin_unlock_bh(&psock->ingress_lock);
}
static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
{
struct sk_msg *msg;
spin_lock_bh(&psock->ingress_lock);
msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
if (msg)
list_del(&msg->list);
spin_unlock_bh(&psock->ingress_lock);
return msg;
}
static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
{
struct sk_msg *msg;
spin_lock_bh(&psock->ingress_lock);
msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
spin_unlock_bh(&psock->ingress_lock);
return msg;
}
static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
struct sk_msg *msg)
{
struct sk_msg *ret;
spin_lock_bh(&psock->ingress_lock);
if (list_is_last(&msg->list, &psock->ingress_msg))
ret = NULL;
else
ret = list_next_entry(msg, list);
spin_unlock_bh(&psock->ingress_lock);
return ret;
}
static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
......@@ -292,6 +335,13 @@ static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
return psock ? list_empty(&psock->ingress_msg) : true;
}
static inline void kfree_sk_msg(struct sk_msg *msg)
{
if (msg->skb)
consume_skb(msg->skb);
kfree(msg);
}
static inline void sk_psock_report_error(struct sk_psock *psock, int err)
{
struct sock *sk = psock->sk;
......@@ -301,6 +351,7 @@ static inline void sk_psock_report_error(struct sk_psock *psock, int err)
}
struct sk_psock *sk_psock_init(struct sock *sk, int node);
void sk_psock_stop(struct sk_psock *psock, bool wait);
#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
......@@ -349,25 +400,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
}
}
static inline void sk_psock_update_proto(struct sock *sk,
struct sk_psock *psock,
struct proto *ops)
{
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, ops);
}
static inline void sk_psock_restore_proto(struct sock *sk,
struct sk_psock *psock)
{
sk->sk_prot->unhash = psock->saved_unhash;
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
if (psock->psock_update_sk_prot)
psock->psock_update_sk_prot(sk, true);
}
static inline void sk_psock_set_state(struct sk_psock *psock,
......@@ -442,6 +480,7 @@ static inline void psock_progs_drop(struct sk_psock_progs *progs)
psock_set_prog(&progs->msg_parser, NULL);
psock_set_prog(&progs->stream_parser, NULL);
psock_set_prog(&progs->stream_verdict, NULL);
psock_set_prog(&progs->skb_verdict, NULL);
}
int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
......
......@@ -1184,6 +1184,9 @@ struct proto {
void (*unhash)(struct sock *sk);
void (*rehash)(struct sock *sk);
int (*get_port)(struct sock *sk, unsigned short snum);
#ifdef CONFIG_BPF_SYSCALL
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
#endif
/* Keeping track of sockets in use */
#ifdef CONFIG_PROC_FS
......
......@@ -2203,13 +2203,12 @@ struct sk_psock;
#ifdef CONFIG_BPF_SYSCALL
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int tcp_bpf_update_proto(struct sock *sk, bool restore);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#endif /* CONFIG_BPF_SYSCALL */
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
int flags);
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len, int flags);
#endif /* CONFIG_NET_SOCK_MSG */
#if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
......
......@@ -329,6 +329,8 @@ struct sock *__udp6_lib_lookup(struct net *net,
struct sk_buff *skb);
struct sock *udp6_lib_lookup_skb(const struct sk_buff *skb,
__be16 sport, __be16 dport);
int udp_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor);
/* UDP uses skb->dev_scratch to cache as much information as possible and avoid
* possibly multiple cache miss on dequeue()
......@@ -518,6 +520,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
#ifdef CONFIG_BPF_SYSCALL
struct sk_psock;
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
int udp_bpf_update_proto(struct sock *sk, bool restore);
#endif
#endif /* _UDP_H */
......@@ -957,6 +957,7 @@ enum bpf_attach_type {
BPF_XDP_CPUMAP,
BPF_SK_LOOKUP,
BPF_XDP,
BPF_SK_SKB_VERDICT,
__MAX_BPF_ATTACH_TYPE
};
......
......@@ -2948,6 +2948,7 @@ attach_type_to_prog_type(enum bpf_attach_type attach_type)
return BPF_PROG_TYPE_SK_MSG;
case BPF_SK_SKB_STREAM_PARSER:
case BPF_SK_SKB_STREAM_VERDICT:
case BPF_SK_SKB_VERDICT:
return BPF_PROG_TYPE_SK_SKB;
case BPF_LIRC_MODE2:
return BPF_PROG_TYPE_LIRC_MODE2;
......
......@@ -2500,9 +2500,32 @@ int skb_splice_bits(struct sk_buff *skb, struct sock *sk, unsigned int offset,
}
EXPORT_SYMBOL_GPL(skb_splice_bits);
/* Send skb data on a socket. Socket must be locked. */
int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset,
int len)
static int sendmsg_unlocked(struct sock *sk, struct msghdr *msg,
struct kvec *vec, size_t num, size_t size)
{
struct socket *sock = sk->sk_socket;
if (!sock)
return -EINVAL;
return kernel_sendmsg(sock, msg, vec, num, size);
}
static int sendpage_unlocked(struct sock *sk, struct page *page, int offset,
size_t size, int flags)
{
struct socket *sock = sk->sk_socket;
if (!sock)
return -EINVAL;
return kernel_sendpage(sock, page, offset, size, flags);
}
typedef int (*sendmsg_func)(struct sock *sk, struct msghdr *msg,
struct kvec *vec, size_t num, size_t size);
typedef int (*sendpage_func)(struct sock *sk, struct page *page, int offset,
size_t size, int flags);
static int __skb_send_sock(struct sock *sk, struct sk_buff *skb, int offset,
int len, sendmsg_func sendmsg, sendpage_func sendpage)
{
unsigned int orig_len = len;
struct sk_buff *head = skb;
......@@ -2522,7 +2545,8 @@ int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset,
memset(&msg, 0, sizeof(msg));
msg.msg_flags = MSG_DONTWAIT;
ret = kernel_sendmsg_locked(sk, &msg, &kv, 1, slen);
ret = INDIRECT_CALL_2(sendmsg, kernel_sendmsg_locked,
sendmsg_unlocked, sk, &msg, &kv, 1, slen);
if (ret <= 0)
goto error;
......@@ -2553,7 +2577,9 @@ int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset,
slen = min_t(size_t, len, skb_frag_size(frag) - offset);
while (slen) {
ret = kernel_sendpage_locked(sk, skb_frag_page(frag),
ret = INDIRECT_CALL_2(sendpage, kernel_sendpage_locked,
sendpage_unlocked, sk,
skb_frag_page(frag),
skb_frag_off(frag) + offset,
slen, MSG_DONTWAIT);
if (ret <= 0)
......@@ -2587,8 +2613,23 @@ int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset,
error:
return orig_len == len ? ret : orig_len - len;
}
/* Send skb data on a socket. Socket must be locked. */
int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset,
int len)
{
return __skb_send_sock(sk, skb, offset, len, kernel_sendmsg_locked,
kernel_sendpage_locked);
}
EXPORT_SYMBOL_GPL(skb_send_sock_locked);
/* Send skb data on a socket. Socket must be unlocked. */
int skb_send_sock(struct sock *sk, struct sk_buff *skb, int offset, int len)
{
return __skb_send_sock(sk, skb, offset, len, sendmsg_unlocked,
sendpage_unlocked);
}
/**
* skb_store_bits - store bits from kernel buffer to skb
* @skb: destination buffer
......
......@@ -399,6 +399,104 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
}
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = sk_wait_event(sk, &timeo,
!list_empty(&psock->ingress_msg) ||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
EXPORT_SYMBOL_GPL(sk_msg_wait_data);
/* Receive sk_msg from psock->ingress_msg to @msg. */
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags)
{
struct iov_iter *iter = &msg->msg_iter;
int peek = flags & MSG_PEEK;
struct sk_msg *msg_rx;
int i, copied = 0;
msg_rx = sk_psock_peek_msg(psock);
while (copied != len) {
struct scatterlist *sge;
if (unlikely(!msg_rx))
break;
i = msg_rx->sg.start;
do {
struct page *page;
int copy;
sge = sk_msg_elem(msg_rx, i);
copy = sge->length;
page = sg_page(sge);
if (copied + copy > len)
copy = len - copied;
copy = copy_page_to_iter(page, sge->offset, copy, iter);
if (!copy)
return copied ? copied : -EFAULT;
copied += copy;
if (likely(!peek)) {
sge->offset += copy;
sge->length -= copy;
if (!msg_rx->skb)
sk_mem_uncharge(sk, copy);
msg_rx->sg.size -= copy;
if (!sge->length) {
sk_msg_iter_var_next(i);
if (!msg_rx->skb)
put_page(page);
}
} else {
/* Lets not optimize peek case if copy_page_to_iter
* didn't copy the entire length lets just break.
*/
if (copy != sge->length)
return copied;
sk_msg_iter_var_next(i);
}
if (copied == len)
break;
} while (i != msg_rx->sg.end);
if (unlikely(peek)) {
msg_rx = sk_psock_next_msg(psock, msg_rx);
if (!msg_rx)
break;
continue;
}
msg_rx->sg.start = i;
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
msg_rx = sk_psock_dequeue_msg(psock);
kfree_sk_msg(msg_rx);
}
msg_rx = sk_psock_peek_msg(psock);
}
return copied;
}
EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
struct sk_buff *skb)
{
......@@ -410,7 +508,7 @@ static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
if (!sk_rmem_schedule(sk, skb, skb->truesize))
return NULL;
msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_KERNEL);
if (unlikely(!msg))
return NULL;
......@@ -497,7 +595,7 @@ static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
if (!ingress) {
if (!sock_writeable(psock->sk))
return -EAGAIN;
return skb_send_sock_locked(psock->sk, skb, off, len);
return skb_send_sock(psock->sk, skb, off, len);
}
return sk_psock_skb_ingress(psock, skb);
}
......@@ -511,8 +609,7 @@ static void sk_psock_backlog(struct work_struct *work)
u32 len, off;
int ret;
/* Lock sock to avoid losing sk_socket during loop. */
lock_sock(psock->sk);
mutex_lock(&psock->work_mutex);
if (state->skb) {
skb = state->skb;
len = state->len;
......@@ -529,7 +626,7 @@ static void sk_psock_backlog(struct work_struct *work)
skb_bpf_redirect_clear(skb);
do {
ret = -EIO;
if (likely(psock->sk->sk_socket))
if (!sock_flag(psock->sk, SOCK_DEAD))
ret = sk_psock_handle_skb(psock, skb, off,
len, ingress);
if (ret <= 0) {
......@@ -553,7 +650,7 @@ static void sk_psock_backlog(struct work_struct *work)
kfree_skb(skb);
}
end:
release_sock(psock->sk);
mutex_unlock(&psock->work_mutex);
}
struct sk_psock *sk_psock_init(struct sock *sk, int node)
......@@ -563,11 +660,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
write_lock_bh(&sk->sk_callback_lock);
if (inet_csk_has_ulp(sk)) {
psock = ERR_PTR(-EINVAL);
goto out;
}
if (sk->sk_user_data) {
psock = ERR_PTR(-EBUSY);
goto out;
......@@ -591,7 +683,9 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
spin_lock_init(&psock->link_lock);
INIT_WORK(&psock->work, sk_psock_backlog);
mutex_init(&psock->work_mutex);
INIT_LIST_HEAD(&psock->ingress_msg);
spin_lock_init(&psock->ingress_lock);
skb_queue_head_init(&psock->ingress_skb);
sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
......@@ -630,11 +724,11 @@ static void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
}
}
static void sk_psock_zap_ingress(struct sk_psock *psock)
static void __sk_psock_zap_ingress(struct sk_psock *psock)
{
struct sk_buff *skb;
while ((skb = __skb_dequeue(&psock->ingress_skb)) != NULL) {
while ((skb = skb_dequeue(&psock->ingress_skb)) != NULL) {
skb_bpf_redirect_clear(skb);
kfree_skb(skb);
}
......@@ -651,23 +745,35 @@ static void sk_psock_link_destroy(struct sk_psock *psock)
}
}
void sk_psock_stop(struct sk_psock *psock, bool wait)
{
spin_lock_bh(&psock->ingress_lock);
sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
sk_psock_cork_free(psock);
__sk_psock_zap_ingress(psock);
spin_unlock_bh(&psock->ingress_lock);
if (wait)
cancel_work_sync(&psock->work);
}
static void sk_psock_done_strp(struct sk_psock *psock);
static void sk_psock_destroy_deferred(struct work_struct *gc)
static void sk_psock_destroy(struct work_struct *work)
{
struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
struct sk_psock *psock = container_of(to_rcu_work(work),
struct sk_psock, rwork);
/* No sk_callback_lock since already detached. */
sk_psock_done_strp(psock);
cancel_work_sync(&psock->work);
mutex_destroy(&psock->work_mutex);
psock_progs_drop(&psock->progs);
sk_psock_link_destroy(psock);
sk_psock_cork_free(psock);
sk_psock_zap_ingress(psock);
if (psock->sk_redir)
sock_put(psock->sk_redir);
......@@ -675,30 +781,21 @@ static void sk_psock_destroy_deferred(struct work_struct *gc)
kfree(psock);
}
static void sk_psock_destroy(struct rcu_head *rcu)
{
struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
schedule_work(&psock->gc);
}
void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
{
sk_psock_cork_free(psock);
sk_psock_zap_ingress(psock);
sk_psock_stop(psock, false);
write_lock_bh(&sk->sk_callback_lock);
sk_psock_restore_proto(sk, psock);
rcu_assign_sk_user_data(sk, NULL);
if (psock->progs.stream_parser)
sk_psock_stop_strp(sk, psock);
else if (psock->progs.stream_verdict)
else if (psock->progs.stream_verdict || psock->progs.skb_verdict)
sk_psock_stop_verdict(sk, psock);
write_unlock_bh(&sk->sk_callback_lock);
sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
call_rcu(&psock->rcu, sk_psock_destroy);
INIT_RCU_WORK(&psock->rwork, sk_psock_destroy);
queue_rcu_work(system_wq, &psock->rwork);
}
EXPORT_SYMBOL_GPL(sk_psock_drop);
......@@ -767,14 +864,20 @@ static void sk_psock_skb_redirect(struct sk_buff *skb)
* error that caused the pipe to break. We can't send a packet on
* a socket that is in this state so we drop the skb.
*/
if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
!sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
if (!psock_other || sock_flag(sk_other, SOCK_DEAD)) {
kfree_skb(skb);
return;
}
spin_lock_bh(&psock_other->ingress_lock);
if (!sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
spin_unlock_bh(&psock_other->ingress_lock);
kfree_skb(skb);
return;
}
skb_queue_tail(&psock_other->ingress_skb, skb);
schedule_work(&psock_other->work);
spin_unlock_bh(&psock_other->ingress_lock);
}
static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict)
......@@ -842,9 +945,13 @@ static void sk_psock_verdict_apply(struct sk_psock *psock,
err = sk_psock_skb_ingress_self(psock, skb);
}
if (err < 0) {
spin_lock_bh(&psock->ingress_lock);
if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
skb_queue_tail(&psock->ingress_skb, skb);
schedule_work(&psock->work);
}
spin_unlock_bh(&psock->ingress_lock);
}
break;
case __SK_REDIRECT:
sk_psock_skb_redirect(skb);
......@@ -1010,6 +1117,8 @@ static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb,
}
skb_set_owner_r(skb, sk);
prog = READ_ONCE(psock->progs.stream_verdict);
if (!prog)
prog = READ_ONCE(psock->progs.skb_verdict);
if (likely(prog)) {
skb_dst_drop(skb);
skb_bpf_redirect_clear(skb);
......
......@@ -26,6 +26,7 @@ struct bpf_stab {
static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
struct bpf_prog *old, u32 which);
static struct sk_psock_progs *sock_map_progs(struct bpf_map *map);
static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
{
......@@ -155,6 +156,8 @@ static void sock_map_del_link(struct sock *sk,
strp_stop = true;
if (psock->saved_data_ready && stab->progs.stream_verdict)
verdict_stop = true;
if (psock->saved_data_ready && stab->progs.skb_verdict)
verdict_stop = true;
list_del(&link->list);
sk_psock_free_link(link);
}
......@@ -182,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
{
struct proto *prot;
switch (sk->sk_type) {
case SOCK_STREAM:
prot = tcp_bpf_get_proto(sk, psock);
break;
case SOCK_DGRAM:
prot = udp_bpf_get_proto(sk, psock);
break;
default:
if (!sk->sk_prot->psock_update_sk_prot)
return -EINVAL;
}
if (IS_ERR(prot))
return PTR_ERR(prot);
sk_psock_update_proto(sk, psock, prot);
return 0;
psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
return sk->sk_prot->psock_update_sk_prot(sk, false);
}
static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
......@@ -224,13 +211,25 @@ static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
return psock;
}
static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
struct sock *sk)
static bool sock_map_redirect_allowed(const struct sock *sk);
static int sock_map_link(struct bpf_map *map, struct sock *sk)
{
struct bpf_prog *msg_parser, *stream_parser, *stream_verdict;
struct sk_psock_progs *progs = sock_map_progs(map);
struct bpf_prog *stream_verdict = NULL;
struct bpf_prog *stream_parser = NULL;
struct bpf_prog *skb_verdict = NULL;
struct bpf_prog *msg_parser = NULL;
struct sk_psock *psock;
int ret;
/* Only sockets we can redirect into/from in BPF need to hold
* refs to parser/verdict progs and have their sk_data_ready
* and sk_write_space callbacks overridden.
*/
if (!sock_map_redirect_allowed(sk))
goto no_progs;
stream_verdict = READ_ONCE(progs->stream_verdict);
if (stream_verdict) {
stream_verdict = bpf_prog_inc_not_zero(stream_verdict);
......@@ -256,6 +255,16 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
}
}
skb_verdict = READ_ONCE(progs->skb_verdict);
if (skb_verdict) {
skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
if (IS_ERR(skb_verdict)) {
ret = PTR_ERR(skb_verdict);
goto out_put_msg_parser;
}
}
no_progs:
psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) {
ret = PTR_ERR(psock);
......@@ -265,6 +274,9 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
if (psock) {
if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
(stream_parser && READ_ONCE(psock->progs.stream_parser)) ||
(skb_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
(skb_verdict && READ_ONCE(psock->progs.stream_verdict)) ||
(stream_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
(stream_verdict && READ_ONCE(psock->progs.stream_verdict))) {
sk_psock_put(sk, psock);
ret = -EBUSY;
......@@ -296,6 +308,9 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
} else if (!stream_parser && stream_verdict && !psock->saved_data_ready) {
psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
sk_psock_start_verdict(sk,psock);
} else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) {
psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
sk_psock_start_verdict(sk, psock);
}
write_unlock_bh(&sk->sk_callback_lock);
return 0;
......@@ -304,6 +319,9 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
out_drop:
sk_psock_put(sk, psock);
out_progs:
if (skb_verdict)
bpf_prog_put(skb_verdict);
out_put_msg_parser:
if (msg_parser)
bpf_prog_put(msg_parser);
out_put_stream_parser:
......@@ -315,27 +333,6 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
return ret;
}
static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
{
struct sk_psock *psock;
int ret;
psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock))
return PTR_ERR(psock);
if (!psock) {
psock = sk_psock_init(sk, map->numa_node);
if (IS_ERR(psock))
return PTR_ERR(psock);
}
ret = sock_map_init_proto(sk, psock);
if (ret < 0)
sk_psock_put(sk, psock);
return ret;
}
static void sock_map_free(struct bpf_map *map)
{
struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
......@@ -466,8 +463,6 @@ static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
return 0;
}
static bool sock_map_redirect_allowed(const struct sock *sk);
static int sock_map_update_common(struct bpf_map *map, u32 idx,
struct sock *sk, u64 flags)
{
......@@ -487,14 +482,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
if (!link)
return -ENOMEM;
/* Only sockets we can redirect into/from in BPF need to hold
* refs to parser/verdict progs and have their sk_data_ready
* and sk_write_space callbacks overridden.
*/
if (sock_map_redirect_allowed(sk))
ret = sock_map_link(map, &stab->progs, sk);
else
ret = sock_map_link_no_progs(map, sk);
ret = sock_map_link(map, sk);
if (ret < 0)
goto out_free;
......@@ -547,12 +535,15 @@ static bool sk_is_udp(const struct sock *sk)
static bool sock_map_redirect_allowed(const struct sock *sk)
{
return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
if (sk_is_tcp(sk))
return sk->sk_state != TCP_LISTEN;
else
return sk->sk_state == TCP_ESTABLISHED;
}
static bool sock_map_sk_is_suitable(const struct sock *sk)
{
return sk_is_tcp(sk) || sk_is_udp(sk);
return !!sk->sk_prot->psock_update_sk_prot;
}
static bool sock_map_sk_state_allowed(const struct sock *sk)
......@@ -999,14 +990,7 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
if (!link)
return -ENOMEM;
/* Only sockets we can redirect into/from in BPF need to hold
* refs to parser/verdict progs and have their sk_data_ready
* and sk_write_space callbacks overridden.
*/
if (sock_map_redirect_allowed(sk))
ret = sock_map_link(map, &htab->progs, sk);
else
ret = sock_map_link_no_progs(map, sk);
ret = sock_map_link(map, sk);
if (ret < 0)
goto out_free;
......@@ -1466,8 +1450,15 @@ static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
break;
#endif
case BPF_SK_SKB_STREAM_VERDICT:
if (progs->skb_verdict)
return -EBUSY;
pprog = &progs->stream_verdict;
break;
case BPF_SK_SKB_VERDICT:
if (progs->stream_verdict)
return -EBUSY;
pprog = &progs->skb_verdict;
break;
default:
return -EOPNOTSUPP;
}
......@@ -1540,6 +1531,7 @@ void sock_map_close(struct sock *sk, long timeout)
saved_close = psock->saved_close;
sock_map_remove_links(sk, psock);
rcu_read_unlock();
sk_psock_stop(psock, true);
release_sock(sk);
saved_close(sk, timeout);
}
......
......@@ -1070,6 +1070,7 @@ const struct proto_ops inet_dgram_ops = {
.setsockopt = sock_common_setsockopt,
.getsockopt = sock_common_getsockopt,
.sendmsg = inet_sendmsg,
.read_sock = udp_read_sock,
.recvmsg = inet_recvmsg,
.mmap = sock_no_mmap,
.sendpage = inet_sendpage,
......
......@@ -10,86 +10,6 @@
#include <net/inet_common.h>
#include <net/tls.h>
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len, int flags)
{
struct iov_iter *iter = &msg->msg_iter;
int peek = flags & MSG_PEEK;
struct sk_msg *msg_rx;
int i, copied = 0;
msg_rx = list_first_entry_or_null(&psock->ingress_msg,
struct sk_msg, list);
while (copied != len) {
struct scatterlist *sge;
if (unlikely(!msg_rx))
break;
i = msg_rx->sg.start;
do {
struct page *page;
int copy;
sge = sk_msg_elem(msg_rx, i);
copy = sge->length;
page = sg_page(sge);
if (copied + copy > len)
copy = len - copied;
copy = copy_page_to_iter(page, sge->offset, copy, iter);
if (!copy)
return copied ? copied : -EFAULT;
copied += copy;
if (likely(!peek)) {
sge->offset += copy;
sge->length -= copy;
if (!msg_rx->skb)
sk_mem_uncharge(sk, copy);
msg_rx->sg.size -= copy;
if (!sge->length) {
sk_msg_iter_var_next(i);
if (!msg_rx->skb)
put_page(page);
}
} else {
/* Lets not optimize peek case if copy_page_to_iter
* didn't copy the entire length lets just break.
*/
if (copy != sge->length)
return copied;
sk_msg_iter_var_next(i);
}
if (copied == len)
break;
} while (i != msg_rx->sg.end);
if (unlikely(peek)) {
if (msg_rx == list_last_entry(&psock->ingress_msg,
struct sk_msg, list))
break;
msg_rx = list_next_entry(msg_rx, list);
continue;
}
msg_rx->sg.start = i;
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
list_del(&msg_rx->list);
if (msg_rx->skb)
consume_skb(msg_rx->skb);
kfree(msg_rx);
}
msg_rx = list_first_entry_or_null(&psock->ingress_msg,
struct sk_msg, list);
}
return copied;
}
EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
struct sk_msg *msg, u32 apply_bytes, int flags)
{
......@@ -243,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
return !empty;
}
static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
int flags, long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = sk_wait_event(sk, &timeo,
!list_empty(&psock->ingress_msg) ||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len)
{
......@@ -284,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
}
lock_sock(sk);
msg_bytes_ready:
copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
if (!copied) {
int data, err = 0;
long timeo;
timeo = sock_rcvtimeo(sk, nonblock);
data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
if (data) {
if (!sk_psock_queue_empty(psock))
goto msg_bytes_ready;
......@@ -601,20 +499,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
}
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int tcp_bpf_update_proto(struct sock *sk, bool restore)
{
struct sk_psock *psock = sk_psock(sk);
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
if (restore) {
if (inet_csk_has_ulp(sk)) {
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
} else {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
}
return 0;
}
if (inet_csk_has_ulp(sk))
return -EINVAL;
if (sk->sk_family == AF_INET6) {
if (tcp_bpf_assert_proto_ops(psock->sk_proto))
return ERR_PTR(-EINVAL);
return -EINVAL;
tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
}
return &tcp_bpf_prots[family][config];
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
return 0;
}
EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
/* If a child got cloned from a listening socket that had tcp_bpf
* protocol callbacks installed, we need to restore the callbacks to
......
......@@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
.hash = inet_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
......
......@@ -1782,6 +1782,35 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags,
}
EXPORT_SYMBOL(__skb_recv_udp);
int udp_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor)
{
int copied = 0;
while (1) {
struct sk_buff *skb;
int err, used;
skb = skb_recv_udp(sk, 0, 1, &err);
if (!skb)
return err;
used = recv_actor(desc, skb, 0, skb->len);
if (used <= 0) {
if (!copied)
copied = used;
break;
} else if (used <= skb->len) {
copied += used;
}
if (!desc->count)
break;
}
return copied;
}
EXPORT_SYMBOL(udp_read_sock);
/*
* This should be easy, if there is something there we
* return it, otherwise we block.
......@@ -2849,6 +2878,9 @@ struct proto udp_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v4_rehash,
.get_port = udp_v4_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
......
......@@ -4,6 +4,68 @@
#include <linux/skmsg.h>
#include <net/sock.h>
#include <net/udp.h>
#include <net/inet_common.h>
#include "udp_impl.h"
static struct proto *udpv6_prot_saved __read_mostly;
static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int noblock, int flags, int *addr_len)
{
#if IS_ENABLED(CONFIG_IPV6)
if (sk->sk_family == AF_INET6)
return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
addr_len);
#endif
return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
}
static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len)
{
struct sk_psock *psock;
int copied, ret;
if (unlikely(flags & MSG_ERRQUEUE))
return inet_recv_error(sk, msg, len, addr_len);
psock = sk_psock_get(sk);
if (unlikely(!psock))
return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
lock_sock(sk);
if (sk_psock_queue_empty(psock)) {
ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
goto out;
}
msg_bytes_ready:
copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
if (!copied) {
int data, err = 0;
long timeo;
timeo = sock_rcvtimeo(sk, nonblock);
data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
if (data) {
if (!sk_psock_queue_empty(psock))
goto msg_bytes_ready;
ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
goto out;
}
if (err) {
ret = err;
goto out;
}
copied = -EAGAIN;
}
ret = copied;
out:
release_sock(sk);
sk_psock_put(sk, psock);
return ret;
}
enum {
UDP_BPF_IPV4,
......@@ -11,7 +73,6 @@ enum {
UDP_BPF_NUM_PROTS,
};
static struct proto *udpv6_prot_saved __read_mostly;
static DEFINE_SPINLOCK(udpv6_prot_lock);
static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
......@@ -20,6 +81,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
*prot = *base;
prot->unhash = sock_map_unhash;
prot->close = sock_map_close;
prot->recvmsg = udp_bpf_recvmsg;
}
static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
......@@ -41,12 +103,23 @@ static int __init udp_bpf_v4_build_proto(void)
}
core_initcall(udp_bpf_v4_build_proto);
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int udp_bpf_update_proto(struct sock *sk, bool restore)
{
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
struct sk_psock *psock = sk_psock(sk);
if (restore) {
sk->sk_write_space = psock->saved_write_space;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
return 0;
}
if (sk->sk_family == AF_INET6)
udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
return &udp_bpf_prots[family];
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
return 0;
}
EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
......@@ -714,6 +714,7 @@ const struct proto_ops inet6_dgram_ops = {
.getsockopt = sock_common_getsockopt, /* ok */
.sendmsg = inet6_sendmsg, /* retpoline's sake */
.recvmsg = inet6_recvmsg, /* retpoline's sake */
.read_sock = udp_read_sock,
.mmap = sock_no_mmap,
.sendpage = sock_no_sendpage,
.set_peek_off = sk_set_peek_off,
......
......@@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = {
.hash = inet6_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = tcp_bpf_update_proto,
#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
......
......@@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v6_rehash,
.get_port = udp_v6_get_port,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = udp_bpf_update_proto,
#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
......
......@@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk,
skb = tls_wait_data(sk, psock, flags, timeo, &err);
if (!skb) {
if (psock) {
int ret = __tcp_bpf_recvmsg(sk, psock,
msg, len, flags);
int ret = sk_msg_recvmsg(sk, psock, msg, len,
flags);
if (ret > 0) {
decrypted += ret;
......
......@@ -57,6 +57,7 @@ const char * const attach_type_name[__MAX_BPF_ATTACH_TYPE] = {
[BPF_SK_SKB_STREAM_PARSER] = "sk_skb_stream_parser",
[BPF_SK_SKB_STREAM_VERDICT] = "sk_skb_stream_verdict",
[BPF_SK_SKB_VERDICT] = "sk_skb_verdict",
[BPF_SK_MSG_VERDICT] = "sk_msg_verdict",
[BPF_LIRC_MODE2] = "lirc_mode2",
[BPF_FLOW_DISSECTOR] = "flow_dissector",
......
......@@ -76,6 +76,7 @@ enum dump_mode {
static const char * const attach_type_strings[] = {
[BPF_SK_SKB_STREAM_PARSER] = "stream_parser",
[BPF_SK_SKB_STREAM_VERDICT] = "stream_verdict",
[BPF_SK_SKB_VERDICT] = "skb_verdict",
[BPF_SK_MSG_VERDICT] = "msg_verdict",
[BPF_FLOW_DISSECTOR] = "flow_dissector",
[__MAX_BPF_ATTACH_TYPE] = NULL,
......
......@@ -957,6 +957,7 @@ enum bpf_attach_type {
BPF_XDP_CPUMAP,
BPF_SK_LOOKUP,
BPF_XDP,
BPF_SK_SKB_VERDICT,
__MAX_BPF_ATTACH_TYPE
};
......
......@@ -7,6 +7,7 @@
#include "test_skmsg_load_helpers.skel.h"
#include "test_sockmap_update.skel.h"
#include "test_sockmap_invalid_update.skel.h"
#include "test_sockmap_skb_verdict_attach.skel.h"
#include "bpf_iter_sockmap.skel.h"
#define TCP_REPAIR 19 /* TCP sock is under repair right now */
......@@ -281,6 +282,39 @@ static void test_sockmap_copy(enum bpf_map_type map_type)
bpf_iter_sockmap__destroy(skel);
}
static void test_sockmap_skb_verdict_attach(enum bpf_attach_type first,
enum bpf_attach_type second)
{
struct test_sockmap_skb_verdict_attach *skel;
int err, map, verdict;
skel = test_sockmap_skb_verdict_attach__open_and_load();
if (CHECK_FAIL(!skel)) {
perror("test_sockmap_skb_verdict_attach__open_and_load");
return;
}
verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
map = bpf_map__fd(skel->maps.sock_map);
err = bpf_prog_attach(verdict, map, first, 0);
if (CHECK_FAIL(err)) {
perror("bpf_prog_attach");
goto out;
}
err = bpf_prog_attach(verdict, map, second, 0);
assert(err == -1 && errno == EBUSY);
err = bpf_prog_detach2(verdict, map, first);
if (CHECK_FAIL(err)) {
perror("bpf_prog_detach2");
goto out;
}
out:
test_sockmap_skb_verdict_attach__destroy(skel);
}
void test_sockmap_basic(void)
{
if (test__start_subtest("sockmap create_update_free"))
......@@ -301,4 +335,10 @@ void test_sockmap_basic(void)
test_sockmap_copy(BPF_MAP_TYPE_SOCKMAP);
if (test__start_subtest("sockhash copy"))
test_sockmap_copy(BPF_MAP_TYPE_SOCKHASH);
if (test__start_subtest("sockmap skb_verdict attach")) {
test_sockmap_skb_verdict_attach(BPF_SK_SKB_VERDICT,
BPF_SK_SKB_STREAM_VERDICT);
test_sockmap_skb_verdict_attach(BPF_SK_SKB_STREAM_VERDICT,
BPF_SK_SKB_VERDICT);
}
}
......@@ -1603,6 +1603,141 @@ static void test_reuseport(struct test_sockmap_listen *skel,
}
}
static void udp_redir_to_connected(int family, int sotype, int sock_mapfd,
int verd_mapfd, enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
struct sockaddr_storage addr;
int c0, c1, p0, p1;
unsigned int pass;
socklen_t len;
int err, n;
u64 value;
u32 key;
char b;
zero_verdict_count(verd_mapfd);
p0 = socket_loopback(family, sotype | SOCK_NONBLOCK);
if (p0 < 0)
return;
len = sizeof(addr);
err = xgetsockname(p0, sockaddr(&addr), &len);
if (err)
goto close_peer0;
c0 = xsocket(family, sotype | SOCK_NONBLOCK, 0);
if (c0 < 0)
goto close_peer0;
err = xconnect(c0, sockaddr(&addr), len);
if (err)
goto close_cli0;
err = xgetsockname(c0, sockaddr(&addr), &len);
if (err)
goto close_cli0;
err = xconnect(p0, sockaddr(&addr), len);
if (err)
goto close_cli0;
p1 = socket_loopback(family, sotype | SOCK_NONBLOCK);
if (p1 < 0)
goto close_cli0;
err = xgetsockname(p1, sockaddr(&addr), &len);
if (err)
goto close_cli0;
c1 = xsocket(family, sotype | SOCK_NONBLOCK, 0);
if (c1 < 0)
goto close_peer1;
err = xconnect(c1, sockaddr(&addr), len);
if (err)
goto close_cli1;
err = xgetsockname(c1, sockaddr(&addr), &len);
if (err)
goto close_cli1;
err = xconnect(p1, sockaddr(&addr), len);
if (err)
goto close_cli1;
key = 0;
value = p0;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err)
goto close_cli1;
key = 1;
value = p1;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err)
goto close_cli1;
n = write(c1, "a", 1);
if (n < 0)
FAIL_ERRNO("%s: write", log_prefix);
if (n == 0)
FAIL("%s: incomplete write", log_prefix);
if (n < 1)
goto close_cli1;
key = SK_PASS;
err = xbpf_map_lookup_elem(verd_mapfd, &key, &pass);
if (err)
goto close_cli1;
if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1);
if (n < 0)
FAIL_ERRNO("%s: read", log_prefix);
if (n == 0)
FAIL("%s: incomplete read", log_prefix);
close_cli1:
xclose(c1);
close_peer1:
xclose(p1);
close_cli0:
xclose(c0);
close_peer0:
xclose(p0);
}
static void udp_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int family)
{
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
int verdict_map = bpf_map__fd(skel->maps.verdict_map);
int sock_map = bpf_map__fd(inner_map);
int err;
err = xbpf_prog_attach(verdict, sock_map, BPF_SK_SKB_VERDICT, 0);
if (err)
return;
skel->bss->test_ingress = false;
udp_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map,
REDIR_EGRESS);
skel->bss->test_ingress = true;
udp_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map,
REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
}
static void test_udp_redir(struct test_sockmap_listen *skel, struct bpf_map *map,
int family)
{
const char *family_name, *map_name;
char s[MAX_TEST_NAME];
family_name = family_str(family);
map_name = map_type_str(map);
snprintf(s, sizeof(s), "%s %s %s", map_name, family_name, __func__);
if (!test__start_subtest(s))
return;
udp_skb_redir_to_connected(skel, map, family);
}
static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map,
int family)
{
......@@ -1611,6 +1746,7 @@ static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map,
test_redir(skel, map, family, SOCK_STREAM);
test_reuseport(skel, map, family, SOCK_STREAM);
test_reuseport(skel, map, family, SOCK_DGRAM);
test_udp_redir(skel, map, family);
}
void test_sockmap_listen(void)
......
......@@ -29,6 +29,7 @@ struct {
} verdict_map SEC(".maps");
static volatile bool test_sockmap; /* toggled by user-space */
static volatile bool test_ingress; /* toggled by user-space */
SEC("sk_skb/stream_parser")
int prog_stream_parser(struct __sk_buff *skb)
......@@ -55,6 +56,27 @@ int prog_stream_verdict(struct __sk_buff *skb)
return verdict;
}
SEC("sk_skb/skb_verdict")
int prog_skb_verdict(struct __sk_buff *skb)
{
unsigned int *count;
__u32 zero = 0;
int verdict;
if (test_sockmap)
verdict = bpf_sk_redirect_map(skb, &sock_map, zero,
test_ingress ? BPF_F_INGRESS : 0);
else
verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero,
test_ingress ? BPF_F_INGRESS : 0);
count = bpf_map_lookup_elem(&verdict_map, &verdict);
if (count)
(*count)++;
return verdict;
}
SEC("sk_msg")
int prog_msg_verdict(struct sk_msg_md *msg)
{
......
// SPDX-License-Identifier: GPL-2.0
#include "vmlinux.h"
#include <bpf/bpf_helpers.h>
struct {
__uint(type, BPF_MAP_TYPE_SOCKMAP);
__uint(max_entries, 2);
__type(key, __u32);
__type(value, __u64);
} sock_map SEC(".maps");
SEC("sk_skb/skb_verdict")
int prog_skb_verdict(struct __sk_buff *skb)
{
return SK_DROP;
}
char _license[] SEC("license") = "GPL";
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment