Commit a9836336 authored by David S. Miller's avatar David S. Miller

Merge branch 'tls-Fix-issues-in-tls_device'

Boris Pismenny says:

====================
tls: Fix issues in tls_device

This series fixes issues encountered in tls_device code paths,
which were introduced recently.

Additionally, this series includes a fix for tls software only receive flow,
which causes corruption of payload received by user space applications.

This series was tested using the OpenSSL integration of KTLS -
https://github.com/mellan
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 7d827379 d069b780
...@@ -199,10 +199,6 @@ struct tls_offload_context_tx { ...@@ -199,10 +199,6 @@ struct tls_offload_context_tx {
(ALIGN(sizeof(struct tls_offload_context_tx), sizeof(void *)) + \ (ALIGN(sizeof(struct tls_offload_context_tx), sizeof(void *)) + \
TLS_DRIVER_STATE_SIZE) TLS_DRIVER_STATE_SIZE)
enum {
TLS_PENDING_CLOSED_RECORD
};
struct cipher_context { struct cipher_context {
char *iv; char *iv;
char *rec_seq; char *rec_seq;
...@@ -335,17 +331,14 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx, ...@@ -335,17 +331,14 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx,
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags); int flags);
int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
int flags, long *timeo);
static inline struct tls_msg *tls_msg(struct sk_buff *skb) static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{ {
return (struct tls_msg *)strp_msg(skb); return (struct tls_msg *)strp_msg(skb);
} }
static inline bool tls_is_pending_closed_record(struct tls_context *ctx) static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
{ {
return test_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); return !!ctx->partially_sent_record;
} }
static inline int tls_complete_pending_work(struct sock *sk, static inline int tls_complete_pending_work(struct sock *sk,
...@@ -357,17 +350,12 @@ static inline int tls_complete_pending_work(struct sock *sk, ...@@ -357,17 +350,12 @@ static inline int tls_complete_pending_work(struct sock *sk,
if (unlikely(sk->sk_write_pending)) if (unlikely(sk->sk_write_pending))
rc = wait_on_pending_writer(sk, timeo); rc = wait_on_pending_writer(sk, timeo);
if (!rc && tls_is_pending_closed_record(ctx)) if (!rc && tls_is_partially_sent_record(ctx))
rc = tls_push_pending_closed_record(sk, ctx, flags, timeo); rc = tls_push_partial_record(sk, ctx, flags);
return rc; return rc;
} }
static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
{
return !!ctx->partially_sent_record;
}
static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
{ {
return tls_ctx->pending_open_record_frags; return tls_ctx->pending_open_record_frags;
...@@ -531,6 +519,9 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk) ...@@ -531,6 +519,9 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk)
return !!tls_sw_ctx_tx(ctx); return !!tls_sw_ctx_tx(ctx);
} }
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
static inline struct tls_offload_context_rx * static inline struct tls_offload_context_rx *
tls_offload_ctx_rx(const struct tls_context *tls_ctx) tls_offload_ctx_rx(const struct tls_context *tls_ctx)
{ {
......
...@@ -271,7 +271,6 @@ static int tls_push_record(struct sock *sk, ...@@ -271,7 +271,6 @@ static int tls_push_record(struct sock *sk,
list_add_tail(&record->list, &offload_ctx->records_list); list_add_tail(&record->list, &offload_ctx->records_list);
spin_unlock_irq(&offload_ctx->lock); spin_unlock_irq(&offload_ctx->lock);
offload_ctx->open_record = NULL; offload_ctx->open_record = NULL;
set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version); tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version);
for (i = 0; i < record->num_frags; i++) { for (i = 0; i < record->num_frags; i++) {
...@@ -368,9 +367,11 @@ static int tls_push_data(struct sock *sk, ...@@ -368,9 +367,11 @@ static int tls_push_data(struct sock *sk,
return -sk->sk_err; return -sk->sk_err;
timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); if (tls_is_partially_sent_record(tls_ctx)) {
rc = tls_push_partial_record(sk, tls_ctx, flags);
if (rc < 0) if (rc < 0)
return rc; return rc;
}
pfrag = sk_page_frag(sk); pfrag = sk_page_frag(sk);
...@@ -545,6 +546,23 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) ...@@ -545,6 +546,23 @@ static int tls_device_push_pending_record(struct sock *sk, int flags)
return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
} }
void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
{
int rc = 0;
if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) {
gfp_t sk_allocation = sk->sk_allocation;
sk->sk_allocation = GFP_ATOMIC;
rc = tls_push_partial_record(sk, ctx,
MSG_DONTWAIT | MSG_NOSIGNAL);
sk->sk_allocation = sk_allocation;
}
if (!rc)
ctx->sk_write_space(sk);
}
void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
......
...@@ -209,23 +209,9 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, ...@@ -209,23 +209,9 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
return tls_push_sg(sk, ctx, sg, offset, flags); return tls_push_sg(sk, ctx, sg, offset, flags);
} }
int tls_push_pending_closed_record(struct sock *sk,
struct tls_context *tls_ctx,
int flags, long *timeo)
{
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
if (tls_is_partially_sent_record(tls_ctx) ||
!list_empty(&ctx->tx_list))
return tls_tx_records(sk, flags);
else
return tls_ctx->push_pending_record(sk, flags);
}
static void tls_write_space(struct sock *sk) static void tls_write_space(struct sock *sk)
{ {
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
/* If in_tcp_sendpages call lower protocol write space handler /* If in_tcp_sendpages call lower protocol write space handler
* to ensure we wake up any waiting operations there. For example * to ensure we wake up any waiting operations there. For example
...@@ -236,14 +222,12 @@ static void tls_write_space(struct sock *sk) ...@@ -236,14 +222,12 @@ static void tls_write_space(struct sock *sk)
return; return;
} }
/* Schedule the transmission if tx list is ready */ #ifdef CONFIG_TLS_DEVICE
if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { if (ctx->tx_conf == TLS_HW)
/* Schedule the transmission */ tls_device_write_space(sk, ctx);
if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) else
schedule_delayed_work(&tx_ctx->tx_work.work, 0); #endif
} tls_sw_write_space(sk, ctx);
ctx->sk_write_space(sk);
} }
static void tls_ctx_free(struct tls_context *ctx) static void tls_ctx_free(struct tls_context *ctx)
......
...@@ -1467,13 +1467,16 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1467,13 +1467,16 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int err = 0; int err = 0;
if (!ctx->decrypted) {
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
err = tls_device_decrypted(sk, skb); err = tls_device_decrypted(sk, skb);
if (err < 0) if (err < 0)
return err; return err;
#endif #endif
/* Still not decrypted after tls_device */
if (!ctx->decrypted) { if (!ctx->decrypted) {
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async); err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
async);
if (err < 0) { if (err < 0) {
if (err == -EINPROGRESS) if (err == -EINPROGRESS)
tls_advance_record_sn(sk, &tls_ctx->rx, tls_advance_record_sn(sk, &tls_ctx->rx,
...@@ -1481,9 +1484,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1481,9 +1484,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
return err; return err;
} }
}
rxm->full_len -= padding_length(ctx, tls_ctx, skb); rxm->full_len -= padding_length(ctx, tls_ctx, skb);
rxm->offset += prot->prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size; rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, &tls_ctx->rx, version); tls_advance_record_sn(sk, &tls_ctx->rx, version);
...@@ -1693,7 +1696,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1693,7 +1696,8 @@ int tls_sw_recvmsg(struct sock *sk,
bool zc = false; bool zc = false;
int to_decrypt; int to_decrypt;
int chunk = 0; int chunk = 0;
bool async; bool async_capable;
bool async = false;
skb = tls_wait_data(sk, psock, flags, timeo, &err); skb = tls_wait_data(sk, psock, flags, timeo, &err);
if (!skb) { if (!skb) {
...@@ -1727,21 +1731,23 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1727,21 +1731,23 @@ int tls_sw_recvmsg(struct sock *sk,
/* Do not use async mode if record is non-data */ /* Do not use async mode if record is non-data */
if (ctx->control == TLS_RECORD_TYPE_DATA) if (ctx->control == TLS_RECORD_TYPE_DATA)
async = ctx->async_capable; async_capable = ctx->async_capable;
else else
async = false; async_capable = false;
err = decrypt_skb_update(sk, skb, &msg->msg_iter, err = decrypt_skb_update(sk, skb, &msg->msg_iter,
&chunk, &zc, async); &chunk, &zc, async_capable);
if (err < 0 && err != -EINPROGRESS) { if (err < 0 && err != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
goto recv_end; goto recv_end;
} }
if (err == -EINPROGRESS) if (err == -EINPROGRESS) {
async = true;
num_async++; num_async++;
else if (prot->version == TLS_1_3_VERSION) } else if (prot->version == TLS_1_3_VERSION) {
tlm->control = ctx->control; tlm->control = ctx->control;
}
/* If the type of records being processed is not known yet, /* If the type of records being processed is not known yet,
* set it to record type just dequeued. If it is already known, * set it to record type just dequeued. If it is already known,
...@@ -2126,6 +2132,19 @@ static void tx_work_handler(struct work_struct *work) ...@@ -2126,6 +2132,19 @@ static void tx_work_handler(struct work_struct *work)
release_sock(sk); release_sock(sk);
} }
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
{
struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
/* Schedule the transmission if tx list is ready */
if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) {
/* Schedule the transmission */
if (!test_and_set_bit(BIT_TX_SCHEDULED,
&tx_ctx->tx_bitmask))
schedule_delayed_work(&tx_ctx->tx_work.work, 0);
}
}
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment