diff mbox series

[14/14] crypto: sahara - add support for crypto_engine

Message ID 20231223181108.3819741-14-ovidiu.panait@windriver.com
State Superseded
Headers show
Series [01/14] crypto: sahara - handle zero-length aes requests | expand

Commit Message

Ovidiu Panait Dec. 23, 2023, 6:11 p.m. UTC
From: Ovidiu Panait <ovidiu.panait@windriver.com>

Convert sahara driver to use crypto_engine, rather than doing manual queue
management.

Signed-off-by: Ovidiu Panait <ovidiu.panait@windriver.com>
---
 drivers/crypto/Kconfig  |   1 +
 drivers/crypto/sahara.c | 326 ++++++++++++++++++----------------------
 2 files changed, 147 insertions(+), 180 deletions(-)
diff mbox series

Patch

diff --git a/drivers/crypto/Kconfig b/drivers/crypto/Kconfig
index 79c3bb9c99c3..0991f026cb07 100644
--- a/drivers/crypto/Kconfig
+++ b/drivers/crypto/Kconfig
@@ -306,6 +306,7 @@  config CRYPTO_DEV_SAHARA
 	select CRYPTO_SKCIPHER
 	select CRYPTO_AES
 	select CRYPTO_ECB
+	select CRYPTO_ENGINE
 	help
 	  This option enables support for the SAHARA HW crypto accelerator
 	  found in some Freescale i.MX chips.
diff --git a/drivers/crypto/sahara.c b/drivers/crypto/sahara.c
index 886395603a3a..6f3935f450ac 100644
--- a/drivers/crypto/sahara.c
+++ b/drivers/crypto/sahara.c
@@ -15,6 +15,7 @@ 
 #include <crypto/internal/hash.h>
 #include <crypto/internal/skcipher.h>
 #include <crypto/scatterwalk.h>
+#include <crypto/engine.h>
 #include <crypto/sha1.h>
 #include <crypto/sha2.h>
 
@@ -24,7 +25,6 @@ 
 #include <linux/io.h>
 #include <linux/irq.h>
 #include <linux/kernel.h>
-#include <linux/kthread.h>
 #include <linux/module.h>
 #include <linux/of.h>
 #include <linux/platform_device.h>
@@ -73,9 +73,6 @@ 
 #define SAHARA_HDR_MDHA_MAC_FULL		BIT(9)
 #define SAHARA_HDR_MDHA_SSL			BIT(10)
 
-/* SAHARA can only process one request at a time */
-#define SAHARA_QUEUE_LENGTH			1
-
 #define SAHARA_REG_VERSION			0x00
 #define SAHARA_REG_DAR				0x04
 #define SAHARA_REG_CONTROL			0x08
@@ -191,12 +188,9 @@  struct sahara_dev {
 	void __iomem		*regs_base;
 	struct clk		*clk_ipg;
 	struct clk		*clk_ahb;
-	spinlock_t		queue_spinlock;
-	struct task_struct	*kthread;
 	struct completion	dma_completion;
 
 	struct sahara_ctx	*ctx;
-	struct crypto_queue	queue;
 	unsigned long		flags;
 
 	struct sahara_hw_desc	*hw_desc[SAHARA_MAX_HW_DESC];
@@ -219,6 +213,8 @@  struct sahara_dev {
 	int		nb_in_sg;
 	struct scatterlist	*out_sg;
 	int		nb_out_sg;
+
+	struct crypto_engine *engine;
 };
 
 static struct sahara_dev *dev_ptr;
@@ -671,7 +667,6 @@  static int sahara_aes_crypt(struct skcipher_request *req, unsigned long mode)
 	struct sahara_ctx *ctx = crypto_skcipher_ctx(
 		crypto_skcipher_reqtfm(req));
 	struct sahara_dev *dev = dev_ptr;
-	int err = 0;
 
 	if (!req->cryptlen)
 		return 0;
@@ -687,13 +682,7 @@  static int sahara_aes_crypt(struct skcipher_request *req, unsigned long mode)
 
 	rctx->mode = mode;
 
-	spin_lock_bh(&dev->queue_spinlock);
-	err = crypto_enqueue_request(&dev->queue, &req->base);
-	spin_unlock_bh(&dev->queue_spinlock);
-
-	wake_up_process(dev->kthread);
-
-	return err;
+	return crypto_transfer_skcipher_request_to_engine(dev->engine, req);
 }
 
 static int sahara_aes_ecb_encrypt(struct skcipher_request *req)
@@ -994,45 +983,26 @@  static int sahara_sha_process(struct ahash_request *req)
 	return 0;
 }
 
-static int sahara_queue_manage(void *data)
+static int sahara_do_one_request(struct crypto_engine *engine, void *areq)
 {
-	struct sahara_dev *dev = data;
-	struct crypto_async_request *async_req;
-	struct crypto_async_request *backlog;
-	int ret = 0;
-
-	do {
-		__set_current_state(TASK_INTERRUPTIBLE);
-
-		spin_lock_bh(&dev->queue_spinlock);
-		backlog = crypto_get_backlog(&dev->queue);
-		async_req = crypto_dequeue_request(&dev->queue);
-		spin_unlock_bh(&dev->queue_spinlock);
-
-		if (backlog)
-			crypto_request_complete(backlog, -EINPROGRESS);
-
-		if (async_req) {
-			if (crypto_tfm_alg_type(async_req->tfm) ==
-			    CRYPTO_ALG_TYPE_AHASH) {
-				struct ahash_request *req =
-					ahash_request_cast(async_req);
-
-				ret = sahara_sha_process(req);
-			} else {
-				struct skcipher_request *req =
-					skcipher_request_cast(async_req);
-
-				ret = sahara_aes_process(req);
-			}
+	struct crypto_async_request *async_req = areq;
+	int err;
 
-			crypto_request_complete(async_req, ret);
+	if (crypto_tfm_alg_type(async_req->tfm) == CRYPTO_ALG_TYPE_AHASH) {
+		struct ahash_request *req = ahash_request_cast(async_req);
 
-			continue;
-		}
+		err = sahara_sha_process(req);
+		local_bh_disable();
+		crypto_finalize_hash_request(engine, req, err);
+		local_bh_enable();
+	} else {
+		struct skcipher_request *req = skcipher_request_cast(async_req);
 
-		schedule();
-	} while (!kthread_should_stop());
+		err = sahara_aes_process(skcipher_request_cast(async_req));
+		local_bh_disable();
+		crypto_finalize_skcipher_request(engine, req, err);
+		local_bh_enable();
+	}
 
 	return 0;
 }
@@ -1041,20 +1011,13 @@  static int sahara_sha_enqueue(struct ahash_request *req, int last)
 {
 	struct sahara_sha_reqctx *rctx = ahash_request_ctx(req);
 	struct sahara_dev *dev = dev_ptr;
-	int ret;
 
 	if (!req->nbytes && !last)
 		return 0;
 
 	rctx->last = last;
 
-	spin_lock_bh(&dev->queue_spinlock);
-	ret = crypto_enqueue_request(&dev->queue, &req->base);
-	spin_unlock_bh(&dev->queue_spinlock);
-
-	wake_up_process(dev->kthread);
-
-	return ret;
+	return crypto_transfer_hash_request_to_engine(dev->engine, req);
 }
 
 static int sahara_sha_init(struct ahash_request *req)
@@ -1132,94 +1095,114 @@  static int sahara_sha_cra_init(struct crypto_tfm *tfm)
 	return 0;
 }
 
-static struct skcipher_alg aes_algs[] = {
+static struct skcipher_engine_alg aes_algs[] = {
 {
-	.base.cra_name		= "ecb(aes)",
-	.base.cra_driver_name	= "sahara-ecb-aes",
-	.base.cra_priority	= 300,
-	.base.cra_flags		= CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK,
-	.base.cra_blocksize	= AES_BLOCK_SIZE,
-	.base.cra_ctxsize	= sizeof(struct sahara_ctx),
-	.base.cra_alignmask	= 0x0,
-	.base.cra_module	= THIS_MODULE,
-
-	.init			= sahara_aes_init_tfm,
-	.exit			= sahara_aes_exit_tfm,
-	.min_keysize		= AES_MIN_KEY_SIZE ,
-	.max_keysize		= AES_MAX_KEY_SIZE,
-	.setkey			= sahara_aes_setkey,
-	.encrypt		= sahara_aes_ecb_encrypt,
-	.decrypt		= sahara_aes_ecb_decrypt,
+	.base = {
+		.base.cra_name		= "ecb(aes)",
+		.base.cra_driver_name	= "sahara-ecb-aes",
+		.base.cra_priority	= 300,
+		.base.cra_flags		= CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK,
+		.base.cra_blocksize	= AES_BLOCK_SIZE,
+		.base.cra_ctxsize	= sizeof(struct sahara_ctx),
+		.base.cra_alignmask	= 0x0,
+		.base.cra_module	= THIS_MODULE,
+
+		.init			= sahara_aes_init_tfm,
+		.exit			= sahara_aes_exit_tfm,
+		.min_keysize		= AES_MIN_KEY_SIZE,
+		.max_keysize		= AES_MAX_KEY_SIZE,
+		.setkey			= sahara_aes_setkey,
+		.encrypt		= sahara_aes_ecb_encrypt,
+		.decrypt		= sahara_aes_ecb_decrypt,
+	},
+	.op = {
+		.do_one_request = sahara_do_one_request,
+	},
 }, {
-	.base.cra_name		= "cbc(aes)",
-	.base.cra_driver_name	= "sahara-cbc-aes",
-	.base.cra_priority	= 300,
-	.base.cra_flags		= CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK,
-	.base.cra_blocksize	= AES_BLOCK_SIZE,
-	.base.cra_ctxsize	= sizeof(struct sahara_ctx),
-	.base.cra_alignmask	= 0x0,
-	.base.cra_module	= THIS_MODULE,
-
-	.init			= sahara_aes_init_tfm,
-	.exit			= sahara_aes_exit_tfm,
-	.min_keysize		= AES_MIN_KEY_SIZE ,
-	.max_keysize		= AES_MAX_KEY_SIZE,
-	.ivsize			= AES_BLOCK_SIZE,
-	.setkey			= sahara_aes_setkey,
-	.encrypt		= sahara_aes_cbc_encrypt,
-	.decrypt		= sahara_aes_cbc_decrypt,
+	.base = {
+		.base.cra_name		= "cbc(aes)",
+		.base.cra_driver_name	= "sahara-cbc-aes",
+		.base.cra_priority	= 300,
+		.base.cra_flags		= CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK,
+		.base.cra_blocksize	= AES_BLOCK_SIZE,
+		.base.cra_ctxsize	= sizeof(struct sahara_ctx),
+		.base.cra_alignmask	= 0x0,
+		.base.cra_module	= THIS_MODULE,
+
+		.init			= sahara_aes_init_tfm,
+		.exit			= sahara_aes_exit_tfm,
+		.min_keysize		= AES_MIN_KEY_SIZE,
+		.max_keysize		= AES_MAX_KEY_SIZE,
+		.ivsize			= AES_BLOCK_SIZE,
+		.setkey			= sahara_aes_setkey,
+		.encrypt		= sahara_aes_cbc_encrypt,
+		.decrypt		= sahara_aes_cbc_decrypt,
+	},
+	.op = {
+		.do_one_request = sahara_do_one_request,
+	},
 }
 };
 
-static struct ahash_alg sha_v3_algs[] = {
+static struct ahash_engine_alg sha_v3_algs[] = {
 {
-	.init		= sahara_sha_init,
-	.update		= sahara_sha_update,
-	.final		= sahara_sha_final,
-	.finup		= sahara_sha_finup,
-	.digest		= sahara_sha_digest,
-	.export		= sahara_sha_export,
-	.import		= sahara_sha_import,
-	.halg.digestsize	= SHA1_DIGEST_SIZE,
-	.halg.statesize         = sizeof(struct sahara_sha_reqctx),
-	.halg.base	= {
-		.cra_name		= "sha1",
-		.cra_driver_name	= "sahara-sha1",
-		.cra_priority		= 300,
-		.cra_flags		= CRYPTO_ALG_ASYNC |
-						CRYPTO_ALG_NEED_FALLBACK,
-		.cra_blocksize		= SHA1_BLOCK_SIZE,
-		.cra_ctxsize		= sizeof(struct sahara_ctx),
-		.cra_alignmask		= 0,
-		.cra_module		= THIS_MODULE,
-		.cra_init		= sahara_sha_cra_init,
-	}
+	.base = {
+		.init		= sahara_sha_init,
+		.update		= sahara_sha_update,
+		.final		= sahara_sha_final,
+		.finup		= sahara_sha_finup,
+		.digest		= sahara_sha_digest,
+		.export		= sahara_sha_export,
+		.import		= sahara_sha_import,
+		.halg.digestsize	= SHA1_DIGEST_SIZE,
+		.halg.statesize         = sizeof(struct sahara_sha_reqctx),
+		.halg.base	= {
+			.cra_name		= "sha1",
+			.cra_driver_name	= "sahara-sha1",
+			.cra_priority		= 300,
+			.cra_flags		= CRYPTO_ALG_ASYNC |
+							CRYPTO_ALG_NEED_FALLBACK,
+			.cra_blocksize		= SHA1_BLOCK_SIZE,
+			.cra_ctxsize		= sizeof(struct sahara_ctx),
+			.cra_alignmask		= 0,
+			.cra_module		= THIS_MODULE,
+			.cra_init		= sahara_sha_cra_init,
+		}
+	},
+	.op = {
+		.do_one_request = sahara_do_one_request,
+	},
 },
 };
 
-static struct ahash_alg sha_v4_algs[] = {
+static struct ahash_engine_alg sha_v4_algs[] = {
 {
-	.init		= sahara_sha_init,
-	.update		= sahara_sha_update,
-	.final		= sahara_sha_final,
-	.finup		= sahara_sha_finup,
-	.digest		= sahara_sha_digest,
-	.export		= sahara_sha_export,
-	.import		= sahara_sha_import,
-	.halg.digestsize	= SHA256_DIGEST_SIZE,
-	.halg.statesize         = sizeof(struct sahara_sha_reqctx),
-	.halg.base	= {
-		.cra_name		= "sha256",
-		.cra_driver_name	= "sahara-sha256",
-		.cra_priority		= 300,
-		.cra_flags		= CRYPTO_ALG_ASYNC |
-						CRYPTO_ALG_NEED_FALLBACK,
-		.cra_blocksize		= SHA256_BLOCK_SIZE,
-		.cra_ctxsize		= sizeof(struct sahara_ctx),
-		.cra_alignmask		= 0,
-		.cra_module		= THIS_MODULE,
-		.cra_init		= sahara_sha_cra_init,
-	}
+	.base = {
+		.init		= sahara_sha_init,
+		.update		= sahara_sha_update,
+		.final		= sahara_sha_final,
+		.finup		= sahara_sha_finup,
+		.digest		= sahara_sha_digest,
+		.export		= sahara_sha_export,
+		.import		= sahara_sha_import,
+		.halg.digestsize	= SHA256_DIGEST_SIZE,
+		.halg.statesize         = sizeof(struct sahara_sha_reqctx),
+		.halg.base	= {
+			.cra_name		= "sha256",
+			.cra_driver_name	= "sahara-sha256",
+			.cra_priority		= 300,
+			.cra_flags		= CRYPTO_ALG_ASYNC |
+							CRYPTO_ALG_NEED_FALLBACK,
+			.cra_blocksize		= SHA256_BLOCK_SIZE,
+			.cra_ctxsize		= sizeof(struct sahara_ctx),
+			.cra_alignmask		= 0,
+			.cra_module		= THIS_MODULE,
+			.cra_init		= sahara_sha_cra_init,
+		}
+	},
+	.op = {
+		.do_one_request = sahara_do_one_request,
+	},
 },
 };
 
@@ -1249,57 +1232,39 @@  static irqreturn_t sahara_irq_handler(int irq, void *data)
 static int sahara_register_algs(struct sahara_dev *dev)
 {
 	int err;
-	unsigned int i, j, k, l;
 
-	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
-		err = crypto_register_skcipher(&aes_algs[i]);
-		if (err)
-			goto err_aes_algs;
-	}
+	err = crypto_engine_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
+	if (err)
+		return err;
 
-	for (k = 0; k < ARRAY_SIZE(sha_v3_algs); k++) {
-		err = crypto_register_ahash(&sha_v3_algs[k]);
+	err = crypto_engine_register_ahashes(sha_v3_algs,
+					     ARRAY_SIZE(sha_v3_algs));
+	if (err)
+		goto err_aes_algs;
+
+	if (dev->version > SAHARA_VERSION_3) {
+		err = crypto_engine_register_ahashes(sha_v4_algs,
+						     ARRAY_SIZE(sha_v4_algs));
 		if (err)
 			goto err_sha_v3_algs;
 	}
 
-	if (dev->version > SAHARA_VERSION_3)
-		for (l = 0; l < ARRAY_SIZE(sha_v4_algs); l++) {
-			err = crypto_register_ahash(&sha_v4_algs[l]);
-			if (err)
-				goto err_sha_v4_algs;
-		}
-
 	return 0;
 
-err_sha_v4_algs:
-	for (j = 0; j < l; j++)
-		crypto_unregister_ahash(&sha_v4_algs[j]);
-
 err_sha_v3_algs:
-	for (j = 0; j < k; j++)
-		crypto_unregister_ahash(&sha_v3_algs[j]);
+	crypto_engine_unregister_ahashes(sha_v3_algs, ARRAY_SIZE(sha_v3_algs));
 
 err_aes_algs:
-	for (j = 0; j < i; j++)
-		crypto_unregister_skcipher(&aes_algs[j]);
+	crypto_engine_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 
 	return err;
 }
 
 static void sahara_unregister_algs(struct sahara_dev *dev)
 {
-	unsigned int i;
-
-	for (i = 0; i < ARRAY_SIZE(aes_algs); i++)
-		crypto_unregister_skcipher(&aes_algs[i]);
-
-	for (i = 0; i < ARRAY_SIZE(sha_v3_algs); i++)
-		crypto_unregister_ahash(&sha_v3_algs[i]);
-
-	if (dev->version > SAHARA_VERSION_3)
-		for (i = 0; i < ARRAY_SIZE(sha_v4_algs); i++)
-			crypto_unregister_ahash(&sha_v4_algs[i]);
+	crypto_engine_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
+	crypto_engine_unregister_ahashes(sha_v3_algs, ARRAY_SIZE(sha_v3_algs));
+	crypto_engine_unregister_ahashes(sha_v4_algs, ARRAY_SIZE(sha_v4_algs));
 }
 
 static const struct of_device_id sahara_dt_ids[] = {
@@ -1388,15 +1353,17 @@  static int sahara_probe(struct platform_device *pdev)
 		dev->hw_link[i] = dev->hw_link[i - 1] + 1;
 	}
 
-	crypto_init_queue(&dev->queue, SAHARA_QUEUE_LENGTH);
-
-	spin_lock_init(&dev->queue_spinlock);
-
 	dev_ptr = dev;
 
-	dev->kthread = kthread_run(sahara_queue_manage, dev, "sahara_crypto");
-	if (IS_ERR(dev->kthread)) {
-		return PTR_ERR(dev->kthread);
+	dev->engine = crypto_engine_alloc_init(&pdev->dev, true);
+	if (!dev->engine)
+		return -ENOMEM;
+
+	err = crypto_engine_start(dev->engine);
+	if (err) {
+		crypto_engine_exit(dev->engine);
+		return dev_err_probe(&pdev->dev, err,
+				     "Could not start crypto engine\n");
 	}
 
 	init_completion(&dev->dma_completion);
@@ -1436,7 +1403,7 @@  static int sahara_probe(struct platform_device *pdev)
 	return 0;
 
 err_algs:
-	kthread_stop(dev->kthread);
+	crypto_engine_exit(dev->engine);
 
 	return err;
 }
@@ -1445,8 +1412,7 @@  static void sahara_remove(struct platform_device *pdev)
 {
 	struct sahara_dev *dev = platform_get_drvdata(pdev);
 
-	kthread_stop(dev->kthread);
-
+	crypto_engine_exit(dev->engine);
 	sahara_unregister_algs(dev);
 }