// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <endian.h>

#include "../utils/utils.h"
#include "udma_ulib.h"
#include "ycc_algs.h"
#include "aead.h"
#include "ske.h"

static int ycc_skcipher_setkey(struct skcipher_ctx *cipher, const uint8_t *key,
			       unsigned int key_size, int mode,
			       unsigned int key_dma_size)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;

	if (ctx->cipher_key) {
		memset(ctx->cipher_key, 0, ctx->keysize);
	} else {
		ctx->cipher_key = malloc(key_size);
		if (!ctx->cipher_key) {
			ycc_err("Failed to alloc memory for key\n");
			return -ENOMEM;
		}
	}

	YCC_memcpy(ctx->cipher_key, key, key_size);
	ctx->mode = mode;
	ctx->keysize = key_size;
	ctx->key_dma_size = key_dma_size;
	return 0;
}

#define DEFINE_YCC_SKE_AES_SETKEY(name, mode, size)			\
int ycc_skcipher_aes_##name##_setkey(struct skcipher_ctx *cipher,	\
				     const uint8_t *key,		\
				     unsigned int key_size)		\
{									\
	int alg_mode;							\
	switch (key_size) {						\
	case AES_KEYSIZE_128:						\
		alg_mode = YCC_AES_128_##mode;				\
		break;							\
	case AES_KEYSIZE_192:						\
		alg_mode = YCC_AES_192_##mode;				\
		break;							\
	case AES_KEYSIZE_256:						\
		alg_mode = YCC_AES_256_##mode;				\
		break;							\
	default:							\
		return -EINVAL;						\
		break;							\
	}								\
	return ycc_skcipher_setkey(cipher, key, key_size, alg_mode, size);	\
}

#define DEFINE_YCC_SKE_SM4_SETKEY(name, mode, size)			\
int ycc_skcipher_sm4_##name##_setkey(struct skcipher_ctx *cipher,	\
				     const uint8_t *key,		\
				     unsigned int key_size)		\
{									\
	int alg_mode = YCC_SM4_##mode;					\
	if (key_size != SM4_KEY_SIZE)					\
		return -EINVAL;						\
	return ycc_skcipher_setkey(cipher, key, key_size, alg_mode, size);	\
}

#define DEFINE_YCC_SKE_DES_SETKEY(name, mode, size)			\
int ycc_skcipher_des_##name##_setkey(struct skcipher_ctx *cipher,	\
				     const uint8_t *key,		\
				     unsigned int key_size)		\
{									\
	int alg_mode = YCC_DES_##mode;					\
	int ret;							\
	if (key_size != DES_KEY_SIZE)					\
		return -EINVAL;						\
	ret = verify_skcipher_des_key(cipher, key);			\
	if (ret)							\
		return ret;						\
	return ycc_skcipher_setkey(cipher, key, key_size, alg_mode, size);	\
}

#define DEFINE_YCC_SKE_3DES_SETKEY(name, mode, size)			\
int ycc_skcipher_3des_##name##_setkey(struct skcipher_ctx *cipher,	\
				      const uint8_t *key,		\
				      unsigned int key_size)		\
{									\
	int alg_mode = YCC_TDES_192_##mode;				\
	int ret;							\
	if (key_size != DES3_EDE_KEY_SIZE)				\
		return -EINVAL;						\
	ret = verify_skcipher_des3_key(cipher, key);			\
	if (ret)							\
		return ret;						\
	return ycc_skcipher_setkey(cipher, key, key_size, alg_mode, size);	\
}

/*
 * ECB: Only 1 key, no IV, at least 32 Bytes.
 * Others except XTS: |key|iv|, at least 48 Bytes.
 */
DEFINE_YCC_SKE_AES_SETKEY(ecb, ECB, 32);
DEFINE_YCC_SKE_AES_SETKEY(cbc, CBC, 48);
DEFINE_YCC_SKE_AES_SETKEY(ctr, CTR, 48);
DEFINE_YCC_SKE_AES_SETKEY(cfb, CFB, 48);
DEFINE_YCC_SKE_AES_SETKEY(ofb, OFB, 48);

DEFINE_YCC_SKE_SM4_SETKEY(ecb, ECB, 32);
DEFINE_YCC_SKE_SM4_SETKEY(cbc, CBC, 48);
DEFINE_YCC_SKE_SM4_SETKEY(ctr, CTR, 48);
DEFINE_YCC_SKE_SM4_SETKEY(cfb, CFB, 48);
DEFINE_YCC_SKE_SM4_SETKEY(ofb, OFB, 48);

DEFINE_YCC_SKE_DES_SETKEY(ecb, ECB, 32);
DEFINE_YCC_SKE_DES_SETKEY(cbc, CBC, 48);
DEFINE_YCC_SKE_DES_SETKEY(ctr, CTR, 48);
DEFINE_YCC_SKE_DES_SETKEY(cfb, CFB, 48);
DEFINE_YCC_SKE_DES_SETKEY(ofb, OFB, 48);

DEFINE_YCC_SKE_3DES_SETKEY(ecb, ECB, 32);
DEFINE_YCC_SKE_3DES_SETKEY(cbc, CBC, 48);
DEFINE_YCC_SKE_3DES_SETKEY(ctr, CTR, 48);
DEFINE_YCC_SKE_3DES_SETKEY(cfb, CFB, 48);
DEFINE_YCC_SKE_3DES_SETKEY(ofb, OFB, 48);

int ycc_skcipher_aes_xts_setkey(struct skcipher_ctx *cipher,
				const uint8_t *key,
				unsigned int key_size)
{
	int alg_mode;
	int ret;

	ret = xts_verify_key(cipher, key, key_size);
	if (ret)
		return ret;

	switch (key_size) {
	case AES_KEYSIZE_128 * 2:
		alg_mode = YCC_AES_128_XTS;
		break;
	case AES_KEYSIZE_256 * 2:
		alg_mode = YCC_AES_256_XTS;
		break;
	default:
		return -EINVAL;
	}

	/* XTS: |key1|key2|iv|, at least 32 + 32 +16 */
	return ycc_skcipher_setkey(cipher, key, key_size, alg_mode, 80);
}

int ycc_skcipher_sm4_xts_setkey(struct skcipher_ctx *cipher,
				const uint8_t *key,
				unsigned int key_size)
{
	int alg_mode;
	int ret;

	ret = xts_verify_key(cipher, key, key_size);
	if (ret)
		return ret;

	if (key_size != SM4_KEY_SIZE * 2)
		return -EINVAL;

	alg_mode = YCC_SM4_XTS;
	return ycc_skcipher_setkey(cipher, key, key_size, alg_mode, 80);
}

static int ycc_skcipher_fill_key(struct ycc_crypto_req *req)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)req->ske_req->ctx->__ctx;
	uint32_t ivsize = req->ske_req->ctx->alg->ivsize;

	if (!req->key_vaddr) {
		ycc_err("Key address is not initialized\n");
		return -1;
	}

	memset(req->key_vaddr, 0, ALIGN(ctx->key_dma_size, MEM_ALIGNMENT_64));
	/* XTS mode has 2 keys & 1 iv */
	if (ctx->key_dma_size == 80) {
		YCC_memcpy(req->key_vaddr + (32 - ctx->keysize / 2),
		       ctx->cipher_key, ctx->keysize / 2);
		YCC_memcpy(req->key_vaddr + (64 - ctx->keysize / 2),
		       ctx->cipher_key + ctx->keysize / 2, ctx->keysize / 2);
	} else {
		YCC_memcpy(req->key_vaddr + (32 - ctx->keysize), ctx->cipher_key,
		       ctx->keysize);
	}

	if (ivsize) {
		if (ctx->mode == YCC_DES_ECB ||
		    ctx->mode == YCC_TDES_128_ECB ||
		    ctx->mode == YCC_TDES_192_ECB ||
		    ctx->mode == YCC_AES_128_ECB ||
		    ctx->mode == YCC_AES_192_ECB ||
		    ctx->mode == YCC_AES_256_ECB ||
		    ctx->mode == YCC_SM4_ECB) {
			ycc_err("Illegal ivsize for ECB mode, should be zero");
			return -EINVAL;
		}
		/* DES or 3DES */
		if (ctx->mode >= YCC_DES_ECB && ctx->mode <= YCC_TDES_192_CTR) {
			if (ivsize > 8)
				return -1;
			YCC_memcpy(req->key_vaddr + ctx->key_dma_size - 8,
			       req->ske_req->iv, ivsize);
		} else {
			YCC_memcpy(req->key_vaddr + ctx->key_dma_size - 16,
			       req->ske_req->iv, ivsize);
		}
	}

	return 0;
}

/* CBC & CTR */
static void ycc_skcipher_iv_out(struct ycc_crypto_req *req, void *dst)
{
	struct skcipher_req *ske_req = req->ske_req;
	struct skcipher_ctx *cipher = ske_req->ctx;
	uint8_t bs = cipher->alg->blocksize;
	uint8_t mode = req->ctx->mode;
	uint8_t cmd = req->desc.cmd.ske_cmd.cmd_id;
	uint32_t nb = (ske_req->cryptlen + bs - 1) / bs;

	switch (mode) {
	case YCC_DES_CBC:
	case YCC_TDES_128_CBC:
	case YCC_TDES_192_CBC:
	case YCC_AES_128_CBC:
	case YCC_AES_192_CBC:
	case YCC_AES_256_CBC:
	case YCC_SM4_CBC:
		if (cmd == YCC_CMD_SKE_DEC)
			YCC_memcpy(ske_req->iv, req->last_block, bs);
		else
			YCC_memcpy(ske_req->iv,
			       (uint8_t *)dst + ALIGN(ske_req->cryptlen, bs) - bs,
			       bs);
		break;
	case YCC_DES_CTR:
	case YCC_TDES_128_CTR:
	case YCC_TDES_192_CTR:
	case YCC_AES_128_CTR:
	case YCC_AES_192_CTR:
	case YCC_AES_256_CTR:
	case YCC_SM4_CTR:
		for ( ; nb-- ; )
			crypto_inc(ske_req->iv, bs);
		break;
	default:
		return;
	}
}

static int ycc_skcipher_alloc_mem(struct ycc_crypto_req *req)
{
	struct ycc_crypto_ctx *ctx = req->ctx;
	uint32_t cryptlen = ALIGN(req->ske_req->cryptlen, MEM_ALIGNMENT_64);
	uint32_t keylen = ALIGN(ctx->key_dma_size, MEM_ALIGNMENT_64);

	req->src_vaddr = ycc_udma_malloc(cryptlen);
	if (!req->src_vaddr)
		return -ENOMEM;

	req->src_paddr = virt_to_phys(req->src_vaddr);
	req->dst_vaddr = req->src_vaddr;
	req->dst_paddr = req->src_paddr;
	YCC_memcpy(req->src_vaddr, req->ske_req->src, req->ske_req->cryptlen);

	req->key_vaddr = ycc_udma_malloc(keylen);
	if (!req->key_vaddr)
		return -ENOMEM;

	req->key_paddr = virt_to_phys(req->key_vaddr);
	return 0;
}

static void ycc_skcipher_free_mem(struct ycc_crypto_req *req)
{
	ycc_udma_free(req->src_vaddr);
	ycc_udma_free(req->key_vaddr);
}

int ycc_skcipher_callback(void *ptr, uint16_t state)
{
	struct ycc_crypto_req *req = (struct ycc_crypto_req *)ptr;
	struct skcipher_req *ske_req = req->ske_req;

	if (state == CMD_SUCCESS) {
		YCC_memcpy(ske_req->dst, req->src_vaddr, ske_req->cryptlen);
		ycc_skcipher_iv_out(req, req->dst_vaddr);
	} else {
		ycc_err("Back state is:%x\n", state);
	}

	/*
	 * Req structure will be freed when the request thread is waken up,
	 * so we should use the structure before waking up the request
	 * thread.
	 */
	ycc_skcipher_free_mem(req);

	if (ske_req->complete)
		ske_req->complete(ske_req, state == CMD_SUCCESS ? 0 : -EBADMSG);

	return 0;
}

static int ycc_skcipher_submit_desc(struct skcipher_req *ske_req, uint8_t cmd)
{
	struct skcipher_ctx *cipher = ske_req->ctx;
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;
	struct ycc_crypto_req *req = (struct ycc_crypto_req *)ske_req->__req;
	struct ycc_skcipher_cmd *ycc_ske_cmd = &req->desc.cmd.ske_cmd;
	struct ycc_flags *aflags;
	uint8_t bs = cipher->alg->blocksize;
	int ret;

	memset(req, 0, sizeof(*req));
	req->ctx = ctx;
	req->ske_req = ske_req;

	ret = ycc_skcipher_alloc_mem(req);
	if (ret < 0)
		goto out;

	ret = ycc_skcipher_fill_key(req);
	if (ret)
		goto free_mem;

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags) {
		ret = -ENOMEM;
		goto free_mem;
	}

	aflags->ptr = (void *)req;
	aflags->ycc_done_callback = ycc_skcipher_callback;
	req->desc.private_ptr = (uint64_t)aflags;
	ycc_ske_cmd->cmd_id = cmd;
	if (cmd == YCC_CMD_SKE_DEC)
		YCC_memcpy(req->last_block,
		       req->src_vaddr + ALIGN(ske_req->cryptlen, bs) - bs, bs);

	ycc_ske_cmd->mode = ctx->mode;
	ycc_ske_cmd->sptr = req->src_paddr;
	ycc_ske_cmd->dptr = req->dst_paddr;

	/*
	 * HW requires cryptlen%16==0 if HW padding is not enabled. However it appears
	 * that real cryptlen also works as we reserved enough space for DMA rd/wr
	 */
	ycc_ske_cmd->dlen   = ske_req->cryptlen;
	ycc_ske_cmd->keyptr = req->key_paddr;

	/*
	 * Not support HW padding now. If cryptlen is n*blocksize, when padding is
	 * enabled, output length will become n*blocksize + blocksize
	 */
	ycc_ske_cmd->padding = 0;
	ret = ycc_enqueue(ctx->ring, &req->desc);
	if (!ret)
		return -EINPROGRESS;

	free(aflags);
free_mem:
	ycc_skcipher_free_mem(req);
out:
	return ret;
}

static int ycc_skcipher_encrypt(struct skcipher_req *ske_req)
{
	return ycc_skcipher_submit_desc(ske_req, YCC_CMD_SKE_ENC);
}

static int ycc_skcipher_decrypt(struct skcipher_req *ske_req)
{
	return ycc_skcipher_submit_desc(ske_req, YCC_CMD_SKE_DEC);
}

static int ycc_skcipher_init(struct skcipher_ctx *cipher)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	cipher->reqsize = sizeof(struct ycc_crypto_req);

	ring = ycc_crypto_get_ring();
	if (!ring) {
		ycc_err("Failed to get ring when doing skcipher init\n");
		return -EFAULT;
	}

	ctx->ring = ring;
	return 0;
}

static void ycc_skcipher_exit(struct skcipher_ctx *cipher)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	if (ctx->cipher_key)
		free(ctx->cipher_key);
}

struct skcipher_alg skcipher_algs[22] = {
	{
		.name = "cbc(aes)",
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_aes_cbc_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = AES_MIN_KEY_SIZE,
		.max_keysize = AES_MAX_KEY_SIZE,
		.ivsize = AES_BLOCK_SIZE,
		.blocksize = AES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
	},
	{
		.name = "ecb(aes)",
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_aes_ecb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = AES_MIN_KEY_SIZE,
		.max_keysize = AES_MAX_KEY_SIZE,
		.blocksize = AES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.ivsize = 0,
	},
	{
		.name = "ctr(aes)",
		.blocksize = AES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_aes_ctr_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = AES_MIN_KEY_SIZE,
		.max_keysize = AES_MAX_KEY_SIZE,
		.ivsize = AES_BLOCK_SIZE,
	},
	{
		.name = "cfb(aes)",
		.blocksize = AES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_aes_cfb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = AES_MIN_KEY_SIZE,
		.max_keysize = AES_MAX_KEY_SIZE,
		.ivsize = AES_BLOCK_SIZE,
	},
	{
		.name = "ofb(aes)",
		.blocksize = AES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_aes_ofb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = AES_MIN_KEY_SIZE,
		.max_keysize = AES_MAX_KEY_SIZE,
		.ivsize = AES_BLOCK_SIZE,
	},
	{
		.name = "xts(aes)",
		.blocksize = AES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_aes_xts_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = AES_MIN_KEY_SIZE * 2,
		.max_keysize = AES_MAX_KEY_SIZE * 2,
		.ivsize = AES_BLOCK_SIZE,
	},
	{
		.name = "cbc(sm4)",
		.blocksize = SM4_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_sm4_cbc_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = SM4_KEY_SIZE,
		.max_keysize = SM4_KEY_SIZE,
		.ivsize = SM4_BLOCK_SIZE,
	},
	{
		.name = "ecb(sm4)",
		.blocksize = SM4_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_sm4_ecb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = SM4_KEY_SIZE,
		.max_keysize = SM4_KEY_SIZE,
		.ivsize = 0,
	},
	{
		.name = "ctr(sm4)",
		.blocksize = SM4_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_sm4_ctr_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = SM4_KEY_SIZE,
		.max_keysize = SM4_KEY_SIZE,
		.ivsize = SM4_BLOCK_SIZE,
	},
	{
		.name = "cfb(sm4)",
		.blocksize = SM4_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_sm4_cfb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = SM4_KEY_SIZE,
		.max_keysize = SM4_KEY_SIZE,
		.ivsize = SM4_BLOCK_SIZE,
	},
	{
		.name = "ofb(sm4)",
		.blocksize = SM4_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_sm4_ofb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = SM4_KEY_SIZE,
		.max_keysize = SM4_KEY_SIZE,
		.ivsize = SM4_BLOCK_SIZE,
	},
	{
		.name = "xts(sm4)",
		.blocksize = SM4_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_sm4_xts_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = SM4_KEY_SIZE * 2,
		.max_keysize = SM4_KEY_SIZE * 2,
		.ivsize = SM4_BLOCK_SIZE,
	},
	{
		.name = "cbc(des)",
		.blocksize = DES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_des_cbc_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES_KEY_SIZE,
		.max_keysize = DES_KEY_SIZE,
		.ivsize = DES_BLOCK_SIZE,
	},
	{
		.name = "ecb(des)",
		.blocksize = DES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_des_ecb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES_KEY_SIZE,
		.max_keysize = DES_KEY_SIZE,
		.ivsize = 0,
	},
	{
		.name = "ctr(des)",
		.blocksize = DES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_des_ctr_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES_KEY_SIZE,
		.max_keysize = DES_KEY_SIZE,
		.ivsize = DES_BLOCK_SIZE,
	},
	{
		.name = "cfb(des)",
		.blocksize = DES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_des_cfb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES_KEY_SIZE,
		.max_keysize = DES_KEY_SIZE,
		.ivsize = DES_BLOCK_SIZE,
	},
	{
		.name = "ofb(des)",
		.blocksize = DES_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_des_ofb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES_KEY_SIZE,
		.max_keysize = DES_KEY_SIZE,
		.ivsize = DES_BLOCK_SIZE,
	},
	{
		.name = "cbc(des3_ede)",
		.blocksize = DES3_EDE_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_3des_cbc_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES3_EDE_KEY_SIZE,
		.max_keysize = DES3_EDE_KEY_SIZE,
		.ivsize = DES3_EDE_BLOCK_SIZE,
	},
	{
		.name = "ecb(des3_ede)",
		.blocksize = DES3_EDE_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_3des_ecb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES3_EDE_KEY_SIZE,
		.max_keysize = DES3_EDE_KEY_SIZE,
		.ivsize = 0,
	},
	{
		.name = "ctr(des3_ede)",
		.blocksize = DES3_EDE_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_3des_ctr_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES3_EDE_KEY_SIZE,
		.max_keysize = DES3_EDE_KEY_SIZE,
		.ivsize = DES3_EDE_BLOCK_SIZE,
	},
	{
		.name = "cfb(des3_ede)",
		.blocksize = DES3_EDE_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_3des_cfb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES3_EDE_KEY_SIZE,
		.max_keysize = DES3_EDE_KEY_SIZE,
		.ivsize = DES3_EDE_BLOCK_SIZE,
	},
	{
		.name = "ofb(des3_ede)",
		.blocksize = DES3_EDE_BLOCK_SIZE,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_skcipher_init,
		.exit = ycc_skcipher_exit,
		.setkey = ycc_skcipher_3des_ofb_setkey,
		.encrypt = ycc_skcipher_encrypt,
		.decrypt = ycc_skcipher_decrypt,
		.min_keysize = DES3_EDE_KEY_SIZE,
		.max_keysize = DES3_EDE_KEY_SIZE,
		.ivsize = DES3_EDE_BLOCK_SIZE,
	},
};

extern struct skcipher_alg_entry *skcipher_alg_entries;
void skcipher_register_algs(void)
{
	struct skcipher_alg_entry *prev = skcipher_alg_entries, *cur;
	uint32_t array_size = sizeof(skcipher_algs) / sizeof(struct skcipher_alg);
	int i;

	for (i = 0; i < array_size; i++) {
		cur = malloc(sizeof(struct skcipher_alg_entry));
		if (!cur) {
			ycc_err("Failed to alloc memory for alg:%s\n",
				skcipher_algs[i].name);
			break;
		}

		cur->alg = &skcipher_algs[i];
		cur->next = NULL;
		if (!prev) {
			skcipher_alg_entries = cur;
			prev = cur;
		} else {
			prev->next = cur;
			prev = cur;
		}
	}
}

void skcipher_unregister_algs(void)
{
	struct skcipher_alg_entry *cur = skcipher_alg_entries, *next;

	while (cur) {
		next = cur->next;
		free(cur);
		cur = next;
	};

	skcipher_alg_entries = NULL;
}

static int ycc_aead_init(struct aead_ctx *cipher)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	cipher->reqsize = sizeof(struct ycc_crypto_req);

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EFAULT;

	ctx->ring = ring;
	return 0;
}

static void ycc_aead_exit(struct aead_ctx *cipher)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	if (ctx->cipher_key)
		free(ctx->cipher_key);
}

static int ycc_aead_setkey(struct aead_ctx *cipher, const char *key,
			   unsigned int key_size)
{
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;
	const char *alg_name = cipher->alg->name;

	if (!strncmp("gcm(sm4)", alg_name, strlen("gcm(sm4)"))) {
		if (key_size != SM4_KEY_SIZE)
			return -EINVAL;
		ctx->mode = YCC_SM4_GCM;
	} else if (!strncmp("ccm(sm4)", alg_name, strlen("ccm(sm4)"))) {
		ctx->mode = YCC_SM4_CCM;
	} else if (!strncmp("gcm(aes)", alg_name, strlen("gcm(aes)"))) {
		switch (key_size) {
		case AES_KEYSIZE_128:
			ctx->mode = YCC_AES_128_GCM;
			break;
		case AES_KEYSIZE_192:
			ctx->mode = YCC_AES_192_GCM;
			break;
		case AES_KEYSIZE_256:
			ctx->mode = YCC_AES_256_GCM;
			break;
		default:
			return -EINVAL;
		}
	} else if (!strncmp("ccm(aes)", alg_name, strlen("ccm(aes)"))) {
		switch (key_size) {
		case AES_KEYSIZE_128:
			ctx->mode = YCC_AES_128_CCM;
			break;
		case AES_KEYSIZE_192:
			ctx->mode = YCC_AES_192_CCM;
			break;
		case AES_KEYSIZE_256:
			ctx->mode = YCC_AES_256_CCM;
			break;
		default:
			return -EINVAL;
		}
	}

	if (ctx->cipher_key) {
		memset(ctx->cipher_key, 0, ctx->keysize);
	} else {
		ctx->cipher_key = malloc(key_size);
		if (!ctx->cipher_key)
			return -ENOMEM;
	}

	YCC_memcpy(ctx->cipher_key, key, key_size);
	ctx->keysize = key_size;
	ctx->key_dma_size = 64;
	return 0;
}

static int ycc_aead_fill_key(struct ycc_crypto_req *req)
{
	struct ycc_crypto_ctx *ctx = req->ctx;
	struct aead_req *aead_req = req->aead_req;
	struct aead_ctx *cipher = aead_req->ctx;
	const char *alg_name = cipher->alg->name;
	int iv_len = 12;
	int i;

	if (!strncmp("ccm", alg_name, strlen("ccm")))
		iv_len = 16;

	if (!req->key_vaddr) {
		ycc_err("Key address is not initialized\n");
		return -EINVAL;
	}

	memset(req->key_vaddr, 0, 64);
	YCC_memcpy(req->key_vaddr + (32 - ctx->keysize), ctx->cipher_key,
		   ctx->keysize);
	for (i = 0; i < iv_len; i++)
		*(uint8_t *)(req->key_vaddr + 32 + i) = *(uint8_t *)(req->aead_req->iv + i);

	return 0;
}

static int ycc_aead_callback(void *ptr, uint16_t state)
{
	struct ycc_crypto_req *req = (struct ycc_crypto_req *)ptr;
	struct aead_req *aead_req = req->aead_req;
	struct ycc_aead_cmd *ycc_aead_cmd = &req->desc.cmd.aead_cmd;
	struct aead_ctx *cipher = aead_req->ctx;
	int taglen = cipher->authsize;
	int i;

	if ((ycc_aead_cmd->cmd_id == YCC_CMD_GCM_ENC ||
	     ycc_aead_cmd->cmd_id == YCC_CMD_CCM_ENC) &&
	     aead_req->cryptlen % 16 != 0) {
		for (i = 0; i < taglen; i++)
			*(uint8_t *)(req->dst_vaddr + aead_req->cryptlen + i) =
				*(uint8_t *)(req->dst_vaddr + ALIGN(aead_req->cryptlen, 16) + i);
	}

	if (ycc_aead_cmd->cmd_id == YCC_CMD_GCM_ENC ||
	    ycc_aead_cmd->cmd_id == YCC_CMD_CCM_ENC) {
		for (i = 0; i < aead_req->cryptlen + taglen; i++)
			*(uint8_t *)(aead_req->dst + i) = *(uint8_t *)(req->dst_vaddr + i);
	} else {
		for (i = 0; i < aead_req->cryptlen - taglen; i++)
			*(uint8_t *)(aead_req->dst + i) = *(uint8_t *)(req->dst_vaddr + i);
	}

	if (aead_req->complete)
		aead_req->complete(aead_req, state == CMD_SUCCESS ? 0 : -EBADMSG);

	ycc_udma_free(req->src_vaddr);
	return 0;
}

static inline int __ycc_aead_format_data(struct ycc_crypto_req *req, uint8_t *b0, uint8_t *b1,
					 int alen, uint8_t cmd)
{
	struct aead_req *aead_req = req->aead_req;
	uint32_t aad_len = aead_req->assoclen;
	uint32_t cryptlen = aead_req->cryptlen;
	uint32_t taglen = aead_req->ctx->authsize;
	uint32_t size, b0_len = 0, totallen;
	uint32_t src_len = cryptlen;
	int i;

	/* b0 stands for ccm */
	if (b0)
		b0_len = 16;

	size = ALIGN((b0_len + alen + aad_len), 16);
	if (cmd == YCC_CMD_GCM_DEC || cmd == YCC_CMD_CCM_DEC) {
		src_len = cryptlen - taglen;
		/* YCC requires cipher text aligned, tag aligned */
		size += ALIGN(src_len, 16) + ALIGN(taglen, 16);
	} else {
		size += ALIGN(cryptlen, 16);
	}

	/* ccm: B0(16B)|AAD L(alenB),AAD(aad_lenB)|Plaintext| */
	req->in_len = size;
	req->aad_offset = b0_len + alen;

	totallen = ALIGN(max(req->in_len, req->out_len), MEM_ALIGNMENT_64) +
		   ALIGN(req->ctx->key_dma_size, MEM_ALIGNMENT_64);

	req->src_vaddr = ycc_udma_malloc(totallen);
	if (!req->src_vaddr)
		return -ENOMEM;

	memset(req->src_vaddr, 0, totallen);

	if (b0)
		YCC_memcpy(req->src_vaddr, b0, b0_len);
	if (b1)
		YCC_memcpy((void *)(req->src_vaddr + b0_len), b1, alen);

	for (i = 0; i < aad_len; i++) {
		*(uint8_t *)(req->src_vaddr + b0_len + alen + i) =
			*((uint8_t *)aead_req->src + i);
	}
	for (i = 0; i < src_len; i++) {
		*(uint8_t *)(req->src_vaddr + ALIGN((b0_len + alen + aad_len), 16) + i) =
			*((uint8_t *)aead_req->src + aad_len + i);
	}

	if (cmd == YCC_CMD_GCM_DEC || cmd == YCC_CMD_CCM_DEC)
		YCC_memcpy((req->src_vaddr + ALIGN((b0_len + alen + aad_len), 16) +
			   ALIGN(src_len, 16)),
			   aead_req->src + aad_len + cryptlen - taglen,  taglen);

	req->src_paddr = virt_to_phys(req->src_vaddr);
	req->dst_vaddr = req->src_vaddr;
	req->dst_paddr = req->src_paddr;
	req->key_vaddr = req->dst_vaddr + totallen - req->ctx->key_dma_size;
	req->key_paddr = req->dst_paddr + totallen - req->ctx->key_dma_size;

	return 0;
}

static inline int ycc_aead_format_ccm_data(struct ycc_crypto_req *req,
					   uint16_t *new_aad_len, uint8_t cmd)
{
	struct aead_req *aead_req = req->aead_req;
	int aad_len = aead_req->assoclen;
	int cryptlen = aead_req->cryptlen;
	int taglen = aead_req->ctx->authsize;
	uint8_t b0[16] = {0};
	uint8_t b1[10] = {0};
	uint8_t alen = 0;
	uint32_t msglen;
	int l;

	/* 1. check iv value aead_req->iv[0] = L - 1 */
	if (aead_req->iv[0] < 1 || aead_req->iv[0] > 7) {
		ycc_err("L value is not valid for CCM\n");
		return -EINVAL;
	}

	l = aead_req->iv[0] + 1;

	/* 2. format control infomration and nonce */
	YCC_memcpy(b0, aead_req->iv, 16); //iv max size is 15 - L
	b0[0] |= (((taglen - 2) / 2) << 3);
	if (aad_len) {
		b0[0] |= (1 << 6);
		if (aad_len < 65280) {
			*(uint16_t *)b1 = htobe16(aad_len);
			alen = 2;
		} else if (aad_len < (2UL << 31)) {
			*(uint16_t *)b1 = htobe16(0xfffe);
			*(uint32_t *)&b1[2] = htobe32(aad_len);
			alen = 6;
		} else {
			*(uint16_t *)b1 = htobe16(0xffff);
			*(uint64_t *)&b1[2] = htobe64(aad_len);
			alen = 10;
		}
		*new_aad_len = ALIGN((16 + alen + aad_len), 16);
	} else {
		*new_aad_len = 16;
	}
	b0[0] |= aead_req->iv[0];

	/* 3. set msg length. L - 1 Bytes store msg length. */
	if (l >= 4)
		l = 4;
	else if (cryptlen > (1 << (8 * l)))
		return -EINVAL;

	msglen = htobe32(cryptlen);

	YCC_memcpy(&b0[16 - l], (uint8_t *)&msglen + 4 - l, l);
	return __ycc_aead_format_data(req, b0, b1, alen, cmd);
}

static inline int ycc_aead_format_data(struct ycc_crypto_req *req,
				       uint16_t *new_aad_len,
				       uint32_t *new_cryptlen, uint8_t cmd)
{
	struct aead_req *aead_req = req->aead_req;
	struct aead_ctx *cipher = aead_req->ctx;
	int taglen = cipher->authsize;
	int ret;

	/*
	 * For GCM/CCM encrypt, aead_req->cryptlen = len(plaintext)
	 * For GCM/CCM decrypt, aead_req->cryptlen = len(ciphertext) + authsize
	 */
	if (cmd == YCC_CMD_GCM_ENC || cmd == YCC_CMD_GCM_DEC) { /* GCM */
		*new_aad_len = aead_req->assoclen;
		*new_cryptlen = aead_req->cryptlen;
		req->out_len = ALIGN(*new_cryptlen, 16) + ALIGN(taglen, 16);
		ret = __ycc_aead_format_data(req, NULL, NULL, 0, cmd);
	} else { /* CCM */
		*new_cryptlen = ALIGN(aead_req->cryptlen, 16);
		req->out_len = *new_cryptlen + taglen;
		ret = ycc_aead_format_ccm_data(req, new_aad_len, cmd);
	}

	return ret;
}

static inline void ycc_aead_free_mem(struct ycc_crypto_req *req)
{
	ycc_udma_free(req->src_vaddr);
}

