Commit 9bdf75cc authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller

tls: rx: don't report text length from the bowels of decrypt

We plumb pointer to chunk all the way to the decryption method.
It's set to the length of the text when decrypt_skb_update()
returns.

I think the code is written this way because original TLS
implementation passed &chunk to zerocopy_from_iter() and this
was carried forward as the code gotten more complex, without
any refactoring.

The fix for peek() introduced a new variable - to_decrypt
which for all practical purposes is what chunk is going to
get set to. Spare ourselves the pointer passing, use to_decrypt.

Use this opportunity to clean things up a little further.

Note that chunk / to_decrypt was mostly needed for the async
path, since the sync path would access rxm->full_len (decryption
transforms full_len from record size to text size). Use the
right source of truth more explicitly.

We have three cases:
 - async - it's TLS 1.2 only, so chunk == to_decrypt, but we
           need the min() because to_decrypt is a whole record
	   and we don't want to underflow len. Note that we can't
	   handle partial record by falling back to sync as it
	   would introduce reordering against records in flight.
 - zc - again, TLS 1.2 only for now, so chunk == to_decrypt,
        we don't do zc if len < to_decrypt, no need to check again.
 - normal - it already handles chunk > len, we can factor out the
            assignment to rxm->full_len and share it with zc.
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent d4bd88e6
...@@ -1412,7 +1412,7 @@ static int tls_setup_from_iter(struct iov_iter *from, ...@@ -1412,7 +1412,7 @@ static int tls_setup_from_iter(struct iov_iter *from,
static int decrypt_internal(struct sock *sk, struct sk_buff *skb, static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct iov_iter *out_iov, struct iov_iter *out_iov,
struct scatterlist *out_sg, struct scatterlist *out_sg,
int *chunk, bool *zc, bool async) bool *zc, bool async)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
...@@ -1526,7 +1526,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1526,7 +1526,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
(n_sgout - 1)); (n_sgout - 1));
if (err < 0) if (err < 0)
goto fallback_to_reg_recv; goto fallback_to_reg_recv;
*chunk = data_len;
} else if (out_sg) { } else if (out_sg) {
memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
} else { } else {
...@@ -1536,7 +1535,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1536,7 +1535,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
fallback_to_reg_recv: fallback_to_reg_recv:
sgout = sgin; sgout = sgin;
pages = 0; pages = 0;
*chunk = data_len;
*zc = false; *zc = false;
} }
...@@ -1555,8 +1553,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1555,8 +1553,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
} }
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct iov_iter *dest, int *chunk, bool *zc, struct iov_iter *dest, bool *zc, bool async)
bool async)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
...@@ -1580,7 +1577,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1580,7 +1577,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
} }
} }
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async); err = decrypt_internal(sk, skb, dest, NULL, zc, async);
if (err < 0) { if (err < 0) {
if (err == -EINPROGRESS) if (err == -EINPROGRESS)
tls_advance_record_sn(sk, prot, &tls_ctx->rx); tls_advance_record_sn(sk, prot, &tls_ctx->rx);
...@@ -1607,9 +1604,8 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb, ...@@ -1607,9 +1604,8 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout) struct scatterlist *sgout)
{ {
bool zc = true; bool zc = true;
int chunk;
return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false); return decrypt_internal(sk, skb, NULL, sgout, &zc, false);
} }
static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
...@@ -1799,9 +1795,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1799,9 +1795,8 @@ int tls_sw_recvmsg(struct sock *sk,
num_async = 0; num_async = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) { while (len && (decrypted + copied < target || ctx->recv_pkt)) {
bool retain_skb = false; bool retain_skb = false;
int to_decrypt, chunk;
bool zc = false; bool zc = false;
int to_decrypt;
int chunk = 0;
bool async_capable; bool async_capable;
bool async = false; bool async = false;
...@@ -1838,7 +1833,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1838,7 +1833,7 @@ int tls_sw_recvmsg(struct sock *sk,
async_capable = false; async_capable = false;
err = decrypt_skb_update(sk, skb, &msg->msg_iter, err = decrypt_skb_update(sk, skb, &msg->msg_iter,
&chunk, &zc, async_capable); &zc, async_capable);
if (err < 0 && err != -EINPROGRESS) { if (err < 0 && err != -EINPROGRESS) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
goto recv_end; goto recv_end;
...@@ -1876,8 +1871,13 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1876,8 +1871,13 @@ int tls_sw_recvmsg(struct sock *sk,
} }
} }
if (async) if (async) {
/* TLS 1.2-only, to_decrypt must be text length */
chunk = min_t(int, to_decrypt, len);
goto pick_next_record; goto pick_next_record;
}
/* TLS 1.3 may have updated the length by more than overhead */
chunk = rxm->full_len;
if (!zc) { if (!zc) {
if (bpf_strp_enabled) { if (bpf_strp_enabled) {
...@@ -1893,11 +1893,9 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1893,11 +1893,9 @@ int tls_sw_recvmsg(struct sock *sk,
} }
} }
if (rxm->full_len > len) { if (chunk > len) {
retain_skb = true; retain_skb = true;
chunk = len; chunk = len;
} else {
chunk = rxm->full_len;
} }
err = skb_copy_datagram_msg(skb, rxm->offset, err = skb_copy_datagram_msg(skb, rxm->offset,
...@@ -1912,9 +1910,6 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1912,9 +1910,6 @@ int tls_sw_recvmsg(struct sock *sk,
} }
pick_next_record: pick_next_record:
if (chunk > len)
chunk = len;
decrypted += chunk; decrypted += chunk;
len -= chunk; len -= chunk;
...@@ -2016,7 +2011,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2016,7 +2011,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
if (!skb) if (!skb)
goto splice_read_end; goto splice_read_end;
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); err = decrypt_skb_update(sk, skb, NULL, &zc, false);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
goto splice_read_end; goto splice_read_end;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment