Commit 13475a30 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

tcp: connect() race with timewait reuse

Its currently possible that several threads issuing a connect() find
the same timewait socket and try to reuse it, leading to list
corruptions.

Condition for bug is that these threads bound their socket on same
address/port of to-be-find timewait socket, and connected to same
target. (SO_REUSEADDR needed)

To fix this problem, we could unhash timewait socket while holding
ehash lock, to make sure lookups/changes will be serialized. Only
first thread finds the timewait socket, other ones find the
established socket and return an EADDRNOTAVAIL error.

This second version takes into account Evgeniy's review and makes sure
inet_twsk_put() is called outside of locked sections.
Signed-off-by: default avatarEric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ff33a6e2
...@@ -199,6 +199,8 @@ static inline __be32 inet_rcv_saddr(const struct sock *sk) ...@@ -199,6 +199,8 @@ static inline __be32 inet_rcv_saddr(const struct sock *sk)
extern void inet_twsk_put(struct inet_timewait_sock *tw); extern void inet_twsk_put(struct inet_timewait_sock *tw);
extern int inet_twsk_unhash(struct inet_timewait_sock *tw);
extern struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, extern struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
const int state); const int state);
......
...@@ -286,6 +286,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, ...@@ -286,6 +286,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
struct sock *sk2; struct sock *sk2;
const struct hlist_nulls_node *node; const struct hlist_nulls_node *node;
struct inet_timewait_sock *tw; struct inet_timewait_sock *tw;
int twrefcnt = 0;
spin_lock(lock); spin_lock(lock);
...@@ -318,20 +319,23 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, ...@@ -318,20 +319,23 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
sk->sk_hash = hash; sk->sk_hash = hash;
WARN_ON(!sk_unhashed(sk)); WARN_ON(!sk_unhashed(sk));
__sk_nulls_add_node_rcu(sk, &head->chain); __sk_nulls_add_node_rcu(sk, &head->chain);
if (tw) {
twrefcnt = inet_twsk_unhash(tw);
NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
}
spin_unlock(lock); spin_unlock(lock);
if (twrefcnt)
inet_twsk_put(tw);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
if (twp) { if (twp) {
*twp = tw; *twp = tw;
NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
} else if (tw) { } else if (tw) {
/* Silly. Should hash-dance instead... */ /* Silly. Should hash-dance instead... */
inet_twsk_deschedule(tw, death_row); inet_twsk_deschedule(tw, death_row);
NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
inet_twsk_put(tw); inet_twsk_put(tw);
} }
return 0; return 0;
not_unique: not_unique:
......
...@@ -14,22 +14,33 @@ ...@@ -14,22 +14,33 @@
#include <net/inet_timewait_sock.h> #include <net/inet_timewait_sock.h>
#include <net/ip.h> #include <net/ip.h>
/*
* unhash a timewait socket from established hash
* lock must be hold by caller
*/
int inet_twsk_unhash(struct inet_timewait_sock *tw)
{
if (hlist_nulls_unhashed(&tw->tw_node))
return 0;
hlist_nulls_del_rcu(&tw->tw_node);
sk_nulls_node_init(&tw->tw_node);
return 1;
}
/* Must be called with locally disabled BHs. */ /* Must be called with locally disabled BHs. */
static void __inet_twsk_kill(struct inet_timewait_sock *tw, static void __inet_twsk_kill(struct inet_timewait_sock *tw,
struct inet_hashinfo *hashinfo) struct inet_hashinfo *hashinfo)
{ {
struct inet_bind_hashbucket *bhead; struct inet_bind_hashbucket *bhead;
struct inet_bind_bucket *tb; struct inet_bind_bucket *tb;
int refcnt;
/* Unlink from established hashes. */ /* Unlink from established hashes. */
spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash); spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash);
spin_lock(lock); spin_lock(lock);
if (hlist_nulls_unhashed(&tw->tw_node)) { refcnt = inet_twsk_unhash(tw);
spin_unlock(lock);
return;
}
hlist_nulls_del_rcu(&tw->tw_node);
sk_nulls_node_init(&tw->tw_node);
spin_unlock(lock); spin_unlock(lock);
/* Disassociate with bind bucket. */ /* Disassociate with bind bucket. */
...@@ -37,9 +48,12 @@ static void __inet_twsk_kill(struct inet_timewait_sock *tw, ...@@ -37,9 +48,12 @@ static void __inet_twsk_kill(struct inet_timewait_sock *tw,
hashinfo->bhash_size)]; hashinfo->bhash_size)];
spin_lock(&bhead->lock); spin_lock(&bhead->lock);
tb = tw->tw_tb; tb = tw->tw_tb;
__hlist_del(&tw->tw_bind_node); if (tb) {
tw->tw_tb = NULL; __hlist_del(&tw->tw_bind_node);
inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); tw->tw_tb = NULL;
inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
refcnt++;
}
spin_unlock(&bhead->lock); spin_unlock(&bhead->lock);
#ifdef SOCK_REFCNT_DEBUG #ifdef SOCK_REFCNT_DEBUG
if (atomic_read(&tw->tw_refcnt) != 1) { if (atomic_read(&tw->tw_refcnt) != 1) {
...@@ -47,7 +61,10 @@ static void __inet_twsk_kill(struct inet_timewait_sock *tw, ...@@ -47,7 +61,10 @@ static void __inet_twsk_kill(struct inet_timewait_sock *tw,
tw->tw_prot->name, tw, atomic_read(&tw->tw_refcnt)); tw->tw_prot->name, tw, atomic_read(&tw->tw_refcnt));
} }
#endif #endif
inet_twsk_put(tw); while (refcnt) {
inet_twsk_put(tw);
refcnt--;
}
} }
static noinline void inet_twsk_free(struct inet_timewait_sock *tw) static noinline void inet_twsk_free(struct inet_timewait_sock *tw)
...@@ -92,6 +109,7 @@ void __inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk, ...@@ -92,6 +109,7 @@ void __inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk,
tw->tw_tb = icsk->icsk_bind_hash; tw->tw_tb = icsk->icsk_bind_hash;
WARN_ON(!icsk->icsk_bind_hash); WARN_ON(!icsk->icsk_bind_hash);
inet_twsk_add_bind_node(tw, &tw->tw_tb->owners); inet_twsk_add_bind_node(tw, &tw->tw_tb->owners);
atomic_inc(&tw->tw_refcnt);
spin_unlock(&bhead->lock); spin_unlock(&bhead->lock);
spin_lock(lock); spin_lock(lock);
......
...@@ -223,6 +223,7 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row, ...@@ -223,6 +223,7 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
struct sock *sk2; struct sock *sk2;
const struct hlist_nulls_node *node; const struct hlist_nulls_node *node;
struct inet_timewait_sock *tw; struct inet_timewait_sock *tw;
int twrefcnt = 0;
spin_lock(lock); spin_lock(lock);
...@@ -250,19 +251,23 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row, ...@@ -250,19 +251,23 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
* in hash table socket with a funny identity. */ * in hash table socket with a funny identity. */
inet->inet_num = lport; inet->inet_num = lport;
inet->inet_sport = htons(lport); inet->inet_sport = htons(lport);
sk->sk_hash = hash;
WARN_ON(!sk_unhashed(sk)); WARN_ON(!sk_unhashed(sk));
__sk_nulls_add_node_rcu(sk, &head->chain); __sk_nulls_add_node_rcu(sk, &head->chain);
sk->sk_hash = hash; if (tw) {
twrefcnt = inet_twsk_unhash(tw);
NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
}
spin_unlock(lock); spin_unlock(lock);
if (twrefcnt)
inet_twsk_put(tw);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
if (twp != NULL) { if (twp) {
*twp = tw; *twp = tw;
NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED); } else if (tw) {
} else if (tw != NULL) {
/* Silly. Should hash-dance instead... */ /* Silly. Should hash-dance instead... */
inet_twsk_deschedule(tw, death_row); inet_twsk_deschedule(tw, death_row);
NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
inet_twsk_put(tw); inet_twsk_put(tw);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment