// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>

#include "ycc_uio.h"
#include "utils.h"
#include "pke.h"

struct akcipher_test {
	const char *name;
	const char *key;
	const char *priv;
	const char *msg;
	const char *sig;
	unsigned int key_len;
	unsigned int priv_len;
	unsigned int msg_len;
	unsigned int sig_len;
};

static const struct akcipher_test sign = {
	.name = "rsa-sign",
	.msg = "\x54\x85\x9b\x34\x2c\x49\xea\x2a",
	.msg_len = 8,
};

static bool inited;

/*
 * Destructor: It is must if driver initialize successfully.
 */
static void __attribute__((destructor)) drv_exit(void)
{
	if (inited)
		ycc_drv_exit();
}

static int rsa_read_privkey(RSA **rsa)
{
	FILE *fp;

	fp = fopen("./test_private.pem", "rb");
	if (!fp)
		return errno;

	PEM_read_RSAPrivateKey(fp, rsa, NULL, NULL);
	fclose(fp);
	return 0;
}

int main(int argc, char *argv[])
{
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[128] = {0};
	uint8_t dst[128] = {0};
	RSA *rsa;
	int ret;

	ret = ycc_drv_init(1);
	if (ret < 0)
		return ret;

	inited = true;
	rsa = RSA_new();
	ret = rsa_read_privkey(&rsa);
	if (ret || !rsa) {
		printf("Failed to read rsa private key\n");
		goto out;
	}

	ret = -1;
	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = akcipher_set_priv_key(cipher, (void *)rsa, 0);
	if (ret < 0) {
		printf("Faile to set key\n");
		goto free_req;
	}

	memcpy(src, sign.msg, sign.msg_len);
	akcipher_set_req(req, src, dst, sign.msg_len, akcipher_maxsize(cipher), RSA_PKCS1_PADDING);
	hex_dump("Dump rsa sign src:", src, sign.msg_len);
	ret = akcipher_sign(req);
	if (!ret)
		printf("Sample code rsa sign passed\n");

	hex_dump("Dump rsa sign Dst", dst, 128);
	ret = akcipher_verify(req);
	if (!ret)
		printf("Sample code rsa verify passed\n");

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}
