Commit b16ed2b1 authored by David S. Miller's avatar David S. Miller

Merge branch 'rework-inet_csk_get_port'

Josef Bacik says:

====================
Rework inet_csk_get_port

V3->V4:
-Removed the random include of addrconf.h that is no longer needed.

V2->V3:
-Dropped the fastsock from the tb and instead just carry the saddrs, family, and
 ipv6 only flag.
-Reworked the helper functions to deal with this change so I could still use
 them when checking the fast path.
-Killed tb->num_owners as per Eric's request.
-Attached a reproducer to the bottom of this email.

V1->V2:
-Added a new patch 'inet: collapse ipv4/v6 rcv_saddr_equal functions into one'
 at Hannes' suggestion.
-Dropped ->bind_conflict and just use the new helper.
-Fixed a compile bug from the original ->bind_conflict patch.

The original description of the series follows:

At some point recently the guys working on our load balancer added the ability
to use SO_REUSEPORT.  When they restarted their app with this option enabled
they immediately hit a softlockup on what appeared to be the
inet_bind_bucket->lock.  Eventually what all of our debugging and discussion led
us to was the fact that the application comes up without SO_REUSEPORT, shuts
down which creates around 100k twsk's, and then comes up and tries to open a
bunch of sockets using SO_REUSEPORT, which meant traversing the inet_bind_bucket
owners list under the lock.  Since this lock is needed for dealing with the
twsk's and basically anything else related to connections we would softlockup,
and sometimes not ever recover.

To solve this problem I did what you see in Path 5/5.  Once we have a
SO_REUSEPORT socket on the tb->owners list we know that the socket has no
conflicts with any of the other sockets on that list.  So we can add a copy of
the sock_common (really all we need is the recv_saddr but it seemed ugly to copy
just the ipv6, ipv4, and flag to indicate if we were ipv6 only in there so I've
copied the whole common) in order to check subsequent SO_REUSEPORT sockets.  If
they match the previous one then we can skip the expensive
inet_csk_bind_conflict check.  This is what eliminated the soft lockup that we
were seeing.

Patches 1-4 are cleanups and re-workings.  For instance when we specify port ==
0 we need to find an open port, but we would do two passes through
inet_csk_bind_conflict every time we found a possible port.  We would also keep
track of the smallest_port value in order to try and use it if we found no
port our first run through.  This however made no sense as it would have had to
fail the first pass through inet_csk_bind_conflict, so would not actually pass
the second pass through either.  Finally I split the function into two functions
in order to make it easier to read and to distinguish between the two behaviors.

I have tested this on one of our load balancing boxes during peak traffic and it
hasn't fallen over.  But this is not my area, so obviously feel free to point
out where I'm being stupid and I'll get it fixed up and retested.  Thanks,
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents ab70e586 637bc8bb
...@@ -88,9 +88,7 @@ int __ipv6_get_lladdr(struct inet6_dev *idev, struct in6_addr *addr, ...@@ -88,9 +88,7 @@ int __ipv6_get_lladdr(struct inet6_dev *idev, struct in6_addr *addr,
u32 banned_flags); u32 banned_flags);
int ipv6_get_lladdr(struct net_device *dev, struct in6_addr *addr, int ipv6_get_lladdr(struct net_device *dev, struct in6_addr *addr,
u32 banned_flags); u32 banned_flags);
int ipv4_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, int inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard);
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard); bool match_wildcard);
void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr); void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr);
void addrconf_leave_solict(struct inet6_dev *idev, const struct in6_addr *addr); void addrconf_leave_solict(struct inet6_dev *idev, const struct in6_addr *addr);
......
...@@ -15,16 +15,11 @@ ...@@ -15,16 +15,11 @@
#include <linux/types.h> #include <linux/types.h>
struct inet_bind_bucket;
struct request_sock; struct request_sock;
struct sk_buff; struct sk_buff;
struct sock; struct sock;
struct sockaddr; struct sockaddr;
int inet6_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb, bool relax,
bool soreuseport_ok);
struct dst_entry *inet6_csk_route_req(const struct sock *sk, struct flowi6 *fl6, struct dst_entry *inet6_csk_route_req(const struct sock *sk, struct flowi6 *fl6,
const struct request_sock *req, u8 proto); const struct request_sock *req, u8 proto);
......
...@@ -62,9 +62,6 @@ struct inet_connection_sock_af_ops { ...@@ -62,9 +62,6 @@ struct inet_connection_sock_af_ops {
char __user *optval, int __user *optlen); char __user *optval, int __user *optlen);
#endif #endif
void (*addr2sockaddr)(struct sock *sk, struct sockaddr *); void (*addr2sockaddr)(struct sock *sk, struct sockaddr *);
int (*bind_conflict)(const struct sock *sk,
const struct inet_bind_bucket *tb,
bool relax, bool soreuseport_ok);
void (*mtu_reduced)(struct sock *sk); void (*mtu_reduced)(struct sock *sk);
}; };
...@@ -263,9 +260,6 @@ inet_csk_rto_backoff(const struct inet_connection_sock *icsk, ...@@ -263,9 +260,6 @@ inet_csk_rto_backoff(const struct inet_connection_sock *icsk,
struct sock *inet_csk_accept(struct sock *sk, int flags, int *err); struct sock *inet_csk_accept(struct sock *sk, int flags, int *err);
int inet_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb, bool relax,
bool soreuseport_ok);
int inet_csk_get_port(struct sock *sk, unsigned short snum); int inet_csk_get_port(struct sock *sk, unsigned short snum);
struct dst_entry *inet_csk_route_req(const struct sock *sk, struct flowi4 *fl4, struct dst_entry *inet_csk_route_req(const struct sock *sk, struct flowi4 *fl4,
......
...@@ -74,13 +74,21 @@ struct inet_ehash_bucket { ...@@ -74,13 +74,21 @@ struct inet_ehash_bucket {
* users logged onto your box, isn't it nice to know that new data * users logged onto your box, isn't it nice to know that new data
* ports are created in O(1) time? I thought so. ;-) -DaveM * ports are created in O(1) time? I thought so. ;-) -DaveM
*/ */
#define FASTREUSEPORT_ANY 1
#define FASTREUSEPORT_STRICT 2
struct inet_bind_bucket { struct inet_bind_bucket {
possible_net_t ib_net; possible_net_t ib_net;
unsigned short port; unsigned short port;
signed char fastreuse; signed char fastreuse;
signed char fastreuseport; signed char fastreuseport;
kuid_t fastuid; kuid_t fastuid;
int num_owners; #if IS_ENABLED(CONFIG_IPV6)
struct in6_addr fast_v6_rcv_saddr;
#endif
__be32 fast_rcv_saddr;
unsigned short fast_sk_family;
bool fast_ipv6_only;
struct hlist_node node; struct hlist_node node;
struct hlist_head owners; struct hlist_head owners;
}; };
...@@ -203,10 +211,7 @@ void inet_hashinfo_init(struct inet_hashinfo *h); ...@@ -203,10 +211,7 @@ void inet_hashinfo_init(struct inet_hashinfo *h);
bool inet_ehash_insert(struct sock *sk, struct sock *osk); bool inet_ehash_insert(struct sock *sk, struct sock *osk);
bool inet_ehash_nolisten(struct sock *sk, struct sock *osk); bool inet_ehash_nolisten(struct sock *sk, struct sock *osk);
int __inet_hash(struct sock *sk, struct sock *osk, int __inet_hash(struct sock *sk, struct sock *osk);
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard));
int inet_hash(struct sock *sk); int inet_hash(struct sock *sk);
void inet_unhash(struct sock *sk); void inet_unhash(struct sock *sk);
......
...@@ -204,7 +204,6 @@ static inline void udp_lib_close(struct sock *sk, long timeout) ...@@ -204,7 +204,6 @@ static inline void udp_lib_close(struct sock *sk, long timeout)
} }
int udp_lib_get_port(struct sock *sk, unsigned short snum, int udp_lib_get_port(struct sock *sk, unsigned short snum,
int (*)(const struct sock *, const struct sock *, bool),
unsigned int hash2_nulladdr); unsigned int hash2_nulladdr);
u32 udp_flow_hashrnd(void); u32 udp_flow_hashrnd(void);
......
...@@ -904,7 +904,6 @@ static const struct inet_connection_sock_af_ops dccp_ipv4_af_ops = { ...@@ -904,7 +904,6 @@ static const struct inet_connection_sock_af_ops dccp_ipv4_af_ops = {
.getsockopt = ip_getsockopt, .getsockopt = ip_getsockopt,
.addr2sockaddr = inet_csk_addr2sockaddr, .addr2sockaddr = inet_csk_addr2sockaddr,
.sockaddr_len = sizeof(struct sockaddr_in), .sockaddr_len = sizeof(struct sockaddr_in),
.bind_conflict = inet_csk_bind_conflict,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_ip_setsockopt, .compat_setsockopt = compat_ip_setsockopt,
.compat_getsockopt = compat_ip_getsockopt, .compat_getsockopt = compat_ip_getsockopt,
......
...@@ -937,7 +937,6 @@ static const struct inet_connection_sock_af_ops dccp_ipv6_af_ops = { ...@@ -937,7 +937,6 @@ static const struct inet_connection_sock_af_ops dccp_ipv6_af_ops = {
.getsockopt = ipv6_getsockopt, .getsockopt = ipv6_getsockopt,
.addr2sockaddr = inet6_csk_addr2sockaddr, .addr2sockaddr = inet6_csk_addr2sockaddr,
.sockaddr_len = sizeof(struct sockaddr_in6), .sockaddr_len = sizeof(struct sockaddr_in6),
.bind_conflict = inet6_csk_bind_conflict,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_ipv6_setsockopt, .compat_setsockopt = compat_ipv6_setsockopt,
.compat_getsockopt = compat_ipv6_getsockopt, .compat_getsockopt = compat_ipv6_getsockopt,
...@@ -958,7 +957,6 @@ static const struct inet_connection_sock_af_ops dccp_ipv6_mapped = { ...@@ -958,7 +957,6 @@ static const struct inet_connection_sock_af_ops dccp_ipv6_mapped = {
.getsockopt = ipv6_getsockopt, .getsockopt = ipv6_getsockopt,
.addr2sockaddr = inet6_csk_addr2sockaddr, .addr2sockaddr = inet6_csk_addr2sockaddr,
.sockaddr_len = sizeof(struct sockaddr_in6), .sockaddr_len = sizeof(struct sockaddr_in6),
.bind_conflict = inet6_csk_bind_conflict,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_ipv6_setsockopt, .compat_setsockopt = compat_ipv6_setsockopt,
.compat_getsockopt = compat_ipv6_getsockopt, .compat_getsockopt = compat_ipv6_getsockopt,
......
...@@ -31,6 +31,86 @@ const char inet_csk_timer_bug_msg[] = "inet_csk BUG: unknown timer value\n"; ...@@ -31,6 +31,86 @@ const char inet_csk_timer_bug_msg[] = "inet_csk BUG: unknown timer value\n";
EXPORT_SYMBOL(inet_csk_timer_bug_msg); EXPORT_SYMBOL(inet_csk_timer_bug_msg);
#endif #endif
#if IS_ENABLED(CONFIG_IPV6)
/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
* only, and any IPv4 addresses if not IPv6 only
* match_wildcard == false: addresses must be exactly the same, i.e.
* IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
* and 0.0.0.0 equals to 0.0.0.0 only
*/
static int ipv6_rcv_saddr_equal(const struct in6_addr *sk1_rcv_saddr6,
const struct in6_addr *sk2_rcv_saddr6,
__be32 sk1_rcv_saddr, __be32 sk2_rcv_saddr,
bool sk1_ipv6only, bool sk2_ipv6only,
bool match_wildcard)
{
int addr_type = ipv6_addr_type(sk1_rcv_saddr6);
int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
/* if both are mapped, treat as IPv4 */
if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
if (!sk2_ipv6only) {
if (sk1_rcv_saddr == sk2_rcv_saddr)
return 1;
if (!sk1_rcv_saddr || !sk2_rcv_saddr)
return match_wildcard;
}
return 0;
}
if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
return 1;
if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
!(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
return 1;
if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
!(sk1_ipv6only && addr_type2 == IPV6_ADDR_MAPPED))
return 1;
if (sk2_rcv_saddr6 &&
ipv6_addr_equal(sk1_rcv_saddr6, sk2_rcv_saddr6))
return 1;
return 0;
}
#endif
/* match_wildcard == true: 0.0.0.0 equals to any IPv4 addresses
* match_wildcard == false: addresses must be exactly the same, i.e.
* 0.0.0.0 only equals to 0.0.0.0
*/
static int ipv4_rcv_saddr_equal(__be32 sk1_rcv_saddr, __be32 sk2_rcv_saddr,
bool sk2_ipv6only, bool match_wildcard)
{
if (!sk2_ipv6only) {
if (sk1_rcv_saddr == sk2_rcv_saddr)
return 1;
if (!sk1_rcv_saddr || !sk2_rcv_saddr)
return match_wildcard;
}
return 0;
}
int inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard)
{
#if IS_ENABLED(CONFIG_IPV6)
if (sk->sk_family == AF_INET6)
return ipv6_rcv_saddr_equal(&sk->sk_v6_rcv_saddr,
&sk2->sk_v6_rcv_saddr,
sk->sk_rcv_saddr,
sk2->sk_rcv_saddr,
ipv6_only_sock(sk),
ipv6_only_sock(sk2),
match_wildcard);
#endif
return ipv4_rcv_saddr_equal(sk->sk_rcv_saddr, sk2->sk_rcv_saddr,
ipv6_only_sock(sk2), match_wildcard);
}
EXPORT_SYMBOL(inet_rcv_saddr_equal);
void inet_get_local_port_range(struct net *net, int *low, int *high) void inet_get_local_port_range(struct net *net, int *low, int *high)
{ {
unsigned int seq; unsigned int seq;
...@@ -44,9 +124,9 @@ void inet_get_local_port_range(struct net *net, int *low, int *high) ...@@ -44,9 +124,9 @@ void inet_get_local_port_range(struct net *net, int *low, int *high)
} }
EXPORT_SYMBOL(inet_get_local_port_range); EXPORT_SYMBOL(inet_get_local_port_range);
int inet_csk_bind_conflict(const struct sock *sk, static int inet_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb, bool relax, const struct inet_bind_bucket *tb,
bool reuseport_ok) bool relax, bool reuseport_ok)
{ {
struct sock *sk2; struct sock *sk2;
bool reuse = sk->sk_reuse; bool reuse = sk->sk_reuse;
...@@ -62,7 +142,6 @@ int inet_csk_bind_conflict(const struct sock *sk, ...@@ -62,7 +142,6 @@ int inet_csk_bind_conflict(const struct sock *sk,
sk_for_each_bound(sk2, &tb->owners) { sk_for_each_bound(sk2, &tb->owners) {
if (sk != sk2 && if (sk != sk2 &&
!inet_v6_ipv6only(sk2) &&
(!sk->sk_bound_dev_if || (!sk->sk_bound_dev_if ||
!sk2->sk_bound_dev_if || !sk2->sk_bound_dev_if ||
sk->sk_bound_dev_if == sk2->sk_bound_dev_if)) { sk->sk_bound_dev_if == sk2->sk_bound_dev_if)) {
...@@ -72,54 +151,34 @@ int inet_csk_bind_conflict(const struct sock *sk, ...@@ -72,54 +151,34 @@ int inet_csk_bind_conflict(const struct sock *sk,
rcu_access_pointer(sk->sk_reuseport_cb) || rcu_access_pointer(sk->sk_reuseport_cb) ||
(sk2->sk_state != TCP_TIME_WAIT && (sk2->sk_state != TCP_TIME_WAIT &&
!uid_eq(uid, sock_i_uid(sk2))))) { !uid_eq(uid, sock_i_uid(sk2))))) {
if (inet_rcv_saddr_equal(sk, sk2, true))
if (!sk2->sk_rcv_saddr || !sk->sk_rcv_saddr ||
sk2->sk_rcv_saddr == sk->sk_rcv_saddr)
break; break;
} }
if (!relax && reuse && sk2->sk_reuse && if (!relax && reuse && sk2->sk_reuse &&
sk2->sk_state != TCP_LISTEN) { sk2->sk_state != TCP_LISTEN) {
if (inet_rcv_saddr_equal(sk, sk2, true))
if (!sk2->sk_rcv_saddr || !sk->sk_rcv_saddr ||
sk2->sk_rcv_saddr == sk->sk_rcv_saddr)
break; break;
} }
} }
} }
return sk2 != NULL; return sk2 != NULL;
} }
EXPORT_SYMBOL_GPL(inet_csk_bind_conflict);
/* Obtain a reference to a local port for the given sock, /*
* if snum is zero it means select any available local port. * Find an open port number for the socket. Returns with the
* We try to allocate an odd port (and leave even ports for connect()) * inet_bind_hashbucket lock held.
*/ */
int inet_csk_get_port(struct sock *sk, unsigned short snum) static struct inet_bind_hashbucket *
inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret)
{ {
bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
int ret = 1, attempts = 5, port = snum; int port = 0;
int smallest_size = -1, smallest_port;
struct inet_bind_hashbucket *head; struct inet_bind_hashbucket *head;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
int i, low, high, attempt_half; int i, low, high, attempt_half;
struct inet_bind_bucket *tb; struct inet_bind_bucket *tb;
kuid_t uid = sock_i_uid(sk);
u32 remaining, offset; u32 remaining, offset;
bool reuseport_ok = !!snum;
if (port) {
have_port:
head = &hinfo->bhash[inet_bhashfn(net, port,
hinfo->bhash_size)];
spin_lock_bh(&head->lock);
inet_bind_bucket_for_each(tb, &head->chain)
if (net_eq(ib_net(tb), net) && tb->port == port)
goto tb_found;
goto tb_not_found;
}
again:
attempt_half = (sk->sk_reuse == SK_CAN_REUSE) ? 1 : 0; attempt_half = (sk->sk_reuse == SK_CAN_REUSE) ? 1 : 0;
other_half_scan: other_half_scan:
inet_get_local_port_range(net, &low, &high); inet_get_local_port_range(net, &low, &high);
...@@ -143,8 +202,6 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) ...@@ -143,8 +202,6 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
* We do the opposite to not pollute connect() users. * We do the opposite to not pollute connect() users.
*/ */
offset |= 1U; offset |= 1U;
smallest_size = -1;
smallest_port = low; /* avoid compiler warning */
other_parity_scan: other_parity_scan:
port = low + offset; port = low + offset;
...@@ -158,30 +215,17 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) ...@@ -158,30 +215,17 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
spin_lock_bh(&head->lock); spin_lock_bh(&head->lock);
inet_bind_bucket_for_each(tb, &head->chain) inet_bind_bucket_for_each(tb, &head->chain)
if (net_eq(ib_net(tb), net) && tb->port == port) { if (net_eq(ib_net(tb), net) && tb->port == port) {
if (((tb->fastreuse > 0 && reuse) || if (!inet_csk_bind_conflict(sk, tb, false, false))
(tb->fastreuseport > 0 && goto success;
sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(tb->fastuid, uid))) &&
(tb->num_owners < smallest_size || smallest_size == -1)) {
smallest_size = tb->num_owners;
smallest_port = port;
}
if (!inet_csk(sk)->icsk_af_ops->bind_conflict(sk, tb, false,
reuseport_ok))
goto tb_found;
goto next_port; goto next_port;
} }
goto tb_not_found; tb = NULL;
goto success;
next_port: next_port:
spin_unlock_bh(&head->lock); spin_unlock_bh(&head->lock);
cond_resched(); cond_resched();
} }
if (smallest_size != -1) {
port = smallest_port;
goto have_port;
}
offset--; offset--;
if (!(offset & 1)) if (!(offset & 1))
goto other_parity_scan; goto other_parity_scan;
...@@ -191,8 +235,74 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) ...@@ -191,8 +235,74 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
attempt_half = 2; attempt_half = 2;
goto other_half_scan; goto other_half_scan;
} }
return ret; return NULL;
success:
*port_ret = port;
*tb_ret = tb;
return head;
}
static inline int sk_reuseport_match(struct inet_bind_bucket *tb,
struct sock *sk)
{
kuid_t uid = sock_i_uid(sk);
if (tb->fastreuseport <= 0)
return 0;
if (!sk->sk_reuseport)
return 0;
if (rcu_access_pointer(sk->sk_reuseport_cb))
return 0;
if (!uid_eq(tb->fastuid, uid))
return 0;
/* We only need to check the rcv_saddr if this tb was once marked
* without fastreuseport and then was reset, as we can only know that
* the fast_*rcv_saddr doesn't have any conflicts with the socks on the
* owners list.
*/
if (tb->fastreuseport == FASTREUSEPORT_ANY)
return 1;
#if IS_ENABLED(CONFIG_IPV6)
if (tb->fast_sk_family == AF_INET6)
return ipv6_rcv_saddr_equal(&tb->fast_v6_rcv_saddr,
&sk->sk_v6_rcv_saddr,
tb->fast_rcv_saddr,
sk->sk_rcv_saddr,
tb->fast_ipv6_only,
ipv6_only_sock(sk), true);
#endif
return ipv4_rcv_saddr_equal(tb->fast_rcv_saddr, sk->sk_rcv_saddr,
ipv6_only_sock(sk), true);
}
/* Obtain a reference to a local port for the given sock,
* if snum is zero it means select any available local port.
* We try to allocate an odd port (and leave even ports for connect())
*/
int inet_csk_get_port(struct sock *sk, unsigned short snum)
{
bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
int ret = 1, port = snum;
struct inet_bind_hashbucket *head;
struct net *net = sock_net(sk);
struct inet_bind_bucket *tb = NULL;
kuid_t uid = sock_i_uid(sk);
if (!port) {
head = inet_csk_find_open_port(sk, &tb, &port);
if (!head)
return ret;
if (!tb)
goto tb_not_found;
goto success;
}
head = &hinfo->bhash[inet_bhashfn(net, port,
hinfo->bhash_size)];
spin_lock_bh(&head->lock);
inet_bind_bucket_for_each(tb, &head->chain)
if (net_eq(ib_net(tb), net) && tb->port == port)
goto tb_found;
tb_not_found: tb_not_found:
tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
net, head, port); net, head, port);
...@@ -203,39 +313,54 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) ...@@ -203,39 +313,54 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
if (sk->sk_reuse == SK_FORCE_REUSE) if (sk->sk_reuse == SK_FORCE_REUSE)
goto success; goto success;
if (((tb->fastreuse > 0 && reuse) || if ((tb->fastreuse > 0 && reuse) ||
(tb->fastreuseport > 0 && sk_reuseport_match(tb, sk))
!rcu_access_pointer(sk->sk_reuseport_cb) &&
sk->sk_reuseport && uid_eq(tb->fastuid, uid))) &&
smallest_size == -1)
goto success; goto success;
if (inet_csk(sk)->icsk_af_ops->bind_conflict(sk, tb, true, if (inet_csk_bind_conflict(sk, tb, true, true))
reuseport_ok)) {
if ((reuse ||
(tb->fastreuseport > 0 &&
sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(tb->fastuid, uid))) &&
!snum && smallest_size != -1 && --attempts >= 0) {
spin_unlock_bh(&head->lock);
goto again;
}
goto fail_unlock; goto fail_unlock;
} }
if (!reuse) success:
tb->fastreuse = 0; if (!hlist_empty(&tb->owners)) {
if (!sk->sk_reuseport || !uid_eq(tb->fastuid, uid)) tb->fastreuse = reuse;
if (sk->sk_reuseport) {
tb->fastreuseport = FASTREUSEPORT_ANY;
tb->fastuid = uid;
tb->fast_rcv_saddr = sk->sk_rcv_saddr;
tb->fast_ipv6_only = ipv6_only_sock(sk);
#if IS_ENABLED(CONFIG_IPV6)
tb->fast_v6_rcv_saddr = sk->sk_v6_rcv_saddr;
#endif
} else {
tb->fastreuseport = 0; tb->fastreuseport = 0;
}
} else { } else {
tb->fastreuse = reuse; if (!reuse)
tb->fastreuse = 0;
if (sk->sk_reuseport) { if (sk->sk_reuseport) {
tb->fastreuseport = 1; /* We didn't match or we don't have fastreuseport set on
* the tb, but we have sk_reuseport set on this socket
* and we know that there are no bind conflicts with
* this socket in this tb, so reset our tb's reuseport
* settings so that any subsequent sockets that match
* our current socket will be put on the fast path.
*
* If we reset we need to set FASTREUSEPORT_STRICT so we
* do extra checking for all subsequent sk_reuseport
* socks.
*/
if (!sk_reuseport_match(tb, sk)) {
tb->fastreuseport = FASTREUSEPORT_STRICT;
tb->fastuid = uid; tb->fastuid = uid;
tb->fast_rcv_saddr = sk->sk_rcv_saddr;
tb->fast_ipv6_only = ipv6_only_sock(sk);
#if IS_ENABLED(CONFIG_IPV6)
tb->fast_v6_rcv_saddr = sk->sk_v6_rcv_saddr;
#endif
}
} else { } else {
tb->fastreuseport = 0; tb->fastreuseport = 0;
} }
} }
success:
if (!inet_csk(sk)->icsk_bind_hash) if (!inet_csk(sk)->icsk_bind_hash)
inet_bind_hash(sk, tb, port); inet_bind_hash(sk, tb, port);
WARN_ON(inet_csk(sk)->icsk_bind_hash != tb); WARN_ON(inet_csk(sk)->icsk_bind_hash != tb);
......
...@@ -73,7 +73,6 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, ...@@ -73,7 +73,6 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
tb->port = snum; tb->port = snum;
tb->fastreuse = 0; tb->fastreuse = 0;
tb->fastreuseport = 0; tb->fastreuseport = 0;
tb->num_owners = 0;
INIT_HLIST_HEAD(&tb->owners); INIT_HLIST_HEAD(&tb->owners);
hlist_add_head(&tb->node, &head->chain); hlist_add_head(&tb->node, &head->chain);
} }
...@@ -96,7 +95,6 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, ...@@ -96,7 +95,6 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
{ {
inet_sk(sk)->inet_num = snum; inet_sk(sk)->inet_num = snum;
sk_add_bind_node(sk, &tb->owners); sk_add_bind_node(sk, &tb->owners);
tb->num_owners++;
inet_csk(sk)->icsk_bind_hash = tb; inet_csk(sk)->icsk_bind_hash = tb;
} }
...@@ -114,7 +112,6 @@ static void __inet_put_port(struct sock *sk) ...@@ -114,7 +112,6 @@ static void __inet_put_port(struct sock *sk)
spin_lock(&head->lock); spin_lock(&head->lock);
tb = inet_csk(sk)->icsk_bind_hash; tb = inet_csk(sk)->icsk_bind_hash;
__sk_del_bind_node(sk); __sk_del_bind_node(sk);
tb->num_owners--;
inet_csk(sk)->icsk_bind_hash = NULL; inet_csk(sk)->icsk_bind_hash = NULL;
inet_sk(sk)->inet_num = 0; inet_sk(sk)->inet_num = 0;
inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
...@@ -435,10 +432,7 @@ bool inet_ehash_nolisten(struct sock *sk, struct sock *osk) ...@@ -435,10 +432,7 @@ bool inet_ehash_nolisten(struct sock *sk, struct sock *osk)
EXPORT_SYMBOL_GPL(inet_ehash_nolisten); EXPORT_SYMBOL_GPL(inet_ehash_nolisten);
static int inet_reuseport_add_sock(struct sock *sk, static int inet_reuseport_add_sock(struct sock *sk,
struct inet_listen_hashbucket *ilb, struct inet_listen_hashbucket *ilb)
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard))
{ {
struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash; struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash;
struct sock *sk2; struct sock *sk2;
...@@ -451,7 +445,7 @@ static int inet_reuseport_add_sock(struct sock *sk, ...@@ -451,7 +445,7 @@ static int inet_reuseport_add_sock(struct sock *sk,
sk2->sk_bound_dev_if == sk->sk_bound_dev_if && sk2->sk_bound_dev_if == sk->sk_bound_dev_if &&
inet_csk(sk2)->icsk_bind_hash == tb && inet_csk(sk2)->icsk_bind_hash == tb &&
sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
saddr_same(sk, sk2, false)) inet_rcv_saddr_equal(sk, sk2, false))
return reuseport_add_sock(sk, sk2); return reuseport_add_sock(sk, sk2);
} }
...@@ -461,10 +455,7 @@ static int inet_reuseport_add_sock(struct sock *sk, ...@@ -461,10 +455,7 @@ static int inet_reuseport_add_sock(struct sock *sk,
return 0; return 0;
} }
int __inet_hash(struct sock *sk, struct sock *osk, int __inet_hash(struct sock *sk, struct sock *osk)
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard))
{ {
struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
struct inet_listen_hashbucket *ilb; struct inet_listen_hashbucket *ilb;
...@@ -479,7 +470,7 @@ int __inet_hash(struct sock *sk, struct sock *osk, ...@@ -479,7 +470,7 @@ int __inet_hash(struct sock *sk, struct sock *osk,
spin_lock(&ilb->lock); spin_lock(&ilb->lock);
if (sk->sk_reuseport) { if (sk->sk_reuseport) {
err = inet_reuseport_add_sock(sk, ilb, saddr_same); err = inet_reuseport_add_sock(sk, ilb);
if (err) if (err)
goto unlock; goto unlock;
} }
...@@ -503,7 +494,7 @@ int inet_hash(struct sock *sk) ...@@ -503,7 +494,7 @@ int inet_hash(struct sock *sk)
if (sk->sk_state != TCP_CLOSE) { if (sk->sk_state != TCP_CLOSE) {
local_bh_disable(); local_bh_disable();
err = __inet_hash(sk, NULL, ipv4_rcv_saddr_equal); err = __inet_hash(sk, NULL);
local_bh_enable(); local_bh_enable();
} }
......
...@@ -1817,7 +1817,6 @@ const struct inet_connection_sock_af_ops ipv4_specific = { ...@@ -1817,7 +1817,6 @@ const struct inet_connection_sock_af_ops ipv4_specific = {
.getsockopt = ip_getsockopt, .getsockopt = ip_getsockopt,
.addr2sockaddr = inet_csk_addr2sockaddr, .addr2sockaddr = inet_csk_addr2sockaddr,
.sockaddr_len = sizeof(struct sockaddr_in), .sockaddr_len = sizeof(struct sockaddr_in),
.bind_conflict = inet_csk_bind_conflict,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_ip_setsockopt, .compat_setsockopt = compat_ip_setsockopt,
.compat_getsockopt = compat_ip_getsockopt, .compat_getsockopt = compat_ip_getsockopt,
......
...@@ -137,11 +137,7 @@ EXPORT_SYMBOL(udp_memory_allocated); ...@@ -137,11 +137,7 @@ EXPORT_SYMBOL(udp_memory_allocated);
static int udp_lib_lport_inuse(struct net *net, __u16 num, static int udp_lib_lport_inuse(struct net *net, __u16 num,
const struct udp_hslot *hslot, const struct udp_hslot *hslot,
unsigned long *bitmap, unsigned long *bitmap,
struct sock *sk, struct sock *sk, unsigned int log)
int (*saddr_comp)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard),
unsigned int log)
{ {
struct sock *sk2; struct sock *sk2;
kuid_t uid = sock_i_uid(sk); kuid_t uid = sock_i_uid(sk);
...@@ -153,7 +149,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num, ...@@ -153,7 +149,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
(!sk2->sk_reuse || !sk->sk_reuse) && (!sk2->sk_reuse || !sk->sk_reuse) &&
(!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
saddr_comp(sk, sk2, true)) { inet_rcv_saddr_equal(sk, sk2, true)) {
if (sk2->sk_reuseport && sk->sk_reuseport && if (sk2->sk_reuseport && sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) && !rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(uid, sock_i_uid(sk2))) { uid_eq(uid, sock_i_uid(sk2))) {
...@@ -176,10 +172,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num, ...@@ -176,10 +172,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
*/ */
static int udp_lib_lport_inuse2(struct net *net, __u16 num, static int udp_lib_lport_inuse2(struct net *net, __u16 num,
struct udp_hslot *hslot2, struct udp_hslot *hslot2,
struct sock *sk, struct sock *sk)
int (*saddr_comp)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard))
{ {
struct sock *sk2; struct sock *sk2;
kuid_t uid = sock_i_uid(sk); kuid_t uid = sock_i_uid(sk);
...@@ -193,7 +186,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num, ...@@ -193,7 +186,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
(!sk2->sk_reuse || !sk->sk_reuse) && (!sk2->sk_reuse || !sk->sk_reuse) &&
(!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
saddr_comp(sk, sk2, true)) { inet_rcv_saddr_equal(sk, sk2, true)) {
if (sk2->sk_reuseport && sk->sk_reuseport && if (sk2->sk_reuseport && sk->sk_reuseport &&
!rcu_access_pointer(sk->sk_reuseport_cb) && !rcu_access_pointer(sk->sk_reuseport_cb) &&
uid_eq(uid, sock_i_uid(sk2))) { uid_eq(uid, sock_i_uid(sk2))) {
...@@ -208,10 +201,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num, ...@@ -208,10 +201,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
return res; return res;
} }
static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot, static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot)
int (*saddr_same)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard))
{ {
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
kuid_t uid = sock_i_uid(sk); kuid_t uid = sock_i_uid(sk);
...@@ -225,7 +215,7 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot, ...@@ -225,7 +215,7 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
(udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) && (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) &&
(sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
(*saddr_same)(sk, sk2, false)) { inet_rcv_saddr_equal(sk, sk2, false)) {
return reuseport_add_sock(sk, sk2); return reuseport_add_sock(sk, sk2);
} }
} }
...@@ -241,14 +231,10 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot, ...@@ -241,14 +231,10 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
* *
* @sk: socket struct in question * @sk: socket struct in question
* @snum: port number to look up * @snum: port number to look up
* @saddr_comp: AF-dependent comparison of bound local IP addresses
* @hash2_nulladdr: AF-dependent hash value in secondary hash chains, * @hash2_nulladdr: AF-dependent hash value in secondary hash chains,
* with NULL address * with NULL address
*/ */
int udp_lib_get_port(struct sock *sk, unsigned short snum, int udp_lib_get_port(struct sock *sk, unsigned short snum,
int (*saddr_comp)(const struct sock *sk1,
const struct sock *sk2,
bool match_wildcard),
unsigned int hash2_nulladdr) unsigned int hash2_nulladdr)
{ {
struct udp_hslot *hslot, *hslot2; struct udp_hslot *hslot, *hslot2;
...@@ -277,7 +263,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -277,7 +263,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
bitmap_zero(bitmap, PORTS_PER_CHAIN); bitmap_zero(bitmap, PORTS_PER_CHAIN);
spin_lock_bh(&hslot->lock); spin_lock_bh(&hslot->lock);
udp_lib_lport_inuse(net, snum, hslot, bitmap, sk, udp_lib_lport_inuse(net, snum, hslot, bitmap, sk,
saddr_comp, udptable->log); udptable->log);
snum = first; snum = first;
/* /*
...@@ -310,12 +296,11 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -310,12 +296,11 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
if (hslot->count < hslot2->count) if (hslot->count < hslot2->count)
goto scan_primary_hash; goto scan_primary_hash;
exist = udp_lib_lport_inuse2(net, snum, hslot2, exist = udp_lib_lport_inuse2(net, snum, hslot2, sk);
sk, saddr_comp);
if (!exist && (hash2_nulladdr != slot2)) { if (!exist && (hash2_nulladdr != slot2)) {
hslot2 = udp_hashslot2(udptable, hash2_nulladdr); hslot2 = udp_hashslot2(udptable, hash2_nulladdr);
exist = udp_lib_lport_inuse2(net, snum, hslot2, exist = udp_lib_lport_inuse2(net, snum, hslot2,
sk, saddr_comp); sk);
} }
if (exist) if (exist)
goto fail_unlock; goto fail_unlock;
...@@ -323,8 +308,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -323,8 +308,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
goto found; goto found;
} }
scan_primary_hash: scan_primary_hash:
if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, 0))
saddr_comp, 0))
goto fail_unlock; goto fail_unlock;
} }
found: found:
...@@ -333,7 +317,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -333,7 +317,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
udp_sk(sk)->udp_portaddr_hash ^= snum; udp_sk(sk)->udp_portaddr_hash ^= snum;
if (sk_unhashed(sk)) { if (sk_unhashed(sk)) {
if (sk->sk_reuseport && if (sk->sk_reuseport &&
udp_reuseport_add_sock(sk, hslot, saddr_comp)) { udp_reuseport_add_sock(sk, hslot)) {
inet_sk(sk)->inet_num = 0; inet_sk(sk)->inet_num = 0;
udp_sk(sk)->udp_port_hash = 0; udp_sk(sk)->udp_port_hash = 0;
udp_sk(sk)->udp_portaddr_hash ^= snum; udp_sk(sk)->udp_portaddr_hash ^= snum;
...@@ -365,24 +349,6 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -365,24 +349,6 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
} }
EXPORT_SYMBOL(udp_lib_get_port); EXPORT_SYMBOL(udp_lib_get_port);
/* match_wildcard == true: 0.0.0.0 equals to any IPv4 addresses
* match_wildcard == false: addresses must be exactly the same, i.e.
* 0.0.0.0 only equals to 0.0.0.0
*/
int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2,
bool match_wildcard)
{
struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
if (!ipv6_only_sock(sk2)) {
if (inet1->inet_rcv_saddr == inet2->inet_rcv_saddr)
return 1;
if (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr)
return match_wildcard;
}
return 0;
}
static u32 udp4_portaddr_hash(const struct net *net, __be32 saddr, static u32 udp4_portaddr_hash(const struct net *net, __be32 saddr,
unsigned int port) unsigned int port)
{ {
...@@ -398,7 +364,7 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum) ...@@ -398,7 +364,7 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
/* precompute partial secondary hash */ /* precompute partial secondary hash */
udp_sk(sk)->udp_portaddr_hash = hash2_partial; udp_sk(sk)->udp_portaddr_hash = hash2_partial;
return udp_lib_get_port(sk, snum, ipv4_rcv_saddr_equal, hash2_nulladdr); return udp_lib_get_port(sk, snum, hash2_nulladdr);
} }
static int compute_score(struct sock *sk, struct net *net, static int compute_score(struct sock *sk, struct net *net,
......
...@@ -28,46 +28,6 @@ ...@@ -28,46 +28,6 @@
#include <net/inet6_connection_sock.h> #include <net/inet6_connection_sock.h>
#include <net/sock_reuseport.h> #include <net/sock_reuseport.h>
int inet6_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb, bool relax,
bool reuseport_ok)
{
const struct sock *sk2;
bool reuse = !!sk->sk_reuse;
bool reuseport = !!sk->sk_reuseport && reuseport_ok;
kuid_t uid = sock_i_uid((struct sock *)sk);
/* We must walk the whole port owner list in this case. -DaveM */
/*
* See comment in inet_csk_bind_conflict about sock lookup
* vs net namespaces issues.
*/
sk_for_each_bound(sk2, &tb->owners) {
if (sk != sk2 &&
(!sk->sk_bound_dev_if ||
!sk2->sk_bound_dev_if ||
sk->sk_bound_dev_if == sk2->sk_bound_dev_if)) {
if ((!reuse || !sk2->sk_reuse ||
sk2->sk_state == TCP_LISTEN) &&
(!reuseport || !sk2->sk_reuseport ||
rcu_access_pointer(sk->sk_reuseport_cb) ||
(sk2->sk_state != TCP_TIME_WAIT &&
!uid_eq(uid,
sock_i_uid((struct sock *)sk2))))) {
if (ipv6_rcv_saddr_equal(sk, sk2, true))
break;
}
if (!relax && reuse && sk2->sk_reuse &&
sk2->sk_state != TCP_LISTEN &&
ipv6_rcv_saddr_equal(sk, sk2, true))
break;
}
}
return sk2 != NULL;
}
EXPORT_SYMBOL_GPL(inet6_csk_bind_conflict);
struct dst_entry *inet6_csk_route_req(const struct sock *sk, struct dst_entry *inet6_csk_route_req(const struct sock *sk,
struct flowi6 *fl6, struct flowi6 *fl6,
const struct request_sock *req, const struct request_sock *req,
......
...@@ -268,54 +268,10 @@ int inet6_hash(struct sock *sk) ...@@ -268,54 +268,10 @@ int inet6_hash(struct sock *sk)
if (sk->sk_state != TCP_CLOSE) { if (sk->sk_state != TCP_CLOSE) {
local_bh_disable(); local_bh_disable();
err = __inet_hash(sk, NULL, ipv6_rcv_saddr_equal); err = __inet_hash(sk, NULL);
local_bh_enable(); local_bh_enable();
} }
return err; return err;
} }
EXPORT_SYMBOL_GPL(inet6_hash); EXPORT_SYMBOL_GPL(inet6_hash);
/* match_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
* only, and any IPv4 addresses if not IPv6 only
* match_wildcard == false: addresses must be exactly the same, i.e.
* IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
* and 0.0.0.0 equals to 0.0.0.0 only
*/
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
bool match_wildcard)
{
const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
int sk2_ipv6only = inet_v6_ipv6only(sk2);
int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
/* if both are mapped, treat as IPv4 */
if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
if (!sk2_ipv6only) {
if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
return 1;
if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
return match_wildcard;
}
return 0;
}
if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
return 1;
if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
!(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
return 1;
if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
!(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
return 1;
if (sk2_rcv_saddr6 &&
ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6))
return 1;
return 0;
}
EXPORT_SYMBOL_GPL(ipv6_rcv_saddr_equal);
...@@ -1621,7 +1621,6 @@ static const struct inet_connection_sock_af_ops ipv6_specific = { ...@@ -1621,7 +1621,6 @@ static const struct inet_connection_sock_af_ops ipv6_specific = {
.getsockopt = ipv6_getsockopt, .getsockopt = ipv6_getsockopt,
.addr2sockaddr = inet6_csk_addr2sockaddr, .addr2sockaddr = inet6_csk_addr2sockaddr,
.sockaddr_len = sizeof(struct sockaddr_in6), .sockaddr_len = sizeof(struct sockaddr_in6),
.bind_conflict = inet6_csk_bind_conflict,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_ipv6_setsockopt, .compat_setsockopt = compat_ipv6_setsockopt,
.compat_getsockopt = compat_ipv6_getsockopt, .compat_getsockopt = compat_ipv6_getsockopt,
...@@ -1652,7 +1651,6 @@ static const struct inet_connection_sock_af_ops ipv6_mapped = { ...@@ -1652,7 +1651,6 @@ static const struct inet_connection_sock_af_ops ipv6_mapped = {
.getsockopt = ipv6_getsockopt, .getsockopt = ipv6_getsockopt,
.addr2sockaddr = inet6_csk_addr2sockaddr, .addr2sockaddr = inet6_csk_addr2sockaddr,
.sockaddr_len = sizeof(struct sockaddr_in6), .sockaddr_len = sizeof(struct sockaddr_in6),
.bind_conflict = inet6_csk_bind_conflict,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_ipv6_setsockopt, .compat_setsockopt = compat_ipv6_setsockopt,
.compat_getsockopt = compat_ipv6_getsockopt, .compat_getsockopt = compat_ipv6_getsockopt,
......
...@@ -103,7 +103,7 @@ int udp_v6_get_port(struct sock *sk, unsigned short snum) ...@@ -103,7 +103,7 @@ int udp_v6_get_port(struct sock *sk, unsigned short snum)
/* precompute partial secondary hash */ /* precompute partial secondary hash */
udp_sk(sk)->udp_portaddr_hash = hash2_partial; udp_sk(sk)->udp_portaddr_hash = hash2_partial;
return udp_lib_get_port(sk, snum, ipv6_rcv_saddr_equal, hash2_nulladdr); return udp_lib_get_port(sk, snum, hash2_nulladdr);
} }
static void udp_v6_rehash(struct sock *sk) static void udp_v6_rehash(struct sock *sk)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment