Commit 36397a18 authored by Martin KaFai Lau's avatar Martin KaFai Lau

Merge branch 'Add SO_REUSEPORT support for TC bpf_sk_assign'

Lorenz Bauer says:

====================
We want to replace iptables TPROXY with a BPF program at TC ingress.
To make this work in all cases we need to assign a SO_REUSEPORT socket
to an skb, which is currently prohibited. This series adds support for
such sockets to bpf_sk_assing.

I did some refactoring to cut down on the amount of duplicate code. The
key to this is to use INDIRECT_CALL in the reuseport helpers. To show
that this approach is not just beneficial to TC sk_assign I removed
duplicate code for bpf_sk_lookup as well.

Joint work with Daniel Borkmann.
Signed-off-by: default avatarLorenz Bauer <lmb@isovalent.com>
---
Changes in v6:
- Reject unhashed UDP sockets in bpf_sk_assign to avoid ref leak
- Link to v5: https://lore.kernel.org/r/20230613-so-reuseport-v5-0-f6686a0dbce0@isovalent.com

Changes in v5:
- Drop reuse_sk == sk check in inet[6]_steal_stock (Kuniyuki)
- Link to v4: https://lore.kernel.org/r/20230613-so-reuseport-v4-0-4ece76708bba@isovalent.com

Changes in v4:
- WARN_ON_ONCE if reuseport socket is refcounted (Kuniyuki)
- Use inet[6]_ehashfn_t to shorten function declarations (Kuniyuki)
- Shuffle documentation patch around (Kuniyuki)
- Update commit message to explain why IPv6 needs EXPORT_SYMBOL
- Link to v3: https://lore.kernel.org/r/20230613-so-reuseport-v3-0-907b4cbb7b99@isovalent.com

Changes in v3:
- Fix warning re udp_ehashfn and udp6_ehashfn (Simon)
- Return higher scoring connected UDP reuseport sockets (Kuniyuki)
- Fix ipv6 module builds
- Link to v2: https://lore.kernel.org/r/20230613-so-reuseport-v2-0-b7c69a342613@isovalent.com

Changes in v2:
- Correct commit abbrev length (Kuniyuki)
- Reduce duplication (Kuniyuki)
- Add checks on sk_state (Martin)
- Split exporting inet[6]_lookup_reuseport into separate patch (Eric)

---
Daniel Borkmann (1):
      selftests/bpf: Test that SO_REUSEPORT can be used with sk_assign helper
====================
Signed-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
parents 7b2b2012 22408d58
......@@ -48,6 +48,22 @@ struct sock *__inet6_lookup_established(struct net *net,
const u16 hnum, const int dif,
const int sdif);
typedef u32 (inet6_ehashfn_t)(const struct net *net,
const struct in6_addr *laddr, const u16 lport,
const struct in6_addr *faddr, const __be16 fport);
inet6_ehashfn_t inet6_ehashfn;
INDIRECT_CALLABLE_DECLARE(inet6_ehashfn_t udp6_ehashfn);
struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
__be16 sport,
const struct in6_addr *daddr,
unsigned short hnum,
inet6_ehashfn_t *ehashfn);
struct sock *inet6_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
......@@ -57,6 +73,15 @@ struct sock *inet6_lookup_listener(struct net *net,
const unsigned short hnum,
const int dif, const int sdif);
struct sock *inet6_lookup_run_sk_lookup(struct net *net,
int protocol,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
const __be16 sport,
const struct in6_addr *daddr,
const u16 hnum, const int dif,
inet6_ehashfn_t *ehashfn);
static inline struct sock *__inet6_lookup(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
......@@ -78,6 +103,46 @@ static inline struct sock *__inet6_lookup(struct net *net,
daddr, hnum, dif, sdif);
}
static inline
struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
const struct in6_addr *saddr, const __be16 sport,
const struct in6_addr *daddr, const __be16 dport,
bool *refcounted, inet6_ehashfn_t *ehashfn)
{
struct sock *sk, *reuse_sk;
bool prefetched;
sk = skb_steal_sock(skb, refcounted, &prefetched);
if (!sk)
return NULL;
if (!prefetched)
return sk;
if (sk->sk_protocol == IPPROTO_TCP) {
if (sk->sk_state != TCP_LISTEN)
return sk;
} else if (sk->sk_protocol == IPPROTO_UDP) {
if (sk->sk_state != TCP_CLOSE)
return sk;
} else {
return sk;
}
reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, ntohs(dport),
ehashfn);
if (!reuse_sk)
return sk;
/* We've chosen a new reuseport sock which is never refcounted. This
* implies that sk also isn't refcounted.
*/
WARN_ON_ONCE(*refcounted);
return reuse_sk;
}
static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const __be16 sport,
......@@ -85,14 +150,20 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
int iif, int sdif,
bool *refcounted)
{
struct sock *sk = skb_steal_sock(skb, refcounted);
struct net *net = dev_net(skb_dst(skb)->dev);
const struct ipv6hdr *ip6h = ipv6_hdr(skb);
struct sock *sk;
sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
refcounted, inet6_ehashfn);
if (IS_ERR(sk))
return NULL;
if (sk)
return sk;
return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
doff, &ipv6_hdr(skb)->saddr, sport,
&ipv6_hdr(skb)->daddr, ntohs(dport),
return __inet6_lookup(net, hashinfo, skb,
doff, &ip6h->saddr, sport,
&ip6h->daddr, ntohs(dport),
iif, sdif, refcounted);
}
......
......@@ -379,6 +379,27 @@ struct sock *__inet_lookup_established(struct net *net,
const __be32 daddr, const u16 hnum,
const int dif, const int sdif);
typedef u32 (inet_ehashfn_t)(const struct net *net,
const __be32 laddr, const __u16 lport,
const __be32 faddr, const __be16 fport);
inet_ehashfn_t inet_ehashfn;
INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);
struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
struct sk_buff *skb, int doff,
__be32 saddr, __be16 sport,
__be32 daddr, unsigned short hnum,
inet_ehashfn_t *ehashfn);
struct sock *inet_lookup_run_sk_lookup(struct net *net,
int protocol,
struct sk_buff *skb, int doff,
__be32 saddr, __be16 sport,
__be32 daddr, u16 hnum, const int dif,
inet_ehashfn_t *ehashfn);
static inline struct sock *
inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
const __be32 saddr, const __be16 sport,
......@@ -428,6 +449,46 @@ static inline struct sock *inet_lookup(struct net *net,
return sk;
}
static inline
struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
const __be32 saddr, const __be16 sport,
const __be32 daddr, const __be16 dport,
bool *refcounted, inet_ehashfn_t *ehashfn)
{
struct sock *sk, *reuse_sk;
bool prefetched;
sk = skb_steal_sock(skb, refcounted, &prefetched);
if (!sk)
return NULL;
if (!prefetched)
return sk;
if (sk->sk_protocol == IPPROTO_TCP) {
if (sk->sk_state != TCP_LISTEN)
return sk;
} else if (sk->sk_protocol == IPPROTO_UDP) {
if (sk->sk_state != TCP_CLOSE)
return sk;
} else {
return sk;
}
reuse_sk = inet_lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, ntohs(dport),
ehashfn);
if (!reuse_sk)
return sk;
/* We've chosen a new reuseport sock which is never refcounted. This
* implies that sk also isn't refcounted.
*/
WARN_ON_ONCE(*refcounted);
return reuse_sk;
}
static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb,
int doff,
......@@ -436,22 +497,23 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
const int sdif,
bool *refcounted)
{
struct sock *sk = skb_steal_sock(skb, refcounted);
struct net *net = dev_net(skb_dst(skb)->dev);
const struct iphdr *iph = ip_hdr(skb);
struct sock *sk;
sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
refcounted, inet_ehashfn);
if (IS_ERR(sk))
return NULL;
if (sk)
return sk;
return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
return __inet_lookup(net, hashinfo, skb,
doff, iph->saddr, sport,
iph->daddr, dport, inet_iif(skb), sdif,
refcounted);
}
u32 inet6_ehashfn(const struct net *net,
const struct in6_addr *laddr, const u16 lport,
const struct in6_addr *faddr, const __be16 fport);
static inline void sk_daddr_set(struct sock *sk, __be32 addr)
{
sk->sk_daddr = addr; /* alias of inet_daddr */
......
......@@ -2815,20 +2815,23 @@ sk_is_refcounted(struct sock *sk)
* skb_steal_sock - steal a socket from an sk_buff
* @skb: sk_buff to steal the socket from
* @refcounted: is set to true if the socket is reference-counted
* @prefetched: is set to true if the socket was assigned from bpf
*/
static inline struct sock *
skb_steal_sock(struct sk_buff *skb, bool *refcounted)
skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched)
{
if (skb->sk) {
struct sock *sk = skb->sk;
*refcounted = true;
if (skb_sk_is_prefetched(skb))
*prefetched = skb_sk_is_prefetched(skb);
if (*prefetched)
*refcounted = sk_is_refcounted(sk);
skb->destructor = NULL;
skb->sk = NULL;
return sk;
}
*prefetched = false;
*refcounted = false;
return NULL;
}
......
......@@ -4198,9 +4198,6 @@ union bpf_attr {
* **-EOPNOTSUPP** if the operation is not supported, for example
* a call from outside of TC ingress.
*
* **-ESOCKTNOSUPPORT** if the socket type is not supported
* (reuseport).
*
* long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
* Description
* Helper is overloaded depending on BPF program type. This
......
......@@ -7351,8 +7351,8 @@ BPF_CALL_3(bpf_sk_assign, struct sk_buff *, skb, struct sock *, sk, u64, flags)
return -EOPNOTSUPP;
if (unlikely(dev_net(skb->dev) != sock_net(sk)))
return -ENETUNREACH;
if (unlikely(sk_fullsock(sk) && sk->sk_reuseport))
return -ESOCKTNOSUPPORT;
if (sk_unhashed(sk))
return -EOPNOTSUPP;
if (sk_is_refcounted(sk) &&
unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
return -ENOENT;
......
......@@ -28,7 +28,7 @@
#include <net/tcp.h>
#include <net/sock_reuseport.h>
static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
u32 inet_ehashfn(const struct net *net, const __be32 laddr,
const __u16 lport, const __be32 faddr,
const __be16 fport)
{
......@@ -39,6 +39,7 @@ static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
return __inet_ehashfn(laddr, lport, faddr, fport,
inet_ehash_secret + net_hash_mix(net));
}
EXPORT_SYMBOL_GPL(inet_ehashfn);
/* This function handles inet_sock, but also timewait and request sockets
* for IPv4/IPv6.
......@@ -332,20 +333,40 @@ static inline int compute_score(struct sock *sk, struct net *net,
return score;
}
static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk,
INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);
/**
* inet_lookup_reuseport() - execute reuseport logic on AF_INET socket if necessary.
* @net: network namespace.
* @sk: AF_INET socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
* @skb: context for a potential SK_REUSEPORT program.
* @doff: header offset.
* @saddr: source address.
* @sport: source port.
* @daddr: destination address.
* @hnum: destination port in host byte order.
* @ehashfn: hash function used to generate the fallback hash.
*
* Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
* the selected sock or an error.
*/
struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
struct sk_buff *skb, int doff,
__be32 saddr, __be16 sport,
__be32 daddr, unsigned short hnum)
__be32 daddr, unsigned short hnum,
inet_ehashfn_t *ehashfn)
{
struct sock *reuse_sk = NULL;
u32 phash;
if (sk->sk_reuseport) {
phash = inet_ehashfn(net, daddr, hnum, saddr, sport);
phash = INDIRECT_CALL_2(ehashfn, udp_ehashfn, inet_ehashfn,
net, daddr, hnum, saddr, sport);
reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
}
return reuse_sk;
}
EXPORT_SYMBOL_GPL(inet_lookup_reuseport);
/*
* Here are some nice properties to exploit here. The BSD API
......@@ -369,8 +390,8 @@ static struct sock *inet_lhash2_lookup(struct net *net,
sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
score = compute_score(sk, net, hnum, daddr, dif, sdif);
if (score > hiscore) {
result = lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, hnum);
result = inet_lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, hnum, inet_ehashfn);
if (result)
return result;
......@@ -382,24 +403,23 @@ static struct sock *inet_lhash2_lookup(struct net *net,
return result;
}
static inline struct sock *inet_lookup_run_bpf(struct net *net,
struct inet_hashinfo *hashinfo,
struct sock *inet_lookup_run_sk_lookup(struct net *net,
int protocol,
struct sk_buff *skb, int doff,
__be32 saddr, __be16 sport,
__be32 daddr, u16 hnum, const int dif)
__be32 daddr, u16 hnum, const int dif,
inet_ehashfn_t *ehashfn)
{
struct sock *sk, *reuse_sk;
bool no_reuseport;
if (hashinfo != net->ipv4.tcp_death_row.hashinfo)
return NULL; /* only TCP is supported */
no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport,
no_reuseport = bpf_sk_lookup_run_v4(net, protocol, saddr, sport,
daddr, hnum, dif, &sk);
if (no_reuseport || IS_ERR_OR_NULL(sk))
return sk;
reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum);
reuse_sk = inet_lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum,
ehashfn);
if (reuse_sk)
sk = reuse_sk;
return sk;
......@@ -417,9 +437,11 @@ struct sock *__inet_lookup_listener(struct net *net,
unsigned int hash2;
/* Lookup redirect from BPF */
if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
result = inet_lookup_run_bpf(net, hashinfo, skb, doff,
saddr, sport, daddr, hnum, dif);
if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
hashinfo == net->ipv4.tcp_death_row.hashinfo) {
result = inet_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
saddr, sport, daddr, hnum, dif,
inet_ehashfn);
if (result)
goto done;
}
......
......@@ -406,9 +406,9 @@ static int compute_score(struct sock *sk, struct net *net,
return score;
}
static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
const __u16 lport, const __be32 faddr,
const __be16 fport)
INDIRECT_CALLABLE_SCOPE
u32 udp_ehashfn(const struct net *net, const __be32 laddr, const __u16 lport,
const __be32 faddr, const __be16 fport)
{
static u32 udp_ehash_secret __read_mostly;
......@@ -418,22 +418,6 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
udp_ehash_secret + net_hash_mix(net));
}
static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
struct sk_buff *skb,
__be32 saddr, __be16 sport,
__be32 daddr, unsigned short hnum)
{
struct sock *reuse_sk = NULL;
u32 hash;
if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
hash = udp_ehashfn(net, daddr, hnum, saddr, sport);
reuse_sk = reuseport_select_sock(sk, hash, skb,
sizeof(struct udphdr));
}
return reuse_sk;
}
/* called with rcu_read_lock() */
static struct sock *udp4_lib_lookup2(struct net *net,
__be32 saddr, __be16 sport,
......@@ -451,40 +435,34 @@ static struct sock *udp4_lib_lookup2(struct net *net,
score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, sdif);
if (score > badness) {
result = lookup_reuseport(net, sk, skb,
saddr, sport, daddr, hnum);
/* Fall back to scoring if group has connections */
if (result && !reuseport_has_conns(sk))
return result;
result = result ? : sk;
badness = score;
if (sk->sk_state == TCP_ESTABLISHED) {
result = sk;
continue;
}
result = inet_lookup_reuseport(net, sk, skb, sizeof(struct udphdr),
saddr, sport, daddr, hnum, udp_ehashfn);
if (!result) {
result = sk;
continue;
}
return result;
}
static struct sock *udp4_lookup_run_bpf(struct net *net,
struct udp_table *udptable,
struct sk_buff *skb,
__be32 saddr, __be16 sport,
__be32 daddr, u16 hnum, const int dif)
{
struct sock *sk, *reuse_sk;
bool no_reuseport;
/* Fall back to scoring if group has connections */
if (!reuseport_has_conns(sk))
return result;
if (udptable != net->ipv4.udp_table)
return NULL; /* only UDP is supported */
/* Reuseport logic returned an error, keep original score. */
if (IS_ERR(result))
continue;
no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_UDP, saddr, sport,
daddr, hnum, dif, &sk);
if (no_reuseport || IS_ERR_OR_NULL(sk))
return sk;
badness = compute_score(result, net, saddr, sport,
daddr, hnum, dif, sdif);
reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
if (reuse_sk)
sk = reuse_sk;
return sk;
}
}
return result;
}
/* UDP is nearly always wildcards out the wazoo, it makes no sense to try
......@@ -511,9 +489,11 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
goto done;
/* Lookup redirect from BPF */
if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
sk = udp4_lookup_run_bpf(net, udptable, skb,
saddr, sport, daddr, hnum, dif);
if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
udptable == net->ipv4.udp_table) {
sk = inet_lookup_run_sk_lookup(net, IPPROTO_UDP, skb, sizeof(struct udphdr),
saddr, sport, daddr, hnum, dif,
udp_ehashfn);
if (sk) {
result = sk;
goto done;
......@@ -2408,7 +2388,11 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
if (udp4_csum_init(skb, uh, proto))
goto csum_error;
sk = skb_steal_sock(skb, &refcounted);
sk = inet_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
&refcounted, udp_ehashfn);
if (IS_ERR(sk))
goto no_sk;
if (sk) {
struct dst_entry *dst = skb_dst(skb);
int ret;
......@@ -2429,7 +2413,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
if (sk)
return udp_unicast_rcv_skb(sk, skb, uh);
no_sk:
if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
goto drop;
nf_reset_ct(skb);
......
......@@ -39,6 +39,7 @@ u32 inet6_ehashfn(const struct net *net,
return __inet6_ehashfn(lhash, lport, fhash, fport,
inet6_ehash_secret + net_hash_mix(net));
}
EXPORT_SYMBOL_GPL(inet6_ehashfn);
/*
* Sockets in TCP_CLOSE state are _always_ taken out of the hash, so
......@@ -111,22 +112,42 @@ static inline int compute_score(struct sock *sk, struct net *net,
return score;
}
static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk,
INDIRECT_CALLABLE_DECLARE(inet6_ehashfn_t udp6_ehashfn);
/**
* inet6_lookup_reuseport() - execute reuseport logic on AF_INET6 socket if necessary.
* @net: network namespace.
* @sk: AF_INET6 socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
* @skb: context for a potential SK_REUSEPORT program.
* @doff: header offset.
* @saddr: source address.
* @sport: source port.
* @daddr: destination address.
* @hnum: destination port in host byte order.
* @ehashfn: hash function used to generate the fallback hash.
*
* Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
* the selected sock or an error.
*/
struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
__be16 sport,
const struct in6_addr *daddr,
unsigned short hnum)
unsigned short hnum,
inet6_ehashfn_t *ehashfn)
{
struct sock *reuse_sk = NULL;
u32 phash;
if (sk->sk_reuseport) {
phash = inet6_ehashfn(net, daddr, hnum, saddr, sport);
phash = INDIRECT_CALL_INET(ehashfn, udp6_ehashfn, inet6_ehashfn,
net, daddr, hnum, saddr, sport);
reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
}
return reuse_sk;
}
EXPORT_SYMBOL_GPL(inet6_lookup_reuseport);
/* called with rcu_read_lock() */
static struct sock *inet6_lhash2_lookup(struct net *net,
......@@ -143,8 +164,8 @@ static struct sock *inet6_lhash2_lookup(struct net *net,
sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
score = compute_score(sk, net, hnum, daddr, dif, sdif);
if (score > hiscore) {
result = lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, hnum);
result = inet6_lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, hnum, inet6_ehashfn);
if (result)
return result;
......@@ -156,30 +177,30 @@ static struct sock *inet6_lhash2_lookup(struct net *net,
return result;
}
static inline struct sock *inet6_lookup_run_bpf(struct net *net,
struct inet_hashinfo *hashinfo,
struct sock *inet6_lookup_run_sk_lookup(struct net *net,
int protocol,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr,
const __be16 sport,
const struct in6_addr *daddr,
const u16 hnum, const int dif)
const u16 hnum, const int dif,
inet6_ehashfn_t *ehashfn)
{
struct sock *sk, *reuse_sk;
bool no_reuseport;
if (hashinfo != net->ipv4.tcp_death_row.hashinfo)
return NULL; /* only TCP is supported */
no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_TCP, saddr, sport,
no_reuseport = bpf_sk_lookup_run_v6(net, protocol, saddr, sport,
daddr, hnum, dif, &sk);
if (no_reuseport || IS_ERR_OR_NULL(sk))
return sk;
reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum);
reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
saddr, sport, daddr, hnum, ehashfn);
if (reuse_sk)
sk = reuse_sk;
return sk;
}
EXPORT_SYMBOL_GPL(inet6_lookup_run_sk_lookup);
struct sock *inet6_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo,
......@@ -193,9 +214,11 @@ struct sock *inet6_lookup_listener(struct net *net,
unsigned int hash2;
/* Lookup redirect from BPF */
if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
result = inet6_lookup_run_bpf(net, hashinfo, skb, doff,
saddr, sport, daddr, hnum, dif);
if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
hashinfo == net->ipv4.tcp_death_row.hashinfo) {
result = inet6_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
saddr, sport, daddr, hnum, dif,
inet6_ehashfn);
if (result)
goto done;
}
......
......@@ -71,7 +71,8 @@ int udpv6_init_sock(struct sock *sk)
return 0;
}
static u32 udp6_ehashfn(const struct net *net,
INDIRECT_CALLABLE_SCOPE
u32 udp6_ehashfn(const struct net *net,
const struct in6_addr *laddr,
const u16 lport,
const struct in6_addr *faddr,
......@@ -160,24 +161,6 @@ static int compute_score(struct sock *sk, struct net *net,
return score;
}
static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
struct sk_buff *skb,
const struct in6_addr *saddr,
__be16 sport,
const struct in6_addr *daddr,
unsigned int hnum)
{
struct sock *reuse_sk = NULL;
u32 hash;
if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
hash = udp6_ehashfn(net, daddr, hnum, saddr, sport);
reuse_sk = reuseport_select_sock(sk, hash, skb,
sizeof(struct udphdr));
}
return reuse_sk;
}
/* called with rcu_read_lock() */
static struct sock *udp6_lib_lookup2(struct net *net,
const struct in6_addr *saddr, __be16 sport,
......@@ -194,42 +177,33 @@ static struct sock *udp6_lib_lookup2(struct net *net,
score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, sdif);
if (score > badness) {
result = lookup_reuseport(net, sk, skb,
saddr, sport, daddr, hnum);
/* Fall back to scoring if group has connections */
if (result && !reuseport_has_conns(sk))
return result;
result = result ? : sk;
badness = score;
if (sk->sk_state == TCP_ESTABLISHED) {
result = sk;
continue;
}
}
return result;
}
static inline struct sock *udp6_lookup_run_bpf(struct net *net,
struct udp_table *udptable,
struct sk_buff *skb,
const struct in6_addr *saddr,
__be16 sport,
const struct in6_addr *daddr,
u16 hnum, const int dif)
{
struct sock *sk, *reuse_sk;
bool no_reuseport;
result = inet6_lookup_reuseport(net, sk, skb, sizeof(struct udphdr),
saddr, sport, daddr, hnum, udp6_ehashfn);
if (!result) {
result = sk;
continue;
}
if (udptable != net->ipv4.udp_table)
return NULL; /* only UDP is supported */
/* Fall back to scoring if group has connections */
if (!reuseport_has_conns(sk))
return result;
no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_UDP, saddr, sport,
daddr, hnum, dif, &sk);
if (no_reuseport || IS_ERR_OR_NULL(sk))
return sk;
/* Reuseport logic returned an error, keep original score. */
if (IS_ERR(result))
continue;
reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
if (reuse_sk)
sk = reuse_sk;
return sk;
badness = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, sdif);
}
}
return result;
}
/* rcu_read_lock() must be held */
......@@ -256,9 +230,11 @@ struct sock *__udp6_lib_lookup(struct net *net,
goto done;
/* Lookup redirect from BPF */
if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
sk = udp6_lookup_run_bpf(net, udptable, skb,
saddr, sport, daddr, hnum, dif);
if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
udptable == net->ipv4.udp_table) {
sk = inet6_lookup_run_sk_lookup(net, IPPROTO_UDP, skb, sizeof(struct udphdr),
saddr, sport, daddr, hnum, dif,
udp6_ehashfn);
if (sk) {
result = sk;
goto done;
......@@ -988,7 +964,11 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
goto csum_error;
/* Check if the socket is already available, e.g. due to early demux */
sk = skb_steal_sock(skb, &refcounted);
sk = inet6_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
&refcounted, udp6_ehashfn);
if (IS_ERR(sk))
goto no_sk;
if (sk) {
struct dst_entry *dst = skb_dst(skb);
int ret;
......@@ -1022,7 +1002,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
goto report_csum_error;
return udp6_unicast_rcv_skb(sk, skb, uh);
}
no_sk:
reason = SKB_DROP_REASON_NO_SOCKET;
if (!uh->check)
......
......@@ -4198,9 +4198,6 @@ union bpf_attr {
* **-EOPNOTSUPP** if the operation is not supported, for example
* a call from outside of TC ingress.
*
* **-ESOCKTNOSUPPORT** if the socket type is not supported
* (reuseport).
*
* long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
* Description
* Helper is overloaded depending on BPF program type. This
......
......@@ -423,6 +423,9 @@ struct nstoken *open_netns(const char *name)
void close_netns(struct nstoken *token)
{
if (!token)
return;
ASSERT_OK(setns(token->orig_netns_fd, CLONE_NEWNET), "setns");
close(token->orig_netns_fd);
free(token);
......
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2023 Isovalent */
#include <uapi/linux/if_link.h>
#include <test_progs.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#include "network_helpers.h"
#include "test_assign_reuse.skel.h"
#define NS_TEST "assign_reuse"
#define LOOPBACK 1
#define PORT 4443
static int attach_reuseport(int sock_fd, int prog_fd)
{
return setsockopt(sock_fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF,
&prog_fd, sizeof(prog_fd));
}
static __u64 cookie(int fd)
{
__u64 cookie = 0;
socklen_t cookie_len = sizeof(cookie);
int ret;
ret = getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie, &cookie_len);
ASSERT_OK(ret, "cookie");
ASSERT_GT(cookie, 0, "cookie_invalid");
return cookie;
}
static int echo_test_udp(int fd_sv)
{
struct sockaddr_storage addr = {};
socklen_t len = sizeof(addr);
char buff[1] = {};
int fd_cl = -1, ret;
fd_cl = connect_to_fd(fd_sv, 100);
ASSERT_GT(fd_cl, 0, "create_client");
ASSERT_EQ(getsockname(fd_cl, (void *)&addr, &len), 0, "getsockname");
ASSERT_EQ(send(fd_cl, buff, sizeof(buff), 0), 1, "send_client");
ret = recv(fd_sv, buff, sizeof(buff), 0);
if (ret < 0) {
close(fd_cl);
return errno;
}
ASSERT_EQ(ret, 1, "recv_server");
ASSERT_EQ(sendto(fd_sv, buff, sizeof(buff), 0, (void *)&addr, len), 1, "send_server");
ASSERT_EQ(recv(fd_cl, buff, sizeof(buff), 0), 1, "recv_client");
close(fd_cl);
return 0;
}
static int echo_test_tcp(int fd_sv)
{
char buff[1] = {};
int fd_cl = -1, fd_sv_cl = -1;
fd_cl = connect_to_fd(fd_sv, 100);
if (fd_cl < 0)
return errno;
fd_sv_cl = accept(fd_sv, NULL, NULL);
ASSERT_GE(fd_sv_cl, 0, "accept_fd");
ASSERT_EQ(send(fd_cl, buff, sizeof(buff), 0), 1, "send_client");
ASSERT_EQ(recv(fd_sv_cl, buff, sizeof(buff), 0), 1, "recv_server");
ASSERT_EQ(send(fd_sv_cl, buff, sizeof(buff), 0), 1, "send_server");
ASSERT_EQ(recv(fd_cl, buff, sizeof(buff), 0), 1, "recv_client");
close(fd_sv_cl);
close(fd_cl);
return 0;
}
void run_assign_reuse(int family, int sotype, const char *ip, __u16 port)
{
DECLARE_LIBBPF_OPTS(bpf_tc_hook, tc_hook,
.ifindex = LOOPBACK,
.attach_point = BPF_TC_INGRESS,
);
DECLARE_LIBBPF_OPTS(bpf_tc_opts, tc_opts,
.handle = 1,
.priority = 1,
);
bool hook_created = false, tc_attached = false;
int ret, fd_tc, fd_accept, fd_drop, fd_map;
int *fd_sv = NULL;
__u64 fd_val;
struct test_assign_reuse *skel;
const int zero = 0;
skel = test_assign_reuse__open();
if (!ASSERT_OK_PTR(skel, "skel_open"))
goto cleanup;
skel->rodata->dest_port = port;
ret = test_assign_reuse__load(skel);
if (!ASSERT_OK(ret, "skel_load"))
goto cleanup;
ASSERT_EQ(skel->bss->sk_cookie_seen, 0, "cookie_init");
fd_tc = bpf_program__fd(skel->progs.tc_main);
fd_accept = bpf_program__fd(skel->progs.reuse_accept);
fd_drop = bpf_program__fd(skel->progs.reuse_drop);
fd_map = bpf_map__fd(skel->maps.sk_map);
fd_sv = start_reuseport_server(family, sotype, ip, port, 100, 1);
if (!ASSERT_NEQ(fd_sv, NULL, "start_reuseport_server"))
goto cleanup;
ret = attach_reuseport(*fd_sv, fd_drop);
if (!ASSERT_OK(ret, "attach_reuseport"))
goto cleanup;
fd_val = *fd_sv;
ret = bpf_map_update_elem(fd_map, &zero, &fd_val, BPF_NOEXIST);
if (!ASSERT_OK(ret, "bpf_sk_map"))
goto cleanup;
ret = bpf_tc_hook_create(&tc_hook);
if (ret == 0)
hook_created = true;
ret = ret == -EEXIST ? 0 : ret;
if (!ASSERT_OK(ret, "bpf_tc_hook_create"))
goto cleanup;
tc_opts.prog_fd = fd_tc;
ret = bpf_tc_attach(&tc_hook, &tc_opts);
if (!ASSERT_OK(ret, "bpf_tc_attach"))
goto cleanup;
tc_attached = true;
if (sotype == SOCK_STREAM)
ASSERT_EQ(echo_test_tcp(*fd_sv), ECONNREFUSED, "drop_tcp");
else
ASSERT_EQ(echo_test_udp(*fd_sv), EAGAIN, "drop_udp");
ASSERT_EQ(skel->bss->reuseport_executed, 1, "program executed once");
skel->bss->sk_cookie_seen = 0;
skel->bss->reuseport_executed = 0;
ASSERT_OK(attach_reuseport(*fd_sv, fd_accept), "attach_reuseport(accept)");
if (sotype == SOCK_STREAM)
ASSERT_EQ(echo_test_tcp(*fd_sv), 0, "echo_tcp");
else
ASSERT_EQ(echo_test_udp(*fd_sv), 0, "echo_udp");
ASSERT_EQ(skel->bss->sk_cookie_seen, cookie(*fd_sv),
"cookie_mismatch");
ASSERT_EQ(skel->bss->reuseport_executed, 1, "program executed once");
cleanup:
if (tc_attached) {
tc_opts.flags = tc_opts.prog_fd = tc_opts.prog_id = 0;
ret = bpf_tc_detach(&tc_hook, &tc_opts);
ASSERT_OK(ret, "bpf_tc_detach");
}
if (hook_created) {
tc_hook.attach_point = BPF_TC_INGRESS | BPF_TC_EGRESS;
bpf_tc_hook_destroy(&tc_hook);
}
test_assign_reuse__destroy(skel);
free_fds(fd_sv, 1);
}
void test_assign_reuse(void)
{
struct nstoken *tok = NULL;
SYS(out, "ip netns add %s", NS_TEST);
SYS(cleanup, "ip -net %s link set dev lo up", NS_TEST);
tok = open_netns(NS_TEST);
if (!ASSERT_OK_PTR(tok, "netns token"))
return;
if (test__start_subtest("tcpv4"))
run_assign_reuse(AF_INET, SOCK_STREAM, "127.0.0.1", PORT);
if (test__start_subtest("tcpv6"))
run_assign_reuse(AF_INET6, SOCK_STREAM, "::1", PORT);
if (test__start_subtest("udpv4"))
run_assign_reuse(AF_INET, SOCK_DGRAM, "127.0.0.1", PORT);
if (test__start_subtest("udpv6"))
run_assign_reuse(AF_INET6, SOCK_DGRAM, "::1", PORT);
cleanup:
close_netns(tok);
SYS_NOFAIL("ip netns delete %s", NS_TEST);
out:
return;
}
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2023 Isovalent */
#include <stdbool.h>
#include <linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <linux/tcp.h>
#include <linux/udp.h>
#include <bpf/bpf_endian.h>
#include <bpf/bpf_helpers.h>
#include <linux/pkt_cls.h>
char LICENSE[] SEC("license") = "GPL";
__u64 sk_cookie_seen;
__u64 reuseport_executed;
union {
struct tcphdr tcp;
struct udphdr udp;
} headers;
const volatile __u16 dest_port;
struct {
__uint(type, BPF_MAP_TYPE_SOCKMAP);
__uint(max_entries, 1);
__type(key, __u32);
__type(value, __u64);
} sk_map SEC(".maps");
SEC("sk_reuseport")
int reuse_accept(struct sk_reuseport_md *ctx)
{
reuseport_executed++;
if (ctx->ip_protocol == IPPROTO_TCP) {
if (ctx->data + sizeof(headers.tcp) > ctx->data_end)
return SK_DROP;
if (__builtin_memcmp(&headers.tcp, ctx->data, sizeof(headers.tcp)) != 0)
return SK_DROP;
} else if (ctx->ip_protocol == IPPROTO_UDP) {
if (ctx->data + sizeof(headers.udp) > ctx->data_end)
return SK_DROP;
if (__builtin_memcmp(&headers.udp, ctx->data, sizeof(headers.udp)) != 0)
return SK_DROP;
} else {
return SK_DROP;
}
sk_cookie_seen = bpf_get_socket_cookie(ctx->sk);
return SK_PASS;
}
SEC("sk_reuseport")
int reuse_drop(struct sk_reuseport_md *ctx)
{
reuseport_executed++;
sk_cookie_seen = 0;
return SK_DROP;
}
static int
assign_sk(struct __sk_buff *skb)
{
int zero = 0, ret = 0;
struct bpf_sock *sk;
sk = bpf_map_lookup_elem(&sk_map, &zero);
if (!sk)
return TC_ACT_SHOT;
ret = bpf_sk_assign(skb, sk, 0);
bpf_sk_release(sk);
return ret ? TC_ACT_SHOT : TC_ACT_OK;
}
static bool
maybe_assign_tcp(struct __sk_buff *skb, struct tcphdr *th)
{
if (th + 1 > (void *)(long)(skb->data_end))
return TC_ACT_SHOT;
if (!th->syn || th->ack || th->dest != bpf_htons(dest_port))
return TC_ACT_OK;
__builtin_memcpy(&headers.tcp, th, sizeof(headers.tcp));
return assign_sk(skb);
}
static bool
maybe_assign_udp(struct __sk_buff *skb, struct udphdr *uh)
{
if (uh + 1 > (void *)(long)(skb->data_end))
return TC_ACT_SHOT;
if (uh->dest != bpf_htons(dest_port))
return TC_ACT_OK;
__builtin_memcpy(&headers.udp, uh, sizeof(headers.udp));
return assign_sk(skb);
}
SEC("tc")
int tc_main(struct __sk_buff *skb)
{
void *data_end = (void *)(long)skb->data_end;
void *data = (void *)(long)skb->data;
struct ethhdr *eth;
eth = (struct ethhdr *)(data);
if (eth + 1 > data_end)
return TC_ACT_SHOT;
if (eth->h_proto == bpf_htons(ETH_P_IP)) {
struct iphdr *iph = (struct iphdr *)(data + sizeof(*eth));
if (iph + 1 > data_end)
return TC_ACT_SHOT;
if (iph->protocol == IPPROTO_TCP)
return maybe_assign_tcp(skb, (struct tcphdr *)(iph + 1));
else if (iph->protocol == IPPROTO_UDP)
return maybe_assign_udp(skb, (struct udphdr *)(iph + 1));
else
return TC_ACT_SHOT;
} else {
struct ipv6hdr *ip6h = (struct ipv6hdr *)(data + sizeof(*eth));
if (ip6h + 1 > data_end)
return TC_ACT_SHOT;
if (ip6h->nexthdr == IPPROTO_TCP)
return maybe_assign_tcp(skb, (struct tcphdr *)(ip6h + 1));
else if (ip6h->nexthdr == IPPROTO_UDP)
return maybe_assign_udp(skb, (struct udphdr *)(ip6h + 1));
else
return TC_ACT_SHOT;
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment