Commit 692d7b5d authored by Vakul Garg's avatar Vakul Garg Committed by David S. Miller

tls: Fix recvmsg() to be able to peek across multiple records

This fixes recvmsg() to be able to peek across multiple tls records.
Without this patch, the tls's selftests test case
'recv_peek_large_buf_mult_recs' fails. Each tls receive context now
maintains a 'rx_list' to retain incoming skb carrying tls records. If a
tls record needs to be retained e.g. for peek case or for the case when
the buffer passed to recvmsg() has a length smaller than decrypted
record length, then it is added to 'rx_list'. Additionally, records are
added in 'rx_list' if the crypto operation runs in async mode. The
records are dequeued from 'rx_list' after the decrypted data is consumed
by copying into the buffer passed to recvmsg(). In case, the MSG_PEEK
flag is used in recvmsg(), then records are not consumed or removed
from the 'rx_list'.
Signed-off-by: default avatarVakul Garg <vakul.garg@nxp.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent fb73d620
...@@ -145,12 +145,13 @@ struct tls_sw_context_tx { ...@@ -145,12 +145,13 @@ struct tls_sw_context_tx {
struct tls_sw_context_rx { struct tls_sw_context_rx {
struct crypto_aead *aead_recv; struct crypto_aead *aead_recv;
struct crypto_wait async_wait; struct crypto_wait async_wait;
struct strparser strp; struct strparser strp;
struct sk_buff_head rx_list; /* list of decrypted 'data' records */
void (*saved_data_ready)(struct sock *sk); void (*saved_data_ready)(struct sock *sk);
struct sk_buff *recv_pkt; struct sk_buff *recv_pkt;
u8 control; u8 control;
int async_capable;
bool decrypted; bool decrypted;
atomic_t decrypt_pending; atomic_t decrypt_pending;
bool async_notify; bool async_notify;
......
...@@ -124,6 +124,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -124,6 +124,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
{ {
struct aead_request *aead_req = (struct aead_request *)req; struct aead_request *aead_req = (struct aead_request *)req;
struct scatterlist *sgout = aead_req->dst; struct scatterlist *sgout = aead_req->dst;
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx; struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx; struct tls_context *tls_ctx;
struct scatterlist *sg; struct scatterlist *sg;
...@@ -134,12 +135,16 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -134,12 +135,16 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
skb = (struct sk_buff *)req->data; skb = (struct sk_buff *)req->data;
tls_ctx = tls_get_ctx(skb->sk); tls_ctx = tls_get_ctx(skb->sk);
ctx = tls_sw_ctx_rx(tls_ctx); ctx = tls_sw_ctx_rx(tls_ctx);
pending = atomic_dec_return(&ctx->decrypt_pending);
/* Propagate if there was an err */ /* Propagate if there was an err */
if (err) { if (err) {
ctx->async_wait.err = err; ctx->async_wait.err = err;
tls_err_abort(skb->sk, err); tls_err_abort(skb->sk, err);
} else {
struct strp_msg *rxm = strp_msg(skb);
rxm->offset += tls_ctx->rx.prepend_size;
rxm->full_len -= tls_ctx->rx.overhead_size;
} }
/* After using skb->sk to propagate sk through crypto async callback /* After using skb->sk to propagate sk through crypto async callback
...@@ -147,18 +152,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -147,18 +152,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
*/ */
skb->sk = NULL; skb->sk = NULL;
/* Release the skb, pages and memory allocated for crypto req */
kfree_skb(skb);
/* Free the destination pages if skb was not decrypted inplace */
if (sgout != sgin) {
/* Skip the first S/G entry as it points to AAD */ /* Skip the first S/G entry as it points to AAD */
for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
if (!sg) if (!sg)
break; break;
put_page(sg_page(sg)); put_page(sg_page(sg));
} }
}
kfree(aead_req); kfree(aead_req);
pending = atomic_dec_return(&ctx->decrypt_pending);
if (!pending && READ_ONCE(ctx->async_notify)) if (!pending && READ_ONCE(ctx->async_notify))
complete(&ctx->async_wait.completion); complete(&ctx->async_wait.completion);
} }
...@@ -1271,7 +1279,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -1271,7 +1279,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
static int decrypt_internal(struct sock *sk, struct sk_buff *skb, static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct iov_iter *out_iov, struct iov_iter *out_iov,
struct scatterlist *out_sg, struct scatterlist *out_sg,
int *chunk, bool *zc) int *chunk, bool *zc, bool async)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
...@@ -1371,13 +1379,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1371,13 +1379,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
fallback_to_reg_recv: fallback_to_reg_recv:
sgout = sgin; sgout = sgin;
pages = 0; pages = 0;
*chunk = 0; *chunk = data_len;
*zc = false; *zc = false;
} }
/* Prepare and submit AEAD request */ /* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, iv, err = tls_do_decryption(sk, skb, sgin, sgout, iv,
data_len, aead_req, *zc); data_len, aead_req, async);
if (err == -EINPROGRESS) if (err == -EINPROGRESS)
return err; return err;
...@@ -1390,7 +1398,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1390,7 +1398,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
} }
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct iov_iter *dest, int *chunk, bool *zc) struct iov_iter *dest, int *chunk, bool *zc,
bool async)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
...@@ -1403,7 +1412,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1403,7 +1412,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
return err; return err;
#endif #endif
if (!ctx->decrypted) { if (!ctx->decrypted) {
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc); err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
if (err < 0) { if (err < 0) {
if (err == -EINPROGRESS) if (err == -EINPROGRESS)
tls_advance_record_sn(sk, &tls_ctx->rx); tls_advance_record_sn(sk, &tls_ctx->rx);
...@@ -1429,7 +1438,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb, ...@@ -1429,7 +1438,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
bool zc = true; bool zc = true;
int chunk; int chunk;
return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc); return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
} }
static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
...@@ -1456,6 +1465,77 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, ...@@ -1456,6 +1465,77 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
return true; return true;
} }
/* This function traverses the rx_list in tls receive context to copies the
* decrypted data records into the buffer provided by caller zero copy is not
* true. Further, the records are removed from the rx_list if it is not a peek
* case and the record has been consumed completely.
*/
static int process_rx_list(struct tls_sw_context_rx *ctx,
struct msghdr *msg,
size_t skip,
size_t len,
bool zc,
bool is_peek)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
ssize_t copied = 0;
while (skip && skb) {
struct strp_msg *rxm = strp_msg(skb);
if (skip < rxm->full_len)
break;
skip = skip - rxm->full_len;
skb = skb_peek_next(skb, &ctx->rx_list);
}
while (len && skb) {
struct sk_buff *next_skb;
struct strp_msg *rxm = strp_msg(skb);
int chunk = min_t(unsigned int, rxm->full_len - skip, len);
if (!zc || (rxm->full_len - skip) > len) {
int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
if (err < 0)
return err;
}
len = len - chunk;
copied = copied + chunk;
/* Consume the data from record if it is non-peek case*/
if (!is_peek) {
rxm->offset = rxm->offset + chunk;
rxm->full_len = rxm->full_len - chunk;
/* Return if there is unconsumed data in the record */
if (rxm->full_len - skip)
break;
}
/* The remaining skip-bytes must lie in 1st record in rx_list.
* So from the 2nd record, 'skip' should be 0.
*/
skip = 0;
if (msg)
msg->msg_flags |= MSG_EOR;
next_skb = skb_peek_next(skb, &ctx->rx_list);
if (!is_peek) {
skb_unlink(skb, &ctx->rx_list);
kfree_skb(skb);
}
skb = next_skb;
}
return copied;
}
int tls_sw_recvmsg(struct sock *sk, int tls_sw_recvmsg(struct sock *sk,
struct msghdr *msg, struct msghdr *msg,
size_t len, size_t len,
...@@ -1466,7 +1546,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1466,7 +1546,8 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock; struct sk_psock *psock;
unsigned char control; unsigned char control = 0;
ssize_t decrypted = 0;
struct strp_msg *rxm; struct strp_msg *rxm;
struct sk_buff *skb; struct sk_buff *skb;
ssize_t copied = 0; ssize_t copied = 0;
...@@ -1474,6 +1555,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1474,6 +1555,7 @@ int tls_sw_recvmsg(struct sock *sk,
int target, err = 0; int target, err = 0;
long timeo; long timeo;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK;
int num_async = 0; int num_async = 0;
flags |= nonblock; flags |= nonblock;
...@@ -1484,11 +1566,28 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1484,11 +1566,28 @@ int tls_sw_recvmsg(struct sock *sk,
psock = sk_psock_get(sk); psock = sk_psock_get(sk);
lock_sock(sk); lock_sock(sk);
/* Process pending decrypted records. It must be non-zero-copy */
err = process_rx_list(ctx, msg, 0, len, false, is_peek);
if (err < 0) {
tls_err_abort(sk, err);
goto end;
} else {
copied = err;
}
len = len - copied;
if (len) {
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
} else {
goto recv_end;
}
do { do {
bool zc = false; bool retain_skb = false;
bool async = false; bool async = false;
bool zc = false;
int to_decrypt;
int chunk = 0; int chunk = 0;
skb = tls_wait_data(sk, psock, flags, timeo, &err); skb = tls_wait_data(sk, psock, flags, timeo, &err);
...@@ -1498,7 +1597,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1498,7 +1597,7 @@ int tls_sw_recvmsg(struct sock *sk,
msg, len, flags); msg, len, flags);
if (ret > 0) { if (ret > 0) {
copied += ret; decrypted += ret;
len -= ret; len -= ret;
continue; continue;
} }
...@@ -1525,15 +1624,13 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1525,15 +1624,13 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end; goto recv_end;
} }
if (!ctx->decrypted) { to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
if (!is_kvec && to_copy <= len && if (to_decrypt <= len && !is_kvec && !is_peek)
likely(!(flags & MSG_PEEK)))
zc = true; zc = true;
err = decrypt_skb_update(sk, skb, &msg->msg_iter, err = decrypt_skb_update(sk, skb, &msg->msg_iter,
&chunk, &zc); &chunk, &zc, ctx->async_capable);
if (err < 0 && err != -EINPROGRESS) { if (err < 0 && err != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
goto recv_end; goto recv_end;
...@@ -1543,29 +1640,39 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1543,29 +1640,39 @@ int tls_sw_recvmsg(struct sock *sk,
async = true; async = true;
num_async++; num_async++;
goto pick_next_record; goto pick_next_record;
} } else {
ctx->decrypted = true;
}
if (!zc) { if (!zc) {
chunk = min_t(unsigned int, rxm->full_len, len); if (rxm->full_len > len) {
retain_skb = true;
chunk = len;
} else {
chunk = rxm->full_len;
}
err = skb_copy_datagram_msg(skb, rxm->offset, msg, err = skb_copy_datagram_msg(skb, rxm->offset,
chunk); msg, chunk);
if (err < 0) if (err < 0)
goto recv_end; goto recv_end;
if (!is_peek) {
rxm->offset = rxm->offset + chunk;
rxm->full_len = rxm->full_len - chunk;
}
}
} }
pick_next_record: pick_next_record:
copied += chunk; if (chunk > len)
chunk = len;
decrypted += chunk;
len -= chunk; len -= chunk;
if (likely(!(flags & MSG_PEEK))) {
u8 control = ctx->control;
/* For async, drop current skb reference */ /* For async or peek case, queue the current skb */
if (async) if (async || is_peek || retain_skb) {
skb_queue_tail(&ctx->rx_list, skb);
skb = NULL; skb = NULL;
}
if (tls_sw_advance_skb(sk, skb, chunk)) { if (tls_sw_advance_skb(sk, skb, chunk)) {
/* Return full control message to /* Return full control message to
...@@ -1573,22 +1680,14 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1573,22 +1680,14 @@ int tls_sw_recvmsg(struct sock *sk,
* another message type * another message type
*/ */
msg->msg_flags |= MSG_EOR; msg->msg_flags |= MSG_EOR;
if (control != TLS_RECORD_TYPE_DATA) if (ctx->control != TLS_RECORD_TYPE_DATA)
goto recv_end; goto recv_end;
} else { } else {
break; break;
} }
} else {
/* MSG_PEEK right now cannot look beyond current skb
* from strparser, meaning we cannot advance skb here
* and thus unpause strparser since we'd loose original
* one.
*/
break;
}
/* If we have a new message from strparser, continue now. */ /* If we have a new message from strparser, continue now. */
if (copied >= target && !ctx->recv_pkt) if (decrypted >= target && !ctx->recv_pkt)
break; break;
} while (len); } while (len);
...@@ -1602,13 +1701,33 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1602,13 +1701,33 @@ int tls_sw_recvmsg(struct sock *sk,
/* one of async decrypt failed */ /* one of async decrypt failed */
tls_err_abort(sk, err); tls_err_abort(sk, err);
copied = 0; copied = 0;
decrypted = 0;
goto end;
} }
} else { } else {
reinit_completion(&ctx->async_wait.completion); reinit_completion(&ctx->async_wait.completion);
} }
WRITE_ONCE(ctx->async_notify, false); WRITE_ONCE(ctx->async_notify, false);
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
err = process_rx_list(ctx, msg, copied,
decrypted, false, is_peek);
else
err = process_rx_list(ctx, msg, 0,
decrypted, true, is_peek);
if (err < 0) {
tls_err_abort(sk, err);
copied = 0;
goto end;
}
WARN_ON(decrypted != err);
} }
copied += decrypted;
end:
release_sock(sk); release_sock(sk);
if (psock) if (psock)
sk_psock_put(sk, psock); sk_psock_put(sk, psock);
...@@ -1645,7 +1764,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -1645,7 +1764,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
} }
if (!ctx->decrypted) { if (!ctx->decrypted) {
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc); err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
...@@ -1832,6 +1951,7 @@ void tls_sw_release_resources_rx(struct sock *sk) ...@@ -1832,6 +1951,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
if (ctx->aead_recv) { if (ctx->aead_recv) {
kfree_skb(ctx->recv_pkt); kfree_skb(ctx->recv_pkt);
ctx->recv_pkt = NULL; ctx->recv_pkt = NULL;
skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv); crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp); strp_stop(&ctx->strp);
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
...@@ -1881,6 +2001,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -1881,6 +2001,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
struct crypto_aead **aead; struct crypto_aead **aead;
struct strp_callbacks cb; struct strp_callbacks cb;
u16 nonce_size, tag_size, iv_size, rec_seq_size; u16 nonce_size, tag_size, iv_size, rec_seq_size;
struct crypto_tfm *tfm;
char *iv, *rec_seq; char *iv, *rec_seq;
int rc = 0; int rc = 0;
...@@ -1927,6 +2048,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -1927,6 +2048,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
crypto_init_wait(&sw_ctx_rx->async_wait); crypto_init_wait(&sw_ctx_rx->async_wait);
crypto_info = &ctx->crypto_recv.info; crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx; cctx = &ctx->rx;
skb_queue_head_init(&sw_ctx_rx->rx_list);
aead = &sw_ctx_rx->aead_recv; aead = &sw_ctx_rx->aead_recv;
} }
...@@ -1994,6 +2116,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -1994,6 +2116,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
goto free_aead; goto free_aead;
if (sw_ctx_rx) { if (sw_ctx_rx) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
sw_ctx_rx->async_capable =
tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
/* Set up strparser */ /* Set up strparser */
memset(&cb, 0, sizeof(cb)); memset(&cb, 0, sizeof(cb));
cb.rcv_msg = tls_queue; cb.rcv_msg = tls_queue;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment