@@ -12,188 +12,299 @@
#include <linux/net.h>
#include <linux/vmalloc.h>
#include <linux/zstd.h>
-#include <crypto/internal/scompress.h>
+#include <crypto/internal/acompress.h>
+#include <crypto/scatterwalk.h>
-#define ZSTD_DEF_LEVEL 3
+#define ZSTD_DEF_LEVEL 3
+#define ZSTD_MAX_WINDOWLOG 18
+#define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG)
struct zstd_ctx {
zstd_cctx *cctx;
zstd_dctx *dctx;
- void *cwksp;
- void *dwksp;
+ size_t wksp_size;
+ zstd_parameters params;
+ u8 wksp[0] __attribute__((aligned(8)));
};
-static zstd_parameters zstd_params(void)
+static DEFINE_MUTEX(zstd_stream_lock);
+
+static void *zstd_alloc_stream(void)
{
- return zstd_get_params(ZSTD_DEF_LEVEL, 0);
-}
-
-static int zstd_comp_init(struct zstd_ctx *ctx)
-{
- int ret = 0;
- const zstd_parameters params = zstd_params();
- const size_t wksp_size = zstd_cctx_workspace_bound(¶ms.cParams);
-
- ctx->cwksp = vzalloc(wksp_size);
- if (!ctx->cwksp) {
- ret = -ENOMEM;
- goto out;
- }
-
- ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size);
- if (!ctx->cctx) {
- ret = -EINVAL;
- goto out_free;
- }
-out:
- return ret;
-out_free:
- vfree(ctx->cwksp);
- goto out;
-}
-
-static int zstd_decomp_init(struct zstd_ctx *ctx)
-{
- int ret = 0;
- const size_t wksp_size = zstd_dctx_workspace_bound();
-
- ctx->dwksp = vzalloc(wksp_size);
- if (!ctx->dwksp) {
- ret = -ENOMEM;
- goto out;
- }
-
- ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size);
- if (!ctx->dctx) {
- ret = -EINVAL;
- goto out_free;
- }
-out:
- return ret;
-out_free:
- vfree(ctx->dwksp);
- goto out;
-}
-
-static void zstd_comp_exit(struct zstd_ctx *ctx)
-{
- vfree(ctx->cwksp);
- ctx->cwksp = NULL;
- ctx->cctx = NULL;
-}
-
-static void zstd_decomp_exit(struct zstd_ctx *ctx)
-{
- vfree(ctx->dwksp);
- ctx->dwksp = NULL;
- ctx->dctx = NULL;
-}
-
-static int __zstd_init(void *ctx)
-{
- int ret;
-
- ret = zstd_comp_init(ctx);
- if (ret)
- return ret;
- ret = zstd_decomp_init(ctx);
- if (ret)
- zstd_comp_exit(ctx);
- return ret;
-}
-
-static void *zstd_alloc_ctx(void)
-{
- int ret;
+ zstd_parameters params;
struct zstd_ctx *ctx;
+ size_t wksp_size;
- ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
+ params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
+
+ wksp_size = max_t(size_t,
+ zstd_cstream_workspace_bound(¶ms.cParams),
+ zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
+ if (!wksp_size)
+ return ERR_PTR(-EINVAL);
+
+ ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL);
if (!ctx)
return ERR_PTR(-ENOMEM);
- ret = __zstd_init(ctx);
- if (ret) {
- kfree(ctx);
- return ERR_PTR(ret);
- }
+ ctx->params = params;
+ ctx->wksp_size = wksp_size;
return ctx;
}
-static void __zstd_exit(void *ctx)
+static struct crypto_acomp_streams zstd_streams = {
+ .alloc_ctx = zstd_alloc_stream,
+ .cfree_ctx = kvfree,
+};
+
+static int zstd_init(struct crypto_acomp *acomp_tfm)
{
- zstd_comp_exit(ctx);
- zstd_decomp_exit(ctx);
+ int ret = 0;
+
+ mutex_lock(&zstd_stream_lock);
+ ret = crypto_acomp_alloc_streams(&zstd_streams);
+ mutex_unlock(&zstd_stream_lock);
+
+ return ret;
}
-static void zstd_free_ctx(void *ctx)
+static void zstd_exit(struct crypto_acomp *acomp_tfm)
{
- __zstd_exit(ctx);
- kfree_sensitive(ctx);
+ crypto_acomp_free_streams(&zstd_streams);
}
-static int __zstd_compress(const u8 *src, unsigned int slen,
- u8 *dst, unsigned int *dlen, void *ctx)
+static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx, unsigned int *dlen)
{
- size_t out_len;
- struct zstd_ctx *zctx = ctx;
- const zstd_parameters params = zstd_params();
+ unsigned int out_len;
- out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, ¶ms);
+ ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
+ if (!ctx->cctx)
+ return -EINVAL;
+
+ out_len = zstd_compress_cctx(ctx->cctx, sg_virt(req->dst),
+ req->dlen, sg_virt(req->src),
+ req->slen, &ctx->params);
if (zstd_is_error(out_len))
return -EINVAL;
+
*dlen = out_len;
+
return 0;
}
-static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
- unsigned int slen, u8 *dst, unsigned int *dlen,
- void *ctx)
+static int zstd_compress(struct acomp_req *req)
{
- return __zstd_compress(src, slen, dst, dlen, ctx);
-}
+ struct crypto_acomp_stream *s;
+ unsigned int pos, scur, dcur;
+ unsigned int total_out = 0;
+ bool data_available = true;
+ zstd_out_buffer outbuf;
+ struct acomp_walk walk;
+ zstd_in_buffer inbuf;
+ struct zstd_ctx *ctx;
+ size_t pending_bytes;
+ size_t num_bytes;
+ int ret;
-static int __zstd_decompress(const u8 *src, unsigned int slen,
- u8 *dst, unsigned int *dlen, void *ctx)
-{
- size_t out_len;
- struct zstd_ctx *zctx = ctx;
+ s = crypto_acomp_lock_stream_bh(&zstd_streams);
+ ctx = s->ctx;
- out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen);
- if (zstd_is_error(out_len))
- return -EINVAL;
- *dlen = out_len;
- return 0;
-}
+ ret = acomp_walk_virt(&walk, req, true);
+ if (ret)
+ goto out;
-static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
- unsigned int slen, u8 *dst, unsigned int *dlen,
- void *ctx)
-{
- return __zstd_decompress(src, slen, dst, dlen, ctx);
-}
-
-static struct scomp_alg scomp = {
- .alloc_ctx = zstd_alloc_ctx,
- .free_ctx = zstd_free_ctx,
- .compress = zstd_scompress,
- .decompress = zstd_sdecompress,
- .base = {
- .cra_name = "zstd",
- .cra_driver_name = "zstd-scomp",
- .cra_module = THIS_MODULE,
+ ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
+ if (!ctx->cctx) {
+ ret = -EINVAL;
+ goto out;
}
+
+ do {
+ dcur = acomp_walk_next_dst(&walk);
+ if (!dcur) {
+ ret = -ENOSPC;
+ goto out;
+ }
+
+ outbuf.pos = 0;
+ outbuf.dst = (u8 *)walk.dst.virt.addr;
+ outbuf.size = dcur;
+
+ do {
+ scur = acomp_walk_next_src(&walk);
+ if (dcur == req->dlen && scur == req->slen) {
+ ret = zstd_compress_one(req, ctx, &total_out);
+ acomp_walk_done_src(&walk, scur);
+ acomp_walk_done_dst(&walk, dcur);
+ goto out;
+ }
+
+ if (scur) {
+ inbuf.pos = 0;
+ inbuf.src = walk.src.virt.addr;
+ inbuf.size = scur;
+ } else {
+ data_available = false;
+ break;
+ }
+
+ num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
+ if (ZSTD_isError(num_bytes)) {
+ ret = -EIO;
+ goto out;
+ }
+
+ pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
+ if (ZSTD_isError(pending_bytes)) {
+ ret = -EIO;
+ goto out;
+ }
+ acomp_walk_done_src(&walk, inbuf.pos);
+ } while (dcur != outbuf.pos);
+
+ total_out += outbuf.pos;
+ acomp_walk_done_dst(&walk, dcur);
+ } while (data_available);
+
+ pos = outbuf.pos;
+ num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
+ if (ZSTD_isError(num_bytes))
+ ret = -EIO;
+ else
+ total_out += (outbuf.pos - pos);
+
+out:
+ if (ret)
+ req->dlen = 0;
+ else
+ req->dlen = total_out;
+
+ crypto_acomp_unlock_stream_bh(s);
+
+ return ret;
+}
+
+static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx, unsigned int *dlen)
+{
+ size_t out_len;
+
+ ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
+ if (!ctx->dctx)
+ return -EINVAL;
+
+ out_len = zstd_decompress_dctx(ctx->dctx, sg_virt(req->dst),
+ req->dlen, sg_virt(req->src),
+ req->slen);
+ if (zstd_is_error(out_len))
+ return -EINVAL;
+
+ *dlen = out_len;
+
+ return 0;
+}
+
+static int zstd_decompress(struct acomp_req *req)
+{
+ struct crypto_acomp_stream *s;
+ unsigned int total_out = 0;
+ unsigned int scur, dcur;
+ zstd_out_buffer outbuf;
+ struct acomp_walk walk;
+ zstd_in_buffer inbuf;
+ struct zstd_ctx *ctx;
+ size_t pending_bytes;
+ int ret;
+
+ s = crypto_acomp_lock_stream_bh(&zstd_streams);
+ ctx = s->ctx;
+
+ ret = acomp_walk_virt(&walk, req, true);
+ if (ret)
+ goto out;
+
+ ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
+ if (!ctx->dctx) {
+ ret = -EINVAL;
+ goto out;
+ }
+
+ do {
+ scur = acomp_walk_next_src(&walk);
+ if (scur) {
+ inbuf.pos = 0;
+ inbuf.size = scur;
+ inbuf.src = walk.src.virt.addr;
+ } else {
+ break;
+ }
+
+ do {
+ dcur = acomp_walk_next_dst(&walk);
+ if (dcur == req->dlen && scur == req->slen) {
+ ret = zstd_decompress_one(req, ctx, &total_out);
+ acomp_walk_done_dst(&walk, dcur);
+ acomp_walk_done_src(&walk, scur);
+ goto out;
+ }
+
+ if (!dcur) {
+ ret = -ENOSPC;
+ goto out;
+ }
+
+ outbuf.pos = 0;
+ outbuf.dst = (u8 *)walk.dst.virt.addr;
+ outbuf.size = dcur;
+
+ pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
+ if (ZSTD_isError(pending_bytes)) {
+ ret = -EIO;
+ goto out;
+ }
+
+ total_out += outbuf.pos;
+
+ acomp_walk_done_dst(&walk, outbuf.pos);
+ } while (scur != inbuf.pos);
+
+ if (scur)
+ acomp_walk_done_src(&walk, scur);
+ } while (ret == 0);
+
+out:
+ if (ret)
+ req->dlen = 0;
+ else
+ req->dlen = total_out;
+
+ crypto_acomp_unlock_stream_bh(s);
+
+ return ret;
+}
+
+static struct acomp_alg zstd_acomp = {
+ .base = {
+ .cra_name = "zstd",
+ .cra_driver_name = "zstd-generic",
+ .cra_flags = CRYPTO_ALG_REQ_VIRT,
+ .cra_module = THIS_MODULE,
+ },
+ .init = zstd_init,
+ .exit = zstd_exit,
+ .compress = zstd_compress,
+ .decompress = zstd_decompress,
};
static int __init zstd_mod_init(void)
{
- return crypto_register_scomp(&scomp);
+ return crypto_register_acomp(&zstd_acomp);
}
static void __exit zstd_mod_fini(void)
{
- crypto_unregister_scomp(&scomp);
+ crypto_unregister_acomp(&zstd_acomp);
}
module_init(zstd_mod_init);