Commit 488b6d91 authored by Vadim Fedorenko's avatar Vadim Fedorenko Committed by Paolo Abeni

net-timestamp: make sk_tskey more predictable in error path

When SOF_TIMESTAMPING_OPT_ID is used to ambiguate timestamped datagrams,
the sk_tskey can become unpredictable in case of any error happened
during sendmsg(). Move increment later in the code and make decrement of
sk_tskey in error path. This solution is still racy in case of multiple
threads doing snedmsg() over the very same socket in parallel, but still
makes error path much more predictable.

Fixes: 09c2d251 ("net-timestamp: add key to disambiguate concurrent datagrams")
Reported-by: default avatarAndy Lutomirski <luto@amacapital.net>
Signed-off-by: default avatarVadim Fedorenko <vadfed@meta.com>
Reviewed-by: default avatarWillem de Bruijn <willemb@google.com>
Link: https://lore.kernel.org/r/20240213110428.1681540-1-vadfed@meta.comSigned-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parent 2ec197fd
...@@ -972,8 +972,8 @@ static int __ip_append_data(struct sock *sk, ...@@ -972,8 +972,8 @@ static int __ip_append_data(struct sock *sk,
unsigned int maxfraglen, fragheaderlen, maxnonfragsize; unsigned int maxfraglen, fragheaderlen, maxnonfragsize;
int csummode = CHECKSUM_NONE; int csummode = CHECKSUM_NONE;
struct rtable *rt = (struct rtable *)cork->dst; struct rtable *rt = (struct rtable *)cork->dst;
bool paged, hold_tskey, extra_uref = false;
unsigned int wmem_alloc_delta = 0; unsigned int wmem_alloc_delta = 0;
bool paged, extra_uref = false;
u32 tskey = 0; u32 tskey = 0;
skb = skb_peek_tail(queue); skb = skb_peek_tail(queue);
...@@ -982,10 +982,6 @@ static int __ip_append_data(struct sock *sk, ...@@ -982,10 +982,6 @@ static int __ip_append_data(struct sock *sk,
mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize; mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize;
paged = !!cork->gso_size; paged = !!cork->gso_size;
if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID)
tskey = atomic_inc_return(&sk->sk_tskey) - 1;
hh_len = LL_RESERVED_SPACE(rt->dst.dev); hh_len = LL_RESERVED_SPACE(rt->dst.dev);
fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0); fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0);
...@@ -1052,6 +1048,11 @@ static int __ip_append_data(struct sock *sk, ...@@ -1052,6 +1048,11 @@ static int __ip_append_data(struct sock *sk,
cork->length += length; cork->length += length;
hold_tskey = cork->tx_flags & SKBTX_ANY_TSTAMP &&
READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID;
if (hold_tskey)
tskey = atomic_inc_return(&sk->sk_tskey) - 1;
/* So, what's going on in the loop below? /* So, what's going on in the loop below?
* *
* We use calculated fragment length to generate chained skb, * We use calculated fragment length to generate chained skb,
...@@ -1274,6 +1275,8 @@ static int __ip_append_data(struct sock *sk, ...@@ -1274,6 +1275,8 @@ static int __ip_append_data(struct sock *sk,
cork->length -= length; cork->length -= length;
IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS); IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS);
refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
if (hold_tskey)
atomic_dec(&sk->sk_tskey);
return err; return err;
} }
......
...@@ -1424,11 +1424,11 @@ static int __ip6_append_data(struct sock *sk, ...@@ -1424,11 +1424,11 @@ static int __ip6_append_data(struct sock *sk,
bool zc = false; bool zc = false;
u32 tskey = 0; u32 tskey = 0;
struct rt6_info *rt = (struct rt6_info *)cork->dst; struct rt6_info *rt = (struct rt6_info *)cork->dst;
bool paged, hold_tskey, extra_uref = false;
struct ipv6_txoptions *opt = v6_cork->opt; struct ipv6_txoptions *opt = v6_cork->opt;
int csummode = CHECKSUM_NONE; int csummode = CHECKSUM_NONE;
unsigned int maxnonfragsize, headersize; unsigned int maxnonfragsize, headersize;
unsigned int wmem_alloc_delta = 0; unsigned int wmem_alloc_delta = 0;
bool paged, extra_uref = false;
skb = skb_peek_tail(queue); skb = skb_peek_tail(queue);
if (!skb) { if (!skb) {
...@@ -1440,10 +1440,6 @@ static int __ip6_append_data(struct sock *sk, ...@@ -1440,10 +1440,6 @@ static int __ip6_append_data(struct sock *sk,
mtu = cork->gso_size ? IP6_MAX_MTU : cork->fragsize; mtu = cork->gso_size ? IP6_MAX_MTU : cork->fragsize;
orig_mtu = mtu; orig_mtu = mtu;
if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID)
tskey = atomic_inc_return(&sk->sk_tskey) - 1;
hh_len = LL_RESERVED_SPACE(rt->dst.dev); hh_len = LL_RESERVED_SPACE(rt->dst.dev);
fragheaderlen = sizeof(struct ipv6hdr) + rt->rt6i_nfheader_len + fragheaderlen = sizeof(struct ipv6hdr) + rt->rt6i_nfheader_len +
...@@ -1538,6 +1534,11 @@ static int __ip6_append_data(struct sock *sk, ...@@ -1538,6 +1534,11 @@ static int __ip6_append_data(struct sock *sk,
flags &= ~MSG_SPLICE_PAGES; flags &= ~MSG_SPLICE_PAGES;
} }
hold_tskey = cork->tx_flags & SKBTX_ANY_TSTAMP &&
READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID;
if (hold_tskey)
tskey = atomic_inc_return(&sk->sk_tskey) - 1;
/* /*
* Let's try using as much space as possible. * Let's try using as much space as possible.
* Use MTU if total length of the message fits into the MTU. * Use MTU if total length of the message fits into the MTU.
...@@ -1794,6 +1795,8 @@ static int __ip6_append_data(struct sock *sk, ...@@ -1794,6 +1795,8 @@ static int __ip6_append_data(struct sock *sk,
cork->length -= length; cork->length -= length;
IP6_INC_STATS(sock_net(sk), rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS); IP6_INC_STATS(sock_net(sk), rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS);
refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
if (hold_tskey)
atomic_dec(&sk->sk_tskey);
return err; return err;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment