// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>

#include "../utils/utils.h"
#include "aead.h"

struct aead_alg_entry *aead_alg_entries;

struct aead_result {
	struct completion completion;
	int err;
};

#define AEAD_REQ_GET_COMPLETION(req)					\
	(&((struct aead_result *)(req)->data)->completion)		\

static void aead_complete(struct aead_req *req, int err)
{
	struct aead_result *result = (struct aead_result *)req->data;

	result->err = err;
	complete(&result->completion);
}

static inline struct aead_alg *find_aead_alg(const char *alg_name)
{
	struct aead_alg_entry *cur = aead_alg_entries;
	struct aead_alg *alg = NULL;

	while (cur) {
		if (!strncmp(alg_name, cur->alg->name,
			     strlen(cur->alg->name))) {
			alg = cur->alg;
			break;
		}
		cur = cur->next;
	}

	return alg;
}

struct aead_ctx *aead_alloc_ctx(const char *alg_name, uint32_t flag)
{
	struct aead_alg *alg;
	struct aead_ctx *ctx;

	alg = find_aead_alg(alg_name);
	if (!alg) {
		ycc_err("Failed to find alg:%s\n", alg_name);
		return NULL;
	}

	ctx = malloc(sizeof(struct aead_ctx) + alg->ctxsize);
	if (!ctx) {
		ycc_err("Failed to alloc ctx for alg:%s\n", alg_name);
		return NULL;
	}
	memset(ctx, 0, sizeof(struct aead_ctx) + alg->ctxsize);

	ctx->alg = alg;

	if (alg->init && alg->init(ctx)) {
		ycc_err("Failed to do init for alg:%s\n", alg_name);
		free(ctx);
		return NULL;
	}

	ctx->flags |= flag;

	/* reqsize & __ctx are initialized by real algrithm driver */
	return ctx;
}

void aead_free_ctx(struct aead_ctx *ctx)
{
	struct aead_alg *alg;

	if (ctx) {
		alg = ctx->alg;
		if (alg->exit)
			alg->exit(ctx);

		free(ctx);
	}
}

static inline uint32_t aead_reqsize(struct aead_ctx *ctx)
{
	return ctx->reqsize;
}

struct aead_req *aead_alloc_req(struct aead_ctx *ctx)
{
	struct aead_req *req;

	req = malloc(sizeof(struct aead_req) + aead_reqsize(ctx));

	if (req) {
		memset(req, 0, sizeof(struct aead_req) + aead_reqsize(ctx));
		req->ctx = ctx;
	}

	if (!(ctx->flags & CRYPTO_SYNC))
		return req;

	req->data = malloc(sizeof(struct aead_result));
	if (!req->data) {
		free(req);
		return NULL;
	}

	req->complete = aead_complete;
	init_completion(AEAD_REQ_GET_COMPLETION(req));
	return req;
}

void aead_free_req(struct aead_req *req)
{
	if (req->data) {
		destroy_completion(AEAD_REQ_GET_COMPLETION(req));
		if (req->ctx->flags & CRYPTO_SYNC)
			free(req->data);
	}

	free(req);
}

int aead_setkey(struct aead_ctx *ctx, const char *key, unsigned int keylen)
{
	struct aead_alg *alg;
	int ret;

	if (!ctx || !key)
		return -1;

	alg = ctx->alg;
	if (!alg->setkey) {
		ycc_err("No setkey function for alg:%s\n", alg->name);
		return -1;
	}

	ret = alg->setkey(ctx, key, keylen);
	if (ret < 0) {
		ycc_err("Failed to setkey for alg:%s\n", alg->name);
		return -1;
	}

	return 0;
}

int aead_setauthsize(struct aead_ctx *ctx, unsigned int authsize)
{
	if (authsize > ctx->alg->maxauthsize)
		return -1;

	if (ctx->alg->setauthsize)
		return ctx->alg->setauthsize(ctx, authsize);

	ctx->authsize = authsize;
	return 0;
}

void aead_req_set_callback(struct aead_req *req, aead_completion_t complete, void *data)
{
	if (req->ctx->flags & CRYPTO_SYNC)
		return;

	req->complete = complete;
	req->data = data;
}

void aead_set_req(struct aead_req *req, unsigned char *src, unsigned char *dst,
		  unsigned int cryptlen, unsigned char *iv)
{
	req->src = src;
	req->dst = dst;
	req->cryptlen = cryptlen;
	req->iv = iv;
}

void aead_set_req_ad(struct aead_req *req, unsigned int assoclen)
{
	req->assoclen = assoclen;
}

int aead_encrypt(struct aead_req *req)
{
	struct completion *comp;
	struct aead_alg *alg;
	int ret;

	alg = req->ctx->alg;
	ret = alg->encrypt(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = AEAD_REQ_GET_COMPLETION(req);

	/* Default is ASYNC */
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Encrypt result timeout\n");
		else
			ret = ((struct aead_result *)req->data)->err;
	}

	return ret;
}

int aead_decrypt(struct aead_req *req)
{
	struct completion *comp;
	struct aead_alg *alg;
	int ret;

	alg = req->ctx->alg;
	ret = alg->decrypt(req);

	if (ret != -EINPROGRESS)
		return ret;

	comp = AEAD_REQ_GET_COMPLETION(req);

	/* Default is ASYNC */
	if (req->ctx->flags & CRYPTO_SYNC) {
		ret = wait_for_completion_timeout(comp, COMPLETION_TIMEOUT_SECS);
		if (ret < 0)
			ycc_err("Decrypt result timeout\n");
		else
			ret = ((struct aead_result *)req->data)->err;
	}

	return ret;
}

void __attribute__((weak)) aead_register_algs(void) {}
void __attribute__((weak)) aead_unregister_algs(void) {}
