crypto: arm64/aes-xctr - Add accelerated implementation of XCTR

Add hardware accelerated version of XCTR for ARM64 CPUs with ARMv8
Crypto Extension support.  This XCTR implementation is based on the CTR
implementation in aes-modes.S.

More information on XCTR can be found in
the HCTR2 paper: "Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdf

Signed-off-by: Nathan Huckleberry <nhuck@google.com>
Reviewed-by: Ard Biesheuvel <ardb@kernel.org>
Reviewed-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
This commit is contained in:
Nathan Huckleberry 2022-05-20 18:14:57 +00:00 committed by Herbert Xu
parent fd94fcf099
commit 23a251cc16
3 changed files with 166 additions and 64 deletions

View File

@ -96,13 +96,13 @@ config CRYPTO_AES_ARM64_CE_CCM
select CRYPTO_LIB_AES select CRYPTO_LIB_AES
config CRYPTO_AES_ARM64_CE_BLK config CRYPTO_AES_ARM64_CE_BLK
tristate "AES in ECB/CBC/CTR/XTS modes using ARMv8 Crypto Extensions" tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using ARMv8 Crypto Extensions"
depends on KERNEL_MODE_NEON depends on KERNEL_MODE_NEON
select CRYPTO_SKCIPHER select CRYPTO_SKCIPHER
select CRYPTO_AES_ARM64_CE select CRYPTO_AES_ARM64_CE
config CRYPTO_AES_ARM64_NEON_BLK config CRYPTO_AES_ARM64_NEON_BLK
tristate "AES in ECB/CBC/CTR/XTS modes using NEON instructions" tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using NEON instructions"
depends on KERNEL_MODE_NEON depends on KERNEL_MODE_NEON
select CRYPTO_SKCIPHER select CRYPTO_SKCIPHER
select CRYPTO_LIB_AES select CRYPTO_LIB_AES

View File

@ -34,10 +34,11 @@
#define aes_essiv_cbc_encrypt ce_aes_essiv_cbc_encrypt #define aes_essiv_cbc_encrypt ce_aes_essiv_cbc_encrypt
#define aes_essiv_cbc_decrypt ce_aes_essiv_cbc_decrypt #define aes_essiv_cbc_decrypt ce_aes_essiv_cbc_decrypt
#define aes_ctr_encrypt ce_aes_ctr_encrypt #define aes_ctr_encrypt ce_aes_ctr_encrypt
#define aes_xctr_encrypt ce_aes_xctr_encrypt
#define aes_xts_encrypt ce_aes_xts_encrypt #define aes_xts_encrypt ce_aes_xts_encrypt
#define aes_xts_decrypt ce_aes_xts_decrypt #define aes_xts_decrypt ce_aes_xts_decrypt
#define aes_mac_update ce_aes_mac_update #define aes_mac_update ce_aes_mac_update
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions"); MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
#else #else
#define MODE "neon" #define MODE "neon"
#define PRIO 200 #define PRIO 200
@ -50,16 +51,18 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#define aes_essiv_cbc_encrypt neon_aes_essiv_cbc_encrypt #define aes_essiv_cbc_encrypt neon_aes_essiv_cbc_encrypt
#define aes_essiv_cbc_decrypt neon_aes_essiv_cbc_decrypt #define aes_essiv_cbc_decrypt neon_aes_essiv_cbc_decrypt
#define aes_ctr_encrypt neon_aes_ctr_encrypt #define aes_ctr_encrypt neon_aes_ctr_encrypt
#define aes_xctr_encrypt neon_aes_xctr_encrypt
#define aes_xts_encrypt neon_aes_xts_encrypt #define aes_xts_encrypt neon_aes_xts_encrypt
#define aes_xts_decrypt neon_aes_xts_decrypt #define aes_xts_decrypt neon_aes_xts_decrypt
#define aes_mac_update neon_aes_mac_update #define aes_mac_update neon_aes_mac_update
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON"); MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
#endif #endif
#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS) #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
MODULE_ALIAS_CRYPTO("ecb(aes)"); MODULE_ALIAS_CRYPTO("ecb(aes)");
MODULE_ALIAS_CRYPTO("cbc(aes)"); MODULE_ALIAS_CRYPTO("cbc(aes)");
MODULE_ALIAS_CRYPTO("ctr(aes)"); MODULE_ALIAS_CRYPTO("ctr(aes)");
MODULE_ALIAS_CRYPTO("xts(aes)"); MODULE_ALIAS_CRYPTO("xts(aes)");
MODULE_ALIAS_CRYPTO("xctr(aes)");
#endif #endif
MODULE_ALIAS_CRYPTO("cts(cbc(aes))"); MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)"); MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
@ -89,6 +92,9 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 ctr[]); int rounds, int bytes, u8 ctr[]);
asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 ctr[], int byte_ctr);
asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
int rounds, int bytes, u32 const rk2[], u8 iv[], int rounds, int bytes, u32 const rk2[], u8 iv[],
int first); int first);
@ -442,6 +448,44 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
return err ?: cbc_decrypt_walk(req, &walk); return err ?: cbc_decrypt_walk(req, &walk);
} }
static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
int err, rounds = 6 + ctx->key_length / 4;
struct skcipher_walk walk;
unsigned int byte_ctr = 0;
err = skcipher_walk_virt(&walk, req, false);
while (walk.nbytes > 0) {
const u8 *src = walk.src.virt.addr;
unsigned int nbytes = walk.nbytes;
u8 *dst = walk.dst.virt.addr;
u8 buf[AES_BLOCK_SIZE];
if (unlikely(nbytes < AES_BLOCK_SIZE))
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
else if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);
kernel_neon_begin();
aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
walk.iv, byte_ctr);
kernel_neon_end();
if (unlikely(nbytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - nbytes, nbytes);
byte_ctr += nbytes;
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
}
return err;
}
static int __maybe_unused ctr_encrypt(struct skcipher_request *req) static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@ -669,6 +713,22 @@ static struct skcipher_alg aes_algs[] = { {
.setkey = skcipher_aes_setkey, .setkey = skcipher_aes_setkey,
.encrypt = ctr_encrypt, .encrypt = ctr_encrypt,
.decrypt = ctr_encrypt, .decrypt = ctr_encrypt,
}, {
.base = {
.cra_name = "xctr(aes)",
.cra_driver_name = "xctr-aes-" MODE,
.cra_priority = PRIO,
.cra_blocksize = 1,
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
.cra_module = THIS_MODULE,
},
.min_keysize = AES_MIN_KEY_SIZE,
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.chunksize = AES_BLOCK_SIZE,
.setkey = skcipher_aes_setkey,
.encrypt = xctr_encrypt,
.decrypt = xctr_encrypt,
}, { }, {
.base = { .base = {
.cra_name = "xts(aes)", .cra_name = "xts(aes)",

View File

@ -318,79 +318,102 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.previous .previous
/* /*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, * This macro generates the code for CTR and XCTR mode.
* int bytes, u8 ctr[])
*/ */
.macro ctr_encrypt xctr
AES_FUNC_START(aes_ctr_encrypt)
stp x29, x30, [sp, #-16]! stp x29, x30, [sp, #-16]!
mov x29, sp mov x29, sp
enc_prepare w3, x2, x12 enc_prepare w3, x2, x12
ld1 {vctr.16b}, [x5] ld1 {vctr.16b}, [x5]
umov x12, vctr.d[1] /* keep swabbed ctr in reg */ .if \xctr
rev x12, x12 umov x12, vctr.d[0]
lsr w11, w6, #4
.else
umov x12, vctr.d[1] /* keep swabbed ctr in reg */
rev x12, x12
.endif
.LctrloopNx: .LctrloopNx\xctr:
add w7, w4, #15 add w7, w4, #15
sub w4, w4, #MAX_STRIDE << 4 sub w4, w4, #MAX_STRIDE << 4
lsr w7, w7, #4 lsr w7, w7, #4
mov w8, #MAX_STRIDE mov w8, #MAX_STRIDE
cmp w7, w8 cmp w7, w8
csel w7, w7, w8, lt csel w7, w7, w8, lt
adds x12, x12, x7
.if \xctr
add x11, x11, x7
.else
adds x12, x12, x7
.endif
mov v0.16b, vctr.16b mov v0.16b, vctr.16b
mov v1.16b, vctr.16b mov v1.16b, vctr.16b
mov v2.16b, vctr.16b mov v2.16b, vctr.16b
mov v3.16b, vctr.16b mov v3.16b, vctr.16b
ST5( mov v4.16b, vctr.16b ) ST5( mov v4.16b, vctr.16b )
bcs 0f .if \xctr
sub x6, x11, #MAX_STRIDE - 1
sub x7, x11, #MAX_STRIDE - 2
sub x8, x11, #MAX_STRIDE - 3
sub x9, x11, #MAX_STRIDE - 4
ST5( sub x10, x11, #MAX_STRIDE - 5 )
eor x6, x6, x12
eor x7, x7, x12
eor x8, x8, x12
eor x9, x9, x12
ST5( eor x10, x10, x12 )
mov v0.d[0], x6
mov v1.d[0], x7
mov v2.d[0], x8
mov v3.d[0], x9
ST5( mov v4.d[0], x10 )
.else
bcs 0f
.subsection 1
/* apply carry to outgoing counter */
0: umov x8, vctr.d[0]
rev x8, x8
add x8, x8, #1
rev x8, x8
ins vctr.d[0], x8
.subsection 1 /* apply carry to N counter blocks for N := x12 */
/* apply carry to outgoing counter */ cbz x12, 2f
0: umov x8, vctr.d[0] adr x16, 1f
rev x8, x8 sub x16, x16, x12, lsl #3
add x8, x8, #1 br x16
rev x8, x8 bti c
ins vctr.d[0], x8 mov v0.d[0], vctr.d[0]
bti c
mov v1.d[0], vctr.d[0]
bti c
mov v2.d[0], vctr.d[0]
bti c
mov v3.d[0], vctr.d[0]
ST5( bti c )
ST5( mov v4.d[0], vctr.d[0] )
1: b 2f
.previous
/* apply carry to N counter blocks for N := x12 */ 2: rev x7, x12
cbz x12, 2f ins vctr.d[1], x7
adr x16, 1f sub x7, x12, #MAX_STRIDE - 1
sub x16, x16, x12, lsl #3 sub x8, x12, #MAX_STRIDE - 2
br x16 sub x9, x12, #MAX_STRIDE - 3
bti c rev x7, x7
mov v0.d[0], vctr.d[0] rev x8, x8
bti c mov v1.d[1], x7
mov v1.d[0], vctr.d[0] rev x9, x9
bti c ST5( sub x10, x12, #MAX_STRIDE - 4 )
mov v2.d[0], vctr.d[0] mov v2.d[1], x8
bti c ST5( rev x10, x10 )
mov v3.d[0], vctr.d[0] mov v3.d[1], x9
ST5( bti c ) ST5( mov v4.d[1], x10 )
ST5( mov v4.d[0], vctr.d[0] ) .endif
1: b 2f tbnz w4, #31, .Lctrtail\xctr
.previous
2: rev x7, x12
ins vctr.d[1], x7
sub x7, x12, #MAX_STRIDE - 1
sub x8, x12, #MAX_STRIDE - 2
sub x9, x12, #MAX_STRIDE - 3
rev x7, x7
rev x8, x8
mov v1.d[1], x7
rev x9, x9
ST5( sub x10, x12, #MAX_STRIDE - 4 )
mov v2.d[1], x8
ST5( rev x10, x10 )
mov v3.d[1], x9
ST5( mov v4.d[1], x10 )
tbnz w4, #31, .Lctrtail
ld1 {v5.16b-v7.16b}, [x1], #48 ld1 {v5.16b-v7.16b}, [x1], #48
ST4( bl aes_encrypt_block4x ) ST4( bl aes_encrypt_block4x )
ST5( bl aes_encrypt_block5x ) ST5( bl aes_encrypt_block5x )
@ -403,16 +426,17 @@ ST5( ld1 {v5.16b-v6.16b}, [x1], #32 )
ST5( eor v4.16b, v6.16b, v4.16b ) ST5( eor v4.16b, v6.16b, v4.16b )
st1 {v0.16b-v3.16b}, [x0], #64 st1 {v0.16b-v3.16b}, [x0], #64
ST5( st1 {v4.16b}, [x0], #16 ) ST5( st1 {v4.16b}, [x0], #16 )
cbz w4, .Lctrout cbz w4, .Lctrout\xctr
b .LctrloopNx b .LctrloopNx\xctr
.Lctrout: .Lctrout\xctr:
st1 {vctr.16b}, [x5] /* return next CTR value */ .if !\xctr
st1 {vctr.16b}, [x5] /* return next CTR value */
.endif
ldp x29, x30, [sp], #16 ldp x29, x30, [sp], #16
ret ret
.Lctrtail: .Lctrtail\xctr:
/* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */
mov x16, #16 mov x16, #16
ands x6, x4, #0xf ands x6, x4, #0xf
csel x13, x6, x16, ne csel x13, x6, x16, ne
@ -427,7 +451,7 @@ ST5( csel x14, x16, xzr, gt )
adr_l x12, .Lcts_permute_table adr_l x12, .Lcts_permute_table
add x12, x12, x13 add x12, x12, x13
ble .Lctrtail1x ble .Lctrtail1x\xctr
ST5( ld1 {v5.16b}, [x1], x14 ) ST5( ld1 {v5.16b}, [x1], x14 )
ld1 {v6.16b}, [x1], x15 ld1 {v6.16b}, [x1], x15
@ -459,9 +483,9 @@ ST5( st1 {v5.16b}, [x0], x14 )
add x13, x13, x0 add x13, x13, x0
st1 {v9.16b}, [x13] // overlapping stores st1 {v9.16b}, [x13] // overlapping stores
st1 {v8.16b}, [x0] st1 {v8.16b}, [x0]
b .Lctrout b .Lctrout\xctr
.Lctrtail1x: .Lctrtail1x\xctr:
sub x7, x6, #16 sub x7, x6, #16
csel x6, x6, x7, eq csel x6, x6, x7, eq
add x1, x1, x6 add x1, x1, x6
@ -476,9 +500,27 @@ ST5( mov v3.16b, v4.16b )
eor v5.16b, v5.16b, v3.16b eor v5.16b, v5.16b, v3.16b
bif v5.16b, v6.16b, v11.16b bif v5.16b, v6.16b, v11.16b
st1 {v5.16b}, [x0] st1 {v5.16b}, [x0]
b .Lctrout b .Lctrout\xctr
.endm
/*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int bytes, u8 ctr[])
*/
AES_FUNC_START(aes_ctr_encrypt)
ctr_encrypt 0
AES_FUNC_END(aes_ctr_encrypt) AES_FUNC_END(aes_ctr_encrypt)
/*
* aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int bytes, u8 const iv[], int byte_ctr)
*/
AES_FUNC_START(aes_xctr_encrypt)
ctr_encrypt 1
AES_FUNC_END(aes_xctr_encrypt)
/* /*
* aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds, * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,