Commit 8c73b263 authored by Dmitry Safonov's avatar Dmitry Safonov Committed by David S. Miller

net/tcp: Prepare tcp_md5sig_pool for TCP-AO

TCP-AO, similarly to TCP-MD5, needs to allocate tfms on a slow-path,
which is setsockopt() and use crypto ahash requests on fast paths,
which are RX/TX softirqs. Also, it needs a temporary/scratch buffer
for preparing the hash.

Rework tcp_md5sig_pool in order to support other hashing algorithms
than MD5. It will make it possible to share pre-allocated crypto_ahash
descriptors and scratch area between all TCP hash users.

Internally tcp_sigpool calls crypto_clone_ahash() API over pre-allocated
crypto ahash tfm. Kudos to Herbert, who provided this new crypto API.

I was a little concerned over GFP_ATOMIC allocations of ahash and
crypto_request in RX/TX (see tcp_sigpool_start()), so I benchmarked both
"backends" with different algorithms, using patched version of iperf3[2].
On my laptop with i7-7600U @ 2.80GHz:

                         clone-tfm                per-CPU-requests
TCP-MD5                  2.25 Gbits/sec           2.30 Gbits/sec
TCP-AO(hmac(sha1))       2.53 Gbits/sec           2.54 Gbits/sec
TCP-AO(hmac(sha512))     1.67 Gbits/sec           1.64 Gbits/sec
TCP-AO(hmac(sha384))     1.77 Gbits/sec           1.80 Gbits/sec
TCP-AO(hmac(sha224))     1.29 Gbits/sec           1.30 Gbits/sec
TCP-AO(hmac(sha3-512))    481 Mbits/sec            480 Mbits/sec
TCP-AO(hmac(md5))        2.07 Gbits/sec           2.12 Gbits/sec
TCP-AO(hmac(rmd160))     1.01 Gbits/sec            995 Mbits/sec
TCP-AO(cmac(aes128))     [not supporetd yet]      2.11 Gbits/sec

So, it seems that my concerns don't have strong grounds and per-CPU
crypto_request allocation can be dropped/removed from tcp_sigpool once
ciphers get crypto_clone_ahash() support.

