Commit 6bd116c8 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller

tls: rx: return the decrypted skb via darg

Instead of using ctx->recv_pkt after decryption read the skb
from darg.skb. This moves the decision of what the "output skb"
is to the decrypt handlers. For now after decrypt handler returns
successfully ctx->recv_pkt is simply moved to darg.skb, but it
will change soon.

Note that tls_decrypt_sg() cannot clear the ctx->recv_pkt
because it gets called to re-encrypt (i.e. by the device offload).
So we need an awkward temporary if() in tls_rx_one_record().
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 541cc48b
...@@ -47,9 +47,13 @@ ...@@ -47,9 +47,13 @@
#include "tls.h" #include "tls.h"
struct tls_decrypt_arg { struct tls_decrypt_arg {
struct_group(inargs,
bool zc; bool zc;
bool async; bool async;
u8 tail; u8 tail;
);
struct sk_buff *skb;
}; };
struct tls_decrypt_ctx { struct tls_decrypt_ctx {
...@@ -1412,6 +1416,7 @@ static int tls_setup_from_iter(struct iov_iter *from, ...@@ -1412,6 +1416,7 @@ static int tls_setup_from_iter(struct iov_iter *from,
* ------------------------------------------------------------------- * -------------------------------------------------------------------
* zc | Zero-copy decrypt allowed | Zero-copy performed * zc | Zero-copy decrypt allowed | Zero-copy performed
* async | Async decrypt allowed | Async crypto used / in progress * async | Async decrypt allowed | Async crypto used / in progress
* skb | * | Output skb
*/ */
/* This function decrypts the input skb into either out_iov or in out_sg /* This function decrypts the input skb into either out_iov or in out_sg
...@@ -1551,12 +1556,17 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, ...@@ -1551,12 +1556,17 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
/* Prepare and submit AEAD request */ /* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv, err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
data_len + prot->tail_size, aead_req, darg); data_len + prot->tail_size, aead_req, darg);
if (err)
goto exit_free_pages;
darg->skb = tls_strp_msg(ctx);
if (darg->async) if (darg->async)
return 0; return 0;
if (prot->tail_size) if (prot->tail_size)
darg->tail = dctx->tail; darg->tail = dctx->tail;
exit_free_pages:
/* Release the pages in case iov was mapped to pages */ /* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--) for (; pages > 0; pages--)
put_page(sg_page(&sgout[pages])); put_page(sg_page(&sgout[pages]));
...@@ -1569,6 +1579,7 @@ static int ...@@ -1569,6 +1579,7 @@ static int
tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
struct tls_decrypt_arg *darg) struct tls_decrypt_arg *darg)
{ {
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
int err; int err;
if (tls_ctx->rx_conf != TLS_HW) if (tls_ctx->rx_conf != TLS_HW)
...@@ -1580,6 +1591,8 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, ...@@ -1580,6 +1591,8 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
darg->zc = false; darg->zc = false;
darg->async = false; darg->async = false;
darg->skb = tls_strp_msg(ctx);
ctx->recv_pkt = NULL;
return 1; return 1;
} }
...@@ -1604,8 +1617,11 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, ...@@ -1604,8 +1617,11 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
return err; return err;
} }
if (darg->async) if (darg->async) {
if (darg->skb == ctx->recv_pkt)
ctx->recv_pkt = NULL;
goto decrypt_next; goto decrypt_next;
}
/* If opportunistic TLS 1.3 ZC failed retry without ZC */ /* If opportunistic TLS 1.3 ZC failed retry without ZC */
if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
darg->tail != TLS_RECORD_TYPE_DATA)) { darg->tail != TLS_RECORD_TYPE_DATA)) {
...@@ -1616,12 +1632,17 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, ...@@ -1616,12 +1632,17 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
return tls_rx_one_record(sk, dest, darg); return tls_rx_one_record(sk, dest, darg);
} }
if (darg->skb == ctx->recv_pkt)
ctx->recv_pkt = NULL;
decrypt_done: decrypt_done:
pad = tls_padding_length(prot, ctx->recv_pkt, darg); pad = tls_padding_length(prot, darg->skb, darg);
if (pad < 0) if (pad < 0) {
consume_skb(darg->skb);
return pad; return pad;
}
rxm = strp_msg(ctx->recv_pkt); rxm = strp_msg(darg->skb);
rxm->full_len -= pad; rxm->full_len -= pad;
rxm->offset += prot->prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size; rxm->full_len -= prot->overhead_size;
...@@ -1663,6 +1684,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, ...@@ -1663,6 +1684,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
{ {
consume_skb(ctx->recv_pkt);
ctx->recv_pkt = NULL; ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp); __strp_unpause(&ctx->strp);
} }
...@@ -1872,7 +1894,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1872,7 +1894,7 @@ int tls_sw_recvmsg(struct sock *sk,
ctx->zc_capable; ctx->zc_capable;
decrypted = 0; decrypted = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) { while (len && (decrypted + copied < target || ctx->recv_pkt)) {
struct tls_decrypt_arg darg = {}; struct tls_decrypt_arg darg;
int to_decrypt, chunk; int to_decrypt, chunk;
err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo); err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
...@@ -1889,9 +1911,10 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1889,9 +1911,10 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end; goto recv_end;
} }
skb = ctx->recv_pkt; memset(&darg.inargs, 0, sizeof(darg.inargs));
rxm = strp_msg(skb);
tlm = tls_msg(skb); rxm = strp_msg(ctx->recv_pkt);
tlm = tls_msg(ctx->recv_pkt);
to_decrypt = rxm->full_len - prot->overhead_size; to_decrypt = rxm->full_len - prot->overhead_size;
...@@ -1911,6 +1934,10 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1911,6 +1934,10 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end; goto recv_end;
} }
skb = darg.skb;
rxm = strp_msg(skb);
tlm = tls_msg(skb);
async |= darg.async; async |= darg.async;
/* If the type of records being processed is not known yet, /* If the type of records being processed is not known yet,
...@@ -2051,21 +2078,23 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2051,21 +2078,23 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
if (!skb_queue_empty(&ctx->rx_list)) { if (!skb_queue_empty(&ctx->rx_list)) {
skb = __skb_dequeue(&ctx->rx_list); skb = __skb_dequeue(&ctx->rx_list);
} else { } else {
struct tls_decrypt_arg darg = {}; struct tls_decrypt_arg darg;
err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
timeo); timeo);
if (err <= 0) if (err <= 0)
goto splice_read_end; goto splice_read_end;
memset(&darg.inargs, 0, sizeof(darg.inargs));
err = tls_rx_one_record(sk, NULL, &darg); err = tls_rx_one_record(sk, NULL, &darg);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
goto splice_read_end; goto splice_read_end;
} }
skb = ctx->recv_pkt;
tls_rx_rec_done(ctx); tls_rx_rec_done(ctx);
skb = darg.skb;
} }
rxm = strp_msg(skb); rxm = strp_msg(skb);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment