// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <endian.h>

#include "../utils/utils.h"
#include "udma_ulib.h"
#include "ycc_algs.h"
#include "kpp.h"

/* Set own private key */
static int ycc_ecdh_set_secret(struct kpp_ctx *cipher, void *buffer, unsigned int len)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;

	if (len != ctx->ndigits || !buffer)
		return -EINVAL;

	/* 16 bytes, pin + private key */
	ctx->private_key_vaddr = ycc_udma_malloc(ALIGN(len + 16, MEM_ALIGNMENT_64));
	if (!ctx->private_key_vaddr)
		return -ENOMEM;

	YCC_memcpy(ctx->private_key_vaddr + 16, buffer, len);
	ctx->private_key_paddr = virt_to_phys(ctx->private_key_vaddr);
	return 0;
}

static int ycc_kpp_done_callback(void *ptr, uint16_t state)
{
	struct ycc_kpp_req *kpp_req = (struct ycc_kpp_req *)ptr;
	struct ycc_kpp_ctx *ctx = kpp_req->ctx;
	struct kpp_req *req = kpp_req->req;
	int ret = 0;

	if (kpp_req->desc.cmd.kpp_cmd.gen_pub_cmd.cmd_id == YCC_CMD_KPP_GEN_KEY) {
		YCC_memcpy(req->dst, kpp_req->dst_vaddr, ctx->ndigits << 1);
		memset(kpp_req->dst_vaddr, 0, ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
		ycc_udma_free(kpp_req->dst_vaddr);
	} else if (kpp_req->desc.cmd.kpp_cmd.gen_key_pair_cmd.cmd_id == YCC_CMD_KPP_GEN_KEY_PAIR) {
		YCC_memcpy(req->src, kpp_req->src_vaddr + 80, ctx->ndigits);
		YCC_memcpy(req->dst, kpp_req->dst_vaddr, ctx->ndigits << 1);
		memset(kpp_req->dst_vaddr, 0, ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
		memset(kpp_req->src_vaddr, 0, ALIGN(ctx->ndigits + 80, MEM_ALIGNMENT_64));
		ycc_udma_free(kpp_req->src_vaddr);
		ycc_udma_free(kpp_req->dst_vaddr);
	} else if (kpp_req->desc.cmd.kpp_cmd.gen_shared_cmd.cmd_id == YCC_CMD_KPP_ECDH_SS) {
		YCC_memcpy(req->dst, kpp_req->dst_vaddr, ctx->ndigits);
		memset(kpp_req->dst_vaddr, 0, ALIGN(ctx->ndigits, MEM_ALIGNMENT_64));
		memset(kpp_req->src_vaddr, 0, ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
		ycc_udma_free(kpp_req->src_vaddr);
		ycc_udma_free(kpp_req->dst_vaddr);
	}

	if (req->complete)
		req->complete(req, (state == CMD_SUCCESS && ret == 0) ? 0 : -EBADMSG);
	return 0;
}

static int ycc_ecdh_generate_public_key(struct kpp_req *req)
{
	struct kpp_ctx *cipher = req->ctx;
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_kpp_req *kpp_req = (struct ycc_kpp_req *)req->__req;
	struct ycc_ecc_gen_pub_cmd *kpp_gen_pub_cmd;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	if (!req->dst || req->dst_len < (ctx->ndigits << 1))
		return -EINVAL;

	kpp_req->ctx = ctx;
	kpp_req->req = req;
	memset(&kpp_req->desc, 0, sizeof(kpp_req->desc));

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)kpp_req;
	aflags->ycc_done_callback = ycc_kpp_done_callback;
	kpp_req->desc.private_ptr = (uint64_t)(void *)aflags;

	kpp_req->dst_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
	if (!kpp_req->dst_vaddr)
		goto out;

	kpp_req->dst_paddr = virt_to_phys(kpp_req->dst_vaddr);

	kpp_gen_pub_cmd           = &kpp_req->desc.cmd.kpp_cmd.gen_pub_cmd;
	kpp_gen_pub_cmd->cmd_id   = YCC_CMD_KPP_GEN_KEY;
	kpp_gen_pub_cmd->curve_id = ctx->curve_id;
	kpp_gen_pub_cmd->priv_ptr = ctx->private_key_paddr;
	kpp_gen_pub_cmd->pub_ptr  = kpp_req->dst_paddr;

	ret = ycc_enqueue(ring, (uint8_t *)&kpp_req->desc);
	if (!ret)
		return -EINPROGRESS;

	free(aflags);
out:
	return ret;
}

static int ycc_ecdh_generate_key_pair(struct kpp_req *req)
{
	struct kpp_ctx *cipher = req->ctx;
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_kpp_req *kpp_req = (struct ycc_kpp_req *)req->__req;
	struct ycc_ecc_gen_key_pair_cmd *kpp_gen_key_pair_cmd;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	if (!req->src || !req->dst
	    || req->dst_len < (ctx->ndigits << 1)
	    || req->src_len < (ctx->ndigits))
		return -EINVAL;

	kpp_req->ctx = ctx;
	kpp_req->req = req;
	memset(&kpp_req->desc, 0, sizeof(kpp_req->desc));

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)kpp_req;
	aflags->ycc_done_callback = ycc_kpp_done_callback;
	kpp_req->desc.private_ptr = (uint64_t)(void *)aflags;

	kpp_req->dst_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
	if (!kpp_req->dst_vaddr)
		goto free_aflags;

	kpp_req->dst_paddr = virt_to_phys(kpp_req->dst_vaddr);

	kpp_req->src_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits + 80, MEM_ALIGNMENT_64));
	if (!kpp_req->src_vaddr)
		goto free_dst;

	kpp_req->src_paddr = virt_to_phys(kpp_req->src_vaddr);

	kpp_gen_key_pair_cmd           = &kpp_req->desc.cmd.kpp_cmd.gen_key_pair_cmd;
	kpp_gen_key_pair_cmd->cmd_id   = YCC_CMD_KPP_GEN_KEY_PAIR;
	kpp_gen_key_pair_cmd->curve_id = ctx->curve_id;
	kpp_gen_key_pair_cmd->pin_ptr  = kpp_req->src_paddr;
	kpp_gen_key_pair_cmd->priv_ptr = kpp_req->src_paddr + 64;
	kpp_gen_key_pair_cmd->pub_ptr  = kpp_req->dst_paddr;

	ret = ycc_enqueue(ring, (uint8_t *)&kpp_req->desc);
	if (!ret)
		return -EINPROGRESS;

free_dst:
	ycc_udma_free(kpp_req->dst_vaddr);
free_aflags:
	free(aflags);
out:
	return ret;
}

static int ycc_ecdh_compute_shared_secret(struct kpp_req *req)
{
	struct kpp_ctx *cipher = req->ctx;
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring = ctx->ring;
	struct ycc_kpp_req *kpp_req = (struct ycc_kpp_req *)req->__req;
	struct ycc_ecdh_gen_shared_cmd *kpp_gen_shared_cmd;
	struct ycc_flags *aflags;
	int ret = -ENOMEM;

	if (!ctx->private_key_vaddr || !req->src || !req->dst
	    || req->src_len < (ctx->ndigits << 1)
	    || req->dst_len < (ctx->ndigits))
		return -EINVAL;

	kpp_req->ctx = ctx;
	kpp_req->req = req;

	memset(&kpp_req->desc, 0, sizeof(kpp_req->desc));

	aflags = malloc(sizeof(struct ycc_flags));
	if (!aflags)
		goto out;

	aflags->ptr = (void *)kpp_req;
	aflags->ycc_done_callback = ycc_kpp_done_callback;
	kpp_req->desc.private_ptr = (uint64_t)(void *)aflags;

	kpp_req->dst_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits, MEM_ALIGNMENT_64));
	if (!kpp_req->dst_vaddr)
		goto free_aflags;

	kpp_req->dst_paddr = virt_to_phys(kpp_req->dst_vaddr);

	kpp_req->src_vaddr = ycc_udma_malloc(ALIGN(ctx->ndigits << 1, MEM_ALIGNMENT_64));
	if (!kpp_req->src_vaddr)
		goto free_dst;

	kpp_req->src_paddr = virt_to_phys(kpp_req->src_vaddr);
	YCC_memcpy(kpp_req->src_vaddr, req->src, req->src_len);

	kpp_gen_shared_cmd             = &kpp_req->desc.cmd.kpp_cmd.gen_shared_cmd;
	kpp_gen_shared_cmd->cmd_id     = YCC_CMD_KPP_ECDH_SS;
	kpp_gen_shared_cmd->curve_id   = ctx->curve_id;
	kpp_gen_shared_cmd->priv_ptr   = ctx->private_key_paddr;
	kpp_gen_shared_cmd->pub_ptr    = kpp_req->src_paddr;
	kpp_gen_shared_cmd->shared_ptr = kpp_req->dst_paddr;

	ret = ycc_enqueue(ring, (uint8_t *)&kpp_req->desc);
	if (!ret)
		return -EINPROGRESS;

free_dst:
	ycc_udma_free(kpp_req->dst_vaddr);
free_aflags:
	free(aflags);
out:
	return ret;
}

static unsigned int ycc_ecdh_max_size(struct kpp_ctx *cipher)
{
	return 0;
}

static int ycc_ecdh_p192_init(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EINVAL;

	ctx->ring = ring;
	ctx->ctx = cipher;

	ctx->curve_id = YCC_EC_P192;
	ctx->ndigits =  YCC_EC_DIGITS_P192;
	cipher->reqsize = sizeof(struct ycc_kpp_req);
	return 0;
}

static void ycc_ecdh_p192_exit(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	/* TODO: explicitly clear key memory */
	if (ctx->private_key_vaddr) {
		ycc_udma_free(ctx->private_key_vaddr);
		ctx->private_key_vaddr = NULL;
	}
}

static int ycc_ecdh_p224_init(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EINVAL;

	ctx->ring = ring;
	ctx->ctx = cipher;

	ctx->curve_id = YCC_EC_P224;
	ctx->ndigits =  YCC_EC_DIGITS_P224;
	cipher->reqsize = sizeof(struct ycc_kpp_req);
	return 0;
}

static void ycc_ecdh_p224_exit(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	/* TODO: explicitly clear key memory */
	if (ctx->private_key_vaddr) {
		ycc_udma_free(ctx->private_key_vaddr);
		ctx->private_key_vaddr = NULL;
	}
}

static int ycc_ecdh_p256_init(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EINVAL;

	ctx->ring = ring;
	ctx->ctx = cipher;

	ctx->curve_id = YCC_EC_P256;
	ctx->ndigits =  YCC_EC_DIGITS_P256;
	cipher->reqsize = sizeof(struct ycc_kpp_req);
	return 0;
}

static void ycc_ecdh_p256_exit(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	/* TODO: explicitly clear key memory */
	if (ctx->private_key_vaddr) {
		ycc_udma_free(ctx->private_key_vaddr);
		ctx->private_key_vaddr = NULL;
	}
}

static int ycc_ecdh_p384_init(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EINVAL;

	ctx->ring = ring;
	ctx->ctx = cipher;

	ctx->curve_id = YCC_EC_P384;
	ctx->ndigits =  YCC_EC_DIGITS_P384;
	cipher->reqsize = sizeof(struct ycc_kpp_req);
	return 0;
}

static void ycc_ecdh_p384_exit(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	/* TODO: explicitly clear key memory */
	if (ctx->private_key_vaddr) {
		ycc_udma_free(ctx->private_key_vaddr);
		ctx->private_key_vaddr = NULL;
	}
}

static int ycc_ecdh_p521_init(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;
	struct ycc_ring *ring;

	ring = ycc_crypto_get_ring();
	if (!ring)
		return -EINVAL;

	ctx->ring = ring;
	ctx->ctx = cipher;

	ctx->curve_id = YCC_EC_P521;
	ctx->ndigits =  YCC_EC_DIGITS_P521;
	cipher->reqsize = sizeof(struct ycc_kpp_req);
	return 0;
}

static void ycc_ecdh_p521_exit(struct kpp_ctx *cipher)
{
	struct ycc_kpp_ctx *ctx = (struct ycc_kpp_ctx *)cipher->__ctx;

	if (ctx->ring)
		ycc_crypto_free_ring(ctx->ring);

	/* TODO: explicitly clear key memory */
	if (ctx->private_key_vaddr) {
		ycc_udma_free(ctx->private_key_vaddr);
		ctx->private_key_vaddr = NULL;
	}
}

static struct kpp_alg kpp_algs[] = {
	{
		.name = "ecdh-nist-p192",
		.set_secret = ycc_ecdh_set_secret,
		.generate_public_key = ycc_ecdh_generate_public_key,
		.generate_key_pair = ycc_ecdh_generate_key_pair,
		.compute_shared_secret = ycc_ecdh_compute_shared_secret,
		.max_size = ycc_ecdh_max_size,
		.init = ycc_ecdh_p192_init,
		.exit = ycc_ecdh_p192_exit,
		.ctxsize = sizeof(struct ycc_kpp_ctx),
	},
	{
		.name = "ecdh-nist-p224",
		.set_secret = ycc_ecdh_set_secret,
		.generate_public_key = ycc_ecdh_generate_public_key,
		.generate_key_pair = ycc_ecdh_generate_key_pair,
		.compute_shared_secret = ycc_ecdh_compute_shared_secret,
		.max_size = ycc_ecdh_max_size,
		.init = ycc_ecdh_p224_init,
		.exit = ycc_ecdh_p224_exit,
		.ctxsize = sizeof(struct ycc_kpp_ctx),
	},
	{
		.name = "ecdh-nist-p256",
		.set_secret = ycc_ecdh_set_secret,
		.generate_public_key = ycc_ecdh_generate_public_key,
		.generate_key_pair = ycc_ecdh_generate_key_pair,
		.compute_shared_secret = ycc_ecdh_compute_shared_secret,
		.max_size = ycc_ecdh_max_size,
		.init = ycc_ecdh_p256_init,
		.exit = ycc_ecdh_p256_exit,
		.ctxsize = sizeof(struct ycc_kpp_ctx),
	},
	{
		.name = "ecdh-nist-p384",
		.set_secret = ycc_ecdh_set_secret,
		.generate_public_key = ycc_ecdh_generate_public_key,
		.generate_key_pair = ycc_ecdh_generate_key_pair,
		.compute_shared_secret = ycc_ecdh_compute_shared_secret,
		.max_size = ycc_ecdh_max_size,
		.init = ycc_ecdh_p384_init,
		.exit = ycc_ecdh_p384_exit,
		.ctxsize = sizeof(struct ycc_kpp_ctx),
	},
	{
		.name = "ecdh-nist-p521",
		.set_secret = ycc_ecdh_set_secret,
		.generate_public_key = ycc_ecdh_generate_public_key,
		.generate_key_pair = ycc_ecdh_generate_key_pair,
		.compute_shared_secret = ycc_ecdh_compute_shared_secret,
		.max_size = ycc_ecdh_max_size,
		.init = ycc_ecdh_p521_init,
		.exit = ycc_ecdh_p521_exit,
		.ctxsize = sizeof(struct ycc_kpp_ctx),
	},
};

extern struct kpp_alg_entry *kpp_alg_entries;
void kpp_register_algs(void)
{
	struct kpp_alg_entry *prev = kpp_alg_entries, *cur;
	uint32_t array_size = sizeof(kpp_algs) / sizeof(struct kpp_alg);
	int i;

	for (i = 0; i < array_size; i++) {
		cur = malloc(sizeof(struct kpp_alg_entry));
		if (!cur) {
			ycc_err("Failed to alloc memory for alg:%s\n", kpp_algs[i].name);
			break;
		}
		cur->alg = &kpp_algs[i];
		cur->next = NULL;
		if (!prev) {
			kpp_alg_entries = cur;
			prev = cur;
		} else {
			prev->next = cur;
			prev = cur;
		}
	}
}

void kpp_unregister_algs(void)
{
	struct kpp_alg_entry *cur = kpp_alg_entries, *next;

	while (cur) {
		next = cur->next;
		free(cur);
		cur = next;
	};
	kpp_alg_entries = NULL;
}