[1]: https://lore.kernel.org/all/ZDefxOq6Ax0JeTRH@gondor.apana.org.au/T/#u
[2]: https://github.com/0x7f454c46/iperf/tree/tcp-md5-aoSigned-off-by: default avatarDmitry Safonov <dima@arista.com>
Reviewed-by: default avatarSteen Hegelund <Steen.Hegelund@microchip.com>
Acked-by: default avatarDavid Ahern <dsahern@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent cc54d2e2
...@@ -1737,12 +1737,39 @@ union tcp_md5sum_block { ...@@ -1737,12 +1737,39 @@ union tcp_md5sum_block {
#endif #endif
}; };
/* - pool: digest algorithm, hash description and scratch buffer */ /*
struct tcp_md5sig_pool { * struct tcp_sigpool - per-CPU pool of ahash_requests
struct ahash_request *md5_req; * @scratch: per-CPU temporary area, that can be used between
* tcp_sigpool_start() and tcp_sigpool_end() to perform
* crypto request
* @req: pre-allocated ahash request
*/
struct tcp_sigpool {
void *scratch; void *scratch;
struct ahash_request *req;
}; };
int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size);
void tcp_sigpool_get(unsigned int id);
void tcp_sigpool_release(unsigned int id);
int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp,
const struct sk_buff *skb,
unsigned int header_len);
/**
* tcp_sigpool_start - disable bh and start using tcp_sigpool_ahash
* @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash()
* @c: returned tcp_sigpool for usage (uninitialized on failure)
*
* Returns 0 on success, error otherwise.
*/
int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c);
/**
* tcp_sigpool_end - enable bh and stop using tcp_sigpool
* @c: tcp_sigpool context that was returned by tcp_sigpool_start()
*/
void tcp_sigpool_end(struct tcp_sigpool *c);
size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len);
/* - functions */ /* - functions */
int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
const struct sock *sk, const struct sk_buff *skb); const struct sock *sk, const struct sk_buff *skb);
...@@ -1798,17 +1825,12 @@ tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb, ...@@ -1798,17 +1825,12 @@ tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
#define tcp_twsk_md5_key(twsk) NULL #define tcp_twsk_md5_key(twsk) NULL
#endif #endif
bool tcp_alloc_md5sig_pool(void); int tcp_md5_alloc_sigpool(void);
void tcp_md5_release_sigpool(void);
struct tcp_md5sig_pool *tcp_get_md5sig_pool(void); void tcp_md5_add_sigpool(void);
static inline void tcp_put_md5sig_pool(void) extern int tcp_md5_sigpool_id;
{
local_bh_enable();
}
int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *, const struct sk_buff *, int tcp_md5_hash_key(struct tcp_sigpool *hp,
unsigned int header_len);
int tcp_md5_hash_key(struct tcp_md5sig_pool *hp,
const struct tcp_md5sig_key *key); const struct tcp_md5sig_key *key);
/* From tcp_fastopen.c */ /* From tcp_fastopen.c */
......
...@@ -741,10 +741,14 @@ config DEFAULT_TCP_CONG ...@@ -741,10 +741,14 @@ config DEFAULT_TCP_CONG
default "bbr" if DEFAULT_BBR default "bbr" if DEFAULT_BBR
default "cubic" default "cubic"
config TCP_SIGPOOL
tristate
config TCP_MD5SIG config TCP_MD5SIG
bool "TCP: MD5 Signature Option support (RFC2385)" bool "TCP: MD5 Signature Option support (RFC2385)"
select CRYPTO select CRYPTO
select CRYPTO_MD5 select CRYPTO_MD5
select TCP_SIGPOOL
help help
RFC2385 specifies a method of giving MD5 protection to TCP sessions. RFC2385 specifies a method of giving MD5 protection to TCP sessions.
Its main (only?) use is to protect BGP sessions between core routers Its main (only?) use is to protect BGP sessions between core routers
......
...@@ -62,6 +62,7 @@ obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o ...@@ -62,6 +62,7 @@ obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o
obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
obj-$(CONFIG_TCP_SIGPOOL) += tcp_sigpool.o
obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o
obj-$(CONFIG_BPF_SYSCALL) += udp_bpf.o obj-$(CONFIG_BPF_SYSCALL) += udp_bpf.o
obj-$(CONFIG_NETLABEL) += cipso_ipv4.o obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
......
...@@ -4305,141 +4305,52 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, ...@@ -4305,141 +4305,52 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
EXPORT_SYMBOL(tcp_getsockopt); EXPORT_SYMBOL(tcp_getsockopt);
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
static DEFINE_PER_CPU(struct tcp_md5sig_pool, tcp_md5sig_pool); int tcp_md5_sigpool_id = -1;
static DEFINE_MUTEX(tcp_md5sig_mutex); EXPORT_SYMBOL_GPL(tcp_md5_sigpool_id);
static bool tcp_md5sig_pool_populated = false;
static void __tcp_alloc_md5sig_pool(void) int tcp_md5_alloc_sigpool(void)
{ {
struct crypto_ahash *hash; size_t scratch_size;
int cpu; int ret;
hash = crypto_alloc_ahash("md5", 0, CRYPTO_ALG_ASYNC);
if (IS_ERR(hash))
return;
for_each_possible_cpu(cpu) {
void *scratch = per_cpu(tcp_md5sig_pool, cpu).scratch;
struct ahash_request *req;
if (!scratch) {
scratch = kmalloc_node(sizeof(union tcp_md5sum_block) +
sizeof(struct tcphdr),
GFP_KERNEL,
cpu_to_node(cpu));
if (!scratch)
return;
per_cpu(tcp_md5sig_pool, cpu).scratch = scratch;
}
if (per_cpu(tcp_md5sig_pool, cpu).md5_req)
continue;
req = ahash_request_alloc(hash, GFP_KERNEL);
if (!req)
return;
ahash_request_set_callback(req, 0, NULL, NULL);
per_cpu(tcp_md5sig_pool, cpu).md5_req = req; scratch_size = sizeof(union tcp_md5sum_block) + sizeof(struct tcphdr);
} ret = tcp_sigpool_alloc_ahash("md5", scratch_size);
/* before setting tcp_md5sig_pool_populated, we must commit all writes if (ret >= 0) {
* to memory. See smp_rmb() in tcp_get_md5sig_pool() /* As long as any md5 sigpool was allocated, the return
* id would stay the same. Re-write the id only for the case
* when previously all MD5 keys were deleted and this call
* allocates the first MD5 key, which may return a different
* sigpool id than was used previously.
*/ */
smp_wmb(); WRITE_ONCE(tcp_md5_sigpool_id, ret); /* Avoids the compiler potentially being smart here */
/* Paired with READ_ONCE() from tcp_alloc_md5sig_pool() return 0;
* and tcp_get_md5sig_pool().
*/
WRITE_ONCE(tcp_md5sig_pool_populated, true);
}
bool tcp_alloc_md5sig_pool(void)
{
/* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */
if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) {
mutex_lock(&tcp_md5sig_mutex);
if (!tcp_md5sig_pool_populated)
__tcp_alloc_md5sig_pool();
mutex_unlock(&tcp_md5sig_mutex);
} }
/* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ return ret;
return READ_ONCE(tcp_md5sig_pool_populated);
} }
EXPORT_SYMBOL(tcp_alloc_md5sig_pool);
void tcp_md5_release_sigpool(void)
/**
* tcp_get_md5sig_pool - get md5sig_pool for this user
*
* We use percpu structure, so if we succeed, we exit with preemption
* and BH disabled, to make sure another thread or softirq handling
* wont try to get same context.
*/
struct tcp_md5sig_pool *tcp_get_md5sig_pool(void)
{ {
local_bh_disable(); tcp_sigpool_release(READ_ONCE(tcp_md5_sigpool_id));
/* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */
if (READ_ONCE(tcp_md5sig_pool_populated)) {
/* coupled with smp_wmb() in __tcp_alloc_md5sig_pool() */
smp_rmb();
return this_cpu_ptr(&tcp_md5sig_pool);
}
local_bh_enable();
return NULL;
} }
EXPORT_SYMBOL(tcp_get_md5sig_pool);
int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *hp, void tcp_md5_add_sigpool(void)
const struct sk_buff *skb, unsigned int header_len)
{ {
struct scatterlist sg; tcp_sigpool_get(READ_ONCE(tcp_md5_sigpool_id));
const struct tcphdr *tp = tcp_hdr(skb);
struct ahash_request *req = hp->md5_req;
unsigned int i;
const unsigned int head_data_len = skb_headlen(skb) > header_len ?
skb_headlen(skb) - header_len : 0;
const struct skb_shared_info *shi = skb_shinfo(skb);
struct sk_buff *frag_iter;
sg_init_table(&sg, 1);
sg_set_buf(&sg, ((u8 *) tp) + header_len, head_data_len);
ahash_request_set_crypt(req, &sg, NULL, head_data_len);
if (crypto_ahash_update(req))
return 1;
for (i = 0; i < shi->nr_frags; ++i) {
const skb_frag_t *f = &shi->frags[i];
unsigned int offset = skb_frag_off(f);
struct page *page = skb_frag_page(f) + (offset >> PAGE_SHIFT);
sg_set_page(&sg, page, skb_frag_size(f),
offset_in_page(offset));
ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f));
if (crypto_ahash_update(req))
return 1;
}
skb_walk_frags(skb, frag_iter)
if (tcp_md5_hash_skb_data(hp, frag_iter, 0))
return 1;
return 0;
} }
EXPORT_SYMBOL(tcp_md5_hash_skb_data);
int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *key) int tcp_md5_hash_key(struct tcp_sigpool *hp,
const struct tcp_md5sig_key *key)
{ {
u8 keylen = READ_ONCE(key->keylen); /* paired with WRITE_ONCE() in tcp_md5_do_add */ u8 keylen = READ_ONCE(key->keylen); /* paired with WRITE_ONCE() in tcp_md5_do_add */
struct scatterlist sg; struct scatterlist sg;
sg_init_one(&sg, key->key, keylen); sg_init_one(&sg, key->key, keylen);
ahash_request_set_crypt(hp->md5_req, &sg, NULL, keylen); ahash_request_set_crypt(hp->req, &sg, NULL, keylen);
/* We use data_race() because tcp_md5_do_add() might change key->key under us */ /* We use data_race() because tcp_md5_do_add() might change
return data_race(crypto_ahash_update(hp->md5_req)); * key->key under us
*/
return data_race(crypto_ahash_update(hp->req));
} }
EXPORT_SYMBOL(tcp_md5_hash_key); EXPORT_SYMBOL(tcp_md5_hash_key);
......
...@@ -1221,10 +1221,6 @@ static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, ...@@ -1221,10 +1221,6 @@ static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
key = sock_kmalloc(sk, sizeof(*key), gfp | __GFP_ZERO); key = sock_kmalloc(sk, sizeof(*key), gfp | __GFP_ZERO);
if (!key) if (!key)
return -ENOMEM; return -ENOMEM;
if (!tcp_alloc_md5sig_pool()) {
sock_kfree_s(sk, key, sizeof(*key));
return -ENOMEM;
}
memcpy(key->key, newkey, newkeylen); memcpy(key->key, newkey, newkeylen);
key->keylen = newkeylen; key->keylen = newkeylen;
...@@ -1246,15 +1242,21 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, ...@@ -1246,15 +1242,21 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
struct tcp_sock *tp = tcp_sk(sk); struct tcp_sock *tp = tcp_sk(sk);
if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
if (tcp_md5sig_info_add(sk, GFP_KERNEL)) if (tcp_md5_alloc_sigpool())
return -ENOMEM; return -ENOMEM;
if (tcp_md5sig_info_add(sk, GFP_KERNEL)) {
tcp_md5_release_sigpool();
return -ENOMEM;
}
if (!static_branch_inc(&tcp_md5_needed.key)) { if (!static_branch_inc(&tcp_md5_needed.key)) {
struct tcp_md5sig_info *md5sig; struct tcp_md5sig_info *md5sig;
md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
rcu_assign_pointer(tp->md5sig_info, NULL); rcu_assign_pointer(tp->md5sig_info, NULL);
kfree_rcu(md5sig, rcu); kfree_rcu(md5sig, rcu);
tcp_md5_release_sigpool();
return -EUSERS; return -EUSERS;
} }
} }
...@@ -1271,8 +1273,12 @@ int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr, ...@@ -1271,8 +1273,12 @@ int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
struct tcp_sock *tp = tcp_sk(sk); struct tcp_sock *tp = tcp_sk(sk);
if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC))) tcp_md5_add_sigpool();
if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC))) {
tcp_md5_release_sigpool();
return -ENOMEM; return -ENOMEM;
}
if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) { if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) {
struct tcp_md5sig_info *md5sig; struct tcp_md5sig_info *md5sig;
...@@ -1281,6 +1287,7 @@ int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr, ...@@ -1281,6 +1287,7 @@ int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
net_warn_ratelimited("Too many TCP-MD5 keys in the system\n"); net_warn_ratelimited("Too many TCP-MD5 keys in the system\n");
rcu_assign_pointer(tp->md5sig_info, NULL); rcu_assign_pointer(tp->md5sig_info, NULL);
kfree_rcu(md5sig, rcu); kfree_rcu(md5sig, rcu);
tcp_md5_release_sigpool();
return -EUSERS; return -EUSERS;
} }
} }
...@@ -1380,7 +1387,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, ...@@ -1380,7 +1387,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
cmd.tcpm_key, cmd.tcpm_keylen); cmd.tcpm_key, cmd.tcpm_keylen);
} }
static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp, static int tcp_v4_md5_hash_headers(struct tcp_sigpool *hp,
__be32 daddr, __be32 saddr, __be32 daddr, __be32 saddr,
const struct tcphdr *th, int nbytes) const struct tcphdr *th, int nbytes)
{ {
...@@ -1400,38 +1407,35 @@ static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp, ...@@ -1400,38 +1407,35 @@ static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
_th->check = 0; _th->check = 0;
sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th)); sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
ahash_request_set_crypt(hp->md5_req, &sg, NULL, ahash_request_set_crypt(hp->req, &sg, NULL,
sizeof(*bp) + sizeof(*th)); sizeof(*bp) + sizeof(*th));
return crypto_ahash_update(hp->md5_req); return crypto_ahash_update(hp->req);
} }
static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
__be32 daddr, __be32 saddr, const struct tcphdr *th) __be32 daddr, __be32 saddr, const struct tcphdr *th)
{ {
struct tcp_md5sig_pool *hp; struct tcp_sigpool hp;
struct ahash_request *req;
hp = tcp_get_md5sig_pool(); if (tcp_sigpool_start(tcp_md5_sigpool_id, &hp))
if (!hp) goto clear_hash_nostart;
goto clear_hash_noput;
req = hp->md5_req;
if (crypto_ahash_init(req)) if (crypto_ahash_init(hp.req))
goto clear_hash; goto clear_hash;
if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2)) if (tcp_v4_md5_hash_headers(&hp, daddr, saddr, th, th->doff << 2))
goto clear_hash; goto clear_hash;
if (tcp_md5_hash_key(hp, key)) if (tcp_md5_hash_key(&hp, key))
goto clear_hash; goto clear_hash;
ahash_request_set_crypt(req, NULL, md5_hash, 0); ahash_request_set_crypt(hp.req, NULL, md5_hash, 0);
if (crypto_ahash_final(req)) if (crypto_ahash_final(hp.req))
goto clear_hash; goto clear_hash;
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
return 0; return 0;
clear_hash: clear_hash:
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
clear_hash_noput: clear_hash_nostart:
memset(md5_hash, 0, 16); memset(md5_hash, 0, 16);
return 1; return 1;
} }
...@@ -1440,9 +1444,8 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, ...@@ -1440,9 +1444,8 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
const struct sock *sk, const struct sock *sk,
const struct sk_buff *skb) const struct sk_buff *skb)
{ {
struct tcp_md5sig_pool *hp;
struct ahash_request *req;
const struct tcphdr *th = tcp_hdr(skb); const struct tcphdr *th = tcp_hdr(skb);
struct tcp_sigpool hp;
__be32 saddr, daddr; __be32 saddr, daddr;
if (sk) { /* valid for establish/request sockets */ if (sk) { /* valid for establish/request sockets */
...@@ -1454,30 +1457,28 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, ...@@ -1454,30 +1457,28 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
daddr = iph->daddr; daddr = iph->daddr;
} }
hp = tcp_get_md5sig_pool(); if (tcp_sigpool_start(tcp_md5_sigpool_id, &hp))
if (!hp) goto clear_hash_nostart;
goto clear_hash_noput;
req = hp->md5_req;
if (crypto_ahash_init(req)) if (crypto_ahash_init(hp.req))
goto clear_hash; goto clear_hash;
if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, skb->len)) if (tcp_v4_md5_hash_headers(&hp, daddr, saddr, th, skb->len))
goto clear_hash; goto clear_hash;
if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2)) if (tcp_sigpool_hash_skb_data(&hp, skb, th->doff << 2))
goto clear_hash; goto clear_hash;
if (tcp_md5_hash_key(hp, key)) if (tcp_md5_hash_key(&hp, key))
goto clear_hash; goto clear_hash;
ahash_request_set_crypt(req, NULL, md5_hash, 0); ahash_request_set_crypt(hp.req, NULL, md5_hash, 0);
if (crypto_ahash_final(req)) if (crypto_ahash_final(hp.req))
goto clear_hash; goto clear_hash;
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
return 0; return 0;
clear_hash: clear_hash:
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
clear_hash_noput: clear_hash_nostart:
memset(md5_hash, 0, 16); memset(md5_hash, 0, 16);
return 1; return 1;
} }
...@@ -2296,6 +2297,18 @@ static int tcp_v4_init_sock(struct sock *sk) ...@@ -2296,6 +2297,18 @@ static int tcp_v4_init_sock(struct sock *sk)
return 0; return 0;
} }
#ifdef CONFIG_TCP_MD5SIG
static void tcp_md5sig_info_free_rcu(struct rcu_head *head)
{
struct tcp_md5sig_info *md5sig;
md5sig = container_of(head, struct tcp_md5sig_info, rcu);
kfree(md5sig);
static_branch_slow_dec_deferred(&tcp_md5_needed);
tcp_md5_release_sigpool();
}
#endif
void tcp_v4_destroy_sock(struct sock *sk) void tcp_v4_destroy_sock(struct sock *sk)
{ {
struct tcp_sock *tp = tcp_sk(sk); struct tcp_sock *tp = tcp_sk(sk);
...@@ -2320,10 +2333,12 @@ void tcp_v4_destroy_sock(struct sock *sk) ...@@ -2320,10 +2333,12 @@ void tcp_v4_destroy_sock(struct sock *sk)
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
/* Clean up the MD5 key list, if any */ /* Clean up the MD5 key list, if any */
if (tp->md5sig_info) { if (tp->md5sig_info) {
struct tcp_md5sig_info *md5sig;
md5sig = rcu_dereference_protected(tp->md5sig_info, 1);
tcp_clear_md5_list(sk); tcp_clear_md5_list(sk);
kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu); call_rcu(&md5sig->rcu, tcp_md5sig_info_free_rcu);
tp->md5sig_info = NULL; rcu_assign_pointer(tp->md5sig_info, NULL);
static_branch_slow_dec_deferred(&tcp_md5_needed);
} }
#endif #endif
......
...@@ -261,10 +261,9 @@ static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw) ...@@ -261,10 +261,9 @@ static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw)
tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
if (!tcptw->tw_md5_key) if (!tcptw->tw_md5_key)
return; return;
if (!tcp_alloc_md5sig_pool())
goto out_free;
if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key))
goto out_free; goto out_free;
tcp_md5_add_sigpool();
} }
return; return;
out_free: out_free:
...@@ -349,16 +348,26 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) ...@@ -349,16 +348,26 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
} }
EXPORT_SYMBOL(tcp_time_wait); EXPORT_SYMBOL(tcp_time_wait);
#ifdef CONFIG_TCP_MD5SIG
static void tcp_md5_twsk_free_rcu(struct rcu_head *head)
{
struct tcp_md5sig_key *key;
key = container_of(head, struct tcp_md5sig_key, rcu);
kfree(key);
static_branch_slow_dec_deferred(&tcp_md5_needed);
tcp_md5_release_sigpool();
}
#endif
void tcp_twsk_destructor(struct sock *sk) void tcp_twsk_destructor(struct sock *sk)
{ {
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed.key)) { if (static_branch_unlikely(&tcp_md5_needed.key)) {
struct tcp_timewait_sock *twsk = tcp_twsk(sk); struct tcp_timewait_sock *twsk = tcp_twsk(sk);
if (twsk->tw_md5_key) { if (twsk->tw_md5_key)
kfree_rcu(twsk->tw_md5_key, rcu); call_rcu(&twsk->tw_md5_key->rcu, tcp_md5_twsk_free_rcu);
static_branch_slow_dec_deferred(&tcp_md5_needed);
}
} }
#endif #endif
} }
......
// SPDX-License-Identifier: GPL-2.0-or-later
#include <crypto/hash.h>
#include <linux/cpu.h>
#include <linux/kref.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/percpu.h>
#include <linux/workqueue.h>
#include <net/tcp.h>
static size_t __scratch_size;
static DEFINE_PER_CPU(void __rcu *, sigpool_scratch);
struct sigpool_entry {
struct crypto_ahash *hash;
const char *alg;
struct kref kref;
uint16_t needs_key:1,
reserved:15;
};
#define CPOOL_SIZE (PAGE_SIZE / sizeof(struct sigpool_entry))
static struct sigpool_entry cpool[CPOOL_SIZE];
static unsigned int cpool_populated;
static DEFINE_MUTEX(cpool_mutex);
/* Slow-path */
struct scratches_to_free {
struct rcu_head rcu;
unsigned int cnt;
void *scratches[];
};
static void free_old_scratches(struct rcu_head *head)
{
struct scratches_to_free *stf;
stf = container_of(head, struct scratches_to_free, rcu);
while (stf->cnt--)
kfree(stf->scratches[stf->cnt]);
kfree(stf);
}
/**
* sigpool_reserve_scratch - re-allocates scratch buffer, slow-path
* @size: request size for the scratch/temp buffer
*/
static int sigpool_reserve_scratch(size_t size)
{
struct scratches_to_free *stf;
size_t stf_sz = struct_size(stf, scratches, num_possible_cpus());
int cpu, err = 0;
lockdep_assert_held(&cpool_mutex);
if (__scratch_size >= size)
return 0;
stf = kmalloc(stf_sz, GFP_KERNEL);
if (!stf)
return -ENOMEM;
stf->cnt = 0;
size = max(size, __scratch_size);
cpus_read_lock();
for_each_possible_cpu(cpu) {
void *scratch, *old_scratch;
scratch = kmalloc_node(size, GFP_KERNEL, cpu_to_node(cpu));
if (!scratch) {
err = -ENOMEM;
break;
}
old_scratch = rcu_replace_pointer(per_cpu(sigpool_scratch, cpu),
scratch, lockdep_is_held(&cpool_mutex));
if (!cpu_online(cpu) || !old_scratch) {
kfree(old_scratch);
continue;
}
stf->scratches[stf->cnt++] = old_scratch;
}
cpus_read_unlock();
if (!err)
__scratch_size = size;
call_rcu(&stf->rcu, free_old_scratches);
return err;
}
static void sigpool_scratch_free(void)
{
int cpu;
for_each_possible_cpu(cpu)
kfree(rcu_replace_pointer(per_cpu(sigpool_scratch, cpu),
NULL, lockdep_is_held(&cpool_mutex)));
__scratch_size = 0;
}
static int __cpool_try_clone(struct crypto_ahash *hash)
{
struct crypto_ahash *tmp;
tmp = crypto_clone_ahash(hash);
if (IS_ERR(tmp))
return PTR_ERR(tmp);
crypto_free_ahash(tmp);
return 0;
}
static int __cpool_alloc_ahash(struct sigpool_entry *e, const char *alg)
{
struct crypto_ahash *cpu0_hash;
int ret;
e->alg = kstrdup(alg, GFP_KERNEL);
if (!e->alg)
return -ENOMEM;
cpu0_hash = crypto_alloc_ahash(alg, 0, CRYPTO_ALG_ASYNC);
if (IS_ERR(cpu0_hash)) {
ret = PTR_ERR(cpu0_hash);
goto out_free_alg;
}
e->needs_key = crypto_ahash_get_flags(cpu0_hash) & CRYPTO_TFM_NEED_KEY;
ret = __cpool_try_clone(cpu0_hash);
if (ret)
goto out_free_cpu0_hash;
e->hash = cpu0_hash;
kref_init(&e->kref);
return 0;
out_free_cpu0_hash:
crypto_free_ahash(cpu0_hash);
out_free_alg:
kfree(e->alg);
e->alg = NULL;
return ret;
}
/**
* tcp_sigpool_alloc_ahash - allocates pool for ahash requests
* @alg: name of async hash algorithm
* @scratch_size: reserve a tcp_sigpool::scratch buffer of this size
*/
int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size)
{
int i, ret;
/* slow-path */
mutex_lock(&cpool_mutex);
ret = sigpool_reserve_scratch(scratch_size);
if (ret)
goto out;
for (i = 0; i < cpool_populated; i++) {
if (!cpool[i].alg)
continue;
if (strcmp(cpool[i].alg, alg))
continue;
if (kref_read(&cpool[i].kref) > 0)
kref_get(&cpool[i].kref);
else
kref_init(&cpool[i].kref);
ret = i;
goto out;
}
for (i = 0; i < cpool_populated; i++) {
if (!cpool[i].alg)
break;
}
if (i >= CPOOL_SIZE) {
ret = -ENOSPC;
goto out;
}
ret = __cpool_alloc_ahash(&cpool[i], alg);
if (!ret) {
ret = i;
if (i == cpool_populated)
cpool_populated++;
}
out:
mutex_unlock(&cpool_mutex);
return ret;
}
EXPORT_SYMBOL_GPL(tcp_sigpool_alloc_ahash);
static void __cpool_free_entry(struct sigpool_entry *e)
{
crypto_free_ahash(e->hash);
kfree(e->alg);
memset(e, 0, sizeof(*e));
}
static void cpool_cleanup_work_cb(struct work_struct *work)
{
bool free_scratch = true;
unsigned int i;
mutex_lock(&cpool_mutex);
for (i = 0; i < cpool_populated; i++) {
if (kref_read(&cpool[i].kref) > 0) {
free_scratch = false;
continue;
}
if (!cpool[i].alg)
continue;
__cpool_free_entry(&cpool[i]);
}
if (free_scratch)
sigpool_scratch_free();
mutex_unlock(&cpool_mutex);
}
static DECLARE_WORK(cpool_cleanup_work, cpool_cleanup_work_cb);
static void cpool_schedule_cleanup(struct kref *kref)
{
schedule_work(&cpool_cleanup_work);
}
/**
* tcp_sigpool_release - decreases number of users for a pool. If it was
* the last user of the pool, releases any memory that was consumed.
* @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash()
*/
void tcp_sigpool_release(unsigned int id)
{
if (WARN_ON_ONCE(id > cpool_populated || !cpool[id].alg))
return;
/* slow-path */
kref_put(&cpool[id].kref, cpool_schedule_cleanup);
}
EXPORT_SYMBOL_GPL(tcp_sigpool_release);
/**
* tcp_sigpool_get - increases number of users (refcounter) for a pool
* @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash()
*/
void tcp_sigpool_get(unsigned int id)
{
if (WARN_ON_ONCE(id > cpool_populated || !cpool[id].alg))
return;
kref_get(&cpool[id].kref);
}
EXPORT_SYMBOL_GPL(tcp_sigpool_get);
int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c) __cond_acquires(RCU_BH)
{
struct crypto_ahash *hash;
rcu_read_lock_bh();
if (WARN_ON_ONCE(id > cpool_populated || !cpool[id].alg)) {
rcu_read_unlock_bh();
return -EINVAL;
}
hash = crypto_clone_ahash(cpool[id].hash);
if (IS_ERR(hash)) {
rcu_read_unlock_bh();
return PTR_ERR(hash);
}
c->req = ahash_request_alloc(hash, GFP_ATOMIC);
if (!c->req) {
crypto_free_ahash(hash);
rcu_read_unlock_bh();
return -ENOMEM;
}
ahash_request_set_callback(c->req, 0, NULL, NULL);
/* Pairs with tcp_sigpool_reserve_scratch(), scratch area is
* valid (allocated) until tcp_sigpool_end().
*/
c->scratch = rcu_dereference_bh(*this_cpu_ptr(&sigpool_scratch));
return 0;
}
EXPORT_SYMBOL_GPL(tcp_sigpool_start);
void tcp_sigpool_end(struct tcp_sigpool *c) __releases(RCU_BH)
{
struct crypto_ahash *hash = crypto_ahash_reqtfm(c->req);
rcu_read_unlock_bh();
ahash_request_free(c->req);
crypto_free_ahash(hash);
}
EXPORT_SYMBOL_GPL(tcp_sigpool_end);
/**
* tcp_sigpool_algo - return algorithm of tcp_sigpool
* @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash()
* @buf: buffer to return name of algorithm
* @buf_len: size of @buf
*/
size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len)
{
if (WARN_ON_ONCE(id > cpool_populated || !cpool[id].alg))
return -EINVAL;
return strscpy(buf, cpool[id].alg, buf_len);
}
EXPORT_SYMBOL_GPL(tcp_sigpool_algo);
/**
* tcp_sigpool_hash_skb_data - hash data in skb with initialized tcp_sigpool
* @hp: tcp_sigpool pointer
* @skb: buffer to add sign for
* @header_len: TCP header length for this segment
*/
int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp,
const struct sk_buff *skb,
unsigned int header_len)
{
const unsigned int head_data_len = skb_headlen(skb) > header_len ?
skb_headlen(skb) - header_len : 0;
const struct skb_shared_info *shi = skb_shinfo(skb);
const struct tcphdr *tp = tcp_hdr(skb);
struct ahash_request *req = hp->req;
struct sk_buff *frag_iter;
struct scatterlist sg;
unsigned int i;
sg_init_table(&sg, 1);
sg_set_buf(&sg, ((u8 *)tp) + header_len, head_data_len);
ahash_request_set_crypt(req, &sg, NULL, head_data_len);
if (crypto_ahash_update(req))
return 1;
for (i = 0; i < shi->nr_frags; ++i) {
const skb_frag_t *f = &shi->frags[i];
unsigned int offset = skb_frag_off(f);
struct page *page;
page = skb_frag_page(f) + (offset >> PAGE_SHIFT);
sg_set_page(&sg, page, skb_frag_size(f), offset_in_page(offset));
ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f));
if (crypto_ahash_update(req))
return 1;
}
skb_walk_frags(skb, frag_iter)
if (tcp_sigpool_hash_skb_data(hp, frag_iter, 0))
return 1;
return 0;
}
EXPORT_SYMBOL(tcp_sigpool_hash_skb_data);
MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("Per-CPU pool of crypto requests");
...@@ -671,7 +671,7 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname, ...@@ -671,7 +671,7 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
cmd.tcpm_key, cmd.tcpm_keylen); cmd.tcpm_key, cmd.tcpm_keylen);
} }
static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp, static int tcp_v6_md5_hash_headers(struct tcp_sigpool *hp,
const struct in6_addr *daddr, const struct in6_addr *daddr,
const struct in6_addr *saddr, const struct in6_addr *saddr,
const struct tcphdr *th, int nbytes) const struct tcphdr *th, int nbytes)
...@@ -692,39 +692,36 @@ static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp, ...@@ -692,39 +692,36 @@ static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
_th->check = 0; _th->check = 0;
sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th)); sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
ahash_request_set_crypt(hp->md5_req, &sg, NULL, ahash_request_set_crypt(hp->req, &sg, NULL,
sizeof(*bp) + sizeof(*th)); sizeof(*bp) + sizeof(*th));
return crypto_ahash_update(hp->md5_req); return crypto_ahash_update(hp->req);
} }
static int tcp_v6_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, static int tcp_v6_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
const struct in6_addr *daddr, struct in6_addr *saddr, const struct in6_addr *daddr, struct in6_addr *saddr,
const struct tcphdr *th) const struct tcphdr *th)
{ {
struct tcp_md5sig_pool *hp; struct tcp_sigpool hp;
struct ahash_request *req;
hp = tcp_get_md5sig_pool(); if (tcp_sigpool_start(tcp_md5_sigpool_id, &hp))
if (!hp) goto clear_hash_nostart;
goto clear_hash_noput;
req = hp->md5_req;
if (crypto_ahash_init(req)) if (crypto_ahash_init(hp.req))
goto clear_hash; goto clear_hash;
if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2)) if (tcp_v6_md5_hash_headers(&hp, daddr, saddr, th, th->doff << 2))
goto clear_hash; goto clear_hash;
if (tcp_md5_hash_key(hp, key)) if (tcp_md5_hash_key(&hp, key))
goto clear_hash; goto clear_hash;
ahash_request_set_crypt(req, NULL, md5_hash, 0); ahash_request_set_crypt(hp.req, NULL, md5_hash, 0);
if (crypto_ahash_final(req)) if (crypto_ahash_final(hp.req))
goto clear_hash; goto clear_hash;
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
return 0; return 0;
clear_hash: clear_hash:
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
clear_hash_noput: clear_hash_nostart:
memset(md5_hash, 0, 16); memset(md5_hash, 0, 16);
return 1; return 1;
} }
...@@ -734,10 +731,9 @@ static int tcp_v6_md5_hash_skb(char *md5_hash, ...@@ -734,10 +731,9 @@ static int tcp_v6_md5_hash_skb(char *md5_hash,
const struct sock *sk, const struct sock *sk,
const struct sk_buff *skb) const struct sk_buff *skb)
{ {
const struct in6_addr *saddr, *daddr;
struct tcp_md5sig_pool *hp;
struct ahash_request *req;
const struct tcphdr *th = tcp_hdr(skb); const struct tcphdr *th = tcp_hdr(skb);
const struct in6_addr *saddr, *daddr;
struct tcp_sigpool hp;
if (sk) { /* valid for establish/request sockets */ if (sk) { /* valid for establish/request sockets */
saddr = &sk->sk_v6_rcv_saddr; saddr = &sk->sk_v6_rcv_saddr;
...@@ -748,30 +744,28 @@ static int tcp_v6_md5_hash_skb(char *md5_hash, ...@@ -748,30 +744,28 @@ static int tcp_v6_md5_hash_skb(char *md5_hash,
daddr = &ip6h->daddr; daddr = &ip6h->daddr;
} }
hp = tcp_get_md5sig_pool(); if (tcp_sigpool_start(tcp_md5_sigpool_id, &hp))
if (!hp) goto clear_hash_nostart;
goto clear_hash_noput;
req = hp->md5_req;
if (crypto_ahash_init(req)) if (crypto_ahash_init(hp.req))
goto clear_hash; goto clear_hash;
if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, skb->len)) if (tcp_v6_md5_hash_headers(&hp, daddr, saddr, th, skb->len))
goto clear_hash; goto clear_hash;
if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2)) if (tcp_sigpool_hash_skb_data(&hp, skb, th->doff << 2))
goto clear_hash; goto clear_hash;
if (tcp_md5_hash_key(hp, key)) if (tcp_md5_hash_key(&hp, key))
goto clear_hash; goto clear_hash;
ahash_request_set_crypt(req, NULL, md5_hash, 0); ahash_request_set_crypt(hp.req, NULL, md5_hash, 0);
if (crypto_ahash_final(req)) if (crypto_ahash_final(hp.req))
goto clear_hash; goto clear_hash;
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
return 0; return 0;
clear_hash: clear_hash:
tcp_put_md5sig_pool(); tcp_sigpool_end(&hp);
clear_hash_noput: clear_hash_nostart:
memset(md5_hash, 0, 16); memset(md5_hash, 0, 16);
return 1; return 1;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment