diff mbox series

[3/8] crypto: arm64/aes-ccm - Pass short inputs via stack buffer

Message ID 20240111123302.589910-13-ardb+git@google.com
State Superseded
Headers show
Series crypto: Clean up arm64 AES-CCM code | expand

Commit Message

Ard Biesheuvel Jan. 11, 2024, 12:33 p.m. UTC
From: Ard Biesheuvel <ardb@kernel.org>

In preparation for optimizing the CCM core asm code using permutation
vectors and overlapping loads and stores, ensure that inputs shorter
than the size of a AES block are passed via a buffer on the stack, in a
way that positions the data at the end of a 16 byte buffer. This removes
the need for the asm code to reason about a rare corner case where the
tail of the data cannot be read/written using a single NEON load/store
instruction.

While at it, tweak the copyright header and authorship to bring it up to
date.

Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
---
 arch/arm64/crypto/aes-ce-ccm-glue.c | 57 ++++++++++++++------
 1 file changed, 40 insertions(+), 17 deletions(-)
diff mbox series

Patch

diff --git a/arch/arm64/crypto/aes-ce-ccm-glue.c b/arch/arm64/crypto/aes-ce-ccm-glue.c
index b177ebea7d09..2f4e6a318fcd 100644
--- a/arch/arm64/crypto/aes-ce-ccm-glue.c
+++ b/arch/arm64/crypto/aes-ce-ccm-glue.c
@@ -1,8 +1,11 @@ 
 // SPDX-License-Identifier: GPL-2.0-only
 /*
- * aes-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
+ * aes-ce-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
  *
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2013 - 2017 Linaro Ltd.
+ * Copyright (C) 2024 Google LLC
+ *
+ * Author: Ard Biesheuvel <ardb@kernel.org>
  */
 
 #include <asm/neon.h>
@@ -149,7 +152,7 @@  static int ccm_encrypt(struct aead_request *req)
 	struct crypto_aes_ctx *ctx = crypto_aead_ctx(aead);
 	struct skcipher_walk walk;
 	u8 __aligned(8) mac[AES_BLOCK_SIZE];
-	u8 buf[AES_BLOCK_SIZE];
+	u8 orig_iv[AES_BLOCK_SIZE];
 	u32 len = req->cryptlen;
 	int err;
 
@@ -158,7 +161,7 @@  static int ccm_encrypt(struct aead_request *req)
 		return err;
 
 	/* preserve the original iv for the final round */
-	memcpy(buf, req->iv, AES_BLOCK_SIZE);
+	memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
 
 	err = skcipher_walk_aead_encrypt(&walk, req, false);
 	if (unlikely(err))
@@ -171,16 +174,26 @@  static int ccm_encrypt(struct aead_request *req)
 
 	do {
 		u32 tail = walk.nbytes % AES_BLOCK_SIZE;
+		const u8 *src = walk.src.virt.addr;
+		u8 *dst = walk.dst.virt.addr;
+		u8 buf[AES_BLOCK_SIZE];
 
 		if (walk.nbytes == walk.total)
 			tail = 0;
 
-		ce_aes_ccm_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   walk.nbytes - tail, ctx->key_enc,
-				   num_rounds(ctx), mac, walk.iv);
+		if (unlikely(walk.total < AES_BLOCK_SIZE))
+			src = dst = memcpy(buf + sizeof(buf) - walk.total,
+					   src, walk.total);
+
+		ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
+				   ctx->key_enc, num_rounds(ctx),
+				   mac, walk.iv);
+
+		if (unlikely(walk.total < AES_BLOCK_SIZE))
+			memcpy(walk.dst.virt.addr, dst, walk.total);
 
 		if (walk.nbytes == walk.total)
-			ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+			ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
 
 		if (walk.nbytes) {
 			err = skcipher_walk_done(&walk, tail);
@@ -206,7 +219,7 @@  static int ccm_decrypt(struct aead_request *req)
 	unsigned int authsize = crypto_aead_authsize(aead);
 	struct skcipher_walk walk;
 	u8 __aligned(8) mac[AES_BLOCK_SIZE];
-	u8 buf[AES_BLOCK_SIZE];
+	u8 orig_iv[AES_BLOCK_SIZE];
 	u32 len = req->cryptlen - authsize;
 	int err;
 
@@ -215,7 +228,7 @@  static int ccm_decrypt(struct aead_request *req)
 		return err;
 
 	/* preserve the original iv for the final round */
-	memcpy(buf, req->iv, AES_BLOCK_SIZE);
+	memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
 
 	err = skcipher_walk_aead_decrypt(&walk, req, false);
 	if (unlikely(err))
@@ -228,16 +241,26 @@  static int ccm_decrypt(struct aead_request *req)
 
 	do {
 		u32 tail = walk.nbytes % AES_BLOCK_SIZE;
+		const u8 *src = walk.src.virt.addr;
+		u8 *dst = walk.dst.virt.addr;
+		u8 buf[AES_BLOCK_SIZE];
 
 		if (walk.nbytes == walk.total)
 			tail = 0;
 
-		ce_aes_ccm_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				   walk.nbytes - tail, ctx->key_enc,
-				   num_rounds(ctx), mac, walk.iv);
+		if (unlikely(walk.total < AES_BLOCK_SIZE))
+			src = dst = memcpy(buf + sizeof(buf) - walk.total,
+					   src, walk.total);
+
+		ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
+				   ctx->key_enc, num_rounds(ctx),
+				   mac, walk.iv);
+
+		if (unlikely(walk.total < AES_BLOCK_SIZE))
+			memcpy(walk.dst.virt.addr, dst, walk.total);
 
 		if (walk.nbytes == walk.total)
-			ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+			ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
 
 		if (walk.nbytes) {
 			err = skcipher_walk_done(&walk, tail);
@@ -250,11 +273,11 @@  static int ccm_decrypt(struct aead_request *req)
 		return err;
 
 	/* compare calculated auth tag with the stored one */
-	scatterwalk_map_and_copy(buf, req->src,
+	scatterwalk_map_and_copy(orig_iv, req->src,
 				 req->assoclen + req->cryptlen - authsize,
 				 authsize, 0);
 
-	if (crypto_memneq(mac, buf, authsize))
+	if (crypto_memneq(mac, orig_iv, authsize))
 		return -EBADMSG;
 	return 0;
 }
@@ -293,6 +316,6 @@  module_init(aes_mod_init);
 module_exit(aes_mod_exit);
 
 MODULE_DESCRIPTION("Synchronous AES in CCM mode using ARMv8 Crypto Extensions");
-MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
+MODULE_AUTHOR("Ard Biesheuvel <ardb@kernel.org>");
 MODULE_LICENSE("GPL v2");
 MODULE_ALIAS_CRYPTO("ccm(aes)");