Commit 87952603 authored by Paolo Abeni's avatar Paolo Abeni Committed by Jakub Kicinski

mptcp: protect the rx path with the msk socket spinlock

Such spinlock is currently used only to protect the 'owned'
flag inside the socket lock itself. With this patch, we extend
its scope to protect the whole msk receive path and
sk_forward_memory.

Given the above, we can always move data into the msk receive
queue (and OoO queue) from the subflow.

We leverage the previous commit, so that we need to acquire the
spinlock in the tx path only when moving fwd memory.

recvmsg() must now explicitly acquire the socket spinlock
when moving skbs out of sk_receive_queue. To reduce the number of
lock operations required we use a second rx queue and splice the
first into the latter in mptcp_lock_sock(). Additionally rmem
allocated memory is bulk-freed via release_cb()
Acked-by: default avatarFlorian Westphal <fw@strlen.de>
Co-developed-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
Reviewed-by: default avatarMat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent e93da928
...@@ -453,15 +453,15 @@ static bool mptcp_subflow_cleanup_rbuf(struct sock *ssk) ...@@ -453,15 +453,15 @@ static bool mptcp_subflow_cleanup_rbuf(struct sock *ssk)
static void mptcp_cleanup_rbuf(struct mptcp_sock *msk) static void mptcp_cleanup_rbuf(struct mptcp_sock *msk)
{ {
struct sock *ack_hint = READ_ONCE(msk->ack_hint);
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
/* if the hinted ssk is still active, try to use it */ /* if the hinted ssk is still active, try to use it */
if (likely(msk->ack_hint)) { if (likely(ack_hint)) {
mptcp_for_each_subflow(msk, subflow) { mptcp_for_each_subflow(msk, subflow) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow); struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
if (msk->ack_hint == ssk && if (ack_hint == ssk && mptcp_subflow_cleanup_rbuf(ssk))
mptcp_subflow_cleanup_rbuf(ssk))
return; return;
} }
} }
...@@ -614,13 +614,13 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk, ...@@ -614,13 +614,13 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
break; break;
} }
} while (more_data_avail); } while (more_data_avail);
msk->ack_hint = ssk; WRITE_ONCE(msk->ack_hint, ssk);
*bytes += moved; *bytes += moved;
return done; return done;
} }
static bool mptcp_ofo_queue(struct mptcp_sock *msk) static bool __mptcp_ofo_queue(struct mptcp_sock *msk)
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct sk_buff *skb, *tail; struct sk_buff *skb, *tail;
...@@ -666,34 +666,27 @@ static bool mptcp_ofo_queue(struct mptcp_sock *msk) ...@@ -666,34 +666,27 @@ static bool mptcp_ofo_queue(struct mptcp_sock *msk)
/* In most cases we will be able to lock the mptcp socket. If its already /* In most cases we will be able to lock the mptcp socket. If its already
* owned, we need to defer to the work queue to avoid ABBA deadlock. * owned, we need to defer to the work queue to avoid ABBA deadlock.
*/ */
static bool move_skbs_to_msk(struct mptcp_sock *msk, struct sock *ssk) static void move_skbs_to_msk(struct mptcp_sock *msk, struct sock *ssk)
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
unsigned int moved = 0; unsigned int moved = 0;
if (READ_ONCE(sk->sk_lock.owned)) if (inet_sk_state_load(sk) == TCP_CLOSE)
return false; return;
if (unlikely(!spin_trylock_bh(&sk->sk_lock.slock)))
return false;
/* must re-check after taking the lock */
if (!READ_ONCE(sk->sk_lock.owned)) {
__mptcp_move_skbs_from_subflow(msk, ssk, &moved);
mptcp_ofo_queue(msk);
/* If the moves have caught up with the DATA_FIN sequence number mptcp_data_lock(sk);
* it's time to ack the DATA_FIN and change socket state, but
* this is not a good place to change state. Let the workqueue
* do it.
*/
if (mptcp_pending_data_fin(sk, NULL))
mptcp_schedule_work(sk);
}
spin_unlock_bh(&sk->sk_lock.slock); __mptcp_move_skbs_from_subflow(msk, ssk, &moved);
__mptcp_ofo_queue(msk);
return moved > 0; /* If the moves have caught up with the DATA_FIN sequence number
* it's time to ack the DATA_FIN and change socket state, but
* this is not a good place to change state. Let the workqueue
* do it.
*/
if (mptcp_pending_data_fin(sk, NULL))
mptcp_schedule_work(sk);
mptcp_data_unlock(sk);
} }
void mptcp_data_ready(struct sock *sk, struct sock *ssk) void mptcp_data_ready(struct sock *sk, struct sock *ssk)
...@@ -937,17 +930,30 @@ static bool mptcp_wmem_alloc(struct sock *sk, int size) ...@@ -937,17 +930,30 @@ static bool mptcp_wmem_alloc(struct sock *sk, int size)
if (msk->wmem_reserved >= size) if (msk->wmem_reserved >= size)
goto account; goto account;
if (!sk_wmem_schedule(sk, size)) mptcp_data_lock(sk);
if (!sk_wmem_schedule(sk, size)) {
mptcp_data_unlock(sk);
return false; return false;
}
sk->sk_forward_alloc -= size; sk->sk_forward_alloc -= size;
msk->wmem_reserved += size; msk->wmem_reserved += size;
mptcp_data_unlock(sk);
account: account:
msk->wmem_reserved -= size; msk->wmem_reserved -= size;
return true; return true;
} }
static void mptcp_wmem_uncharge(struct sock *sk, int size)
{
struct mptcp_sock *msk = mptcp_sk(sk);
if (msk->wmem_reserved < 0)
msk->wmem_reserved = 0;
msk->wmem_reserved += size;
}
static void dfrag_uncharge(struct sock *sk, int len) static void dfrag_uncharge(struct sock *sk, int len)
{ {
sk_mem_uncharge(sk, len); sk_mem_uncharge(sk, len);
...@@ -976,6 +982,7 @@ static void mptcp_clean_una(struct sock *sk) ...@@ -976,6 +982,7 @@ static void mptcp_clean_una(struct sock *sk)
if (__mptcp_check_fallback(msk)) if (__mptcp_check_fallback(msk))
atomic64_set(&msk->snd_una, msk->snd_nxt); atomic64_set(&msk->snd_una, msk->snd_nxt);
mptcp_data_lock(sk);
snd_una = atomic64_read(&msk->snd_una); snd_una = atomic64_read(&msk->snd_una);
list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) { list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) {
...@@ -1007,6 +1014,7 @@ static void mptcp_clean_una(struct sock *sk) ...@@ -1007,6 +1014,7 @@ static void mptcp_clean_una(struct sock *sk)
out: out:
if (cleaned && tcp_under_memory_pressure(sk)) if (cleaned && tcp_under_memory_pressure(sk))
sk_mem_reclaim_partial(sk); sk_mem_reclaim_partial(sk);
mptcp_data_unlock(sk);
} }
static void mptcp_clean_una_wakeup(struct sock *sk) static void mptcp_clean_una_wakeup(struct sock *sk)
...@@ -1436,7 +1444,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -1436,7 +1444,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (copy_page_from_iter(dfrag->page, offset, psize, if (copy_page_from_iter(dfrag->page, offset, psize,
&msg->msg_iter) != psize) { &msg->msg_iter) != psize) {
msk->wmem_reserved += psize + frag_truesize; mptcp_wmem_uncharge(sk, psize + frag_truesize);
ret = -EFAULT; ret = -EFAULT;
goto out; goto out;
} }
...@@ -1502,11 +1510,10 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk, ...@@ -1502,11 +1510,10 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,
struct msghdr *msg, struct msghdr *msg,
size_t len) size_t len)
{ {
struct sock *sk = (struct sock *)msk;
struct sk_buff *skb; struct sk_buff *skb;
int copied = 0; int copied = 0;
while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) { while ((skb = skb_peek(&msk->receive_queue)) != NULL) {
u32 offset = MPTCP_SKB_CB(skb)->offset; u32 offset = MPTCP_SKB_CB(skb)->offset;
u32 data_len = skb->len - offset; u32 data_len = skb->len - offset;
u32 count = min_t(size_t, len - copied, data_len); u32 count = min_t(size_t, len - copied, data_len);
...@@ -1526,7 +1533,10 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk, ...@@ -1526,7 +1533,10 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,
break; break;
} }
__skb_unlink(skb, &sk->sk_receive_queue); /* we will bulk release the skb memory later */
skb->destructor = NULL;
msk->rmem_released += skb->truesize;
__skb_unlink(skb, &msk->receive_queue);
__kfree_skb(skb); __kfree_skb(skb);
if (copied >= len) if (copied >= len)
...@@ -1634,25 +1644,47 @@ static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied) ...@@ -1634,25 +1644,47 @@ static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)
msk->rcvq_space.time = mstamp; msk->rcvq_space.time = mstamp;
} }
static void __mptcp_update_rmem(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);
if (!msk->rmem_released)
return;
atomic_sub(msk->rmem_released, &sk->sk_rmem_alloc);
sk_mem_uncharge(sk, msk->rmem_released);
msk->rmem_released = 0;
}
static void __mptcp_splice_receive_queue(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);
skb_queue_splice_tail_init(&sk->sk_receive_queue, &msk->receive_queue);
}
static bool __mptcp_move_skbs(struct mptcp_sock *msk, unsigned int rcv) static bool __mptcp_move_skbs(struct mptcp_sock *msk, unsigned int rcv)
{ {
struct sock *sk = (struct sock *)msk;
unsigned int moved = 0; unsigned int moved = 0;
bool done; bool ret, done;
/* avoid looping forever below on racing close */
if (((struct sock *)msk)->sk_state == TCP_CLOSE)
return false;
__mptcp_flush_join_list(msk); __mptcp_flush_join_list(msk);
do { do {
struct sock *ssk = mptcp_subflow_recv_lookup(msk); struct sock *ssk = mptcp_subflow_recv_lookup(msk);
bool slowpath; bool slowpath;
if (!ssk) /* we can have data pending in the subflows only if the msk
* receive buffer was full at subflow_data_ready() time,
* that is an unlikely slow path.
*/
if (likely(!ssk))
break; break;
slowpath = lock_sock_fast(ssk); slowpath = lock_sock_fast(ssk);
mptcp_data_lock(sk);
done = __mptcp_move_skbs_from_subflow(msk, ssk, &moved); done = __mptcp_move_skbs_from_subflow(msk, ssk, &moved);
mptcp_data_unlock(sk);
if (moved && rcv) { if (moved && rcv) {
WRITE_ONCE(msk->rmem_pending, min(rcv, moved)); WRITE_ONCE(msk->rmem_pending, min(rcv, moved));
tcp_cleanup_rbuf(ssk, 1); tcp_cleanup_rbuf(ssk, 1);
...@@ -1661,11 +1693,19 @@ static bool __mptcp_move_skbs(struct mptcp_sock *msk, unsigned int rcv) ...@@ -1661,11 +1693,19 @@ static bool __mptcp_move_skbs(struct mptcp_sock *msk, unsigned int rcv)
unlock_sock_fast(ssk, slowpath); unlock_sock_fast(ssk, slowpath);
} while (!done); } while (!done);
if (mptcp_ofo_queue(msk) || moved > 0) { /* acquire the data lock only if some input data is pending */
mptcp_check_data_fin((struct sock *)msk); ret = moved > 0;
return true; if (!RB_EMPTY_ROOT(&msk->out_of_order_queue) ||
!skb_queue_empty_lockless(&sk->sk_receive_queue)) {
mptcp_data_lock(sk);
__mptcp_update_rmem(sk);
ret |= __mptcp_ofo_queue(msk);
__mptcp_splice_receive_queue(sk);
mptcp_data_unlock(sk);
} }
return false; if (ret)
mptcp_check_data_fin((struct sock *)msk);
return !skb_queue_empty(&msk->receive_queue);
} }
static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
...@@ -1679,7 +1719,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1679,7 +1719,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT)) if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
return -EOPNOTSUPP; return -EOPNOTSUPP;
lock_sock(sk); mptcp_lock_sock(sk, __mptcp_splice_receive_queue(sk));
if (unlikely(sk->sk_state == TCP_LISTEN)) { if (unlikely(sk->sk_state == TCP_LISTEN)) {
copied = -ENOTCONN; copied = -ENOTCONN;
goto out_err; goto out_err;
...@@ -1689,7 +1729,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1689,7 +1729,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
len = min_t(size_t, len, INT_MAX); len = min_t(size_t, len, INT_MAX);
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
__mptcp_flush_join_list(msk);
for (;;) { for (;;) {
int bytes_read, old_space; int bytes_read, old_space;
...@@ -1703,7 +1742,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1703,7 +1742,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
copied += bytes_read; copied += bytes_read;
if (skb_queue_empty(&sk->sk_receive_queue) && if (skb_queue_empty(&msk->receive_queue) &&
__mptcp_move_skbs(msk, len - copied)) __mptcp_move_skbs(msk, len - copied))
continue; continue;
...@@ -1734,8 +1773,14 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1734,8 +1773,14 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags)) if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
mptcp_check_for_eof(msk); mptcp_check_for_eof(msk);
if (sk->sk_shutdown & RCV_SHUTDOWN) if (sk->sk_shutdown & RCV_SHUTDOWN) {
/* race breaker: the shutdown could be after the
* previous receive queue check
*/
if (__mptcp_move_skbs(msk, len - copied))
continue;
break; break;
}
if (sk->sk_state == TCP_CLOSE) { if (sk->sk_state == TCP_CLOSE) {
copied = -ENOTCONN; copied = -ENOTCONN;
...@@ -1757,7 +1802,8 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1757,7 +1802,8 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
mptcp_wait_data(sk, &timeo); mptcp_wait_data(sk, &timeo);
} }
if (skb_queue_empty(&sk->sk_receive_queue)) { if (skb_queue_empty_lockless(&sk->sk_receive_queue) &&
skb_queue_empty(&msk->receive_queue)) {
/* entire backlog drained, clear DATA_READY. */ /* entire backlog drained, clear DATA_READY. */
clear_bit(MPTCP_DATA_READY, &msk->flags); clear_bit(MPTCP_DATA_READY, &msk->flags);
...@@ -1773,7 +1819,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1773,7 +1819,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
out_err: out_err:
pr_debug("msk=%p data_ready=%d rx queue empty=%d copied=%d", pr_debug("msk=%p data_ready=%d rx queue empty=%d copied=%d",
msk, test_bit(MPTCP_DATA_READY, &msk->flags), msk, test_bit(MPTCP_DATA_READY, &msk->flags),
skb_queue_empty(&sk->sk_receive_queue), copied); skb_queue_empty_lockless(&sk->sk_receive_queue), copied);
mptcp_rcv_space_adjust(msk, copied); mptcp_rcv_space_adjust(msk, copied);
release_sock(sk); release_sock(sk);
...@@ -2076,9 +2122,11 @@ static int __mptcp_init_sock(struct sock *sk) ...@@ -2076,9 +2122,11 @@ static int __mptcp_init_sock(struct sock *sk)
INIT_LIST_HEAD(&msk->join_list); INIT_LIST_HEAD(&msk->join_list);
INIT_LIST_HEAD(&msk->rtx_queue); INIT_LIST_HEAD(&msk->rtx_queue);
INIT_WORK(&msk->work, mptcp_worker); INIT_WORK(&msk->work, mptcp_worker);
__skb_queue_head_init(&msk->receive_queue);
msk->out_of_order_queue = RB_ROOT; msk->out_of_order_queue = RB_ROOT;
msk->first_pending = NULL; msk->first_pending = NULL;
msk->wmem_reserved = 0; msk->wmem_reserved = 0;
msk->rmem_released = 0;
msk->ack_hint = NULL; msk->ack_hint = NULL;
msk->first = NULL; msk->first = NULL;
...@@ -2274,6 +2322,7 @@ static void __mptcp_destroy_sock(struct sock *sk) ...@@ -2274,6 +2322,7 @@ static void __mptcp_destroy_sock(struct sock *sk)
sk->sk_prot->destroy(sk); sk->sk_prot->destroy(sk);
WARN_ON_ONCE(msk->wmem_reserved); WARN_ON_ONCE(msk->wmem_reserved);
WARN_ON_ONCE(msk->rmem_released);
sk_stream_kill_queues(sk); sk_stream_kill_queues(sk);
xfrm_sk_free_policy(sk); xfrm_sk_free_policy(sk);
sk_refcnt_debug_release(sk); sk_refcnt_debug_release(sk);
...@@ -2491,6 +2540,11 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, ...@@ -2491,6 +2540,11 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
void mptcp_destroy_common(struct mptcp_sock *msk) void mptcp_destroy_common(struct mptcp_sock *msk)
{ {
struct sock *sk = (struct sock *)msk;
/* move to sk_receive_queue, sk_stream_kill_queues will purge it */
skb_queue_splice_tail_init(&msk->receive_queue, &sk->sk_receive_queue);
skb_rbtree_purge(&msk->out_of_order_queue); skb_rbtree_purge(&msk->out_of_order_queue);
mptcp_token_destroy(msk); mptcp_token_destroy(msk);
mptcp_pm_free_anno_list(msk); mptcp_pm_free_anno_list(msk);
...@@ -2626,6 +2680,7 @@ static void mptcp_release_cb(struct sock *sk) ...@@ -2626,6 +2680,7 @@ static void mptcp_release_cb(struct sock *sk)
/* clear any wmem reservation and errors */ /* clear any wmem reservation and errors */
__mptcp_update_wmem(sk); __mptcp_update_wmem(sk);
__mptcp_update_rmem(sk);
do { do {
flags = sk->sk_tsq_flags; flags = sk->sk_tsq_flags;
......
...@@ -227,6 +227,7 @@ struct mptcp_sock { ...@@ -227,6 +227,7 @@ struct mptcp_sock {
unsigned long timer_ival; unsigned long timer_ival;
u32 token; u32 token;
int rmem_pending; int rmem_pending;
int rmem_released;
unsigned long flags; unsigned long flags;
bool can_ack; bool can_ack;
bool fully_established; bool fully_established;
...@@ -238,6 +239,7 @@ struct mptcp_sock { ...@@ -238,6 +239,7 @@ struct mptcp_sock {
struct work_struct work; struct work_struct work;
struct sk_buff *ooo_last_skb; struct sk_buff *ooo_last_skb;
struct rb_root out_of_order_queue; struct rb_root out_of_order_queue;
struct sk_buff_head receive_queue;
struct list_head conn_list; struct list_head conn_list;
struct list_head rtx_queue; struct list_head rtx_queue;
struct mptcp_data_frag *first_pending; struct mptcp_data_frag *first_pending;
...@@ -267,6 +269,9 @@ struct mptcp_sock { ...@@ -267,6 +269,9 @@ struct mptcp_sock {
local_bh_enable(); \ local_bh_enable(); \
} while (0) } while (0)
#define mptcp_data_lock(sk) spin_lock_bh(&(sk)->sk_lock.slock)
#define mptcp_data_unlock(sk) spin_unlock_bh(&(sk)->sk_lock.slock)
#define mptcp_for_each_subflow(__msk, __subflow) \ #define mptcp_for_each_subflow(__msk, __subflow) \
list_for_each_entry(__subflow, &((__msk)->conn_list), node) list_for_each_entry(__subflow, &((__msk)->conn_list), node)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment