diff mbox series

[RFC/RFT] crypto: arm64/aes-ce - add support for CTS-CBC mode

Message ID 20180908114213.9839-1-ard.biesheuvel@linaro.org
State New
Headers show
Series [RFC/RFT] crypto: arm64/aes-ce - add support for CTS-CBC mode | expand

Commit Message

Ard Biesheuvel Sept. 8, 2018, 11:42 a.m. UTC
Currently, we rely on the generic CTS chaining mode wrapper to
instantiate the cts(cbc(aes)) skcipher. Due to the high performance
of the ARMv8 Crypto Extensions AES instructions (~1 cycles per byte),
any overhead in the chaining mode layers is amplified, and so it pays
off considerably to fold the CTS handling into the core algorithm.

On Cortex-A53, this results in a ~50% speedup for smaller block sizes.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>

---
Raw performance numbers after the patch.

 arch/arm64/crypto/aes-glue.c  | 142 ++++++++++++++++++++
 arch/arm64/crypto/aes-modes.S |  73 ++++++++++
 2 files changed, 215 insertions(+)

-- 
2.18.0

Cortex-A53 @ 1 GHz

BEFORE:

testing speed of async cts(cbc(aes)) (cts(cbc-aes-ce)) encryption
 0 (128 bit key,   16 byte blocks): 1407866 ops in 1 secs ( 22525856 bytes)
 1 (128 bit key,   64 byte blocks):  466814 ops in 1 secs ( 29876096 bytes)
 2 (128 bit key,  256 byte blocks):  401023 ops in 1 secs (102661888 bytes)
 3 (128 bit key, 1024 byte blocks):  258238 ops in 1 secs (264435712 bytes)
 4 (128 bit key, 8192 byte blocks):   57905 ops in 1 secs (474357760 bytes)
 5 (192 bit key,   16 byte blocks): 1388333 ops in 1 secs ( 22213328 bytes)
 6 (192 bit key,   64 byte blocks):  448595 ops in 1 secs ( 28710080 bytes)
 7 (192 bit key,  256 byte blocks):  376951 ops in 1 secs ( 96499456 bytes)
 8 (192 bit key, 1024 byte blocks):  231635 ops in 1 secs (237194240 bytes)
 9 (192 bit key, 8192 byte blocks):   43345 ops in 1 secs (355082240 bytes)
10 (256 bit key,   16 byte blocks): 1370820 ops in 1 secs ( 21933120 bytes)
11 (256 bit key,   64 byte blocks):  452352 ops in 1 secs ( 28950528 bytes)
12 (256 bit key,  256 byte blocks):  376506 ops in 1 secs ( 96385536 bytes)
13 (256 bit key, 1024 byte blocks):  223219 ops in 1 secs (228576256 bytes)
14 (256 bit key, 8192 byte blocks):   44874 ops in 1 secs (367607808 bytes)

testing speed of async cts(cbc(aes)) (cts(cbc-aes-ce)) decryption
 0 (128 bit key,   16 byte blocks): 1402795 ops in 1 secs ( 22444720 bytes)
 1 (128 bit key,   64 byte blocks):  403300 ops in 1 secs ( 25811200 bytes)
 2 (128 bit key,  256 byte blocks):  367710 ops in 1 secs ( 94133760 bytes)
 3 (128 bit key, 1024 byte blocks):  269118 ops in 1 secs (275576832 bytes)
 4 (128 bit key, 8192 byte blocks):   74706 ops in 1 secs (611991552 bytes)
 5 (192 bit key,   16 byte blocks): 1381390 ops in 1 secs ( 22102240 bytes)
 6 (192 bit key,   64 byte blocks):  388555 ops in 1 secs ( 24867520 bytes)
 7 (192 bit key,  256 byte blocks):  350282 ops in 1 secs ( 89672192 bytes)
 8 (192 bit key, 1024 byte blocks):  251268 ops in 1 secs (257298432 bytes)
 9 (192 bit key, 8192 byte blocks):   56535 ops in 1 secs (463134720 bytes)
10 (256 bit key,   16 byte blocks): 1364334 ops in 1 secs ( 21829344 bytes)
11 (256 bit key,   64 byte blocks):  392610 ops in 1 secs ( 25127040 bytes)
12 (256 bit key,  256 byte blocks):  351150 ops in 1 secs ( 89894400 bytes)
13 (256 bit key, 1024 byte blocks):  247455 ops in 1 secs (253393920 bytes)
14 (256 bit key, 8192 byte blocks):   62530 ops in 1 secs (512245760 bytes)

AFTER:

testing speed of async cts(cbc(aes)) (cts-cbc-aes-ce) encryption
 0 (128 bit key,   16 byte blocks): 1380568 ops in 1 secs ( 22089088 bytes)
 1 (128 bit key,   64 byte blocks):  692731 ops in 1 secs ( 44334784 bytes)
 2 (128 bit key,  256 byte blocks):  556393 ops in 1 secs (142436608 bytes)
 3 (128 bit key, 1024 byte blocks):  314635 ops in 1 secs (322186240 bytes)
 4 (128 bit key, 8192 byte blocks):   57550 ops in 1 secs (471449600 bytes)
 5 (192 bit key,   16 byte blocks): 1367027 ops in 1 secs ( 21872432 bytes)
 6 (192 bit key,   64 byte blocks):  675058 ops in 1 secs ( 43203712 bytes)
 7 (192 bit key,  256 byte blocks):  523177 ops in 1 secs (133933312 bytes)
 8 (192 bit key, 1024 byte blocks):  279235 ops in 1 secs (285936640 bytes)
 9 (192 bit key, 8192 byte blocks):   46316 ops in 1 secs (379420672 bytes)
10 (256 bit key,   16 byte blocks): 1353576 ops in 1 secs ( 21657216 bytes)
11 (256 bit key,   64 byte blocks):  664523 ops in 1 secs ( 42529472 bytes)
12 (256 bit key,  256 byte blocks):  508141 ops in 1 secs (130084096 bytes)
13 (256 bit key, 1024 byte blocks):  264386 ops in 1 secs (270731264 bytes)
14 (256 bit key, 8192 byte blocks):   47224 ops in 1 secs (386859008 bytes)

testing speed of async cts(cbc(aes)) (cts-cbc-aes-ce) decryption
 0 (128 bit key,   16 byte blocks): 1388553 ops in 1 secs ( 22216848 bytes)
 1 (128 bit key,   64 byte blocks):  688402 ops in 1 secs ( 44057728 bytes)
 2 (128 bit key,  256 byte blocks):  589268 ops in 1 secs (150852608 bytes)
 3 (128 bit key, 1024 byte blocks):  372238 ops in 1 secs (381171712 bytes)
 4 (128 bit key, 8192 byte blocks):   75691 ops in 1 secs (620060672 bytes)
 5 (192 bit key,   16 byte blocks): 1366220 ops in 1 secs ( 21859520 bytes)
 6 (192 bit key,   64 byte blocks):  666889 ops in 1 secs ( 42680896 bytes)
 7 (192 bit key,  256 byte blocks):  561809 ops in 1 secs (143823104 bytes)
 8 (192 bit key, 1024 byte blocks):  344117 ops in 1 secs (352375808 bytes)
 9 (192 bit key, 8192 byte blocks):   63150 ops in 1 secs (517324800 bytes)
10 (256 bit key,   16 byte blocks): 1349266 ops in 1 secs ( 21588256 bytes)
11 (256 bit key,   64 byte blocks):  661056 ops in 1 secs ( 42307584 bytes)
12 (256 bit key,  256 byte blocks):  550261 ops in 1 secs (140866816 bytes)
13 (256 bit key, 1024 byte blocks):  332947 ops in 1 secs (340937728 bytes)
14 (256 bit key, 8192 byte blocks):   68759 ops in 1 secs (563273728 bytes)
diff mbox series

Patch

diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index adcb83eb683c..0860feedbafe 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -15,6 +15,7 @@ 
 #include <crypto/internal/hash.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <linux/module.h>
 #include <linux/cpufeature.h>
 #include <crypto/xts.h>
@@ -31,6 +32,8 @@ 
 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
 #define aes_xts_encrypt		ce_aes_xts_encrypt
 #define aes_xts_decrypt		ce_aes_xts_decrypt
@@ -45,6 +48,8 @@  MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
 #define aes_xts_encrypt		neon_aes_xts_encrypt
 #define aes_xts_decrypt		neon_aes_xts_decrypt
@@ -73,6 +78,11 @@  asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
 				int rounds, int blocks, u8 iv[]);
 
+asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u8 const rk[],
+				int rounds, int bytes, u8 iv[]);
+asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u8 const rk[],
+				int rounds, int bytes, u8 iv[]);
+
 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
 				int rounds, int blocks, u8 ctr[]);
 
@@ -209,6 +219,120 @@  static int cbc_decrypt(struct skcipher_request *req)
 	return err;
 }
 
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct skcipher_request subreq = *req;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+	unsigned int blocks;
+
+	if (req->cryptlen == AES_BLOCK_SIZE)
+		cbc_blocks = 1;
+
+	if (cbc_blocks > 0) {
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+		err = skcipher_walk_virt(&walk, &subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					(u8 *)ctx->key_enc, rounds, blocks,
+					walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+		dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    (u8 *)ctx->key_enc, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct skcipher_request subreq = *req;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+	unsigned int blocks;
+
+	if (req->cryptlen == AES_BLOCK_SIZE)
+		cbc_blocks = 1;
+
+	if (cbc_blocks > 0) {
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+		err = skcipher_walk_virt(&walk, &subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					(u8 *)ctx->key_dec, rounds, blocks,
+					walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+		dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    (u8 *)ctx->key_dec, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
 static int ctr_encrypt(struct skcipher_request *req)
 {
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -334,6 +458,24 @@  static struct skcipher_alg aes_algs[] = { {
 	.setkey		= skcipher_aes_setkey,
 	.encrypt	= cbc_encrypt,
 	.decrypt	= cbc_decrypt,
+}, {
+	.base = {
+		.cra_name		= "__cts(cbc(aes))",
+		.cra_driver_name	= "__cts-cbc-aes-" MODE,
+		.cra_priority		= PRIO,
+		.cra_flags		= CRYPTO_ALG_INTERNAL,
+		.cra_blocksize		= 1,
+		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
+		.cra_module		= THIS_MODULE,
+	},
+	.min_keysize	= AES_MIN_KEY_SIZE,
+	.max_keysize	= AES_MAX_KEY_SIZE,
+	.ivsize		= AES_BLOCK_SIZE,
+	.chunksize	= AES_BLOCK_SIZE,
+	.walksize	= 2 * AES_BLOCK_SIZE,
+	.setkey		= skcipher_aes_setkey,
+	.encrypt	= cts_cbc_encrypt,
+	.decrypt	= cts_cbc_decrypt,
 }, {
 	.base = {
 		.cra_name		= "__ctr(aes)",
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 483a7130cf0e..61bab20de8da 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -205,6 +205,79 @@  AES_ENTRY(aes_cbc_decrypt)
 	ret
 AES_ENDPROC(aes_cbc_decrypt)
 
+	/*
+	 * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u8 const rk[],
+	 *		       int rounds, int bytes, u8 iv[])
+	 * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u8 const rk[],
+	 *		       int rounds, int bytes, u8 iv[])
+	 */
+
+AES_ENTRY(aes_cbc_cts_encrypt)
+	adr		x8, .Lcts_permute_table + 48
+	sub		x9, x8, x4
+	sub		x4, x4, #16
+	sub		x8, x8, #48
+	add		x8, x8, x4
+	ld1		{v6.16b}, [x9]
+	ld1		{v7.16b}, [x8]
+
+	ld1		{v4.16b}, [x5]			/* get iv */
+	enc_prepare	w3, x2, x6
+
+	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
+	ld1		{v1.16b}, [x1]
+
+	eor		v0.16b, v0.16b, v4.16b		/* xor with iv */
+	tbl		v1.16b, {v1.16b}, v6.16b
+	encrypt_block	v0, w3, x2, x6, w7
+
+	eor		v1.16b, v1.16b, v0.16b
+	tbl		v0.16b, {v0.16b}, v7.16b
+	encrypt_block	v1, w3, x2, x6, w7
+
+	add		x4, x0, x4
+	st1		{v0.16b}, [x4]			/* overlapping stores */
+	st1		{v1.16b}, [x0]
+	ret
+AES_ENDPROC(aes_cbc_cts_encrypt)
+
+AES_ENTRY(aes_cbc_cts_decrypt)
+	adr		x8, .Lcts_permute_table + 48
+	sub		x9, x8, x4
+	sub		x4, x4, #16
+	sub		x8, x8, #48
+	add		x8, x8, x4
+	ld1		{v6.16b}, [x9]
+	ld1		{v7.16b}, [x8]
+
+	ld1		{v4.16b}, [x5]			/* get iv */
+	dec_prepare	w3, x2, x6
+
+	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
+	ld1		{v1.16b}, [x1]
+
+	tbl		v2.16b, {v1.16b}, v6.16b
+	decrypt_block	v0, w3, x2, x6, w7
+	eor		v2.16b, v2.16b, v0.16b
+
+	tbx		v0.16b, {v1.16b}, v6.16b
+	tbl		v2.16b, {v2.16b}, v7.16b
+	decrypt_block	v0, w3, x2, x6, w7
+	eor		v0.16b, v0.16b, v4.16b		/* xor with iv */
+
+	add		x4, x0, x4
+	st1		{v2.16b}, [x4]			/* overlapping stores */
+	st1		{v0.16b}, [x0]
+	ret
+AES_ENDPROC(aes_cbc_cts_decrypt)
+
+.Lcts_permute_table:
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
 
 	/*
 	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,