Commit 78e563f2 authored by David S. Miller's avatar David S. Miller

Merge branch 'tls-fixes'

Jakub Kicinski says:

====================
net: tls: fix some issues with async encryption

valis was reporting a race on socket close so I sat down to try to fix it.
I used Sabrina's async crypto debug patch to test... and in the process
run into some of the same issues, and created very similar fixes :(
I didn't realize how many of those patches weren't applied. Once I found
Sabrina's code [1] it turned out to be so similar in fact that I added
her S-o-b's and Co-develop'eds in a semi-haphazard way.

With this series in place all expected tests pass with async crypto.
Sabrina had a few more fixes, but I'll leave those to her, things are
not crashing anymore.

[1] https://lore.kernel.org/netdev/cover.1694018970.git.sd@queasysnail.net/
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 2fbdc5c6 ac437a51
...@@ -97,9 +97,6 @@ struct tls_sw_context_tx { ...@@ -97,9 +97,6 @@ struct tls_sw_context_tx {
struct tls_rec *open_rec; struct tls_rec *open_rec;
struct list_head tx_list; struct list_head tx_list;
atomic_t encrypt_pending; atomic_t encrypt_pending;
/* protect crypto_wait with encrypt_pending */
spinlock_t encrypt_compl_lock;
int async_notify;
u8 async_capable:1; u8 async_capable:1;
#define BIT_TX_SCHEDULED 0 #define BIT_TX_SCHEDULED 0
...@@ -136,8 +133,6 @@ struct tls_sw_context_rx { ...@@ -136,8 +133,6 @@ struct tls_sw_context_rx {
struct tls_strparser strp; struct tls_strparser strp;
atomic_t decrypt_pending; atomic_t decrypt_pending;
/* protect crypto_wait with decrypt_pending*/
spinlock_t decrypt_compl_lock;
struct sk_buff_head async_hold; struct sk_buff_head async_hold;
struct wait_queue_head wq; struct wait_queue_head wq;
}; };
......
...@@ -63,6 +63,7 @@ struct tls_decrypt_ctx { ...@@ -63,6 +63,7 @@ struct tls_decrypt_ctx {
u8 iv[TLS_MAX_IV_SIZE]; u8 iv[TLS_MAX_IV_SIZE];
u8 aad[TLS_MAX_AAD_SIZE]; u8 aad[TLS_MAX_AAD_SIZE];
u8 tail; u8 tail;
bool free_sgout;
struct scatterlist sg[]; struct scatterlist sg[];
}; };
...@@ -187,7 +188,6 @@ static void tls_decrypt_done(void *data, int err) ...@@ -187,7 +188,6 @@ static void tls_decrypt_done(void *data, int err)
struct aead_request *aead_req = data; struct aead_request *aead_req = data;
struct crypto_aead *aead = crypto_aead_reqtfm(aead_req); struct crypto_aead *aead = crypto_aead_reqtfm(aead_req);
struct scatterlist *sgout = aead_req->dst; struct scatterlist *sgout = aead_req->dst;
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx; struct tls_sw_context_rx *ctx;
struct tls_decrypt_ctx *dctx; struct tls_decrypt_ctx *dctx;
struct tls_context *tls_ctx; struct tls_context *tls_ctx;
...@@ -196,6 +196,17 @@ static void tls_decrypt_done(void *data, int err) ...@@ -196,6 +196,17 @@ static void tls_decrypt_done(void *data, int err)
struct sock *sk; struct sock *sk;
int aead_size; int aead_size;
/* If requests get too backlogged crypto API returns -EBUSY and calls
* ->complete(-EINPROGRESS) immediately followed by ->complete(0)
* to make waiting for backlog to flush with crypto_wait_req() easier.
* First wait converts -EBUSY -> -EINPROGRESS, and the second one
* -EINPROGRESS -> 0.
* We have a single struct crypto_async_request per direction, this
* scheme doesn't help us, so just ignore the first ->complete().
*/
if (err == -EINPROGRESS)
return;
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead); aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead);
aead_size = ALIGN(aead_size, __alignof__(*dctx)); aead_size = ALIGN(aead_size, __alignof__(*dctx));
dctx = (void *)((u8 *)aead_req + aead_size); dctx = (void *)((u8 *)aead_req + aead_size);
...@@ -213,7 +224,7 @@ static void tls_decrypt_done(void *data, int err) ...@@ -213,7 +224,7 @@ static void tls_decrypt_done(void *data, int err)
} }
/* Free the destination pages if skb was not decrypted inplace */ /* Free the destination pages if skb was not decrypted inplace */
if (sgout != sgin) { if (dctx->free_sgout) {
/* Skip the first S/G entry as it points to AAD */ /* Skip the first S/G entry as it points to AAD */
for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
if (!sg) if (!sg)
...@@ -224,10 +235,17 @@ static void tls_decrypt_done(void *data, int err) ...@@ -224,10 +235,17 @@ static void tls_decrypt_done(void *data, int err)
kfree(aead_req); kfree(aead_req);
spin_lock_bh(&ctx->decrypt_compl_lock); if (atomic_dec_and_test(&ctx->decrypt_pending))
if (!atomic_dec_return(&ctx->decrypt_pending))
complete(&ctx->async_wait.completion); complete(&ctx->async_wait.completion);
spin_unlock_bh(&ctx->decrypt_compl_lock); }
static int tls_decrypt_async_wait(struct tls_sw_context_rx *ctx)
{
if (!atomic_dec_and_test(&ctx->decrypt_pending))
crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
atomic_inc(&ctx->decrypt_pending);
return ctx->async_wait.err;
} }
static int tls_do_decryption(struct sock *sk, static int tls_do_decryption(struct sock *sk,
...@@ -253,6 +271,7 @@ static int tls_do_decryption(struct sock *sk, ...@@ -253,6 +271,7 @@ static int tls_do_decryption(struct sock *sk,
aead_request_set_callback(aead_req, aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG, CRYPTO_TFM_REQ_MAY_BACKLOG,
tls_decrypt_done, aead_req); tls_decrypt_done, aead_req);
DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1);
atomic_inc(&ctx->decrypt_pending); atomic_inc(&ctx->decrypt_pending);
} else { } else {
aead_request_set_callback(aead_req, aead_request_set_callback(aead_req,
...@@ -261,6 +280,10 @@ static int tls_do_decryption(struct sock *sk, ...@@ -261,6 +280,10 @@ static int tls_do_decryption(struct sock *sk,
} }
ret = crypto_aead_decrypt(aead_req); ret = crypto_aead_decrypt(aead_req);
if (ret == -EBUSY) {
ret = tls_decrypt_async_wait(ctx);
ret = ret ?: -EINPROGRESS;
}
if (ret == -EINPROGRESS) { if (ret == -EINPROGRESS) {
if (darg->async) if (darg->async)
return 0; return 0;
...@@ -439,9 +462,10 @@ static void tls_encrypt_done(void *data, int err) ...@@ -439,9 +462,10 @@ static void tls_encrypt_done(void *data, int err)
struct tls_rec *rec = data; struct tls_rec *rec = data;
struct scatterlist *sge; struct scatterlist *sge;
struct sk_msg *msg_en; struct sk_msg *msg_en;
bool ready = false;
struct sock *sk; struct sock *sk;
int pending;
if (err == -EINPROGRESS) /* see the comment in tls_decrypt_done() */
return;
msg_en = &rec->msg_encrypted; msg_en = &rec->msg_encrypted;
...@@ -476,23 +500,25 @@ static void tls_encrypt_done(void *data, int err) ...@@ -476,23 +500,25 @@ static void tls_encrypt_done(void *data, int err)
/* If received record is at head of tx_list, schedule tx */ /* If received record is at head of tx_list, schedule tx */
first_rec = list_first_entry(&ctx->tx_list, first_rec = list_first_entry(&ctx->tx_list,
struct tls_rec, list); struct tls_rec, list);
if (rec == first_rec) if (rec == first_rec) {
ready = true; /* Schedule the transmission */
if (!test_and_set_bit(BIT_TX_SCHEDULED,
&ctx->tx_bitmask))
schedule_delayed_work(&ctx->tx_work.work, 1);
}
} }
spin_lock_bh(&ctx->encrypt_compl_lock); if (atomic_dec_and_test(&ctx->encrypt_pending))
pending = atomic_dec_return(&ctx->encrypt_pending);
if (!pending && ctx->async_notify)
complete(&ctx->async_wait.completion); complete(&ctx->async_wait.completion);
spin_unlock_bh(&ctx->encrypt_compl_lock); }
if (!ready) static int tls_encrypt_async_wait(struct tls_sw_context_tx *ctx)
return; {
if (!atomic_dec_and_test(&ctx->encrypt_pending))
crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
atomic_inc(&ctx->encrypt_pending);
/* Schedule the transmission */ return ctx->async_wait.err;
if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
schedule_delayed_work(&ctx->tx_work.work, 1);
} }
static int tls_do_encryption(struct sock *sk, static int tls_do_encryption(struct sock *sk,
...@@ -541,9 +567,14 @@ static int tls_do_encryption(struct sock *sk, ...@@ -541,9 +567,14 @@ static int tls_do_encryption(struct sock *sk,
/* Add the record in tx_list */ /* Add the record in tx_list */
list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->encrypt_pending) < 1);
atomic_inc(&ctx->encrypt_pending); atomic_inc(&ctx->encrypt_pending);
rc = crypto_aead_encrypt(aead_req); rc = crypto_aead_encrypt(aead_req);
if (rc == -EBUSY) {
rc = tls_encrypt_async_wait(ctx);
rc = rc ?: -EINPROGRESS;
}
if (!rc || rc != -EINPROGRESS) { if (!rc || rc != -EINPROGRESS) {
atomic_dec(&ctx->encrypt_pending); atomic_dec(&ctx->encrypt_pending);
sge->offset -= prot->prepend_size; sge->offset -= prot->prepend_size;
...@@ -984,7 +1015,6 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg, ...@@ -984,7 +1015,6 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
int num_zc = 0; int num_zc = 0;
int orig_size; int orig_size;
int ret = 0; int ret = 0;
int pending;
if (!eor && (msg->msg_flags & MSG_EOR)) if (!eor && (msg->msg_flags & MSG_EOR))
return -EINVAL; return -EINVAL;
...@@ -1163,24 +1193,12 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg, ...@@ -1163,24 +1193,12 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
if (!num_async) { if (!num_async) {
goto send_end; goto send_end;
} else if (num_zc) { } else if (num_zc) {
/* Wait for pending encryptions to get completed */ int err;
spin_lock_bh(&ctx->encrypt_compl_lock);
ctx->async_notify = true;
pending = atomic_read(&ctx->encrypt_pending);
spin_unlock_bh(&ctx->encrypt_compl_lock);
if (pending)
crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
else
reinit_completion(&ctx->async_wait.completion);
/* There can be no concurrent accesses, since we have no
* pending encrypt operations
*/
WRITE_ONCE(ctx->async_notify, false);
if (ctx->async_wait.err) { /* Wait for pending encryptions to get completed */
ret = ctx->async_wait.err; err = tls_encrypt_async_wait(ctx);
if (err) {
ret = err;
copied = 0; copied = 0;
} }
} }
...@@ -1229,7 +1247,6 @@ void tls_sw_splice_eof(struct socket *sock) ...@@ -1229,7 +1247,6 @@ void tls_sw_splice_eof(struct socket *sock)
ssize_t copied = 0; ssize_t copied = 0;
bool retrying = false; bool retrying = false;
int ret = 0; int ret = 0;
int pending;
if (!ctx->open_rec) if (!ctx->open_rec)
return; return;
...@@ -1264,22 +1281,7 @@ void tls_sw_splice_eof(struct socket *sock) ...@@ -1264,22 +1281,7 @@ void tls_sw_splice_eof(struct socket *sock)
} }
/* Wait for pending encryptions to get completed */ /* Wait for pending encryptions to get completed */
spin_lock_bh(&ctx->encrypt_compl_lock); if (tls_encrypt_async_wait(ctx))
ctx->async_notify = true;
pending = atomic_read(&ctx->encrypt_pending);
spin_unlock_bh(&ctx->encrypt_compl_lock);
if (pending)
crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
else
reinit_completion(&ctx->async_wait.completion);
/* There can be no concurrent accesses, since we have no pending
* encrypt operations
*/
WRITE_ONCE(ctx->async_notify, false);
if (ctx->async_wait.err)
goto unlock; goto unlock;
/* Transmit if any encryptions have completed */ /* Transmit if any encryptions have completed */
...@@ -1581,6 +1583,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, ...@@ -1581,6 +1583,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
} else if (out_sg) { } else if (out_sg) {
memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
} }
dctx->free_sgout = !!pages;
/* Prepare and submit AEAD request */ /* Prepare and submit AEAD request */
err = tls_do_decryption(sk, sgin, sgout, dctx->iv, err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
...@@ -2109,16 +2112,10 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -2109,16 +2112,10 @@ int tls_sw_recvmsg(struct sock *sk,
recv_end: recv_end:
if (async) { if (async) {
int ret, pending; int ret;
/* Wait for all previously submitted records to be decrypted */ /* Wait for all previously submitted records to be decrypted */
spin_lock_bh(&ctx->decrypt_compl_lock); ret = tls_decrypt_async_wait(ctx);
reinit_completion(&ctx->async_wait.completion);
pending = atomic_read(&ctx->decrypt_pending);
spin_unlock_bh(&ctx->decrypt_compl_lock);
ret = 0;
if (pending)
ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
__skb_queue_purge(&ctx->async_hold); __skb_queue_purge(&ctx->async_hold);
if (ret) { if (ret) {
...@@ -2135,7 +2132,6 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -2135,7 +2132,6 @@ int tls_sw_recvmsg(struct sock *sk,
else else
err = process_rx_list(ctx, msg, &control, 0, err = process_rx_list(ctx, msg, &control, 0,
async_copy_bytes, is_peek); async_copy_bytes, is_peek);
decrypted += max(err, 0);
} }
copied += decrypted; copied += decrypted;
...@@ -2435,16 +2431,9 @@ void tls_sw_release_resources_tx(struct sock *sk) ...@@ -2435,16 +2431,9 @@ void tls_sw_release_resources_tx(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec, *tmp; struct tls_rec *rec, *tmp;
int pending;
/* Wait for any pending async encryptions to complete */ /* Wait for any pending async encryptions to complete */
spin_lock_bh(&ctx->encrypt_compl_lock); tls_encrypt_async_wait(ctx);
ctx->async_notify = true;
pending = atomic_read(&ctx->encrypt_pending);
spin_unlock_bh(&ctx->encrypt_compl_lock);
if (pending)
crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
tls_tx_records(sk, -1); tls_tx_records(sk, -1);
...@@ -2607,7 +2596,7 @@ static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct soc ...@@ -2607,7 +2596,7 @@ static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct soc
} }
crypto_init_wait(&sw_ctx_tx->async_wait); crypto_init_wait(&sw_ctx_tx->async_wait);
spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); atomic_set(&sw_ctx_tx->encrypt_pending, 1);
INIT_LIST_HEAD(&sw_ctx_tx->tx_list); INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
sw_ctx_tx->tx_work.sk = sk; sw_ctx_tx->tx_work.sk = sk;
...@@ -2628,7 +2617,7 @@ static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx) ...@@ -2628,7 +2617,7 @@ static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx)
} }
crypto_init_wait(&sw_ctx_rx->async_wait); crypto_init_wait(&sw_ctx_rx->async_wait);
spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); atomic_set(&sw_ctx_rx->decrypt_pending, 1);
init_waitqueue_head(&sw_ctx_rx->wq); init_waitqueue_head(&sw_ctx_rx->wq);
skb_queue_head_init(&sw_ctx_rx->rx_list); skb_queue_head_init(&sw_ctx_rx->rx_list);
skb_queue_head_init(&sw_ctx_rx->async_hold); skb_queue_head_init(&sw_ctx_rx->async_hold);
......
...@@ -1002,12 +1002,12 @@ TEST_F(tls, recv_partial) ...@@ -1002,12 +1002,12 @@ TEST_F(tls, recv_partial)
memset(recv_mem, 0, sizeof(recv_mem)); memset(recv_mem, 0, sizeof(recv_mem));
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len); EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first), EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
MSG_WAITALL), -1); MSG_WAITALL), strlen(test_str_first));
EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0); EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
memset(recv_mem, 0, sizeof(recv_mem)); memset(recv_mem, 0, sizeof(recv_mem));
EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second), EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
MSG_WAITALL), -1); MSG_WAITALL), strlen(test_str_second));
EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)), EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
0); 0);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment