// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>

#include "ycc_uio.h"
#include "utils.h"
#include "pke.h"

struct akcipher_test {
	const char *name;
	const char *key;
	const char *priv;
	const char *msg;
	const char *sig;
	unsigned int key_len;
	unsigned int priv_len;
	unsigned int msg_len;
	unsigned int sig_len;
};

static const struct akcipher_test pkcs = {
	.name = "rsa-pkcs",
	.msg = "\x54\x85\x9b\x34\x2c\x49\xea\x2a",
	.msg_len = 8,
	.sig = "\x9e\xa7\xdf\x1e\xb4\xa1\x25\xdf"
	       "\x6b\xa5\xec\xa5\xc8\x4f\x30\x94"
	       "\xb2\x91\xea\x3d\x0e\x75\x39\xc7"
	       "\x5e\x8b\x0c\x3e\x7e\x76\x54\x47"
	       "\xa8\x27\x62\xc6\x25\xa7\x33\x29"
	       "\xa8\x9b\x94\xa3\xe6\xfa\x8b\x7c"
	       "\x37\x37\x43\x20\xaf\xe1\x71\x6a"
	       "\x32\xfb\x0a\x60\x47\xb9\x30\xec"
	       "\xc1\xa8\xa5\x95\x82\x03\xee\x3d"
	       "\x05\xc8\xb3\x7d\x10\xae\x4b\xcf"
	       "\x66\x07\x29\x6c\xec\x5d\x09\x4c"
	       "\xb3\xd7\xc8\xff\x9a\x8d\x73\xe7"
	       "\x82\x16\xfa\x9d\x0b\xd0\xe9\x7c"
	       "\xba\x2a\x1c\x5b\xb2\xba\x88\xef"
	       "\xd8\x99\x4b\xbd\x7d\x35\xa4\x20"
	       "\xee\x5a\x81\x99\xb5\x1a\xa2\x51",
	.sig_len = 128,
};

static bool inited;

/*
 * Destructor: It is must if driver initialize successfully.
 */
static void __attribute__((destructor)) drv_exit(void)
{
	if (inited)
		ycc_drv_exit();
}

static int rsa_read_privkey(RSA **rsa)
{
	FILE *fp;

	fp = fopen("./test_private.pem", "rb");
	if (!fp)
		return errno;

	PEM_read_RSAPrivateKey(fp, rsa, NULL, NULL);
	fclose(fp);
	return 0;
}

int main(int argc, char *argv[])
{
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[128] = {0};
	uint8_t dst[128] = {0};
	RSA *rsa;
	int ret;

	ret = ycc_drv_init(1);
	if (ret < 0)
		return ret;

	inited = true;
	rsa = RSA_new();
	ret = rsa_read_privkey(&rsa);
	if (ret || !rsa) {
		printf("Failed to read rsa private key\n");
		goto out;
	}

	ret = -1;
	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = akcipher_set_priv_key(cipher, (void *)rsa, 0);
	if (ret < 0) {
		printf("Failed to set key\n");
		goto free_req;
	}

	memcpy(src, pkcs.sig, pkcs.sig_len);
	akcipher_set_req(req, src, dst, pkcs.sig_len, akcipher_maxsize(cipher), RSA_PKCS1_PADDING);

	ret = akcipher_decrypt(req);
	hex_dump("Dump rsa-pkcs-decrypt src:", src, pkcs.sig_len);
	hex_dump("Dump rsa-pkcs-decrypt dst:", dst, 128);
	if (!ret)
		printf("Sample code rsa pkcs decrypt passed\n");

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}
