Commit d3b18ad3 authored by John Fastabend's avatar John Fastabend Committed by Alexei Starovoitov

tls: add bpf support to sk_msg handling

This work adds BPF sk_msg verdict program support to kTLS
allowing BPF and kTLS to be combined together. Previously kTLS
and sk_msg verdict programs were mutually exclusive in the
ULP layer which created challenges for the orchestrator when
trying to apply TCP based policy, for example. To resolve this,
leveraging the work from previous patches that consolidates
the use of sk_msg, we can finally enable BPF sk_msg verdict
programs so they continue to run after the kTLS socket is
created. No change in behavior when kTLS is not used in
combination with BPF, the kselftest suite for kTLS also runs
successfully.

Joint work with Daniel.
Signed-off-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 924ad65e
...@@ -29,7 +29,11 @@ struct sk_msg_sg { ...@@ -29,7 +29,11 @@ struct sk_msg_sg {
u32 size; u32 size;
u32 copybreak; u32 copybreak;
bool copy[MAX_MSG_FRAGS]; bool copy[MAX_MSG_FRAGS];
struct scatterlist data[MAX_MSG_FRAGS]; /* The extra element is used for chaining the front and sections when
* the list becomes partitioned (e.g. end < start). The crypto APIs
* require the chaining.
*/
struct scatterlist data[MAX_MSG_FRAGS + 1];
}; };
struct sk_msg { struct sk_msg {
...@@ -112,6 +116,7 @@ void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, ...@@ -112,6 +116,7 @@ void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
u32 bytes); u32 bytes);
void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes); void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
...@@ -161,8 +166,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg) ...@@ -161,8 +166,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
static inline void sk_msg_init(struct sk_msg *msg) static inline void sk_msg_init(struct sk_msg *msg)
{ {
BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
memset(msg, 0, sizeof(*msg)); memset(msg, 0, sizeof(*msg));
sg_init_marker(msg->sg.data, ARRAY_SIZE(msg->sg.data)); sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
} }
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
...@@ -174,6 +180,12 @@ static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, ...@@ -174,6 +180,12 @@ static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
src->sg.data[which].offset += size; src->sg.data[which].offset += size;
} }
static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
{
memcpy(dst, src, sizeof(*src));
sk_msg_init(src);
}
static inline u32 sk_msg_elem_used(const struct sk_msg *msg) static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{ {
return msg->sg.end >= msg->sg.start ? return msg->sg.end >= msg->sg.start ?
...@@ -229,6 +241,26 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page, ...@@ -229,6 +241,26 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
sk_msg_iter_next(msg, end); sk_msg_iter_next(msg, end);
} }
static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
{
do {
msg->sg.copy[i] = copy_state;
sk_msg_iter_var_next(i);
if (i == msg->sg.end)
break;
} while (1);
}
static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
{
sk_msg_sg_copy(msg, start, true);
}
static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
{
sk_msg_sg_copy(msg, start, false);
}
static inline struct sk_psock *sk_psock(const struct sock *sk) static inline struct sk_psock *sk_psock(const struct sock *sk)
{ {
return rcu_dereference_sk_user_data(sk); return rcu_dereference_sk_user_data(sk);
...@@ -245,6 +277,11 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock, ...@@ -245,6 +277,11 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock,
list_add_tail(&msg->list, &psock->ingress_msg); list_add_tail(&msg->list, &psock->ingress_msg);
} }
static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
{
return psock ? list_empty(&psock->ingress_msg) : true;
}
static inline void sk_psock_report_error(struct sk_psock *psock, int err) static inline void sk_psock_report_error(struct sk_psock *psock, int err)
{ {
struct sock *sk = psock->sk; struct sock *sk = psock->sk;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
* Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
* Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
* Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
* *
* This software is available to you under a choice of one of two * This software is available to you under a choice of one of two
* licenses. You may choose to be licensed under the terms of the GNU * licenses. You may choose to be licensed under the terms of the GNU
...@@ -258,21 +259,58 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) ...@@ -258,21 +259,58 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
return sk_msg_clone(sk, msg_pl, msg_en, skip, len); return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
} }
static void tls_free_open_rec(struct sock *sk) static struct tls_rec *tls_get_rec(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_pl, *msg_en;
struct tls_rec *rec;
int mem_size;
/* Return if there is no open record */ mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
rec = kzalloc(mem_size, sk->sk_allocation);
if (!rec) if (!rec)
return; return NULL;
msg_pl = &rec->msg_plaintext;
msg_en = &rec->msg_encrypted;
sk_msg_init(msg_pl);
sk_msg_init(msg_en);
sg_init_table(rec->sg_aead_in, 2);
sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
sizeof(rec->aad_space));
sg_unmark_end(&rec->sg_aead_in[1]);
sg_init_table(rec->sg_aead_out, 2);
sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
sizeof(rec->aad_space));
sg_unmark_end(&rec->sg_aead_out[1]);
return rec;
}
static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
{
sk_msg_free(sk, &rec->msg_encrypted); sk_msg_free(sk, &rec->msg_encrypted);
sk_msg_free(sk, &rec->msg_plaintext); sk_msg_free(sk, &rec->msg_plaintext);
kfree(rec); kfree(rec);
} }
static void tls_free_open_rec(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec;
if (rec) {
tls_free_rec(sk, rec);
ctx->open_rec = NULL;
}
}
int tls_tx_records(struct sock *sk, int flags) int tls_tx_records(struct sock *sk, int flags)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
...@@ -439,16 +477,135 @@ static int tls_do_encryption(struct sock *sk, ...@@ -439,16 +477,135 @@ static int tls_do_encryption(struct sock *sk,
return rc; return rc;
} }
static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
struct tls_rec **to, struct sk_msg *msg_opl,
struct sk_msg *msg_oen, u32 split_point,
u32 tx_overhead_size, u32 *orig_end)
{
u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
struct scatterlist *sge, *osge, *nsge;
u32 orig_size = msg_opl->sg.size;
struct scatterlist tmp = { };
struct sk_msg *msg_npl;
struct tls_rec *new;
int ret;
new = tls_get_rec(sk);
if (!new)
return -ENOMEM;
ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
tx_overhead_size, 0);
if (ret < 0) {
tls_free_rec(sk, new);
return ret;
}
*orig_end = msg_opl->sg.end;
i = msg_opl->sg.start;
sge = sk_msg_elem(msg_opl, i);
while (apply && sge->length) {
if (sge->length > apply) {
u32 len = sge->length - apply;
get_page(sg_page(sge));
sg_set_page(&tmp, sg_page(sge), len,
sge->offset + apply);
sge->length = apply;
bytes += apply;
apply = 0;
} else {
apply -= sge->length;
bytes += sge->length;
}
sk_msg_iter_var_next(i);
if (i == msg_opl->sg.end)
break;
sge = sk_msg_elem(msg_opl, i);
}
msg_opl->sg.end = i;
msg_opl->sg.curr = i;
msg_opl->sg.copybreak = 0;
msg_opl->apply_bytes = 0;
msg_opl->sg.size = bytes;
msg_npl = &new->msg_plaintext;
msg_npl->apply_bytes = apply;
msg_npl->sg.size = orig_size - bytes;
j = msg_npl->sg.start;
nsge = sk_msg_elem(msg_npl, j);
if (tmp.length) {
memcpy(nsge, &tmp, sizeof(*nsge));
sk_msg_iter_var_next(j);
nsge = sk_msg_elem(msg_npl, j);
}
osge = sk_msg_elem(msg_opl, i);
while (osge->length) {
memcpy(nsge, osge, sizeof(*nsge));
sg_unmark_end(nsge);
sk_msg_iter_var_next(i);
sk_msg_iter_var_next(j);
if (i == *orig_end)
break;
osge = sk_msg_elem(msg_opl, i);
nsge = sk_msg_elem(msg_npl, j);
}
msg_npl->sg.end = j;
msg_npl->sg.curr = j;
msg_npl->sg.copybreak = 0;
*to = new;
return 0;
}
static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
struct tls_rec *from, u32 orig_end)
{
struct sk_msg *msg_npl = &from->msg_plaintext;
struct sk_msg *msg_opl = &to->msg_plaintext;
struct scatterlist *osge, *nsge;
u32 i, j;
i = msg_opl->sg.end;
sk_msg_iter_var_prev(i);
j = msg_npl->sg.start;
osge = sk_msg_elem(msg_opl, i);
nsge = sk_msg_elem(msg_npl, j);
if (sg_page(osge) == sg_page(nsge) &&
osge->offset + osge->length == nsge->offset) {
osge->length += nsge->length;
put_page(sg_page(nsge));
}
msg_opl->sg.end = orig_end;
msg_opl->sg.curr = orig_end;
msg_opl->sg.copybreak = 0;
msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
msg_opl->sg.size += msg_npl->sg.size;
sk_msg_free(sk, &to->msg_encrypted);
sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
kfree(from);
}
static int tls_push_record(struct sock *sk, int flags, static int tls_push_record(struct sock *sk, int flags,
unsigned char record_type) unsigned char record_type)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec; struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
u32 i, split_point, uninitialized_var(orig_end);
struct sk_msg *msg_pl, *msg_en; struct sk_msg *msg_pl, *msg_en;
struct aead_request *req; struct aead_request *req;
bool split;
int rc; int rc;
u32 i;
if (!rec) if (!rec)
return 0; return 0;
...@@ -456,6 +613,18 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -456,6 +613,18 @@ static int tls_push_record(struct sock *sk, int flags,
msg_pl = &rec->msg_plaintext; msg_pl = &rec->msg_plaintext;
msg_en = &rec->msg_encrypted; msg_en = &rec->msg_encrypted;
split_point = msg_pl->apply_bytes;
split = split_point && split_point < msg_pl->sg.size;
if (split) {
rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
split_point, tls_ctx->tx.overhead_size,
&orig_end);
if (rc < 0)
return rc;
sk_msg_trim(sk, msg_en, msg_pl->sg.size +
tls_ctx->tx.overhead_size);
}
rec->tx_flags = flags; rec->tx_flags = flags;
req = &rec->aead_req; req = &rec->aead_req;
...@@ -487,57 +656,139 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -487,57 +656,139 @@ static int tls_push_record(struct sock *sk, int flags,
rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i); rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i);
if (rc < 0) { if (rc < 0) {
if (rc != -EINPROGRESS) if (rc != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
if (split) {
tls_ctx->pending_open_record_frags = true;
tls_merge_open_record(sk, rec, tmp, orig_end);
}
}
return rc; return rc;
} else if (split) {
msg_pl = &tmp->msg_plaintext;
msg_en = &tmp->msg_encrypted;
sk_msg_trim(sk, msg_en, msg_pl->sg.size +
tls_ctx->tx.overhead_size);
tls_ctx->pending_open_record_frags = true;
ctx->open_rec = tmp;
} }
return tls_tx_records(sk, flags); return tls_tx_records(sk, flags);
} }
static int tls_sw_push_pending_record(struct sock *sk, int flags) static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
{ bool full_record, u8 record_type,
return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA); size_t *copied, int flags)
}
static struct tls_rec *get_rec(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct sk_msg *msg_pl, *msg_en; struct sk_msg msg_redir = { };
struct sk_psock *psock;
struct sock *sk_redir;
struct tls_rec *rec; struct tls_rec *rec;
int mem_size; int err = 0, send;
bool enospc;
psock = sk_psock_get(sk);
if (!psock)
return tls_push_record(sk, flags, record_type);
more_data:
enospc = sk_msg_full(msg);
if (psock->eval == __SK_NONE)
psock->eval = sk_psock_msg_verdict(sk, psock, msg);
if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
!enospc && !full_record) {
err = -ENOSPC;
goto out_err;
}
msg->cork_bytes = 0;
send = msg->sg.size;
if (msg->apply_bytes && msg->apply_bytes < send)
send = msg->apply_bytes;
switch (psock->eval) {
case __SK_PASS:
err = tls_push_record(sk, flags, record_type);
if (err < 0) {
*copied -= sk_msg_free(sk, msg);
tls_free_open_rec(sk);
goto out_err;
}
break;
case __SK_REDIRECT:
sk_redir = psock->sk_redir;
memcpy(&msg_redir, msg, sizeof(*msg));
if (msg->apply_bytes < send)
msg->apply_bytes = 0;
else
msg->apply_bytes -= send;
sk_msg_return_zero(sk, msg, send);
msg->sg.size -= send;
release_sock(sk);
err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
lock_sock(sk);
if (err < 0) {
*copied -= sk_msg_free_nocharge(sk, &msg_redir);
msg->sg.size = 0;
}
if (msg->sg.size == 0)
tls_free_open_rec(sk);
break;
case __SK_DROP:
default:
sk_msg_free_partial(sk, msg, send);
if (msg->apply_bytes < send)
msg->apply_bytes = 0;
else
msg->apply_bytes -= send;
if (msg->sg.size == 0)
tls_free_open_rec(sk);
*copied -= send;
err = -EACCES;
}
/* Return if we already have an open record */ if (likely(!err)) {
if (ctx->open_rec) bool reset_eval = !ctx->open_rec;
return ctx->open_rec;
mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); rec = ctx->open_rec;
if (rec) {
msg = &rec->msg_plaintext;
if (!msg->apply_bytes)
reset_eval = true;
}
if (reset_eval) {
psock->eval = __SK_NONE;
if (psock->sk_redir) {
sock_put(psock->sk_redir);
psock->sk_redir = NULL;
}
}
if (rec)
goto more_data;
}
out_err:
sk_psock_put(sk, psock);
return err;
}
static int tls_sw_push_pending_record(struct sock *sk, int flags)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec;
struct sk_msg *msg_pl;
size_t copied;
rec = kzalloc(mem_size, sk->sk_allocation);
if (!rec) if (!rec)
return NULL; return 0;
msg_pl = &rec->msg_plaintext; msg_pl = &rec->msg_plaintext;
msg_en = &rec->msg_encrypted; copied = msg_pl->sg.size;
if (!copied)
sk_msg_init(msg_pl); return 0;
sk_msg_init(msg_en);
sg_init_table(rec->sg_aead_in, 2);
sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
sizeof(rec->aad_space));
sg_unmark_end(&rec->sg_aead_in[1]);
sg_init_table(rec->sg_aead_out, 2);
sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
sizeof(rec->aad_space));
sg_unmark_end(&rec->sg_aead_out[1]);
ctx->open_rec = rec;
rec->inplace_crypto = 1;
return rec; return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
&copied, flags);
} }
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
...@@ -589,7 +840,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -589,7 +840,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
goto send_end; goto send_end;
} }
rec = get_rec(sk); if (ctx->open_rec)
rec = ctx->open_rec;
else
rec = ctx->open_rec = tls_get_rec(sk);
if (!rec) { if (!rec) {
ret = -ENOMEM; ret = -ENOMEM;
goto send_end; goto send_end;
...@@ -628,6 +882,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -628,6 +882,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
} }
if (!is_kvec && (full_record || eor) && !async_capable) { if (!is_kvec && (full_record || eor) && !async_capable) {
u32 first = msg_pl->sg.end;
ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
msg_pl, try_to_copy); msg_pl, try_to_copy);
if (ret) if (ret)
...@@ -637,15 +893,27 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -637,15 +893,27 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
num_zc++; num_zc++;
copied += try_to_copy; copied += try_to_copy;
ret = tls_push_record(sk, msg->msg_flags, record_type);
sk_msg_sg_copy_set(msg_pl, first);
ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
record_type, &copied,
msg->msg_flags);
if (ret) { if (ret) {
if (ret == -EINPROGRESS) if (ret == -EINPROGRESS)
num_async++; num_async++;
else if (ret == -ENOMEM)
goto wait_for_memory;
else if (ret == -ENOSPC)
goto rollback_iter;
else if (ret != -EAGAIN) else if (ret != -EAGAIN)
goto send_end; goto send_end;
} }
continue; continue;
rollback_iter:
copied -= try_to_copy;
sk_msg_sg_copy_clear(msg_pl, first);
iov_iter_revert(&msg->msg_iter,
msg_pl->sg.size - orig_size);
fallback_to_reg_send: fallback_to_reg_send:
sk_msg_trim(sk, msg_pl, orig_size); sk_msg_trim(sk, msg_pl, orig_size);
} }
...@@ -678,14 +946,21 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -678,14 +946,21 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
tls_ctx->pending_open_record_frags = true; tls_ctx->pending_open_record_frags = true;
copied += try_to_copy; copied += try_to_copy;
if (full_record || eor) { if (full_record || eor) {
ret = tls_push_record(sk, msg->msg_flags, record_type); ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
record_type, &copied,
msg->msg_flags);
if (ret) { if (ret) {
if (ret == -EINPROGRESS) if (ret == -EINPROGRESS)
num_async++; num_async++;
else if (ret != -EAGAIN) else if (ret == -ENOMEM)
goto wait_for_memory;
else if (ret != -EAGAIN) {
if (ret == -ENOSPC)
ret = 0;
goto send_end; goto send_end;
} }
} }
}
continue; continue;
...@@ -742,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -742,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
unsigned char record_type = TLS_RECORD_TYPE_DATA; unsigned char record_type = TLS_RECORD_TYPE_DATA;
size_t orig_size = size;
struct sk_msg *msg_pl; struct sk_msg *msg_pl;
struct tls_rec *rec; struct tls_rec *rec;
int num_async = 0; int num_async = 0;
size_t copied = 0;
bool full_record; bool full_record;
int record_room; int record_room;
int ret = 0; int ret = 0;
...@@ -778,7 +1053,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -778,7 +1053,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
goto sendpage_end; goto sendpage_end;
} }
rec = get_rec(sk); if (ctx->open_rec)
rec = ctx->open_rec;
else
rec = ctx->open_rec = tls_get_rec(sk);
if (!rec) { if (!rec) {
ret = -ENOMEM; ret = -ENOMEM;
goto sendpage_end; goto sendpage_end;
...@@ -788,6 +1066,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -788,6 +1066,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
full_record = false; full_record = false;
record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
copied = 0;
copy = size; copy = size;
if (copy >= record_room) { if (copy >= record_room) {
copy = record_room; copy = record_room;
...@@ -818,18 +1097,25 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -818,18 +1097,25 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
offset += copy; offset += copy;
size -= copy; size -= copy;
copied += copy;
tls_ctx->pending_open_record_frags = true; tls_ctx->pending_open_record_frags = true;
if (full_record || eor || sk_msg_full(msg_pl)) { if (full_record || eor || sk_msg_full(msg_pl)) {
rec->inplace_crypto = 0; rec->inplace_crypto = 0;
ret = tls_push_record(sk, flags, record_type); ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
record_type, &copied, flags);
if (ret) { if (ret) {
if (ret == -EINPROGRESS) if (ret == -EINPROGRESS)
num_async++; num_async++;
else if (ret != -EAGAIN) else if (ret == -ENOMEM)
goto wait_for_memory;
else if (ret != -EAGAIN) {
if (ret == -ENOSPC)
ret = 0;
goto sendpage_end; goto sendpage_end;
} }
} }
}
continue; continue;
wait_for_sndbuf: wait_for_sndbuf:
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
...@@ -851,24 +1137,20 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -851,24 +1137,20 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
} }
} }
sendpage_end: sendpage_end:
if (orig_size > size)
ret = orig_size - size;
else
ret = sk_stream_error(sk, flags, ret); ret = sk_stream_error(sk, flags, ret);
release_sock(sk); release_sock(sk);
return ret; return copied ? copied : ret;
} }
static struct sk_buff *tls_wait_data(struct sock *sk, int flags, static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
long timeo, int *err) int flags, long timeo, int *err)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb; struct sk_buff *skb;
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
while (!(skb = ctx->recv_pkt)) { while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
if (sk->sk_err) { if (sk->sk_err) {
*err = sock_error(sk); *err = sock_error(sk);
return NULL; return NULL;
...@@ -887,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, ...@@ -887,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
add_wait_queue(sk_sleep(sk), &wait); add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); sk_wait_event(sk, &timeo,
ctx->recv_pkt != skb ||
!sk_psock_queue_empty(psock),
&wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait); remove_wait_queue(sk_sleep(sk), &wait);
...@@ -1164,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1164,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk,
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
unsigned char control; unsigned char control;
struct strp_msg *rxm; struct strp_msg *rxm;
struct sk_buff *skb; struct sk_buff *skb;
...@@ -1179,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1179,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk,
if (unlikely(flags & MSG_ERRQUEUE)) if (unlikely(flags & MSG_ERRQUEUE))
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
psock = sk_psock_get(sk);
lock_sock(sk); lock_sock(sk);
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
...@@ -1188,9 +1475,19 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1188,9 +1475,19 @@ int tls_sw_recvmsg(struct sock *sk,
bool async = false; bool async = false;
int chunk = 0; int chunk = 0;
skb = tls_wait_data(sk, flags, timeo, &err); skb = tls_wait_data(sk, psock, flags, timeo, &err);
if (!skb) if (!skb) {
if (psock) {
int ret = __tcp_bpf_recvmsg(sk, psock, msg, len);
if (ret > 0) {
copied += ret;
len -= ret;
continue;
}
}
goto recv_end; goto recv_end;
}
rxm = strp_msg(skb); rxm = strp_msg(skb);
...@@ -1296,6 +1593,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1296,6 +1593,8 @@ int tls_sw_recvmsg(struct sock *sk,
} }
release_sock(sk); release_sock(sk);
if (psock)
sk_psock_put(sk, psock);
return copied ? : err; return copied ? : err;
} }
...@@ -1318,7 +1617,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -1318,7 +1617,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
skb = tls_wait_data(sk, flags, timeo, &err); skb = tls_wait_data(sk, NULL, flags, timeo, &err);
if (!skb) if (!skb)
goto splice_read_end; goto splice_read_end;
...@@ -1356,11 +1655,16 @@ bool tls_sw_stream_read(const struct sock *sk) ...@@ -1356,11 +1655,16 @@ bool tls_sw_stream_read(const struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
bool ingress_empty = true;
struct sk_psock *psock;
if (ctx->recv_pkt) rcu_read_lock();
return true; psock = sk_psock(sk);
if (psock)
ingress_empty = list_empty(&psock->ingress_msg);
rcu_read_unlock();
return false; return !ingress_empty || ctx->recv_pkt;
} }
static int tls_read_size(struct strparser *strp, struct sk_buff *skb) static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
...@@ -1439,8 +1743,15 @@ static void tls_data_ready(struct sock *sk) ...@@ -1439,8 +1743,15 @@ static void tls_data_ready(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
strp_data_ready(&ctx->strp); strp_data_ready(&ctx->strp);
psock = sk_psock_get(sk);
if (psock && !list_empty(&psock->ingress_msg)) {
ctx->saved_data_ready(sk);
sk_psock_put(sk, psock);
}
} }
void tls_sw_free_resources_tx(struct sock *sk) void tls_sw_free_resources_tx(struct sock *sk)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment