Commit eca9bfaf authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller

tls: rx: strp: preserve decryption status of skbs when needed

When receive buffer is small we try to copy out the data from
TCP into a skb maintained by TLS to prevent connection from
stalling. Unfortunately if a single record is made up of a mix
of decrypted and non-decrypted skbs combining them into a single
skb leads to loss of decryption status, resulting in decryption
errors or data corruption.

Similarly when trying to use TCP receive queue directly we need
to make sure that all the skbs within the record have the same
status. If we don't the mixed status will be detected correctly
but we'll CoW the anchor, again collapsing it into a single paged
skb without decrypted status preserved. So the "fixup" code will
not know which parts of skb to re-encrypt.

Fixes: 84c61fe1 ("tls: rx: do not use the standard strparser")
Tested-by: default avatarShai Amiram <samiram@nvidia.com>
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
Reviewed-by: default avatarSimon Horman <simon.horman@corigine.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent c1c607b1
...@@ -126,6 +126,7 @@ struct tls_strparser { ...@@ -126,6 +126,7 @@ struct tls_strparser {
u32 mark : 8; u32 mark : 8;
u32 stopped : 1; u32 stopped : 1;
u32 copy_mode : 1; u32 copy_mode : 1;
u32 mixed_decrypted : 1;
u32 msg_ready : 1; u32 msg_ready : 1;
struct strp_msg stm; struct strp_msg stm;
......
...@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx) ...@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx)
return ctx->strp.msg_ready; return ctx->strp.msg_ready;
} }
static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx)
{
return ctx->strp.mixed_decrypted;
}
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
int tls_device_init(void); int tls_device_init(void);
void tls_device_cleanup(void); void tls_device_cleanup(void);
......
...@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) ...@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb = tls_strp_msg(sw_ctx); struct sk_buff *skb = tls_strp_msg(sw_ctx);
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int is_decrypted = skb->decrypted; int is_decrypted, is_encrypted;
int is_encrypted = !is_decrypted;
struct sk_buff *skb_iter; if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
int left; is_decrypted = skb->decrypted;
is_encrypted = !is_decrypted;
left = rxm->full_len + rxm->offset - skb_pagelen(skb); } else {
/* Check if all the data is decrypted already */ is_decrypted = 0;
skb_iter = skb_shinfo(skb)->frag_list; is_encrypted = 0;
while (skb_iter && left > 0) {
is_decrypted &= skb_iter->decrypted;
is_encrypted &= !skb_iter->decrypted;
left -= skb_iter->len;
skb_iter = skb_iter->next;
} }
trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len, trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
......
...@@ -29,6 +29,7 @@ static void tls_strp_anchor_free(struct tls_strparser *strp) ...@@ -29,6 +29,7 @@ static void tls_strp_anchor_free(struct tls_strparser *strp)
struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
if (!strp->copy_mode)
shinfo->frag_list = NULL; shinfo->frag_list = NULL;
consume_skb(strp->anchor); consume_skb(strp->anchor);
strp->anchor = NULL; strp->anchor = NULL;
...@@ -195,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) ...@@ -195,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
for (i = 0; i < shinfo->nr_frags; i++) for (i = 0; i < shinfo->nr_frags; i++)
__skb_frag_unref(&shinfo->frags[i], false); __skb_frag_unref(&shinfo->frags[i], false);
shinfo->nr_frags = 0; shinfo->nr_frags = 0;
if (strp->copy_mode) {
kfree_skb_list(shinfo->frag_list);
shinfo->frag_list = NULL;
}
strp->copy_mode = 0; strp->copy_mode = 0;
strp->mixed_decrypted = 0;
} }
static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
unsigned int offset, size_t in_len) struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{ {
struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
struct sk_buff *skb;
skb_frag_t *frag;
size_t len, chunk; size_t len, chunk;
skb_frag_t *frag;
int sz; int sz;
if (strp->msg_ready)
return 0;
skb = strp->anchor;
frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
len = in_len; len = in_len;
...@@ -228,10 +229,8 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, ...@@ -228,10 +229,8 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
skb_frag_size_add(frag, chunk); skb_frag_size_add(frag, chunk);
sz = tls_rx_msg_size(strp, skb); sz = tls_rx_msg_size(strp, skb);
if (sz < 0) { if (sz < 0)
desc->error = sz; return sz;
return 0;
}
/* We may have over-read, sz == 0 is guaranteed under-read */ /* We may have over-read, sz == 0 is guaranteed under-read */
if (unlikely(sz && sz < skb->len)) { if (unlikely(sz && sz < skb->len)) {
...@@ -271,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, ...@@ -271,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
offset += chunk; offset += chunk;
} }
if (strp->stm.full_len == skb->len) { read_done:
return in_len - len;
}
static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb,
struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{
struct sk_buff *nskb, *first, *last;
struct skb_shared_info *shinfo;
size_t chunk;
int sz;
if (strp->stm.full_len)
chunk = strp->stm.full_len - skb->len;
else
chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
chunk = min(chunk, in_len);
nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk);
if (!nskb)
return -ENOMEM;
shinfo = skb_shinfo(skb);
if (!shinfo->frag_list) {
shinfo->frag_list = nskb;
nskb->prev = nskb;
} else {
first = shinfo->frag_list;
last = first->prev;
last->next = nskb;
first->prev = nskb;
}
skb->len += chunk;
skb->data_len += chunk;
if (!strp->stm.full_len) {
sz = tls_rx_msg_size(strp, skb);
if (sz < 0)
return sz;
/* We may have over-read, sz == 0 is guaranteed under-read */
if (unlikely(sz && sz < skb->len)) {
int over = skb->len - sz;
WARN_ON_ONCE(over > chunk);
skb->len -= over;
skb->data_len -= over;
__pskb_trim(nskb, nskb->len - over);
chunk -= over;
}
strp->stm.full_len = sz;
}
return chunk;
}
static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
unsigned int offset, size_t in_len)
{
struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
struct sk_buff *skb;
int ret;
if (strp->msg_ready)
return 0;
skb = strp->anchor;
if (!skb->len)
skb_copy_decrypted(skb, in_skb);
else
strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb);
if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted)
ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len);
else
ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len);
if (ret < 0) {
desc->error = ret;
ret = 0;
}
if (strp->stm.full_len && strp->stm.full_len == skb->len) {
desc->count = 0; desc->count = 0;
strp->msg_ready = 1; strp->msg_ready = 1;
tls_rx_msg_ready(strp); tls_rx_msg_ready(strp);
} }
read_done: return ret;
return in_len - len;
} }
static int tls_strp_read_copyin(struct tls_strparser *strp) static int tls_strp_read_copyin(struct tls_strparser *strp)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment