Commit 5318d3db authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu

crypto: arm64/aes-ctr - improve tail handling

Counter mode is a stream cipher chaining mode that is typically used
with inputs that are of arbitrarily length, and so a tail block which
is smaller than a full AES block is rule rather than exception.

The current ctr(aes) implementation for arm64 always makes a separate
call into the assembler routine to process this tail block, which is
suboptimal, given that it requires reloading of the AES round keys,
and prevents us from handling this tail block using the 5-way stride
that we use for better performance on deep pipelines.

So let's update the assembler routine so it can handle any input size,
and uses NEON permutation instructions and overlapping loads and stores
to handle the tail block. This results in a ~16% speedup for 1420 byte
blocks on cores with deep pipelines such as ThunderX2.
Signed-off-by: default avatarArd Biesheuvel <ardb@kernel.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 15deb433
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifdef USE_V8_CRYPTO_EXTENSIONS #ifdef USE_V8_CRYPTO_EXTENSIONS
#define MODE "ce" #define MODE "ce"
#define PRIO 300 #define PRIO 300
#define STRIDE 5
#define aes_expandkey ce_aes_expandkey #define aes_expandkey ce_aes_expandkey
#define aes_ecb_encrypt ce_aes_ecb_encrypt #define aes_ecb_encrypt ce_aes_ecb_encrypt
#define aes_ecb_decrypt ce_aes_ecb_decrypt #define aes_ecb_decrypt ce_aes_ecb_decrypt
...@@ -41,6 +42,7 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions"); ...@@ -41,6 +42,7 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#else #else
#define MODE "neon" #define MODE "neon"
#define PRIO 200 #define PRIO 200
#define STRIDE 4
#define aes_ecb_encrypt neon_aes_ecb_encrypt #define aes_ecb_encrypt neon_aes_ecb_encrypt
#define aes_ecb_decrypt neon_aes_ecb_decrypt #define aes_ecb_decrypt neon_aes_ecb_decrypt
#define aes_cbc_encrypt neon_aes_cbc_encrypt #define aes_cbc_encrypt neon_aes_cbc_encrypt
...@@ -87,7 +89,7 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[], ...@@ -87,7 +89,7 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 const iv[]); int rounds, int bytes, u8 const iv[]);
asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 ctr[]); int rounds, int bytes, u8 ctr[], u8 finalbuf[]);
asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
int rounds, int bytes, u32 const rk2[], u8 iv[], int rounds, int bytes, u32 const rk2[], u8 iv[],
...@@ -448,34 +450,36 @@ static int ctr_encrypt(struct skcipher_request *req) ...@@ -448,34 +450,36 @@ static int ctr_encrypt(struct skcipher_request *req)
struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
int err, rounds = 6 + ctx->key_length / 4; int err, rounds = 6 + ctx->key_length / 4;
struct skcipher_walk walk; struct skcipher_walk walk;
int blocks;
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) { while (walk.nbytes > 0) {
const u8 *src = walk.src.virt.addr;
unsigned int nbytes = walk.nbytes;
u8 *dst = walk.dst.virt.addr;
u8 buf[AES_BLOCK_SIZE];
unsigned int tail;
if (unlikely(nbytes < AES_BLOCK_SIZE))
src = memcpy(buf, src, nbytes);
else if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);
kernel_neon_begin(); kernel_neon_begin();
aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
ctx->key_enc, rounds, blocks, walk.iv); walk.iv, buf);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
}
if (walk.nbytes) {
u8 __aligned(8) tail[AES_BLOCK_SIZE];
unsigned int nbytes = walk.nbytes;
u8 *tdst = walk.dst.virt.addr;
u8 *tsrc = walk.src.virt.addr;
tail = nbytes % (STRIDE * AES_BLOCK_SIZE);
if (tail > 0 && tail < AES_BLOCK_SIZE)
/* /*
* Tell aes_ctr_encrypt() to process a tail block. * The final partial block could not be returned using
* an overlapping store, so it was passed via buf[]
* instead.
*/ */
blocks = -1; memcpy(dst + nbytes - tail, buf, tail);
kernel_neon_begin(); err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
blocks, walk.iv);
kernel_neon_end();
crypto_xor_cpy(tdst, tsrc, tail, nbytes);
err = skcipher_walk_done(&walk, 0);
} }
return err; return err;
......
...@@ -321,42 +321,76 @@ AES_FUNC_END(aes_cbc_cts_decrypt) ...@@ -321,42 +321,76 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
/* /*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int blocks, u8 ctr[]) * int bytes, u8 ctr[], u8 finalbuf[])
*/ */
AES_FUNC_START(aes_ctr_encrypt) AES_FUNC_START(aes_ctr_encrypt)
stp x29, x30, [sp, #-16]! stp x29, x30, [sp, #-16]!
mov x29, sp mov x29, sp
enc_prepare w3, x2, x6 enc_prepare w3, x2, x12
ld1 {vctr.16b}, [x5] ld1 {vctr.16b}, [x5]
umov x6, vctr.d[1] /* keep swabbed ctr in reg */ umov x12, vctr.d[1] /* keep swabbed ctr in reg */
rev x6, x6 rev x12, x12
cmn w6, w4 /* 32 bit overflow? */
bcs .Lctrloop
.LctrloopNx: .LctrloopNx:
subs w4, w4, #MAX_STRIDE add w7, w4, #15
bmi .Lctr1x sub w4, w4, #MAX_STRIDE << 4
add w7, w6, #1 lsr w7, w7, #4
mov w8, #MAX_STRIDE
cmp w7, w8
csel w7, w7, w8, lt
adds x12, x12, x7
mov v0.16b, vctr.16b mov v0.16b, vctr.16b
add w8, w6, #2
mov v1.16b, vctr.16b mov v1.16b, vctr.16b
add w9, w6, #3
mov v2.16b, vctr.16b mov v2.16b, vctr.16b
add w9, w6, #3
rev w7, w7
mov v3.16b, vctr.16b mov v3.16b, vctr.16b
rev w8, w8
ST5( mov v4.16b, vctr.16b ) ST5( mov v4.16b, vctr.16b )
mov v1.s[3], w7 bcs 0f
rev w9, w9
ST5( add w10, w6, #4 ) .subsection 1
mov v2.s[3], w8 /* apply carry to outgoing counter */
ST5( rev w10, w10 ) 0: umov x8, vctr.d[0]
mov v3.s[3], w9 rev x8, x8
ST5( mov v4.s[3], w10 ) add x8, x8, #1
ld1 {v5.16b-v7.16b}, [x1], #48 /* get 3 input blocks */ rev x8, x8
ins vctr.d[0], x8
/* apply carry to N counter blocks for N := x12 */
adr x16, 1f
sub x16, x16, x12, lsl #3
br x16
hint 34 // bti c
mov v0.d[0], vctr.d[0]
hint 34 // bti c
mov v1.d[0], vctr.d[0]
hint 34 // bti c
mov v2.d[0], vctr.d[0]
hint 34 // bti c
mov v3.d[0], vctr.d[0]
ST5( hint 34 )
ST5( mov v4.d[0], vctr.d[0] )
1: b 2f
.previous
2: rev x7, x12
ins vctr.d[1], x7
sub x7, x12, #MAX_STRIDE - 1
sub x8, x12, #MAX_STRIDE - 2
sub x9, x12, #MAX_STRIDE - 3
rev x7, x7
rev x8, x8
mov v1.d[1], x7
rev x9, x9
ST5( sub x10, x12, #MAX_STRIDE - 4 )
mov v2.d[1], x8
ST5( rev x10, x10 )
mov v3.d[1], x9
ST5( mov v4.d[1], x10 )
tbnz w4, #31, .Lctrtail
ld1 {v5.16b-v7.16b}, [x1], #48
ST4( bl aes_encrypt_block4x ) ST4( bl aes_encrypt_block4x )
ST5( bl aes_encrypt_block5x ) ST5( bl aes_encrypt_block5x )
eor v0.16b, v5.16b, v0.16b eor v0.16b, v5.16b, v0.16b
...@@ -368,47 +402,72 @@ ST5( ld1 {v5.16b-v6.16b}, [x1], #32 ) ...@@ -368,47 +402,72 @@ ST5( ld1 {v5.16b-v6.16b}, [x1], #32 )
ST5( eor v4.16b, v6.16b, v4.16b ) ST5( eor v4.16b, v6.16b, v4.16b )
st1 {v0.16b-v3.16b}, [x0], #64 st1 {v0.16b-v3.16b}, [x0], #64
ST5( st1 {v4.16b}, [x0], #16 ) ST5( st1 {v4.16b}, [x0], #16 )
add x6, x6, #MAX_STRIDE
rev x7, x6
ins vctr.d[1], x7
cbz w4, .Lctrout cbz w4, .Lctrout
b .LctrloopNx b .LctrloopNx
.Lctr1x:
adds w4, w4, #MAX_STRIDE
beq .Lctrout
.Lctrloop:
mov v0.16b, vctr.16b
encrypt_block v0, w3, x2, x8, w7
adds x6, x6, #1 /* increment BE ctr */
rev x7, x6
ins vctr.d[1], x7
bcs .Lctrcarry /* overflow? */
.Lctrcarrydone:
subs w4, w4, #1
bmi .Lctrtailblock /* blocks <0 means tail block */
ld1 {v3.16b}, [x1], #16
eor v3.16b, v0.16b, v3.16b
st1 {v3.16b}, [x0], #16
bne .Lctrloop
.Lctrout: .Lctrout:
st1 {vctr.16b}, [x5] /* return next CTR value */ st1 {vctr.16b}, [x5] /* return next CTR value */
ldp x29, x30, [sp], #16 ldp x29, x30, [sp], #16
ret ret
.Lctrtailblock: .Lctrtail:
st1 {v0.16b}, [x0] /* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */
mov x16, #16
ands x13, x4, #0xf
csel x13, x13, x16, ne
ST5( cmp w4, #64 - (MAX_STRIDE << 4) )
ST5( csel x14, x16, xzr, gt )
cmp w4, #48 - (MAX_STRIDE << 4)
csel x15, x16, xzr, gt
cmp w4, #32 - (MAX_STRIDE << 4)
csel x16, x16, xzr, gt
cmp w4, #16 - (MAX_STRIDE << 4)
ble .Lctrtail1x
adr_l x12, .Lcts_permute_table
add x12, x12, x13
ST5( ld1 {v5.16b}, [x1], x14 )
ld1 {v6.16b}, [x1], x15
ld1 {v7.16b}, [x1], x16
ST4( bl aes_encrypt_block4x )
ST5( bl aes_encrypt_block5x )
ld1 {v8.16b}, [x1], x13
ld1 {v9.16b}, [x1]
ld1 {v10.16b}, [x12]
ST4( eor v6.16b, v6.16b, v0.16b )
ST4( eor v7.16b, v7.16b, v1.16b )
ST4( tbl v3.16b, {v3.16b}, v10.16b )
ST4( eor v8.16b, v8.16b, v2.16b )
ST4( eor v9.16b, v9.16b, v3.16b )
ST5( eor v5.16b, v5.16b, v0.16b )
ST5( eor v6.16b, v6.16b, v1.16b )
ST5( tbl v4.16b, {v4.16b}, v10.16b )
ST5( eor v7.16b, v7.16b, v2.16b )
ST5( eor v8.16b, v8.16b, v3.16b )
ST5( eor v9.16b, v9.16b, v4.16b )
ST5( st1 {v5.16b}, [x0], x14 )
st1 {v6.16b}, [x0], x15
st1 {v7.16b}, [x0], x16
add x13, x13, x0
st1 {v9.16b}, [x13] // overlapping stores
st1 {v8.16b}, [x0]
b .Lctrout b .Lctrout
.Lctrcarry: .Lctrtail1x:
umov x7, vctr.d[0] /* load upper word of ctr */ csel x0, x0, x6, eq // use finalbuf if less than a full block
rev x7, x7 /* ... to handle the carry */ ld1 {v5.16b}, [x1]
add x7, x7, #1 ST5( mov v3.16b, v4.16b )
rev x7, x7 encrypt_block v3, w3, x2, x8, w7
ins vctr.d[0], x7 eor v5.16b, v5.16b, v3.16b
b .Lctrcarrydone st1 {v5.16b}, [x0]
b .Lctrout
AES_FUNC_END(aes_ctr_encrypt) AES_FUNC_END(aes_ctr_encrypt)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment