// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <openssl/rsa.h>
#include <openssl/bn.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <endian.h>
#include <errno.h>

#include "rsaprivkey.asn1.h"
#include "rsapubkey.asn1.h"
#include "../utils/utils.h"
#include "udma_ulib.h"
#include "ycc_algs.h"
#include "pke.h"

static int ycc_rsa_padding_retrieve(struct ycc_pke_req *rsa_req)
{
	struct akcipher_req *req = rsa_req->req;
	struct ycc_rsa_ctx *ctx = rsa_req->ctx.rsa_ctx;
	unsigned char *from = rsa_req->dst_vaddr;
	unsigned char *to = req->dst;
	int ret = -1;

	switch (req->padding) {
	case RSA_PKCS1_PADDING:
		ret = RSA_padding_check_PKCS1_type_2(to, ctx->key_len, from, req->src_len,
						     ctx->key_len);
		break;
	case RSA_PKCS1_OAEP_PADDING:
		ret = RSA_padding_check_PKCS1_OAEP(to, ctx->key_len, from, req->src_len,
						   ctx->key_len, NULL, 0);
		break;
	case RSA_SSLV23_PADDING:
		ret = RSA_padding_check_SSLv23(to, ctx->key_len, from, req->src_len, ctx->key_len);
		break;
	case RSA_X931_PADDING:
		ret = RSA_padding_check_X931(to, ctx->key_len, from, req->src_len, ctx->key_len);
		break;
	}

	return ret;
}

static int ycc_rsa_done_callback(void *ptr, uint16_t state)
{
	struct ycc_pke_req *rsa_req = (struct ycc_pke_req *)ptr;
	struct akcipher_req *req = rsa_req->req;
	int ret = 0;

	if (req->padding != RSA_NO_PADDING && rsa_req->type == YCC_PKE_DEC) {
		ret = ycc_rsa_padding_retrieve(rsa_req);

		if (ret > 0) {
			req->dst_len = ret;
			ret = 0;
		}
	} else if (req->dst && rsa_req->type != YCC_PKE_VERIFY) {
		YCC_memcpy(req->dst, rsa_req->dst_vaddr, req->dst_len);
	}

	ycc_udma_free(rsa_req->src_vaddr);
	if (rsa_req->dst_vaddr)
		ycc_udma_free(rsa_req->dst_vaddr);

	if (req->complete)
		req->complete(req, (state == CMD_SUCCESS && ret == 0) ? 0 : -EBADMSG);
	return 0;
}

static int ycc_rsa_padding(struct ycc_pke_req *rsa_req)
{
	struct akcipher_req *req = rsa_req->req;
	struct ycc_rsa_ctx *ctx = rsa_req->ctx.rsa_ctx;
	unsigned char *to = rsa_req->src_vaddr;
	unsigned char *from = req->src;
	int ret = 0;

	if ((req->padding != RSA_NO_PADDING) &&
	    (req->padding != RSA_PKCS1_PADDING) &&
	    (req->padding != RSA_PKCS1_OAEP_PADDING) &&
	    (req->padding != RSA_SSLV23_PADDING) &&
	    (req->padding != RSA_X931_PADDING))
		return -EINVAL;

	switch (req->padding) {
	case RSA_PKCS1_PADDING:
		if (rsa_req->type == YCC_PKE_ENC)
			ret = RSA_padding_add_PKCS1_type_2(to, ctx->key_len, from, req->src_len);
		else
			ret = RSA_padding_add_PKCS1_type_1(to, ctx->key_len, from, req->src_len);
		break;
	case RSA_PKCS1_OAEP_PADDING:
		ret = RSA_padding_add_PKCS1_OAEP(to, ctx->key_len, from, req->src_len, NULL, 0);
		break;
	case RSA_SSLV23_PADDING:
		ret = RSA_padding_add_SSLv23(to, ctx->key_len, from, req->src_len);
		break;
	case RSA_X931_PADDING:
		ret = RSA_padding_add_X931(to, ctx->key_len, from, req->src_len);
		break;
	case RSA_NO_PADDING:
		ret = RSA_padding_check_none(to, ctx->key_len, from, req->src_len, 0);
		break;
	}

	if (ret <= 0)
		return -1;

	return 0;
}

static int ycc_prepare_dma_buf(struct ycc_pke_req *rsa_req, bool is_src)
{
	struct ycc_rsa_ctx *ctx = rsa_req->ctx.rsa_ctx;
	struct akcipher_req *req = rsa_req->req;
	unsigned int dma_length = ctx->key_len;
	int ret;

	if (rsa_req->type == YCC_PKE_VERIFY) {
		if (req->dst_len > ctx->key_len)
			return -EINVAL;
		dma_length = ctx->key_len * 2;
	}

	if (is_src) {
		if (ctx->key_len < req->src_len)
			return -EINVAL;

		rsa_req->src_vaddr = ycc_udma_malloc(ALIGN(dma_length, MEM_ALIGNMENT_64));
		if (!rsa_req->src_vaddr)
			return -1;

		rsa_req->src_paddr = virt_to_phys(rsa_req->src_vaddr);
		memset((void *)rsa_req->src_vaddr, 0, dma_length);

		if (rsa_req->type == YCC_PKE_ENC ||
		    rsa_req->type == YCC_PKE_VERIFY ||
		    rsa_req->type == YCC_PKE_SIGN) {
			/*
			 * Copy signature to the end. For verifying, src_len is msg length,
			 * dst_len is signature len, if req->dst != NULL, req->dst is signature,
			 * req->src is msg, otherwise src format is: |no padding msg|signature|
			 */
			if (rsa_req->type == YCC_PKE_VERIFY) {
				if (ctx->key_len < req->dst_len) {
					ycc_udma_free(rsa_req->src_vaddr);
					return -1;
				}
				if (req->dst)
					YCC_memcpy(rsa_req->src_vaddr + ctx->key_len, req->dst,
					       req->dst_len);
				else
					YCC_memcpy(rsa_req->src_vaddr + ctx->key_len,
					       req->src + req->src_len, req->dst_len);
			}
			ret = ycc_rsa_padding(rsa_req);
			if (ret < 0) {
				ycc_udma_free(rsa_req->src_vaddr);
				return -1;
			}
		} else {
			YCC_memcpy(rsa_req->src_vaddr, req->src, req->src_len);
		}
	} else {
		rsa_req->dst_vaddr = ycc_udma_malloc(ALIGN(dma_length, MEM_ALIGNMENT_64));
		if (!rsa_req->dst_vaddr)
			return -1;

		rsa_req->dst_paddr = virt_to_phys(rsa_req->dst_vaddr);
		memset((void *)rsa_req->dst_vaddr, 0, dma_length);
	}

	return 0;
}

/*
 * Using public key to encrypt or verify
 */
static int ycc_rsa_submit_pub(struct akcipher_req *req, bool is_enc)
{
	struct akcipher_ctx *cipher = req->ctx;
	struct ycc_rsa_ctx *ctx = (struct ycc_rsa_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_pke_req *rsa_req = (struct ycc_pke_req *)req->__req;
	struct ycc_rsa_enc_cmd *rsa_enc_cmd;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	rsa_req->ctx.rsa_ctx = ctx;
	rsa_req->req = req;

	if (!ctx->pub_key_vaddr)
		return -EINVAL;

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)rsa_req;
	aflags->ycc_done_callback = ycc_rsa_done_callback;

	memset(&rsa_req->desc, 0, sizeof(rsa_req->desc));
	rsa_req->desc.private_ptr = (uint64_t)(void *)aflags;
	rsa_req->type = is_enc ? YCC_PKE_ENC : YCC_PKE_VERIFY;

	rsa_enc_cmd         = &rsa_req->desc.cmd.rsa_enc_cmd;
	rsa_enc_cmd->cmd_id = is_enc ? YCC_CMD_RSA_ENC : YCC_CMD_RSA_VERIFY;
	rsa_enc_cmd->keyptr = ctx->pub_key_paddr;
	rsa_enc_cmd->elen   = ctx->e_len << 3;
	rsa_enc_cmd->nlen   = ctx->key_len << 3;

	ret = ycc_prepare_dma_buf(rsa_req, true);
	if (ret)
		goto free_aflags;

	rsa_enc_cmd->sptr = rsa_req->src_paddr;
	if (is_enc) {
		ret = ycc_prepare_dma_buf(rsa_req, false);
		if (ret)
			goto free_src;

		rsa_enc_cmd->dptr = rsa_req->dst_paddr;
	}

	ret = ycc_enqueue(ring, (void *)&rsa_req->desc);
	if (!ret)
		return -EINPROGRESS;

	if (is_enc)
		ycc_udma_free(rsa_req->dst_vaddr);
free_src:
	ycc_udma_free(rsa_req->src_vaddr);
free_aflags:
	free(aflags);
out:
	return ret;
}

/*
 * Using private key to decrypt or signature
 */
static int ycc_rsa_submit_priv(struct akcipher_req *req, bool is_dec)
{
	struct akcipher_ctx *cipher = req->ctx;
	struct ycc_rsa_ctx *ctx = (struct ycc_rsa_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_pke_req *rsa_req = (struct ycc_pke_req *)req->__req;
	struct ycc_rsa_dec_cmd *rsa_dec_cmd;
	struct ycc_cmd_desc *desc;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	rsa_req->ctx.rsa_ctx = ctx;
	rsa_req->req = req;

	if (!ctx->priv_key_vaddr)
		return -EINVAL;

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)rsa_req;
	aflags->ycc_done_callback = ycc_rsa_done_callback;

	desc = &rsa_req->desc;
	memset(desc, 0, sizeof(struct ycc_cmd_desc));
	desc->private_ptr = (uint64_t)(void *)aflags;

	rsa_dec_cmd         = &desc->cmd.rsa_dec_cmd;
	rsa_dec_cmd->keyptr = ctx->priv_key_paddr;
	rsa_dec_cmd->elen   = ctx->e_len << 3;
	rsa_dec_cmd->nlen   = ctx->key_len << 3;

	rsa_req->type = is_dec ? YCC_PKE_DEC : YCC_PKE_SIGN;
	if (ctx->crt_mode)
		rsa_dec_cmd->cmd_id = is_dec ? YCC_CMD_RSA_CRT_DEC : YCC_CMD_RSA_CRT_SIGN;
	else
		rsa_dec_cmd->cmd_id = is_dec ? YCC_CMD_RSA_DEC : YCC_CMD_RSA_SIGN;

	ret = ycc_prepare_dma_buf(rsa_req, true);
	if (ret)
		goto free_aflags;

	ret = ycc_prepare_dma_buf(rsa_req, false);
	if (ret)
		goto free_src;

	rsa_dec_cmd->sptr = rsa_req->src_paddr;
	rsa_dec_cmd->dptr = rsa_req->dst_paddr;

	ret = ycc_enqueue(ring, (void *)desc);
	if (!ret)
		return -EINPROGRESS;

	ycc_udma_free(rsa_req->dst_vaddr);
free_src:
	ycc_udma_free(rsa_req->src_vaddr);
free_aflags:
	free(aflags);
out:
	return ret;
}

static int ycc_rsa_encrypt(struct akcipher_req *req)
{
	return ycc_rsa_submit_pub(req, true);
}

static int ycc_rsa_decrypt(struct akcipher_req *req)
{
	return ycc_rsa_submit_priv(req, true);
}

static int ycc_rsa_verify(struct akcipher_req *req)
{
	return ycc_rsa_submit_pub(req, false);
}

static int ycc_rsa_sign(struct akcipher_req *req)
{
	return ycc_rsa_submit_priv(req, false);
}

static int ycc_rsa_validate_n(unsigned int len)
{
	unsigned int bitslen = len << 3;

	switch (bitslen) {
	case 512:
	case 1024:
	case 1536:
	case 2048:
	case 3072:
	case 4096:
		return 0;
	default:
		return -EINVAL;
	}
}

static void __ycc_rsa_drop_leading_zeros(const uint8_t **ptr, uint32_t *len)
{
	if (!*ptr || !*len)
		return;

	while (!**ptr && *len) {
		(*ptr)++;
		(*len)--;
	}
}

static int ycc_rsa_set_n(struct ycc_rsa_ctx *ctx, const unsigned char *value,
			 uint32_t value_len, bool private)
{
	const unsigned char *ptr = value;

	if (!ctx->e_len || !value_len)
		return -EINVAL;

	if (!ctx->key_len)
		ctx->key_len = value_len;

	if (private && !ctx->crt_mode) {
		YCC_memcpy(ctx->priv_key_vaddr + ctx->e_len + YCC_PIN_SZ +
		       ctx->rsa_key->d_sz, ptr, value_len);
	}

	YCC_memcpy(ctx->pub_key_vaddr + ctx->e_len, ptr, value_len);
	return 0;
}

static int ycc_rsa_set_e(struct ycc_rsa_ctx *ctx, const unsigned char *value,
			 uint32_t value_len, bool private)
{
	const unsigned char *ptr = value;

	if (!ctx->key_len || !value_len || value_len > YCC_RSA_E_SZ_MAX)
		return -EINVAL;

	ctx->e_len = value_len;
	if (private)
		YCC_memcpy(ctx->priv_key_vaddr, ptr, value_len);

	YCC_memcpy(ctx->pub_key_vaddr, ptr, value_len);
	return 0;
}

static int ycc_rsa_set_d(struct ycc_rsa_ctx *ctx, const unsigned char *value,
			 uint32_t value_len)
{
	const unsigned char *ptr = value;

	if (!ctx->key_len || !value_len || value_len > ctx->key_len)
		return -EINVAL;

	YCC_memcpy(ctx->priv_key_vaddr + ctx->e_len + YCC_PIN_SZ, ptr, value_len);
	return 0;
}

static int ycc_rsa_set_crt_param(char *param, uint32_t half_key_len,
				 const unsigned char *value, uint32_t value_len)
{
	const unsigned char *ptr = value;
	uint32_t len = value_len;

	if (!len || len > half_key_len)
		return -EINVAL;

	YCC_memcpy(param, ptr, len);
	return 0;
}

static int ycc_rsa_setkey_crt(struct ycc_rsa_ctx *ctx, struct rsa_key *rsa_key)
{
	unsigned int half_key_len = ctx->key_len >> 1;
	char *tmp = (char *)ctx->priv_key_vaddr;
	int ret;

	tmp += ctx->rsa_key->e_sz + 16;
	ret = ycc_rsa_set_crt_param(tmp, half_key_len, rsa_key->p,
				    rsa_key->p_sz);
	if (ret)
		goto err;

	tmp += half_key_len;
	ret = ycc_rsa_set_crt_param(tmp, half_key_len, rsa_key->q,
				    rsa_key->q_sz);
	if (ret)
		goto err;

	tmp += half_key_len;
	ret = ycc_rsa_set_crt_param(tmp, half_key_len, rsa_key->dp,
				    rsa_key->dp_sz);
	if (ret)
		goto err;

	tmp += half_key_len;
	ret = ycc_rsa_set_crt_param(tmp, half_key_len, rsa_key->dq,
				    rsa_key->dq_sz);
	if (ret)
		goto err;

	tmp += half_key_len;
	ret = ycc_rsa_set_crt_param(tmp, half_key_len, rsa_key->qinv,
				    rsa_key->qinv_sz);
	if (ret)
		goto err;

	ctx->crt_mode = true;
	return 0;

err:
	ctx->crt_mode = false;
	return ret;
}

static void ycc_rsa_clear_ctx(struct ycc_rsa_ctx *ctx)
{
	if (ctx->pub_key_vaddr) {
		ycc_udma_free(ctx->pub_key_vaddr);
		ctx->pub_key_vaddr = NULL;
	}

	if (ctx->priv_key_vaddr) {
		ycc_udma_free(ctx->priv_key_vaddr);
		ctx->priv_key_vaddr = NULL;
	}

	if (ctx->rsa_key) {
		memset(ctx->rsa_key, 0, sizeof(struct rsa_key));
		free(ctx->rsa_key);
		ctx->rsa_key = NULL;
	}

	ctx->key_len = 0;
	ctx->e_len = 0;
	ctx->crt_mode = false;
}

static void ycc_rsa_drop_leading_zeros(struct rsa_key *rsa_key)
{
	__ycc_rsa_drop_leading_zeros(&rsa_key->n, &rsa_key->n_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->e, &rsa_key->e_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->d, &rsa_key->d_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->p, &rsa_key->p_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->q, &rsa_key->q_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->dp, &rsa_key->dp_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->dq, &rsa_key->dq_sz);
	__ycc_rsa_drop_leading_zeros(&rsa_key->qinv, &rsa_key->qinv_sz);
}

#if OPENSSL_VERSION_NUMBER < 0x10100000L
void __attribute__((weak)) RSA_get0_key(const RSA *r,
					const BIGNUM **n,
					const BIGNUM **e,
					const BIGNUM **d)
{
	if (n)
		*n = r->n;
	if (e)
		*e = r->e;
	if (d)
		*d = r->d;
}

void __attribute__((weak)) RSA_get0_factors(const RSA *r,
					    const BIGNUM **p, const BIGNUM **q)
{
	if (p)
		*p = r->p;
	if (q)
		*q = r->q;
}

void __attribute__((weak)) RSA_get0_crt_params(const RSA *r,
					       const BIGNUM **dmp1,
					       const BIGNUM **dmq1,
					       const BIGNUM **iqmp)
{
	if (dmp1)
		*dmp1 = r->dmp1;
	if (dmq1)
		*dmq1 = r->dmq1;
	if (iqmp)
		*iqmp = r->iqmp;
}
#endif

static int ycc_rsa_convert_key(struct rsa_key *rsa_key, RSA *rsa)
{
	const BIGNUM *n = NULL, *e = NULL, *d = NULL;
	const BIGNUM *p = NULL, *q = NULL;
	const BIGNUM *dp = NULL, *dq = NULL, *qinv = NULL;
	unsigned char *buf;

	if (!rsa)
		return 0;

	RSA_get0_key(rsa, &n, &e, &d);
	RSA_get0_factors((const RSA *)rsa, &p, &q);
	RSA_get0_crt_params((const RSA *)rsa, &dp, &dq, &qinv);

	rsa_key->n_sz = BN_num_bytes(n);
	rsa_key->e_sz = BN_num_bytes(e);
	if (d)
		rsa_key->d_sz = BN_num_bytes(d);
	if (p)
		rsa_key->p_sz = BN_num_bytes(p);
	if (q)
		rsa_key->q_sz = BN_num_bytes(q);
	if (dp)
		rsa_key->dp_sz = BN_num_bytes(dp);
	if (dq)
		rsa_key->dq_sz = BN_num_bytes(dq);
	if (qinv)
		rsa_key->qinv_sz = BN_num_bytes(qinv);

	buf = malloc(rsa_key->n_sz + rsa_key->e_sz + rsa_key->d_sz + rsa_key->p_sz +
		     rsa_key->q_sz + rsa_key->dp_sz + rsa_key->dq_sz + rsa_key->qinv_sz);
	if (!buf)
		return -ENOMEM;

	BN_bn2bin(n, buf);
	rsa_key->n = buf;
	buf += rsa_key->n_sz;
	BN_bn2bin(e, buf);
	rsa_key->e = buf;
	buf += rsa_key->e_sz;
	if (d) {
		BN_bn2bin(d, buf);
		rsa_key->d = buf;
		buf += rsa_key->d_sz;
	}
	if (p) {
		BN_bn2bin(p, buf);
		rsa_key->p = buf;
		buf += rsa_key->p_sz;
	}
	if (q) {
		BN_bn2bin(q, buf);
		rsa_key->q = buf;
		buf += rsa_key->q_sz;
	}
	if (dp) {
		BN_bn2bin(dp, buf);
		rsa_key->dp = buf;
		buf += rsa_key->dp_sz;
	}
	if (dq) {
		BN_bn2bin(dq, buf);
		rsa_key->dq = buf;
		buf += rsa_key->dq_sz;
	}
	if (qinv) {
		BN_bn2bin(qinv, buf);
		rsa_key->qinv = buf;
	}

	return 0;
}

static inline void ycc_rsa_free_key(struct ycc_rsa_ctx *ctx)
{
	if (!ctx->is_asn1 && ctx->rsa_key)
		free((void *)ctx->rsa_key->n);
}

static int ycc_rsa_alloc_key(struct ycc_rsa_ctx *ctx, bool priv, RSA *rsa)
{
	struct rsa_key *rsa_key = ctx->rsa_key;
	unsigned int half_key_len;
	uint32_t size = 0;
	int ret;

	ret =  ycc_rsa_convert_key(rsa_key, rsa);
	if (ret)
		goto out;

	ycc_rsa_drop_leading_zeros(rsa_key);
	ctx->key_len = rsa_key->n_sz;

	ret = ycc_rsa_validate_n(ctx->key_len);
	if (ret) {
		ycc_err("Invalid n size:%d bits\n", ctx->key_len << 3);
		goto free_key;
	}

	ret = -ENOMEM;
	if (priv) {
		if (!(rsa_key->p_sz + rsa_key->q_sz + rsa_key->dp_sz +
		      rsa_key->dq_sz + rsa_key->qinv_sz)) {
			size = ALIGN(rsa_key->e_sz + YCC_PIN_SZ + rsa_key->d_sz +
				     ctx->key_len, YCC_CMD_DATA_ALIGN_SZ);
		} else {
			half_key_len = ctx->key_len >> 1;
			size = ALIGN(rsa_key->e_sz + YCC_PIN_SZ + half_key_len *
				     YCC_RSA_CRT_PARAMS, YCC_CMD_DATA_ALIGN_SZ);
			ctx->crt_mode = true;
		}

		ctx->priv_key_vaddr = ycc_udma_malloc(size);
		if (!ctx->priv_key_vaddr)
			goto out;

		ctx->priv_key_paddr = virt_to_phys(ctx->priv_key_vaddr);
		memset(ctx->priv_key_vaddr, 0, size);
	}

	if (!ctx->pub_key_vaddr) {
		size = ALIGN(ctx->key_len + rsa_key->e_sz, YCC_CMD_DATA_ALIGN_SZ);
		ctx->pub_key_vaddr = ycc_udma_malloc(size);
		if (!ctx->pub_key_vaddr)
			goto out;

		ctx->pub_key_paddr = virt_to_phys(ctx->pub_key_vaddr);
		memset(ctx->pub_key_vaddr, 0, size);
	}

	ret = ycc_rsa_set_e(ctx, rsa_key->e, rsa_key->e_sz, priv);
	if (ret) {
		ycc_err("Failed to set e for rsa %s key\n",
			priv ? "private" : "public");
		goto out;
	}

	ret = ycc_rsa_set_n(ctx, rsa_key->n, rsa_key->n_sz, priv);
	if (ret) {
		ycc_err("Failed to set n for rsa private key\n");
		goto out;
	}

	if (priv) {
		if (ctx->crt_mode) {
			ret = ycc_rsa_setkey_crt(ctx, rsa_key);
			if (ret) {
				ycc_err("Failed to set private key for rsa crt key\n");
				goto out;
			}
		} else {
			ret = ycc_rsa_set_d(ctx, rsa_key->d, rsa_key->d_sz);
			if (ret) {
				ycc_err("Failed to set d for rsa private key\n");
				goto out;
			}
		}
	}

	return 0;

free_key:
	ycc_rsa_free_key(ctx);
out:
	ycc_rsa_clear_ctx(ctx);
	return ret;
}

extern int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key, unsigned int key_len);
extern int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key, unsigned int key_len);

static int ycc_rsa_setkey(struct akcipher_ctx *cipher, void *key,
			  unsigned int keylen, bool priv)
{
	struct ycc_rsa_ctx *ctx = (struct ycc_rsa_ctx *)cipher->__ctx;
	struct rsa_key *rsa_key;
	RSA *rsa = NULL;
	int ret = -1;

	ycc_rsa_clear_ctx(ctx);

	rsa_key = malloc(sizeof(struct rsa_key));
	if (!rsa_key)
		return -ENOMEM;

	memset(rsa_key, 0, sizeof(struct rsa_key));

	if (keylen) {
		if (priv)
			ret = rsa_parse_priv_key(rsa_key, key, keylen);
		else if (!ctx->pub_key_vaddr)
			ret = rsa_parse_pub_key(rsa_key, key, keylen);
		if (ret) {
			ycc_err("Failed to parse %s key\n", priv ? "private" : "public");
			free(rsa_key);
			return ret;
		}
		ctx->is_asn1 = true;
	} else {
		ctx->is_asn1 = false;
		rsa = (RSA *)key;
	}

	ctx->rsa_key = rsa_key;
	return ycc_rsa_alloc_key(ctx, priv, rsa);
}

