Commit 57ebc623 authored by Daniel Borkmann's avatar Daniel Borkmann

Merge branch 'bpf-sockmap-tls-fixes'

Jakub Kicinski says:

====================
John says:

Resolve a series of splats discovered by syzbot and an unhash
TLS issue noted by Eric Dumazet.

The main issues revolved around interaction between TLS and
sockmap tear down. TLS and sockmap could both reset sk->prot
ops creating a condition where a close or unhash op could be
called forever. A rare race condition resulting from a missing
rcu sync operation was causing a use after free. Then on the
TLS side dropping the sock lock and re-acquiring it during the
close op could hang. Finally, sockmap must be deployed before
tls for current stack assumptions to be met. This is enforced
now. A feature series can enable it.

To fix this first refactor TLS code so the lock is held for the
entire teardown operation. Then add an unhash callback to ensure
TLS can not transition from ESTABLISHED to LISTEN state. This
transition is a similar bug to the one found and fixed previously
in sockmap. Then apply three fixes to sockmap to fix up races
on tear down around map free and close. Finally, if sockmap
is destroyed before TLS we add a new ULP op update to inform
the TLS stack it should not call sockmap ops. This last one
appears to be the most commonly found issue from syzbot.

v4:
 - fix some use after frees;
 - disable disconnect work for offload (ctx lifetime is much
   more complex);
 - remove some of the dead code which made it hard to understand
   (for me) that things work correctly (e.g. the checks TLS is
   the top ULP);
 - add selftets.
====================
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parents 1d4126c4 d4d34185
...@@ -513,3 +513,9 @@ Redirects leak clear text ...@@ -513,3 +513,9 @@ Redirects leak clear text
In the RX direction, if segment has already been decrypted by the device In the RX direction, if segment has already been decrypted by the device
and it gets redirected or mirrored - clear text will be transmitted out. and it gets redirected or mirrored - clear text will be transmitted out.
shutdown() doesn't clear TLS state
----------------------------------
shutdown() system call allows for a TLS socket to be reused as a different
connection. Offload doesn't currently handle that.
...@@ -354,6 +354,12 @@ static inline void sk_psock_restore_proto(struct sock *sk, ...@@ -354,6 +354,12 @@ static inline void sk_psock_restore_proto(struct sock *sk,
sk->sk_write_space = psock->saved_write_space; sk->sk_write_space = psock->saved_write_space;
if (psock->sk_proto) { if (psock->sk_proto) {
struct inet_connection_sock *icsk = inet_csk(sk);
bool has_ulp = !!icsk->icsk_ulp_data;
if (has_ulp)
tcp_update_ulp(sk, psock->sk_proto);
else
sk->sk_prot = psock->sk_proto; sk->sk_prot = psock->sk_proto;
psock->sk_proto = NULL; psock->sk_proto = NULL;
} }
......
...@@ -2103,6 +2103,8 @@ struct tcp_ulp_ops { ...@@ -2103,6 +2103,8 @@ struct tcp_ulp_ops {
/* initialize ulp */ /* initialize ulp */
int (*init)(struct sock *sk); int (*init)(struct sock *sk);
/* update ulp */
void (*update)(struct sock *sk, struct proto *p);
/* cleanup ulp */ /* cleanup ulp */
void (*release)(struct sock *sk); void (*release)(struct sock *sk);
...@@ -2114,6 +2116,7 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type); ...@@ -2114,6 +2116,7 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type);
int tcp_set_ulp(struct sock *sk, const char *name); int tcp_set_ulp(struct sock *sk, const char *name);
void tcp_get_available_ulp(char *buf, size_t len); void tcp_get_available_ulp(char *buf, size_t len);
void tcp_cleanup_ulp(struct sock *sk); void tcp_cleanup_ulp(struct sock *sk);
void tcp_update_ulp(struct sock *sk, struct proto *p);
#define MODULE_ALIAS_TCP_ULP(name) \ #define MODULE_ALIAS_TCP_ULP(name) \
__MODULE_INFO(alias, alias_userspace, name); \ __MODULE_INFO(alias, alias_userspace, name); \
......
...@@ -107,9 +107,7 @@ struct tls_device { ...@@ -107,9 +107,7 @@ struct tls_device {
enum { enum {
TLS_BASE, TLS_BASE,
TLS_SW, TLS_SW,
#ifdef CONFIG_TLS_DEVICE
TLS_HW, TLS_HW,
#endif
TLS_HW_RECORD, TLS_HW_RECORD,
TLS_NUM_CONFIG, TLS_NUM_CONFIG,
}; };
...@@ -162,6 +160,7 @@ struct tls_sw_context_tx { ...@@ -162,6 +160,7 @@ struct tls_sw_context_tx {
int async_capable; int async_capable;
#define BIT_TX_SCHEDULED 0 #define BIT_TX_SCHEDULED 0
#define BIT_TX_CLOSING 1
unsigned long tx_bitmask; unsigned long tx_bitmask;
}; };
...@@ -272,6 +271,8 @@ struct tls_context { ...@@ -272,6 +271,8 @@ struct tls_context {
unsigned long flags; unsigned long flags;
/* cache cold stuff */ /* cache cold stuff */
struct proto *sk_proto;
void (*sk_destruct)(struct sock *sk); void (*sk_destruct)(struct sock *sk);
void (*sk_proto_close)(struct sock *sk, long timeout); void (*sk_proto_close)(struct sock *sk, long timeout);
...@@ -289,6 +290,8 @@ struct tls_context { ...@@ -289,6 +290,8 @@ struct tls_context {
struct list_head list; struct list_head list;
refcount_t refcount; refcount_t refcount;
struct work_struct gc;
}; };
enum tls_offload_ctx_dir { enum tls_offload_ctx_dir {
...@@ -355,13 +358,17 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, ...@@ -355,13 +358,17 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
unsigned int optlen); unsigned int optlen);
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
void tls_sw_strparser_done(struct tls_context *tls_ctx);
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_sw_sendpage(struct sock *sk, struct page *page, int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags); int offset, size_t size, int flags);
void tls_sw_close(struct sock *sk, long timeout); void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
void tls_sw_free_resources_tx(struct sock *sk); void tls_sw_release_resources_tx(struct sock *sk);
void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
void tls_sw_free_resources_rx(struct sock *sk); void tls_sw_free_resources_rx(struct sock *sk);
void tls_sw_release_resources_rx(struct sock *sk); void tls_sw_release_resources_rx(struct sock *sk);
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
bool tls_sw_stream_read(const struct sock *sk); bool tls_sw_stream_read(const struct sock *sk);
......
...@@ -585,12 +585,12 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy); ...@@ -585,12 +585,12 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy);
void sk_psock_drop(struct sock *sk, struct sk_psock *psock) void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
{ {
rcu_assign_sk_user_data(sk, NULL);
sk_psock_cork_free(psock); sk_psock_cork_free(psock);
sk_psock_zap_ingress(psock); sk_psock_zap_ingress(psock);
sk_psock_restore_proto(sk, psock);
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
sk_psock_restore_proto(sk, psock);
rcu_assign_sk_user_data(sk, NULL);
if (psock->progs.skb_parser) if (psock->progs.skb_parser)
sk_psock_stop_strp(sk, psock); sk_psock_stop_strp(sk, psock);
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
......
...@@ -247,6 +247,8 @@ static void sock_map_free(struct bpf_map *map) ...@@ -247,6 +247,8 @@ static void sock_map_free(struct bpf_map *map)
raw_spin_unlock_bh(&stab->lock); raw_spin_unlock_bh(&stab->lock);
rcu_read_unlock(); rcu_read_unlock();
synchronize_rcu();
bpf_map_area_free(stab->sks); bpf_map_area_free(stab->sks);
kfree(stab); kfree(stab);
} }
...@@ -276,16 +278,20 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, ...@@ -276,16 +278,20 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
struct sock **psk) struct sock **psk)
{ {
struct sock *sk; struct sock *sk;
int err = 0;
raw_spin_lock_bh(&stab->lock); raw_spin_lock_bh(&stab->lock);
sk = *psk; sk = *psk;
if (!sk_test || sk_test == sk) if (!sk_test || sk_test == sk)
*psk = NULL; sk = xchg(psk, NULL);
raw_spin_unlock_bh(&stab->lock);
if (unlikely(!sk)) if (likely(sk))
return -EINVAL;
sock_map_unref(sk, psk); sock_map_unref(sk, psk);
return 0; else
err = -EINVAL;
raw_spin_unlock_bh(&stab->lock);
return err;
} }
static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
...@@ -328,6 +334,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, ...@@ -328,6 +334,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
struct sock *sk, u64 flags) struct sock *sk, u64 flags)
{ {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map); struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct inet_connection_sock *icsk = inet_csk(sk);
struct sk_psock_link *link; struct sk_psock_link *link;
struct sk_psock *psock; struct sk_psock *psock;
struct sock *osk; struct sock *osk;
...@@ -338,6 +345,8 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, ...@@ -338,6 +345,8 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL; return -EINVAL;
if (unlikely(idx >= map->max_entries)) if (unlikely(idx >= map->max_entries))
return -E2BIG; return -E2BIG;
if (unlikely(icsk->icsk_ulp_data))
return -EINVAL;
link = sk_psock_init_link(); link = sk_psock_init_link();
if (!link) if (!link)
......
...@@ -96,6 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen) ...@@ -96,6 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen)
rcu_read_unlock(); rcu_read_unlock();
} }
void tcp_update_ulp(struct sock *sk, struct proto *proto)
{
struct inet_connection_sock *icsk = inet_csk(sk);
if (!icsk->icsk_ulp_ops) {
sk->sk_prot = proto;
return;
}
if (icsk->icsk_ulp_ops->update)
icsk->icsk_ulp_ops->update(sk, proto);
}
void tcp_cleanup_ulp(struct sock *sk) void tcp_cleanup_ulp(struct sock *sk)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
......
...@@ -261,24 +261,36 @@ void tls_ctx_free(struct tls_context *ctx) ...@@ -261,24 +261,36 @@ void tls_ctx_free(struct tls_context *ctx)
kfree(ctx); kfree(ctx);
} }
static void tls_sk_proto_close(struct sock *sk, long timeout) static void tls_ctx_free_deferred(struct work_struct *gc)
{ {
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = container_of(gc, struct tls_context, gc);
long timeo = sock_sndtimeo(sk, 0);
void (*sk_proto_close)(struct sock *sk, long timeout);
bool free_ctx = false;
lock_sock(sk);
sk_proto_close = ctx->sk_proto_close;
if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) /* Ensure any remaining work items are completed. The sk will
goto skip_tx_cleanup; * already have lost its tls_ctx reference by the time we get
* here so no xmit operation will actually be performed.
*/
if (ctx->tx_conf == TLS_SW) {
tls_sw_cancel_work_tx(ctx);
tls_sw_free_ctx_tx(ctx);
}
if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { if (ctx->rx_conf == TLS_SW) {
free_ctx = true; tls_sw_strparser_done(ctx);
goto skip_tx_cleanup; tls_sw_free_ctx_rx(ctx);
} }
tls_ctx_free(ctx);
}
static void tls_ctx_free_wq(struct tls_context *ctx)
{
INIT_WORK(&ctx->gc, tls_ctx_free_deferred);
schedule_work(&ctx->gc);
}
static void tls_sk_proto_cleanup(struct sock *sk,
struct tls_context *ctx, long timeo)
{
if (unlikely(sk->sk_write_pending) && if (unlikely(sk->sk_write_pending) &&
!wait_on_pending_writer(sk, &timeo)) !wait_on_pending_writer(sk, &timeo))
tls_handle_open_record(sk, 0); tls_handle_open_record(sk, 0);
...@@ -287,7 +299,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -287,7 +299,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
if (ctx->tx_conf == TLS_SW) { if (ctx->tx_conf == TLS_SW) {
kfree(ctx->tx.rec_seq); kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv); kfree(ctx->tx.iv);
tls_sw_free_resources_tx(sk); tls_sw_release_resources_tx(sk);
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
} else if (ctx->tx_conf == TLS_HW) { } else if (ctx->tx_conf == TLS_HW) {
tls_device_free_resources_tx(sk); tls_device_free_resources_tx(sk);
...@@ -295,26 +307,67 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -295,26 +307,67 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
} }
if (ctx->rx_conf == TLS_SW) if (ctx->rx_conf == TLS_SW)
tls_sw_free_resources_rx(sk); tls_sw_release_resources_rx(sk);
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
if (ctx->rx_conf == TLS_HW) if (ctx->rx_conf == TLS_HW)
tls_device_offload_cleanup_rx(sk); tls_device_offload_cleanup_rx(sk);
if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
#else
{
#endif #endif
tls_ctx_free(ctx); }
ctx = NULL;
static void tls_sk_proto_unhash(struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
long timeo = sock_sndtimeo(sk, 0);
struct tls_context *ctx;
if (unlikely(!icsk->icsk_ulp_data)) {
if (sk->sk_prot->unhash)
sk->sk_prot->unhash(sk);
} }
skip_tx_cleanup: ctx = tls_get_ctx(sk);
tls_sk_proto_cleanup(sk, ctx, timeo);
write_lock_bh(&sk->sk_callback_lock);
icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto;
write_unlock_bh(&sk->sk_callback_lock);
if (ctx->sk_proto->unhash)
ctx->sk_proto->unhash(sk);
tls_ctx_free_wq(ctx);
}
static void tls_sk_proto_close(struct sock *sk, long timeout)
{
struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx = tls_get_ctx(sk);
long timeo = sock_sndtimeo(sk, 0);
bool free_ctx;
if (ctx->tx_conf == TLS_SW)
tls_sw_cancel_work_tx(ctx);
lock_sock(sk);
free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
tls_sk_proto_cleanup(sk, ctx, timeo);
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto;
write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk); release_sock(sk);
sk_proto_close(sk, timeout); if (ctx->tx_conf == TLS_SW)
/* free ctx for TLS_HW_RECORD, used by tcp_set_state tls_sw_free_ctx_tx(ctx);
* for sk->sk_prot->unhash [tls_hw_unhash] if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
*/ tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx);
ctx->sk_proto_close(sk, timeout);
if (free_ctx) if (free_ctx)
tls_ctx_free(ctx); tls_ctx_free(ctx);
} }
...@@ -526,6 +579,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -526,6 +579,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
{ {
#endif #endif
rc = tls_set_sw_offload(sk, ctx, 1); rc = tls_set_sw_offload(sk, ctx, 1);
if (rc)
goto err_crypto_info;
conf = TLS_SW; conf = TLS_SW;
} }
} else { } else {
...@@ -537,13 +592,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -537,13 +592,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
{ {
#endif #endif
rc = tls_set_sw_offload(sk, ctx, 0); rc = tls_set_sw_offload(sk, ctx, 0);
if (rc)
goto err_crypto_info;
conf = TLS_SW; conf = TLS_SW;
} }
tls_sw_strparser_arm(sk, ctx);
} }
if (rc)
goto err_crypto_info;
if (tx) if (tx)
ctx->tx_conf = conf; ctx->tx_conf = conf;
else else
...@@ -607,6 +662,7 @@ static struct tls_context *create_ctx(struct sock *sk) ...@@ -607,6 +662,7 @@ static struct tls_context *create_ctx(struct sock *sk)
ctx->setsockopt = sk->sk_prot->setsockopt; ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt; ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close; ctx->sk_proto_close = sk->sk_prot->close;
ctx->unhash = sk->sk_prot->unhash;
return ctx; return ctx;
} }
...@@ -730,6 +786,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -730,6 +786,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
prot[TLS_BASE][TLS_BASE].unhash = tls_sk_proto_unhash;
prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
...@@ -747,16 +804,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -747,16 +804,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_HW][TLS_BASE].unhash = base->unhash;
prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
prot[TLS_HW][TLS_SW].unhash = base->unhash;
prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
prot[TLS_BASE][TLS_HW].unhash = base->unhash;
prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
prot[TLS_SW][TLS_HW].unhash = base->unhash;
prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
#endif #endif
...@@ -764,7 +825,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -764,7 +825,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
} }
static int tls_init(struct sock *sk) static int tls_init(struct sock *sk)
...@@ -773,7 +833,7 @@ static int tls_init(struct sock *sk) ...@@ -773,7 +833,7 @@ static int tls_init(struct sock *sk)
int rc = 0; int rc = 0;
if (tls_hw_prot(sk)) if (tls_hw_prot(sk))
goto out; return 0;
/* The TLS ulp is currently supported only for TCP sockets /* The TLS ulp is currently supported only for TCP sockets
* in ESTABLISHED state. * in ESTABLISHED state.
...@@ -784,21 +844,38 @@ static int tls_init(struct sock *sk) ...@@ -784,21 +844,38 @@ static int tls_init(struct sock *sk)
if (sk->sk_state != TCP_ESTABLISHED) if (sk->sk_state != TCP_ESTABLISHED)
return -ENOTSUPP; return -ENOTSUPP;
tls_build_proto(sk);
/* allocate tls context */ /* allocate tls context */
write_lock_bh(&sk->sk_callback_lock);
ctx = create_ctx(sk); ctx = create_ctx(sk);
if (!ctx) { if (!ctx) {
rc = -ENOMEM; rc = -ENOMEM;
goto out; goto out;
} }
tls_build_proto(sk);
ctx->tx_conf = TLS_BASE; ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE;
ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
out: out:
write_unlock_bh(&sk->sk_callback_lock);
return rc; return rc;
} }
static void tls_update(struct sock *sk, struct proto *p)
{
struct tls_context *ctx;
ctx = tls_get_ctx(sk);
if (likely(ctx)) {
ctx->sk_proto_close = p->close;
ctx->sk_proto = p;
} else {
sk->sk_prot = p;
}
}
void tls_register_device(struct tls_device *device) void tls_register_device(struct tls_device *device)
{ {
spin_lock_bh(&device_spinlock); spin_lock_bh(&device_spinlock);
...@@ -819,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { ...@@ -819,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.name = "tls", .name = "tls",
.owner = THIS_MODULE, .owner = THIS_MODULE,
.init = tls_init, .init = tls_init,
.update = tls_update,
}; };
static int __init tls_register(void) static int __init tls_register(void)
......
...@@ -2054,7 +2054,16 @@ static void tls_data_ready(struct sock *sk) ...@@ -2054,7 +2054,16 @@ static void tls_data_ready(struct sock *sk)
} }
} }
void tls_sw_free_resources_tx(struct sock *sk) void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
{
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
cancel_delayed_work_sync(&ctx->tx_work.work);
}
void tls_sw_release_resources_tx(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
...@@ -2065,11 +2074,6 @@ void tls_sw_free_resources_tx(struct sock *sk) ...@@ -2065,11 +2074,6 @@ void tls_sw_free_resources_tx(struct sock *sk)
if (atomic_read(&ctx->encrypt_pending)) if (atomic_read(&ctx->encrypt_pending))
crypto_wait_req(-EINPROGRESS, &ctx->async_wait); crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
release_sock(sk);
cancel_delayed_work_sync(&ctx->tx_work.work);
lock_sock(sk);
/* Tx whatever records we can transmit and abandon the rest */
tls_tx_records(sk, -1); tls_tx_records(sk, -1);
/* Free up un-sent records in tx_list. First, free /* Free up un-sent records in tx_list. First, free
...@@ -2092,6 +2096,11 @@ void tls_sw_free_resources_tx(struct sock *sk) ...@@ -2092,6 +2096,11 @@ void tls_sw_free_resources_tx(struct sock *sk)
crypto_free_aead(ctx->aead_send); crypto_free_aead(ctx->aead_send);
tls_free_open_rec(sk); tls_free_open_rec(sk);
}
void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
{
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
kfree(ctx); kfree(ctx);
} }
...@@ -2110,25 +2119,40 @@ void tls_sw_release_resources_rx(struct sock *sk) ...@@ -2110,25 +2119,40 @@ void tls_sw_release_resources_rx(struct sock *sk)
skb_queue_purge(&ctx->rx_list); skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv); crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp); strp_stop(&ctx->strp);
/* If tls_sw_strparser_arm() was not called (cleanup paths)
* we still want to strp_stop(), but sk->sk_data_ready was
* never swapped.
*/
if (ctx->saved_data_ready) {
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
sk->sk_data_ready = ctx->saved_data_ready; sk->sk_data_ready = ctx->saved_data_ready;
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk); }
strp_done(&ctx->strp);
lock_sock(sk);
} }
} }
void tls_sw_free_resources_rx(struct sock *sk) void tls_sw_strparser_done(struct tls_context *tls_ctx)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
tls_sw_release_resources_rx(sk); strp_done(&ctx->strp);
}
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
{
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
kfree(ctx); kfree(ctx);
} }
void tls_sw_free_resources_rx(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
tls_sw_release_resources_rx(sk);
tls_sw_free_ctx_rx(tls_ctx);
}
/* The work handler to transmitt the encrypted records in tx_list */ /* The work handler to transmitt the encrypted records in tx_list */
static void tx_work_handler(struct work_struct *work) static void tx_work_handler(struct work_struct *work)
{ {
...@@ -2137,11 +2161,17 @@ static void tx_work_handler(struct work_struct *work) ...@@ -2137,11 +2161,17 @@ static void tx_work_handler(struct work_struct *work)
struct tx_work, work); struct tx_work, work);
struct sock *sk = tx_work->sk; struct sock *sk = tx_work->sk;
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx;
if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) if (unlikely(!tls_ctx))
return;
ctx = tls_sw_ctx_tx(tls_ctx);
if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
return; return;
if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
return;
lock_sock(sk); lock_sock(sk);
tls_tx_records(sk, -1); tls_tx_records(sk, -1);
release_sock(sk); release_sock(sk);
...@@ -2160,6 +2190,18 @@ void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) ...@@ -2160,6 +2190,18 @@ void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
} }
} }
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
{
struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
write_lock_bh(&sk->sk_callback_lock);
rx_ctx->saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = tls_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
strp_check_rcv(&rx_ctx->strp);
}
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
...@@ -2357,13 +2399,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2357,13 +2399,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
cb.parse_msg = tls_read_size; cb.parse_msg = tls_read_size;
strp_init(&sw_ctx_rx->strp, sk, &cb); strp_init(&sw_ctx_rx->strp, sk, &cb);
write_lock_bh(&sk->sk_callback_lock);
sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = tls_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
strp_check_rcv(&sw_ctx_rx->strp);
} }
goto out; goto out;
......
...@@ -25,6 +25,80 @@ ...@@ -25,6 +25,80 @@
#define TLS_PAYLOAD_MAX_LEN 16384 #define TLS_PAYLOAD_MAX_LEN 16384
#define SOL_TLS 282 #define SOL_TLS 282
#ifndef ENOTSUPP
#define ENOTSUPP 524
#endif
FIXTURE(tls_basic)
{
int fd, cfd;
bool notls;
};
FIXTURE_SETUP(tls_basic)
{
struct sockaddr_in addr;
socklen_t len;
int sfd, ret;
self->notls = false;
len = sizeof(addr);
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0;
self->fd = socket(AF_INET, SOCK_STREAM, 0);
sfd = socket(AF_INET, SOCK_STREAM, 0);
ret = bind(sfd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = listen(sfd, 10);
ASSERT_EQ(ret, 0);
ret = getsockname(sfd, &addr, &len);
ASSERT_EQ(ret, 0);
ret = connect(self->fd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
self->cfd = accept(sfd, &addr, &len);
ASSERT_GE(self->cfd, 0);
close(sfd);
ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
if (ret != 0) {
ASSERT_EQ(errno, ENOTSUPP);
self->notls = true;
printf("Failure setting TCP_ULP, testing without tls\n");
return;
}
ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
ASSERT_EQ(ret, 0);
}
FIXTURE_TEARDOWN(tls_basic)
{
close(self->fd);
close(self->cfd);
}
/* Send some data through with ULP but no keys */
TEST_F(tls_basic, base_base)
{
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
ASSERT_EQ(strlen(test_str) + 1, send_len);
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
};
FIXTURE(tls) FIXTURE(tls)
{ {
int fd, cfd; int fd, cfd;
...@@ -165,6 +239,16 @@ TEST_F(tls, msg_more) ...@@ -165,6 +239,16 @@ TEST_F(tls, msg_more)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
TEST_F(tls, msg_more_unsent)
{
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
}
TEST_F(tls, sendmsg_single) TEST_F(tls, sendmsg_single)
{ {
struct msghdr msg; struct msghdr msg;
...@@ -610,6 +694,37 @@ TEST_F(tls, recv_lowat) ...@@ -610,6 +694,37 @@ TEST_F(tls, recv_lowat)
EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0); EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
} }
TEST_F(tls, bidir)
{
struct tls12_crypto_info_aes_gcm_128 tls12;
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
int ret;
memset(&tls12, 0, sizeof(tls12));
tls12.info.version = TLS_1_3_VERSION;
tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
ASSERT_EQ(ret, 0);
ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
ASSERT_EQ(ret, 0);
ASSERT_EQ(strlen(test_str) + 1, send_len);
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
memset(buf, 0, sizeof(buf));
EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
};
TEST_F(tls, pollin) TEST_F(tls, pollin)
{ {
char const *test_str = "test_poll"; char const *test_str = "test_poll";
...@@ -837,6 +952,85 @@ TEST_F(tls, control_msg) ...@@ -837,6 +952,85 @@ TEST_F(tls, control_msg)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
TEST_F(tls, shutdown)
{
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
ASSERT_EQ(strlen(test_str) + 1, send_len);
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
shutdown(self->fd, SHUT_RDWR);
shutdown(self->cfd, SHUT_RDWR);
}
TEST_F(tls, shutdown_unsent)
{
char const *test_str = "test_read";
int send_len = 10;
EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
shutdown(self->fd, SHUT_RDWR);
shutdown(self->cfd, SHUT_RDWR);
}
TEST(non_established) {
struct tls12_crypto_info_aes_gcm_256 tls12;
struct sockaddr_in addr;
int sfd, ret, fd;
socklen_t len;
len = sizeof(addr);
memset(&tls12, 0, sizeof(tls12));
tls12.info.version = TLS_1_2_VERSION;
tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0;
fd = socket(AF_INET, SOCK_STREAM, 0);
sfd = socket(AF_INET, SOCK_STREAM, 0);
ret = bind(sfd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = listen(sfd, 10);
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
EXPECT_EQ(ret, -1);
/* TLS ULP not supported */
if (errno == ENOENT)
return;
EXPECT_EQ(errno, ENOTSUPP);
ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
EXPECT_EQ(ret, -1);
EXPECT_EQ(errno, ENOTSUPP);
ret = getsockname(sfd, &addr, &len);
ASSERT_EQ(ret, 0);
ret = connect(fd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
EXPECT_EQ(ret, -1);
EXPECT_EQ(errno, EEXIST);
close(fd);
close(sfd);
}
TEST(keysizes) { TEST(keysizes) {
struct tls12_crypto_info_aes_gcm_256 tls12; struct tls12_crypto_info_aes_gcm_256 tls12;
struct sockaddr_in addr; struct sockaddr_in addr;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment