diff mbox series

[v4,08/11] crypto: chacha20poly1305 - Use lib/crypto poly1305

Message ID 0babdb56d14256b44249dc2bf3190ec200d9d738.1745815528.git.herbert@gondor.apana.org.au
State New
Headers show
Series crypto: lib - Add partial block helper | expand

Commit Message

Herbert Xu April 28, 2025, 4:56 a.m. UTC
Since the poly1305 algorithm is fixed, there is no point in going
through the Crypto API for it.  Use the lib/crypto poly1305 interface
instead.

For compatiblity keep the poly1305 parameter in the algorithm name.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
---
 crypto/Kconfig            |   2 +-
 crypto/chacha20poly1305.c | 323 ++++++++------------------------------
 2 files changed, 67 insertions(+), 258 deletions(-)
diff mbox series

Patch

diff --git a/crypto/Kconfig b/crypto/Kconfig
index 9878286d1d68..f87e2a26d2dd 100644
--- a/crypto/Kconfig
+++ b/crypto/Kconfig
@@ -784,8 +784,8 @@  config CRYPTO_AEGIS128_SIMD
 config CRYPTO_CHACHA20POLY1305
 	tristate "ChaCha20-Poly1305"
 	select CRYPTO_CHACHA20
-	select CRYPTO_POLY1305
 	select CRYPTO_AEAD
+	select CRYPTO_LIB_POLY1305
 	select CRYPTO_MANAGER
 	help
 	  ChaCha20 stream cipher and Poly1305 authenticator combined
diff --git a/crypto/chacha20poly1305.c b/crypto/chacha20poly1305.c
index d740849f1c19..b29f66ba1e2f 100644
--- a/crypto/chacha20poly1305.c
+++ b/crypto/chacha20poly1305.c
@@ -12,36 +12,23 @@ 
 #include <crypto/chacha.h>
 #include <crypto/poly1305.h>
 #include <linux/err.h>
-#include <linux/init.h>
 #include <linux/kernel.h>
+#include <linux/mm.h>
 #include <linux/module.h>
+#include <linux/string.h>
 
 struct chachapoly_instance_ctx {
 	struct crypto_skcipher_spawn chacha;
-	struct crypto_ahash_spawn poly;
 	unsigned int saltlen;
 };
 
 struct chachapoly_ctx {
 	struct crypto_skcipher *chacha;
-	struct crypto_ahash *poly;
 	/* key bytes we use for the ChaCha20 IV */
 	unsigned int saltlen;
 	u8 salt[] __counted_by(saltlen);
 };
 
-struct poly_req {
-	/* zero byte padding for AD/ciphertext, as needed */
-	u8 pad[POLY1305_BLOCK_SIZE];
-	/* tail data with AD/ciphertext lengths */
-	struct {
-		__le64 assoclen;
-		__le64 cryptlen;
-	} tail;
-	struct scatterlist src[1];
-	struct ahash_request req; /* must be last member */
-};
-
 struct chacha_req {
 	u8 iv[CHACHA_IV_SIZE];
 	struct scatterlist src[1];
@@ -62,7 +49,6 @@  struct chachapoly_req_ctx {
 	/* request flags, with MAY_SLEEP cleared if needed */
 	u32 flags;
 	union {
-		struct poly_req poly;
 		struct chacha_req chacha;
 	} u;
 };
@@ -105,16 +91,6 @@  static int poly_verify_tag(struct aead_request *req)
 	return 0;
 }
 
-static int poly_copy_tag(struct aead_request *req)
-{
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-
-	scatterwalk_map_and_copy(rctx->tag, req->dst,
-				 req->assoclen + rctx->cryptlen,
-				 sizeof(rctx->tag), 1);
-	return 0;
-}
-
 static void chacha_decrypt_done(void *data, int err)
 {
 	async_done_continue(data, err, poly_verify_tag);
@@ -151,210 +127,76 @@  static int chacha_decrypt(struct aead_request *req)
 	return poly_verify_tag(req);
 }
 
-static int poly_tail_continue(struct aead_request *req)
+static int poly_hash(struct aead_request *req)
 {
 	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
+	const void *zp = page_address(ZERO_PAGE(0));
+	struct scatterlist *sg = req->src;
+	struct poly1305_desc_ctx desc;
+	struct scatter_walk walk;
+	struct {
+		union {
+			struct {
+				__le64 assoclen;
+				__le64 cryptlen;
+			};
+			u8 u8[16];
+		};
+	} tail;
+	unsigned int padlen;
+	unsigned int total;
+
+	if (sg != req->dst)
+		memcpy_sglist(req->dst, sg, req->assoclen);
 
 	if (rctx->cryptlen == req->cryptlen) /* encrypting */
-		return poly_copy_tag(req);
+		sg = req->dst;
 
-	return chacha_decrypt(req);
-}
+	poly1305_init(&desc, rctx->key);
+	scatterwalk_start(&walk, sg);
 
-static void poly_tail_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_tail_continue);
-}
+	total = rctx->assoclen;
+	while (total) {
+		unsigned int n = scatterwalk_next(&walk, total);
 
-static int poly_tail(struct aead_request *req)
-{
-	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(tfm);
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	int err;
-
-	preq->tail.assoclen = cpu_to_le64(rctx->assoclen);
-	preq->tail.cryptlen = cpu_to_le64(rctx->cryptlen);
-	sg_init_one(preq->src, &preq->tail, sizeof(preq->tail));
-
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_tail_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-	ahash_request_set_crypt(&preq->req, preq->src,
-				rctx->tag, sizeof(preq->tail));
-
-	err = crypto_ahash_finup(&preq->req);
-	if (err)
-		return err;
-
-	return poly_tail_continue(req);
-}
-
-static void poly_cipherpad_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_tail);
-}
-
-static int poly_cipherpad(struct aead_request *req)
-{
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	unsigned int padlen;
-	int err;
-
-	padlen = -rctx->cryptlen % POLY1305_BLOCK_SIZE;
-	memset(preq->pad, 0, sizeof(preq->pad));
-	sg_init_one(preq->src, preq->pad, padlen);
-
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_cipherpad_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-	ahash_request_set_crypt(&preq->req, preq->src, NULL, padlen);
-
-	err = crypto_ahash_update(&preq->req);
-	if (err)
-		return err;
-
-	return poly_tail(req);
-}
-
-static void poly_cipher_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_cipherpad);
-}
-
-static int poly_cipher(struct aead_request *req)
-{
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	struct scatterlist *crypt = req->src;
-	int err;
-
-	if (rctx->cryptlen == req->cryptlen) /* encrypting */
-		crypt = req->dst;
-
-	crypt = scatterwalk_ffwd(rctx->src, crypt, req->assoclen);
-
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_cipher_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-	ahash_request_set_crypt(&preq->req, crypt, NULL, rctx->cryptlen);
-
-	err = crypto_ahash_update(&preq->req);
-	if (err)
-		return err;
-
-	return poly_cipherpad(req);
-}
-
-static void poly_adpad_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_cipher);
-}
-
-static int poly_adpad(struct aead_request *req)
-{
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	unsigned int padlen;
-	int err;
+		poly1305_update(&desc, walk.addr, n);
+		scatterwalk_done_src(&walk, n);
+		total -= n;
+	}
 
 	padlen = -rctx->assoclen % POLY1305_BLOCK_SIZE;
-	memset(preq->pad, 0, sizeof(preq->pad));
-	sg_init_one(preq->src, preq->pad, padlen);
+	poly1305_update(&desc, zp, padlen);
 
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_adpad_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-	ahash_request_set_crypt(&preq->req, preq->src, NULL, padlen);
+	scatterwalk_skip(&walk, req->assoclen - rctx->assoclen);
 
-	err = crypto_ahash_update(&preq->req);
-	if (err)
-		return err;
+	total = rctx->cryptlen;
+	while (total) {
+		unsigned int n = scatterwalk_next(&walk, total);
 
-	return poly_cipher(req);
-}
+		poly1305_update(&desc, walk.addr, n);
+		scatterwalk_done_src(&walk, n);
+		total -= n;
+	}
 
-static void poly_ad_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_adpad);
-}
+	padlen = -rctx->cryptlen % POLY1305_BLOCK_SIZE;
+	poly1305_update(&desc, zp, padlen);
 
-static int poly_ad(struct aead_request *req)
-{
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	int err;
+	tail.assoclen = cpu_to_le64(rctx->assoclen);
+	tail.cryptlen = cpu_to_le64(rctx->cryptlen);
+	poly1305_update(&desc, tail.u8, sizeof(tail));
+	memzero_explicit(&tail, sizeof(tail));
+	poly1305_final(&desc, rctx->tag);
 
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_ad_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-	ahash_request_set_crypt(&preq->req, req->src, NULL, rctx->assoclen);
+	if (rctx->cryptlen != req->cryptlen)
+		return chacha_decrypt(req);
 
-	err = crypto_ahash_update(&preq->req);
-	if (err)
-		return err;
-
-	return poly_adpad(req);
-}
-
-static void poly_setkey_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_ad);
-}
-
-static int poly_setkey(struct aead_request *req)
-{
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	int err;
-
-	sg_init_one(preq->src, rctx->key, sizeof(rctx->key));
-
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_setkey_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-	ahash_request_set_crypt(&preq->req, preq->src, NULL, sizeof(rctx->key));
-
-	err = crypto_ahash_update(&preq->req);
-	if (err)
-		return err;
-
-	return poly_ad(req);
-}
-
-static void poly_init_done(void *data, int err)
-{
-	async_done_continue(data, err, poly_setkey);
-}
-
-static int poly_init(struct aead_request *req)
-{
-	struct chachapoly_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
-	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
-	struct poly_req *preq = &rctx->u.poly;
-	int err;
-
-	ahash_request_set_callback(&preq->req, rctx->flags,
-				   poly_init_done, req);
-	ahash_request_set_tfm(&preq->req, ctx->poly);
-
-	err = crypto_ahash_init(&preq->req);
-	if (err)
-		return err;
-
-	return poly_setkey(req);
+	memcpy_to_scatterwalk(&walk, rctx->tag, sizeof(rctx->tag));
+	return 0;
 }
 
 static void poly_genkey_done(void *data, int err)
 {
-	async_done_continue(data, err, poly_init);
+	async_done_continue(data, err, poly_hash);
 }
 
 static int poly_genkey(struct aead_request *req)
@@ -388,7 +230,7 @@  static int poly_genkey(struct aead_request *req)
 	if (err)
 		return err;
 
-	return poly_init(req);
+	return poly_hash(req);
 }
 
 static void chacha_encrypt_done(void *data, int err)
@@ -437,14 +279,7 @@  static int chachapoly_encrypt(struct aead_request *req)
 	/* encrypt call chain:
 	 * - chacha_encrypt/done()
 	 * - poly_genkey/done()
-	 * - poly_init/done()
-	 * - poly_setkey/done()
-	 * - poly_ad/done()
-	 * - poly_adpad/done()
-	 * - poly_cipher/done()
-	 * - poly_cipherpad/done()
-	 * - poly_tail/done/continue()
-	 * - poly_copy_tag()
+	 * - poly_hash()
 	 */
 	return chacha_encrypt(req);
 }
@@ -458,13 +293,7 @@  static int chachapoly_decrypt(struct aead_request *req)
 
 	/* decrypt call chain:
 	 * - poly_genkey/done()
-	 * - poly_init/done()
-	 * - poly_setkey/done()
-	 * - poly_ad/done()
-	 * - poly_adpad/done()
-	 * - poly_cipher/done()
-	 * - poly_cipherpad/done()
-	 * - poly_tail/done/continue()
+	 * - poly_hash()
 	 * - chacha_decrypt/done()
 	 * - poly_verify_tag()
 	 */
@@ -503,21 +332,13 @@  static int chachapoly_init(struct crypto_aead *tfm)
 	struct chachapoly_instance_ctx *ictx = aead_instance_ctx(inst);
 	struct chachapoly_ctx *ctx = crypto_aead_ctx(tfm);
 	struct crypto_skcipher *chacha;
-	struct crypto_ahash *poly;
 	unsigned long align;
 
-	poly = crypto_spawn_ahash(&ictx->poly);
-	if (IS_ERR(poly))
-		return PTR_ERR(poly);
-
 	chacha = crypto_spawn_skcipher(&ictx->chacha);
-	if (IS_ERR(chacha)) {
-		crypto_free_ahash(poly);
+	if (IS_ERR(chacha))
 		return PTR_ERR(chacha);
-	}
 
 	ctx->chacha = chacha;
-	ctx->poly = poly;
 	ctx->saltlen = ictx->saltlen;
 
 	align = crypto_aead_alignmask(tfm);
@@ -525,12 +346,9 @@  static int chachapoly_init(struct crypto_aead *tfm)
 	crypto_aead_set_reqsize(
 		tfm,
 		align + offsetof(struct chachapoly_req_ctx, u) +
-		max(offsetof(struct chacha_req, req) +
-		    sizeof(struct skcipher_request) +
-		    crypto_skcipher_reqsize(chacha),
-		    offsetof(struct poly_req, req) +
-		    sizeof(struct ahash_request) +
-		    crypto_ahash_reqsize(poly)));
+		offsetof(struct chacha_req, req) +
+		sizeof(struct skcipher_request) +
+		crypto_skcipher_reqsize(chacha));
 
 	return 0;
 }
@@ -539,7 +357,6 @@  static void chachapoly_exit(struct crypto_aead *tfm)
 {
 	struct chachapoly_ctx *ctx = crypto_aead_ctx(tfm);
 
-	crypto_free_ahash(ctx->poly);
 	crypto_free_skcipher(ctx->chacha);
 }
 
@@ -548,7 +365,6 @@  static void chachapoly_free(struct aead_instance *inst)
 	struct chachapoly_instance_ctx *ctx = aead_instance_ctx(inst);
 
 	crypto_drop_skcipher(&ctx->chacha);
-	crypto_drop_ahash(&ctx->poly);
 	kfree(inst);
 }
 
@@ -559,7 +375,6 @@  static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
 	struct aead_instance *inst;
 	struct chachapoly_instance_ctx *ctx;
 	struct skcipher_alg_common *chacha;
-	struct hash_alg_common *poly;
 	int err;
 
 	if (ivsize > CHACHAPOLY_IV_SIZE)
@@ -581,14 +396,9 @@  static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
 		goto err_free_inst;
 	chacha = crypto_spawn_skcipher_alg_common(&ctx->chacha);
 
-	err = crypto_grab_ahash(&ctx->poly, aead_crypto_instance(inst),
-				crypto_attr_alg_name(tb[2]), 0, mask);
-	if (err)
-		goto err_free_inst;
-	poly = crypto_spawn_ahash_alg(&ctx->poly);
-
 	err = -EINVAL;
-	if (poly->digestsize != POLY1305_DIGEST_SIZE)
+	if (strcmp(crypto_attr_alg_name(tb[2]), "poly1305") &&
+	    strcmp(crypto_attr_alg_name(tb[2]), "poly1305-generic"))
 		goto err_free_inst;
 	/* Need 16-byte IV size, including Initial Block Counter value */
 	if (chacha->ivsize != CHACHA_IV_SIZE)
@@ -599,16 +409,15 @@  static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
 
 	err = -ENAMETOOLONG;
 	if (snprintf(inst->alg.base.cra_name, CRYPTO_MAX_ALG_NAME,
-		     "%s(%s,%s)", name, chacha->base.cra_name,
-		     poly->base.cra_name) >= CRYPTO_MAX_ALG_NAME)
+		     "%s(%s,poly1305)", name,
+		     chacha->base.cra_name) >= CRYPTO_MAX_ALG_NAME)
 		goto err_free_inst;
 	if (snprintf(inst->alg.base.cra_driver_name, CRYPTO_MAX_ALG_NAME,
-		     "%s(%s,%s)", name, chacha->base.cra_driver_name,
-		     poly->base.cra_driver_name) >= CRYPTO_MAX_ALG_NAME)
+		     "%s(%s,poly1305-generic)", name,
+		     chacha->base.cra_driver_name) >= CRYPTO_MAX_ALG_NAME)
 		goto err_free_inst;
 
-	inst->alg.base.cra_priority = (chacha->base.cra_priority +
-				       poly->base.cra_priority) / 2;
+	inst->alg.base.cra_priority = chacha->base.cra_priority;
 	inst->alg.base.cra_blocksize = 1;
 	inst->alg.base.cra_alignmask = chacha->base.cra_alignmask;
 	inst->alg.base.cra_ctxsize = sizeof(struct chachapoly_ctx) +