static int ycc_rsa_setpubkey(struct akcipher_ctx *cipher, void *key,
			     unsigned int keylen)
{
	return ycc_rsa_setkey(cipher, key, keylen, false);
}

static int ycc_rsa_setprivkey(struct akcipher_ctx *cipher, void *key,
			      unsigned int keylen)
{
	return ycc_rsa_setkey(cipher, key, keylen, true);
}

static unsigned int ycc_rsa_max_size(struct akcipher_ctx *cipher)
{
	struct ycc_rsa_ctx *ctx = (struct ycc_rsa_ctx *)cipher->__ctx;

	return ctx->rsa_key ? ctx->key_len : 0;
}

static int ycc_rsa_init(struct akcipher_ctx *cipher)
{
	struct ycc_rsa_ctx *ctx = (struct ycc_rsa_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EINVAL;

	ctx->ring = ring;
	ctx->key_len = 0;
	return 0;
}

static void ycc_rsa_exit(struct akcipher_ctx *cipher)
{
	struct ycc_rsa_ctx *ctx = (struct ycc_rsa_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	ycc_rsa_free_key(ctx);
	ycc_rsa_clear_ctx(ctx);
}

#define YCC_ECDSA_INIT(curve, Curve)						\
static int ycc_ecdsa_##curve##_init(struct akcipher_ctx *cipher)		\
{										\
	struct ycc_ecdsa_ctx *ctx = (struct ycc_ecdsa_ctx *)cipher->__ctx;	\
	struct ycc_ring *ring;							\
	ring = ycc_crypto_get_ring();						\
	if (!ring)								\
		return -EINVAL;							\
	ctx->ring = ring;							\
	ctx->curve_id = YCC_EC_##Curve;						\
	ctx->ndigits = YCC_EC_DIGITS_##Curve;					\
	return 0;								\
}

YCC_ECDSA_INIT(p192, P192)
YCC_ECDSA_INIT(p224, P224)
YCC_ECDSA_INIT(p256, P256)
YCC_ECDSA_INIT(p384, P384)
YCC_ECDSA_INIT(p521, P521)

#define YCC_ECDSA_EXIT(curve)							\
static void ycc_ecdsa_##curve##_exit(struct akcipher_ctx *cipher)		\
{										\
	struct ycc_ecdsa_ctx *ctx = (struct ycc_ecdsa_ctx *)cipher->__ctx;	\
	if (ctx->ring)								\
		ycc_crypto_free_ring(ctx->ring);				\
	if (ctx->priv_key_vaddr)						\
		ycc_udma_free(ctx->priv_key_vaddr);				\
	if (ctx->pub_key_vaddr)							\
		ycc_udma_free(ctx->pub_key_vaddr);				\
}

YCC_ECDSA_EXIT(p192)
YCC_ECDSA_EXIT(p224)
YCC_ECDSA_EXIT(p256)
YCC_ECDSA_EXIT(p384)
YCC_ECDSA_EXIT(p521)

static int ycc_ecdsa_done_callback(void *ptr, uint16_t state)
{
	struct ycc_pke_req *ecdsa_req = (struct ycc_pke_req *)ptr;
	struct akcipher_req *req = ecdsa_req->req;
	int ret = 0;

	if (ecdsa_req->dst_vaddr) {
		YCC_memcpy(req->dst, ecdsa_req->dst_vaddr, req->dst_len);
		ycc_udma_free(ecdsa_req->dst_vaddr);
	}

	ycc_udma_free(ecdsa_req->src_vaddr);

	if (req->complete)
		req->complete(req, (state == CMD_SUCCESS && ret == 0) ? 0 : -EBADMSG);
	return 0;
}

static int ycc_ecdsa_sign(struct akcipher_req *req)
{
	struct akcipher_ctx *cipher = req->ctx;
	struct ycc_ecdsa_ctx *ctx = (struct ycc_ecdsa_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_pke_req *ecdsa_req = (struct ycc_pke_req *)req->__req;
	struct ycc_ecdsa_sign_cmd *ecdsa_sign_cmd;
	struct ycc_cmd_desc *desc;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	ecdsa_req->ctx.ecdsa_ctx = ctx;
	ecdsa_req->req = req;

	if (!ctx->priv_key_vaddr || !req->src || !req->dst)
		return -EINVAL;

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)ecdsa_req;
	aflags->ycc_done_callback = ycc_ecdsa_done_callback;

	desc = &ecdsa_req->desc;

	memset(desc, 0, sizeof(struct ycc_cmd_desc));
	desc->private_ptr = (uint64_t)(void *)aflags;

	ecdsa_req->src_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits, MEM_ALIGNMENT_64));
	if (!ecdsa_req->src_vaddr)
		goto free_aflags;
	ecdsa_req->src_paddr = virt_to_phys(ecdsa_req->src_vaddr);

	ecdsa_req->dst_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
	if (!ecdsa_req->dst_vaddr)
		goto free_src;
	ecdsa_req->dst_paddr = virt_to_phys(ecdsa_req->dst_vaddr);

	ecdsa_sign_cmd = &desc->cmd.ecdsa_sign_cmd;
	if (req->src_len > ctx->ndigits) { /* take left-most bytes */
		YCC_memcpy(ecdsa_req->src_vaddr, req->src, ctx->ndigits);
		ecdsa_sign_cmd->sig_len = ctx->ndigits;
	} else { /* set leading zero */
		YCC_memcpy(ecdsa_req->src_vaddr, req->src, req->src_len);
		ecdsa_sign_cmd->sig_len = req->src_len;
	}

	ecdsa_sign_cmd->cmd_id   = YCC_CMD_ECDSA_SIGN;
	ecdsa_sign_cmd->curve_id = ctx->curve_id;
	ecdsa_sign_cmd->sptr     = ecdsa_req->src_paddr;
	ecdsa_sign_cmd->keyptr   = ctx->priv_key_paddr;
	ecdsa_sign_cmd->dptr     = ecdsa_req->dst_paddr;

	ret = ycc_enqueue(ring, (void *)desc);
	if (!ret)
		return -EINPROGRESS;

	ycc_udma_free(ecdsa_req->dst_vaddr);
free_src:
	ycc_udma_free(ecdsa_req->src_vaddr);
free_aflags:
	free(aflags);
out:
	return ret;
}

