Commit 637bc8bb authored by Josef Bacik's avatar Josef Bacik Committed by David S. Miller

inet: reset tb->fastreuseport when adding a reuseport sk

If we have non reuseport sockets on a tb we will set tb->fastreuseport to 0 and
never set it again.  Which means that in the future if we end up adding a bunch
of reuseport sk's to that tb we'll have to do the expensive scan every time.
Instead add the ipv4/ipv6 saddr fields to the bind bucket, as well as the family
so we know what comparison to make, and the ipv6 only setting so we can make
sure to compare with new sockets appropriately.  Once one sk has made it onto
the list we know that there are no potential bind conflicts on the owners list
that match that sk's rcv_addr.  So copy the sk's information into our bind
bucket and set tb->fastruseport to FASTREUSESOCK_STRICT so we know we have to do
an extra check for subsequent reuseport sockets and skip the expensive bind
conflict check.
Signed-off-by: default avatarJosef Bacik <jbacik@fb.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 289141b7
...@@ -74,12 +74,21 @@ struct inet_ehash_bucket { ...@@ -74,12 +74,21 @@ struct inet_ehash_bucket {
* users logged onto your box, isn't it nice to know that new data * users logged onto your box, isn't it nice to know that new data
* ports are created in O(1) time? I thought so. ;-) -DaveM * ports are created in O(1) time? I thought so. ;-) -DaveM
*/ */
#define FASTREUSEPORT_ANY 1
#define FASTREUSEPORT_STRICT 2
struct inet_bind_bucket { struct inet_bind_bucket {
possible_net_t ib_net; possible_net_t ib_net;
unsigned short port; unsigned short port;
signed char fastreuse; signed char fastreuse;
signed char fastreuseport; signed char fastreuseport;
kuid_t fastuid; kuid_t fastuid;
#if IS_ENABLED(CONFIG_IPV6)
struct in6_addr fast_v6_rcv_saddr;
#endif
__be32 fast_rcv_saddr;
unsigned short fast_sk_family;
bool fast_ipv6_only;
struct hlist_node node; struct hlist_node node;
struct hlist_head owners; struct hlist_head owners;
}; };
......
...@@ -38,20 +38,21 @@ EXPORT_SYMBOL(inet_csk_timer_bug_msg); ...@@ -38,20 +38,21 @@ EXPORT_SYMBOL(inet_csk_timer_bug_msg);
* IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY, * IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
* and 0.0.0.0 equals to 0.0.0.0 only * and 0.0.0.0 equals to 0.0.0.0 only
*/ */
static int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, static int ipv6_rcv_saddr_equal(const struct in6_addr *sk1_rcv_saddr6,
const struct in6_addr *sk2_rcv_saddr6,
__be32 sk1_rcv_saddr, __be32 sk2_rcv_saddr,
bool sk1_ipv6only, bool sk2_ipv6only,
bool match_wildcard) bool match_wildcard)
{ {
const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2); int addr_type = ipv6_addr_type(sk1_rcv_saddr6);
int sk2_ipv6only = inet_v6_ipv6only(sk2);
int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED; int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
/* if both are mapped, treat as IPv4 */ /* if both are mapped, treat as IPv4 */
if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) { if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
if (!sk2_ipv6only) { if (!sk2_ipv6only) {
if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr) if (sk1_rcv_saddr == sk2_rcv_saddr)
return 1; return 1;
if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr) if (!sk1_rcv_saddr || !sk2_rcv_saddr)
return match_wildcard; return match_wildcard;
} }
return 0; return 0;
...@@ -65,11 +66,11 @@ static int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, ...@@ -65,11 +66,11 @@ static int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
return 1; return 1;
if (addr_type == IPV6_ADDR_ANY && match_wildcard && if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
!(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED)) !(sk1_ipv6only && addr_type2 == IPV6_ADDR_MAPPED))
return 1; return 1;
if (sk2_rcv_saddr6 && if (sk2_rcv_saddr6 &&
ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6)) ipv6_addr_equal(sk1_rcv_saddr6, sk2_rcv_saddr6))
return 1; return 1;
return 0; return 0;
...@@ -80,13 +81,13 @@ static int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, ...@@ -80,13 +81,13 @@ static int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
* match_wildcard == false: addresses must be exactly the same, i.e. * match_wildcard == false: addresses must be exactly the same, i.e.
* 0.0.0.0 only equals to 0.0.0.0 * 0.0.0.0 only equals to 0.0.0.0
*/ */
static int ipv4_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, static int ipv4_rcv_saddr_equal(__be32 sk1_rcv_saddr, __be32 sk2_rcv_saddr,
bool match_wildcard) bool sk2_ipv6only, bool match_wildcard)
{ {
if (!ipv6_only_sock(sk2)) { if (!sk2_ipv6only) {
if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr) if (sk1_rcv_saddr == sk2_rcv_saddr)
return 1; return 1;
if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr) if (!sk1_rcv_saddr || !sk2_rcv_saddr)
return match_wildcard; return match_wildcard;
} }
return 0; return 0;
...@@ -97,9 +98,16 @@ int inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, ...@@ -97,9 +98,16 @@ int inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
{ {
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
if (sk->sk_family == AF_INET6) if (sk->sk_family == AF_INET6)
return ipv6_rcv_saddr_equal(sk, sk2, match_wildcard); return ipv6_rcv_saddr_equal(&sk->sk_v6_rcv_saddr,
&sk2->sk_v6_rcv_saddr,
sk->sk_rcv_saddr,
sk2->sk_rcv_saddr,
ipv6_only_sock(sk),
ipv6_only_sock(sk2),
match_wildcard);
#endif #endif
return ipv4_rcv_saddr_equal(sk, sk2, match_wildcard); return ipv4_rcv_saddr_equal(sk->sk_rcv_saddr, sk2->sk_rcv_saddr,
ipv6_only_sock(sk2), match_wildcard);
} }
EXPORT_SYMBOL(inet_rcv_saddr_equal); EXPORT_SYMBOL(inet_rcv_saddr_equal);
...@@ -234,6 +242,39 @@ inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int * ...@@ -234,6 +242,39 @@ inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *
return head; return head;
} }
static inline int sk_reuseport_match(struct inet_bind_bucket *tb,
struct sock *sk)
{
kuid_t uid = sock_i_uid(sk);
if (tb->fastreuseport <= 0)
return 0;
if (!sk->sk_reuseport)
return 0;
if (rcu_access_pointer(sk->sk_reuseport_cb))
return 0;
if (!uid_eq(tb->fastuid, uid))
return 0;
/* We only need to check the rcv_saddr if this tb was once marked
* without fastreuseport and then was reset, as we can only know that
* the fast_*rcv_saddr doesn't have any conflicts with the socks on the
* owners list.
*/
if (tb->fastreuseport == FASTREUSEPORT_ANY)
return 1;
#if IS_ENABLED(CONFIG_IPV6)
if (tb->fast_sk_family == AF_INET6)
return ipv6_rcv_saddr_equal(&tb->fast_v6_rcv_saddr,
&sk->sk_v6_rcv_saddr,
tb->fast_rcv_saddr,
sk->sk_rcv_saddr,
tb->fast_ipv6_only,
ipv6_only_sock(sk), true);
#endif
return ipv4_rcv_saddr_equal(tb->fast_rcv_saddr, sk->sk_rcv_saddr,
ipv6_only_sock(sk), true);
}
/* Obtain a reference to a local port for the given sock, /* Obtain a reference to a local port for the given sock,
* if snum is zero it means select any available local port. * if snum is zero it means select any available local port.
* We try to allocate an odd port (and leave even ports for connect()) * We try to allocate an odd port (and leave even ports for connect())
...@@ -273,9 +314,7 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) ...@@ -273,9 +314,7 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
goto success; goto success;
if ((tb->fastreuse > 0 && reuse) || if ((tb->fastreuse > 0 && reuse) ||
(tb->fastreuseport > 0 && sk_reuseport_match(tb, sk))
!rcu_access_pointer(sk->sk_reuseport_cb) &&
sk->sk_reuseport && uid_eq(tb->fastuid, uid)))
goto success; goto success;
if (inet_csk_bind_conflict(sk, tb, true, true)) if (inet_csk_bind_conflict(sk, tb, true, true))
goto fail_unlock; goto fail_unlock;
...@@ -284,16 +323,43 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) ...@@ -284,16 +323,43 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
if (!hlist_empty(&tb->owners)) { if (!hlist_empty(&tb->owners)) {
tb->fastreuse = reuse; tb->fastreuse = reuse;
if (sk->sk_reuseport) { if (sk->sk_reuseport) {
tb->fastreuseport = 1; tb->fastreuseport = FASTREUSEPORT_ANY;
tb->fastuid = uid; tb->fastuid = uid;
tb->fast_rcv_saddr = sk->sk_rcv_saddr;
tb->fast_ipv6_only = ipv6_only_sock(sk);
#if IS_ENABLED(CONFIG_IPV6)
tb->fast_v6_rcv_saddr = sk->sk_v6_rcv_saddr;
#endif
} else { } else {
tb->fastreuseport = 0; tb->fastreuseport = 0;
} }
} else { } else {
if (!reuse) if (!reuse)
tb->fastreuse = 0; tb->fastreuse = 0;
if (!sk->sk_reuseport || !uid_eq(tb->fastuid, uid)) if (sk->sk_reuseport) {
/* We didn't match or we don't have fastreuseport set on
* the tb, but we have sk_reuseport set on this socket
* and we know that there are no bind conflicts with
* this socket in this tb, so reset our tb's reuseport
* settings so that any subsequent sockets that match
* our current socket will be put on the fast path.
*
* If we reset we need to set FASTREUSEPORT_STRICT so we
* do extra checking for all subsequent sk_reuseport
* socks.
*/
if (!sk_reuseport_match(tb, sk)) {
tb->fastreuseport = FASTREUSEPORT_STRICT;
tb->fastuid = uid;
tb->fast_rcv_saddr = sk->sk_rcv_saddr;
tb->fast_ipv6_only = ipv6_only_sock(sk);
#if IS_ENABLED(CONFIG_IPV6)
tb->fast_v6_rcv_saddr = sk->sk_v6_rcv_saddr;
#endif
}
} else {
tb->fastreuseport = 0; tb->fastreuseport = 0;
}
} }
if (!inet_csk(sk)->icsk_bind_hash) if (!inet_csk(sk)->icsk_bind_hash)
inet_bind_hash(sk, tb, port); inet_bind_hash(sk, tb, port);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment