// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>

#include "../utils/utils.h"
#include "pke.h"

struct akcipher_alg_entry *akcipher_alg_entries;

struct akcipher_result {
	struct completion completion;
	int err;
};

#define PKE_REQ_GET_COMPLETION(req)					\
	(&((struct akcipher_result *)(req)->data)->completion)		\

static void akcipher_complete(struct akcipher_req *req, int err)
{
	struct akcipher_result *result = (struct akcipher_result *)req->data;

	result->err = err;
	complete(&result->completion);
}

static inline struct akcipher_alg *find_akcipher_alg(const char *alg_name)
{
	struct akcipher_alg_entry *cur = akcipher_alg_entries;
	struct akcipher_alg *alg = NULL;

	while (cur) {
		if (!strncmp(alg_name, cur->alg->name,
			     strlen(cur->alg->name))) {
			alg = cur->alg;
			break;
		}
		cur = cur->next;
	}

	return alg;
}

struct akcipher_ctx *akcipher_alloc_ctx(const char *alg_name, uint32_t flag)
{
	struct akcipher_alg *alg;
	struct akcipher_ctx *ctx;

	alg = find_akcipher_alg(alg_name);
	if (!alg) {
		ycc_err("Failed to find alg:%s\n", alg_name);
		return NULL;
	}

	ctx = malloc(sizeof(struct akcipher_ctx) + alg->ctxsize);
	if (!ctx) {
		ycc_err("Failed to alloc ctx for alg:%s\n", alg_name);
		return NULL;
	}

	memset(ctx, 0, sizeof(struct akcipher_ctx) + alg->ctxsize);
	ctx->alg = alg;
	if (alg->init && alg->init(ctx)) {
		ycc_err("Failed to do init for alg:%s\n", alg_name);
		free(ctx);
		return NULL;
	}

	ctx->flags |= flag;

	/* reqsize & __ctx are initialized by real algorithm driver */
	return ctx;
}

void akcipher_free_ctx(struct akcipher_ctx *ctx)
{
	struct akcipher_alg *alg;

	if (ctx) {
		alg = ctx->alg;
		if (alg->exit)
			alg->exit(ctx);

		free(ctx);
	}
}

struct akcipher_req *akcipher_alloc_req(struct akcipher_ctx *ctx)
{
	struct akcipher_req *req;

	req = malloc(sizeof(struct akcipher_req) + ctx->alg->reqsize);
	if (!req)
		return NULL;

	memset(req, 0, sizeof(struct akcipher_req) + ctx->alg->reqsize);
	req->ctx = ctx;
	if (!(ctx->flags & CRYPTO_SYNC))
		return req;

	req->data = malloc(sizeof(struct akcipher_result));
	if (!req->data) {
		free(req);
		return NULL;
	}

	req->complete = akcipher_complete;
	init_completion(PKE_REQ_GET_COMPLETION(req));

	return req;
}

void akcipher_free_req(struct akcipher_req *req)
{
	if (req->data) {
		destroy_completion(PKE_REQ_GET_COMPLETION(req));
		if (req->ctx->flags & CRYPTO_SYNC)
			free(req->data);
	}

	free(req);
}

void akcipher_req_set_callback(struct akcipher_req *req,
			       akcipher_completion_t complete, void *data)
{
	if (req->ctx->flags & CRYPTO_SYNC)
		return;

	req->complete = complete;
	req->data = data;
}

void akcipher_set_req(struct akcipher_req *req, unsigned char *src,
		      unsigned char *dst, unsigned int src_len,
		      unsigned int dst_len, int padding)
{
	req->src = src;
	req->dst = dst;
	req->src_len = src_len;
	req->dst_len = dst_len;
	req->padding = padding;
}

int akcipher_sign(struct akcipher_req *req)
{
	struct akcipher_alg *alg;
	struct completion *comp;
	int ret;

	alg = req->ctx->alg;
	ret = alg->sign(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = PKE_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Sign result timeout\n");
		else
			ret = ((struct akcipher_result *)req->data)->err;
	}

	return ret;
}

int akcipher_verify(struct akcipher_req *req)
{
	struct akcipher_alg *alg;
	struct completion *comp;
	int ret;

	alg = req->ctx->alg;
	ret = alg->verify(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = PKE_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Verify result timeout\n");
		else
			ret = ((struct akcipher_result *)req->data)->err;
	}

	return ret;
}

int akcipher_set_pub_key(struct akcipher_ctx *ctx, void *key,
			 unsigned int keylen)
{
	struct akcipher_alg *alg = ctx->alg;

	return alg->set_pub_key(ctx, key, keylen);
}

int akcipher_set_priv_key(struct akcipher_ctx *ctx, void *key,
			  unsigned int keylen)
{
	struct akcipher_alg *alg = ctx->alg;

	return alg->set_priv_key(ctx, key, keylen);
}

unsigned int akcipher_maxsize(struct akcipher_ctx *ctx)
{
	return ctx->alg->max_size(ctx);
}

int akcipher_encrypt(struct akcipher_req *req)
{
	struct akcipher_alg *alg;
	struct completion *comp;
	int ret;

	alg = req->ctx->alg;
	ret = alg->encrypt(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = PKE_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Encrypt result timeout\n");
		else
			ret = ((struct akcipher_result *)req->data)->err;
	}

	return ret;
}

int akcipher_decrypt(struct akcipher_req *req)
{
	struct akcipher_alg *alg;
	struct completion *comp;
	int ret;

	alg = req->ctx->alg;
	ret = alg->decrypt(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = PKE_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Decrypt result timeout\n");
		else
			ret = ((struct akcipher_result *)req->data)->err;
	}

	return ret;
}

void __attribute__((weak)) akcipher_register_algs(void) {}
void __attribute__((weak)) akcipher_unregister_algs(void) {}