/*
 * If req->dst is NULL, req->src stores hash + signature,
 * otherwise, req->src stores hash
 */
static int ycc_ecdsa_verify(struct akcipher_req *req)
{
	struct akcipher_ctx *cipher = req->ctx;
	struct ycc_ecdsa_ctx *ctx = (struct ycc_ecdsa_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_pke_req *ecdsa_req = (struct ycc_pke_req *)req->__req;
	struct ycc_ecdsa_verify_cmd *ecdsa_verify_cmd;
	struct ycc_cmd_desc *desc;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	ecdsa_req->ctx.ecdsa_ctx = ctx;
	ecdsa_req->req = req;

	if (!ctx->pub_key_vaddr || !req->src)
		return -EINVAL;

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)ecdsa_req;
	aflags->ycc_done_callback = ycc_ecdsa_done_callback;

	desc = &ecdsa_req->desc;

	memset(desc, 0, sizeof(struct ycc_cmd_desc));
	desc->private_ptr = (uint64_t)(void *)aflags;

	/*
	 * Validate length
	 * If dst is not null, dst_len must be ctx->ndigits << 1 which stores r + s.
	 * If dst is null, src_len must not less than ctx->ndigits << 1 as it stores digest + r + s
	 */
	if ((!req->dst && (req->src_len <= (ctx->ndigits << 1))) ||
	     (req->dst && req->dst_len != (ctx->ndigits << 1))) {
		ret = -EINVAL;
		goto free_aflags;
	}

	/* Max length is ctx->ndigits * 3 */
	ecdsa_req->src_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits * 3, MEM_ALIGNMENT_64));
	if (!ecdsa_req->src_vaddr)
		goto free_aflags;
	ecdsa_req->src_paddr = virt_to_phys(ecdsa_req->src_vaddr);

	ecdsa_verify_cmd = &desc->cmd.ecdsa_verify_cmd;
	if (req->dst) {
		if (req->src_len > ctx->ndigits) { /* take left-most bytes */
			YCC_memcpy(ecdsa_req->src_vaddr, req->src, ctx->ndigits);
			YCC_memcpy(ecdsa_req->src_vaddr + ctx->ndigits, req->dst, req->dst_len);
			ecdsa_verify_cmd->sig_len = ctx->ndigits;
		} else { /* set leading zero */
			YCC_memcpy(ecdsa_req->src_vaddr, req->src, req->src_len);
			YCC_memcpy(ecdsa_req->src_vaddr + req->src_len, req->dst, req->dst_len);
			ecdsa_verify_cmd->sig_len = req->src_len;
		}
	} else {
		YCC_memcpy(ecdsa_req->src_vaddr, req->src, req->src_len);
		ecdsa_verify_cmd->sig_len = req->src_len - (ctx->ndigits << 1);
	}

	ecdsa_verify_cmd->cmd_id   = YCC_CMD_ECDSA_VERIFY;
	ecdsa_verify_cmd->curve_id = ctx->curve_id;
	ecdsa_verify_cmd->sptr     = ecdsa_req->src_paddr;
	ecdsa_verify_cmd->keyptr   = ctx->pub_key_paddr;

	ret = ycc_enqueue(ring, (void *)desc);
	if (!ret)
		return -EINPROGRESS;

	ycc_udma_free(ecdsa_req->src_vaddr);
