// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>

#include "../utils/utils.h"
#include "ske.h"

struct skcipher_alg_entry *skcipher_alg_entries;

struct skcipher_result {
	struct completion completion;
	int err;
};

#define SKE_REQ_GET_COMPLETION(req)					\
	(&((struct skcipher_result *)(req)->data)->completion)		\

static void skcipher_complete(struct skcipher_req *req, int err)
{
	struct skcipher_result *result = (struct skcipher_result *)req->data;

	result->err = err;
	complete(&result->completion);
}

static inline struct skcipher_alg *find_skcipher_alg(const char *alg_name)
{
	struct skcipher_alg_entry *cur = skcipher_alg_entries;
	struct skcipher_alg *alg = NULL;

	while (cur) {
		if (!strncmp(alg_name, cur->alg->name,
			     strlen(cur->alg->name))) {
			alg = cur->alg;
			break;
		}
		cur = cur->next;
	}

	return alg;
}

struct skcipher_ctx *skcipher_alloc_ctx(const char *alg_name, uint32_t flag)
{
	struct skcipher_alg *alg;
	struct skcipher_ctx *ctx;

	alg = find_skcipher_alg(alg_name);
	if (!alg) {
		ycc_err("Failed to find alg:%s\n", alg_name);
		return NULL;
	}

	ctx = malloc(sizeof(struct skcipher_ctx) + alg->ctxsize);
	if (!ctx) {
		ycc_err("Failed to alloc ctx for alg:%s\n", alg_name);
		return NULL;
	}

	memset(ctx, 0, sizeof(struct skcipher_ctx) + alg->ctxsize);
	ctx->alg = alg;

	if (alg->init && alg->init(ctx)) {
		ycc_err("Failed to do init for alg:%s\n", alg_name);
		free(ctx);
		return NULL;
	}

	ctx->flags |= flag;

	/* reqsize & __ctx are initialized by real algrithm driver */
	return ctx;
}

void skcipher_free_ctx(struct skcipher_ctx *ctx)
{
	struct skcipher_alg *alg;

	if (!ctx)
		return;

	alg = ctx->alg;
	if (alg->exit)
		alg->exit(ctx);

	free(ctx);
}

static inline uint32_t skcipher_reqsize(struct skcipher_ctx *ctx)
{
	return ctx->reqsize;
}

struct skcipher_req *skcipher_alloc_req(struct skcipher_ctx *ctx)
{
	struct skcipher_req *req;

	req = malloc(sizeof(struct skcipher_req) + skcipher_reqsize(ctx));
	if (!req)
		return NULL;

	memset(req, 0, sizeof(struct skcipher_req) + skcipher_reqsize(ctx));
	req->ctx = ctx;
	if (!(ctx->flags & CRYPTO_SYNC))
		return req;

	req->data = malloc(sizeof(struct skcipher_result));
	if (!req->data) {
		free(req);
		return NULL;
	}

	req->complete = skcipher_complete;
	init_completion(SKE_REQ_GET_COMPLETION(req));
	return req;
}

void skcipher_free_req(struct skcipher_req *req)
{
	if (req->data) {
		destroy_completion(SKE_REQ_GET_COMPLETION(req));
		if (req->ctx->flags & CRYPTO_SYNC)
			free(req->data);
	}

	free(req);
}

int skcipher_setkey(struct skcipher_ctx *ctx, const unsigned char *key,
		    unsigned int keylen)
{
	struct skcipher_alg *alg;
	int ret;

	if (!ctx || !key)
		return -1;

	alg = ctx->alg;
	if (keylen < alg->min_keysize || keylen > alg->max_keysize) {
		ycc_err("Invalid key length:%d for alg:%s\n", keylen, alg->name);
		return -1;
	}

	if (!alg->setkey) {
		ycc_err("No setkey function for alg:%s\n", alg->name);
		return -1;
	}

	ret = alg->setkey(ctx, key, keylen);
	if (ret < 0) {
		ycc_err("Failed to setkey for alg:%s\n", alg->name);
		return -1;
	}

	return 0;
}

void skcipher_req_set_callback(struct skcipher_req *req,
			       skcipher_completion_t complete, void *data)
{
	if (req->ctx->flags & CRYPTO_SYNC)
		return;

	req->complete = complete;
	req->data = data;
}

void skcipher_set_req(struct skcipher_req *req, unsigned char *src,
		      unsigned char *dst, unsigned int cryptlen, unsigned char *iv)
{
	req->src = src;
	req->dst = dst;
	req->cryptlen = cryptlen;
	req->iv = iv;
}

int skcipher_encrypt(struct skcipher_req *req)
{
	struct skcipher_alg *alg;
	struct completion *comp;
	int ret;

	alg = req->ctx->alg;
	ret = alg->encrypt(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = SKE_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Encrypt result timeout\n");
		else
			ret = ((struct skcipher_result *)req->data)->err;
	}

	return ret;
}

int skcipher_decrypt(struct skcipher_req *req)
{
	struct skcipher_alg *alg;
	struct completion *comp;
	int ret;

	alg = req->ctx->alg;
	ret = alg->decrypt(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = SKE_REQ_GET_COMPLETION(req);
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Decrypt result timeout\n");
		else
			ret = ((struct skcipher_result *)req->data)->err;
	}

	return ret;
}

void __attribute__((weak)) skcipher_register_algs(void) {}
void __attribute__((weak)) skcipher_unregister_algs(void) {}
