// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>

#include "../utils/utils.h"
#include "kpp.h"


struct kpp_alg_entry *kpp_alg_entries;

struct kpp_result {
	struct completion completion;
	int err;
};

#define KPP_REQ_GET_COMPLETION(req)					\
	(&((struct kpp_result *)(req)->data)->completion)		\

static void kpp_complete(struct kpp_req *req, int err)
{
	struct kpp_result *result = (struct kpp_result *)req->data;

	result->err = err;
	complete(&result->completion);
}

static inline struct kpp_alg *find_kpp_alg(const char *alg_name)
{
	struct kpp_alg *alg = NULL;
	struct kpp_alg_entry *cur = kpp_alg_entries;

	while (cur) {
		if (!strncmp(alg_name, cur->alg->name,
			     strlen(cur->alg->name))) {
			alg = cur->alg;
			break;
		}
		cur = cur->next;
	}

	return alg;
}

struct kpp_ctx *kpp_alloc_ctx(const char *alg_name, uint32_t flag)
{
	struct kpp_alg *alg;
	struct kpp_ctx *ctx;

	alg = find_kpp_alg(alg_name);
	if (!alg) {
		ycc_err("Failed to find alg:%s\n", alg_name);
		return NULL;
	}

	ctx = malloc(sizeof(struct kpp_ctx) + alg->ctxsize);
	if (!ctx) {
		ycc_err("Failed to alloc ctx for alg:%s\n", alg_name);
		return NULL;
	}

	memset(ctx, 0, sizeof(struct kpp_ctx) + alg->ctxsize);
	ctx->alg = alg;

	if (alg->init && alg->init(ctx)) {
		ycc_err("Failed to do init for alg:%s\n", alg_name);
		free(ctx);
		return NULL;
	}

	ctx->flags |= flag;

	/* reqsize & __ctx are initialized by real algrithm driver */
	return ctx;
}

void kpp_free_ctx(struct kpp_ctx *ctx)
{
	struct kpp_alg *alg;

	if (ctx) {
		alg = ctx->alg;
		if (alg->exit)
			alg->exit(ctx);

		free(ctx);
	}
}

struct kpp_req *kpp_alloc_req(struct kpp_ctx *ctx)
{
	struct kpp_req *req;

	req = malloc(sizeof(struct kpp_req) + ctx->reqsize);

	if (!req)
		return NULL;

	memset(req, 0, sizeof(struct kpp_req) + ctx->reqsize);
	req->ctx = ctx;
	if (!(ctx->flags & CRYPTO_SYNC))
		return req;

	req->data = malloc(sizeof(struct kpp_result));
	if (!req->data) {
		free(req);
		return NULL;
	}

	req->complete = kpp_complete;
	init_completion(KPP_REQ_GET_COMPLETION(req));
	return req;
}

void kpp_free_req(struct kpp_req *req)
{
	if (req->data) {
		destroy_completion(KPP_REQ_GET_COMPLETION(req));
		if (req->ctx->flags & CRYPTO_SYNC)
			free(req->data);
	}

	free(req);
}

void kpp_req_set_callback(struct kpp_req *req, kpp_completion_t complete, void *data)
{
	if (req->ctx->flags & CRYPTO_SYNC)
		return;

	req->complete = complete;
	req->data = data;
}

void kpp_set_req(struct kpp_req *req, unsigned char *src, uint32_t src_len,
		 unsigned char *dst, uint32_t dst_len)
{
	req->src = src;
	req->src_len = src_len;
	req->dst = dst;
	req->dst_len = dst_len;
}

/* Set own private key */
int kpp_set_secret(struct kpp_ctx *ctx, void *buffer, unsigned int len)
{
	struct kpp_alg *alg;
	int ret;

	alg = ctx->alg;
	ret = alg->set_secret(ctx, buffer, len);
	return ret;
}

int kpp_generate_public_key(struct kpp_req *req)
{
	struct completion *comp;
	struct kpp_alg *alg;
	int ret;

	alg = req->ctx->alg;
	ret = alg->generate_public_key(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = KPP_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Generate pub key result timeout\n");
		else
			ret = ((struct kpp_result *)req->data)->err;
	}

	return ret;
}

int kpp_generate_key_pair(struct kpp_req *req)
{
	struct completion *comp;
	struct kpp_alg *alg;
	int ret;

	alg = req->ctx->alg;
	ret = alg->generate_key_pair(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = KPP_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Generate key pair result timeout\n");
		else
			ret = ((struct kpp_result *)req->data)->err;
	}

	return ret;
}

int kpp_compute_shared_secret(struct kpp_req *req)
{
	struct completion *comp;
	struct kpp_alg *alg;
	int ret;

	alg = req->ctx->alg;
	ret = alg->compute_shared_secret(req);

	comp = KPP_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Compute shared key result timeout\n");
		else
			ret = ((struct kpp_result *)req->data)->err;
	}

	return ret;
}

unsigned int kpp_maxsize(struct kpp_ctx *ctx)
{
	return ctx->alg->max_size(ctx);
}

void __attribute__((weak)) kpp_register_algs(void) {}
void __attribute__((weak)) kpp_unregister_algs(void) {}