static inline int ycc_aead_submit_desc(struct aead_req *aead_req, uint8_t cmd)
{
	struct aead_ctx *cipher = aead_req->ctx;
	struct ycc_crypto_ctx *ctx = (struct ycc_crypto_ctx *)cipher->__ctx;
	struct ycc_crypto_req *req = (struct ycc_crypto_req *)aead_req->__req;
	struct ycc_aead_cmd *ycc_aead_cmd = &req->desc.cmd.aead_cmd;
	struct ycc_flags *aflags;
	int taglen = cipher->authsize;
	uint16_t new_aad_len;
	uint32_t new_cryptlen;
	int ret;

	memset(req, 0, sizeof(*req));
	req->ctx = ctx;
	req->aead_req = aead_req;

	ret = ycc_aead_format_data(req, &new_aad_len, &new_cryptlen, cmd);
	if (ret < 0)
		goto out;

	ret = ycc_aead_fill_key(req);
	if (ret < 0)
		goto free_mem;

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags) {
		ret = -ENOMEM;
		goto free_mem;
	}

	memset(&req->desc.cmd, 0, sizeof(ycc_real_cmd_t));
	aflags->ptr = (void *)req;
	aflags->ycc_done_callback = ycc_aead_callback;
	req->desc.private_ptr = (uint64_t)aflags;

	ycc_aead_cmd->cmd_id = cmd;
	ycc_aead_cmd->mode   = ctx->mode;
	ycc_aead_cmd->sptr   = req->src_paddr;
	ycc_aead_cmd->dptr   = req->dst_paddr;
	if (cmd == YCC_CMD_GCM_DEC || cmd == YCC_CMD_CCM_DEC)
		new_cryptlen = aead_req->cryptlen - taglen;

	ycc_aead_cmd->dlen   = new_cryptlen;
	ycc_aead_cmd->keyptr = req->key_paddr;
	ycc_aead_cmd->aadlen = new_aad_len;
	ycc_aead_cmd->taglen = taglen;

	/* 4, submit desc to cmd queue */
	ret = ycc_enqueue(ctx->ring, &req->desc);
	if (!ret)
		return -EINPROGRESS;

	free(aflags);
free_mem:
	ycc_aead_free_mem(req);
out:
	return ret;
}

static int ycc_aead_gcm_encrypt(struct aead_req *aead_req)
{
	return ycc_aead_submit_desc(aead_req, YCC_CMD_GCM_ENC);
}

static int ycc_aead_gcm_decrypt(struct aead_req *aead_req)
{
	return ycc_aead_submit_desc(aead_req, YCC_CMD_GCM_DEC);
}

static int ycc_aead_ccm_encrypt(struct aead_req *aead_req)
{
	return ycc_aead_submit_desc(aead_req, YCC_CMD_CCM_ENC);
}

static int ycc_aead_ccm_decrypt(struct aead_req *aead_req)
{
	return ycc_aead_submit_desc(aead_req, YCC_CMD_CCM_DEC);
}

static struct aead_alg aead_algs[] = {
	{
		.name = "gcm(aes)",
		.blocksize = 1,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_aead_init,
		.exit = ycc_aead_exit,
		.setkey = ycc_aead_setkey,
		.decrypt = ycc_aead_gcm_decrypt,
		.encrypt = ycc_aead_gcm_encrypt,
		.ivsize = 12,
		.maxauthsize = AES_BLOCK_SIZE,
	},
	{
		.name = "ccm(aes)",
		.blocksize = 1,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_aead_init,
		.exit = ycc_aead_exit,
		.setkey = ycc_aead_setkey,
		.decrypt = ycc_aead_ccm_decrypt,
		.encrypt = ycc_aead_ccm_encrypt,
		.ivsize = AES_BLOCK_SIZE,
		.maxauthsize = AES_BLOCK_SIZE,
	},
	{
		.name = "gcm(sm4)",
		.blocksize = 1,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_aead_init,
		.exit = ycc_aead_exit,
		.setkey = ycc_aead_setkey,
		.decrypt = ycc_aead_gcm_decrypt,
		.encrypt = ycc_aead_gcm_encrypt,
		.ivsize = SM4_BLOCK_SIZE,
		.maxauthsize = SM4_BLOCK_SIZE,
	},
	{
		.name = "ccm(sm4)",
		.blocksize = 1,
		.ctxsize = sizeof(struct ycc_crypto_ctx),
		.init = ycc_aead_init,
		.exit = ycc_aead_exit,
		.setkey = ycc_aead_setkey,
		.decrypt = ycc_aead_ccm_decrypt,
		.encrypt = ycc_aead_ccm_encrypt,
		.ivsize = SM4_BLOCK_SIZE,
		.maxauthsize = SM4_BLOCK_SIZE,
	},
};

extern struct aead_alg_entry *aead_alg_entries;
void aead_register_algs(void)
{
	struct aead_alg_entry *prev = aead_alg_entries, *cur;
	uint32_t array_size = sizeof(aead_algs) / sizeof(struct aead_alg);
	int i;

	for (i = 0; i < array_size; i++) {
		cur = malloc(sizeof(struct aead_alg_entry));
		if (!cur) {
			ycc_err("Failed to alloc memory for alg:%s\n", aead_algs[i].name);
			break;
		}
		cur->alg = &aead_algs[i];
		cur->next = NULL;
		if (!prev) {
			aead_alg_entries = cur;
			prev = cur;
		} else {
			prev->next = cur;
			prev = cur;
		}
	}
}

void aead_unregister_algs(void)
{
	struct aead_alg_entry *cur = aead_alg_entries, *next;

	while (cur) {
		next = cur->next;
		free(cur);
		cur = next;
	};
	aead_alg_entries = NULL;
}