free_aflags:
	free(aflags);
out:
	return ret;
}

/*
 * The length of pubkey must be ctx->ndigits << 1, caller must pad 0 at
 * the most significant bits for x, y individually.
 */
static int ycc_ecdsa_setpubkey(struct akcipher_ctx *cipher, void *key,
			       unsigned int keylen)
{
	struct ycc_ecdsa_ctx *ctx = (struct ycc_ecdsa_ctx *)cipher->__ctx;

	if (keylen != ctx->ndigits <<  1 || !key)
		return -EINVAL;

	if (!ctx->pub_key_vaddr) {
		ctx->pub_key_vaddr = ycc_udma_malloc(ALIGN(keylen, MEM_ALIGNMENT_64));
		if (!ctx->pub_key_vaddr)
			return -ENOMEM;
	}

	YCC_memcpy(ctx->pub_key_vaddr, key, keylen);
	ctx->pub_key_paddr = virt_to_phys(ctx->pub_key_vaddr);
	return 0;
}

/*
 * The length of private key must be ctx->ndigits, caller must pad 0 at
 * the most significant bits if it is less than ctx->ndigits
 */
static int ycc_ecdsa_setprivkey(struct akcipher_ctx *cipher, void *key,
				unsigned int keylen)
{
	struct ycc_ecdsa_ctx *ctx = (struct ycc_ecdsa_ctx *)cipher->__ctx;

	if (keylen != ctx->ndigits || !key)
		return -EINVAL;

	if (!ctx->priv_key_vaddr) {
		ctx->priv_key_vaddr = ycc_udma_malloc(ALIGN(keylen + 16, MEM_ALIGNMENT_64));
		if (!ctx->priv_key_vaddr)
			return -ENOMEM;
	}

	memset(ctx->priv_key_vaddr, 0, 16);
	YCC_memcpy(ctx->priv_key_vaddr + 16, key, keylen);
	ctx->priv_key_paddr = virt_to_phys(ctx->priv_key_vaddr);
	return 0;
}

static unsigned int ycc_ecdsa_max_size(struct akcipher_ctx *cipher)
{
	return 0;
}

#define YCC_ECDSA_ALG(alg, curve)				\
	{							\
		.name = #alg,					\
		.ctxsize = sizeof(struct ycc_ecdsa_ctx),	\
		.sign = ycc_ecdsa_sign,				\
		.verify = ycc_ecdsa_verify,			\
		.set_pub_key = ycc_ecdsa_setpubkey,		\
		.set_priv_key = ycc_ecdsa_setprivkey,		\
		.max_size = ycc_ecdsa_max_size,			\
		.init = ycc_ecdsa_##curve##_init,		\
		.exit = ycc_ecdsa_##curve##_exit,		\
		.reqsize = sizeof(struct ycc_pke_req),		\
	}

static struct akcipher_alg akcipher_algs[] = {
	{
		.name = "rsa",
		.ctxsize = sizeof(struct ycc_rsa_ctx),
		.sign = ycc_rsa_sign,
		.verify = ycc_rsa_verify,
		.encrypt = ycc_rsa_encrypt,
		.decrypt = ycc_rsa_decrypt,
		.set_pub_key = ycc_rsa_setpubkey,
		.set_priv_key = ycc_rsa_setprivkey,
		.max_size = ycc_rsa_max_size,
		.init = ycc_rsa_init,
		.exit = ycc_rsa_exit,
		.reqsize = sizeof(struct ycc_pke_req),
	},
	YCC_ECDSA_ALG(ecdsa-nist-p192, p192),
	YCC_ECDSA_ALG(ecdsa-nist-p224, p224),
	YCC_ECDSA_ALG(ecdsa-nist-p256, p256),
	YCC_ECDSA_ALG(ecdsa-nist-p384, p384),
	YCC_ECDSA_ALG(ecdsa-nist-p521, p521),
};

extern struct akcipher_alg_entry *akcipher_alg_entries;
void akcipher_register_algs(void)
{
	struct akcipher_alg_entry *prev = akcipher_alg_entries, *cur;
	uint32_t array_size = sizeof(akcipher_algs) / sizeof(struct akcipher_alg);
	int i;

	for (i = 0; i < array_size; i++) {
		cur = malloc(sizeof(struct akcipher_alg_entry));
		if (!cur)
			break;

		cur->alg = &akcipher_algs[i];
		cur->next = NULL;
		if (!prev) {
			akcipher_alg_entries = cur;
			prev = cur;
		} else {
			prev->next = cur;
			prev = cur;
		}
	}
}

void akcipher_unregister_algs(void)
{
	struct akcipher_alg_entry *cur = akcipher_alg_entries, *next;

	while (cur) {
		next = cur->next;
		free(cur);
		cur = next;
	};
	akcipher_alg_entries = NULL;
}
