// SPDX-License-Identifier: GPL-2.0
/*
 * Testmgr: Testcases for YCC support algrothims.
 *
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#define _GNU_SOURCE
#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <sys/sysinfo.h>
#include <pthread.h>
#include <string.h>
#include <errno.h>

#include "testmgr.h"
#include "ycc_uio.h"
#include "utils.h"
#include "pke.h"
#include "aead.h"
#include "ske.h"
#include "rng.h"
#include "kpp.h"

static bool inited;

/*
 * Destructor: It is must if driver initialize successfully.
 */
static void __attribute__((destructor)) drv_exit(void)
{
	if (inited)
		ycc_drv_exit();
}

static int test_ecdsa_nist_p521(void)
{
	const struct akcipher_testvec *ecdsa = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	unsigned char src[256];
	unsigned char dst[256];
	unsigned char key[256];
	int ret = -1;
	int i;

	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "ecdsa-nist-p521"))
			ecdsa = &akcipher_tv[i];

	if (!ecdsa)
		goto out;

	cipher = akcipher_alloc_ctx("ecdsa-nist-p521", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	memcpy(src, ecdsa->msg, ecdsa->msg_len);
	memcpy(dst, ecdsa->sig, ecdsa->sig_len);
	memcpy(key, ecdsa->key, ecdsa->key_len);

	akcipher_set_req(req, src, dst, ecdsa->msg_len, ecdsa->sig_len, 0);

	ret = akcipher_set_pub_key(cipher, key, ecdsa->key_len);
	if (ret < 0)
		goto free_req;

	/* Check ret outside */
	ret = akcipher_verify(req);

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}

static int test_ecdsa_nist_p256(void)
{
	const struct akcipher_testvec *ecdsa = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	unsigned char src[64];
	unsigned char dst[64];
	unsigned char pub[64];
	unsigned char priv[64];
	int ret = -1;
	int i;

	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "ecdsa-nist-p256"))
			ecdsa = &akcipher_tv[i];

	if (!ecdsa)
		goto out;

	cipher = akcipher_alloc_ctx("ecdsa-nist-p256", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	memcpy(priv, ecdsa->priv, ecdsa->priv_len);
	ret = akcipher_set_priv_key(cipher, priv, ecdsa->priv_len);
	if (ret < 0)
		goto free_req;

	memcpy(src, ecdsa->msg, ecdsa->msg_len);
	akcipher_set_req(req, src, dst, ecdsa->msg_len, ecdsa->key_len, 0);

	ret = akcipher_sign(req);
	if (!ret)
		hex_dump("Dump ecdsa-nist-p256 dst:", dst, ecdsa->key_len);

	memcpy(pub, ecdsa->key, ecdsa->key_len);
	ret = akcipher_set_pub_key(cipher, pub, ecdsa->key_len);
	if (ret < 0)
		goto free_req;

	/* Check ret outside */
	ret = akcipher_verify(req);

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}

static int test_rsa_crt_key(void)
{
	const struct akcipher_testvec *crt = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[128] = {0};
	uint8_t dst[128] = {0};
	RSA *rsa;
	FILE *fp;
	int ret = -1;
	int i;

	rsa = RSA_new();
	fp = fopen("./crt_private.pem", "rb");
	if (!fp) {
		printf("Failed to open crt_private.pem\n");
		return errno;
	}

	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "rsa-crt"))
			crt = &akcipher_tv[i];

	if (!crt)
		goto out;

	PEM_read_RSAPrivateKey(fp, &rsa, NULL, NULL);
	if (!rsa) {
		printf("Failed to read crt_private.pem\n");
		goto out;
	}

	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = akcipher_set_priv_key(cipher, (void *)rsa, 0);
	if (ret < 0) {
		printf("Failed to set key\n");
		goto free_req;
	}

	memcpy(src, crt->msg, crt->msg_len);
	akcipher_set_req(req, src, dst, crt->msg_len, akcipher_maxsize(cipher), RSA_NO_PADDING);
	hex_dump("Dump rsa-crt src:", src, crt->msg_len);

	/* Check ret outside */
	ret = akcipher_encrypt(req);
	hex_dump("Dump rsa-crt dst:", dst, 128);

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	fclose(fp);
	return ret;
}

static int rsa_read_pubkey(RSA **rsa)
{
	FILE *fp;

	fp = fopen("./test_public.pem", "rb");
	if (!fp)
		return errno;

	PEM_read_RSA_PUBKEY(fp, rsa, NULL, NULL);
	fclose(fp);
	return 0;
}

static int rsa_read_privkey(RSA **rsa)
{
	FILE *fp;

	fp = fopen("./test_private.pem", "rb");
	if (!fp)
		return errno;

	PEM_read_RSAPrivateKey(fp, rsa, NULL, NULL);
	fclose(fp);
	return 0;
}

static int test_rsa_pkcs_encrypt(bool clear)
{
	const struct akcipher_testvec *pkcs = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[128] = {0};
	uint8_t dst[128] = {0};
	RSA *rsa;
	int ret;
	int i;

	rsa = RSA_new();
	ret = rsa_read_pubkey(&rsa);
	if (ret || !rsa) {
		printf("Failed to read rsa public key\n");
		goto out;
	}

	ret = -1;
	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "rsa-pkcs"))
			pkcs = &akcipher_tv[i];

	if (!pkcs)
		goto out;

	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = akcipher_set_pub_key(cipher, (void *)rsa, 0);
	if (ret < 0) {
		printf("failed to set key\n");
		goto free_req;
	}

	memcpy(src, pkcs->msg, pkcs->msg_len);
	akcipher_set_req(req, src, dst, pkcs->msg_len, akcipher_maxsize(cipher), RSA_PKCS1_PADDING);

	/* Check ret outside */
	ret = akcipher_encrypt(req);

	if (!clear) {
		hex_dump("Dump rsa-pkcs-encrypt src:", src, pkcs->msg_len);
		hex_dump("Dump rsa-pkcs-encrypt dst:", dst, 128);
	}

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}

static int test_rsa_pkcs_decrypt(bool clear)
{
	const struct akcipher_testvec *pkcs = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[128] = {0};
	uint8_t dst[128] = {0};
	RSA *rsa;
	int ret;
	int i;

	rsa = RSA_new();
	ret = rsa_read_privkey(&rsa);
	if (ret || !rsa) {
		printf("Failed to read rsa private key\n");
		goto out;
	}

	ret = -1;
	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "rsa-pkcs"))
			pkcs = &akcipher_tv[i];

	if (!pkcs)
		goto out;

	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = akcipher_set_priv_key(cipher, (void *)rsa, 0);
	if (ret < 0) {
		printf("Failed to set key\n");
		goto free_req;
	}

	memcpy(src, pkcs->sig, pkcs->sig_len);
	akcipher_set_req(req, src, dst, pkcs->sig_len, akcipher_maxsize(cipher), RSA_PKCS1_PADDING);

	/* Check ret outside */
	ret = akcipher_decrypt(req);
	if (!clear) {
		hex_dump("Dump rsa-pkcs-decrypt src:", src, pkcs->sig_len);
		hex_dump("Dump rsa-pkcs-decrypt dst:", dst, 128);
	}

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}

static int test_rsa_pke_unit(void)
{
	const struct akcipher_testvec *pke = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[1024];
	uint8_t dst[1024];
	uint8_t pub[1024];
	int ret = -1;
	int i;

	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "rsa-pke-unit"))
			pke = &akcipher_tv[i];

	if (!pke)
		goto out;

	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	memcpy(pub, pke->key, pke->key_len);
	ret = akcipher_set_priv_key(cipher, pub, pke->key_len);
	if (ret < 0) {
		printf("Faile to set key\n");
		goto free_req;
	}

	memcpy(src, pke->msg, pke->msg_len);
	akcipher_set_req(req, src, dst, pke->msg_len, akcipher_maxsize(cipher), RSA_NO_PADDING);

	hex_dump("Dump rsa pke unit src:", src, pke->msg_len);

	/* Check ret outside */
	ret = akcipher_encrypt(req);
	hex_dump("Dump rsa pke unit dst:", dst, pke->sig_len);

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}

static int test_rsa_sign(void)
{
	const struct akcipher_testvec *sign = NULL;
	struct akcipher_ctx *cipher;
	struct akcipher_req *req;
	uint8_t src[128] = {0};
	uint8_t dst[128] = {0};
	RSA *rsa;
	int ret;
	int i;

	rsa = RSA_new();
	ret = rsa_read_privkey(&rsa);
	if (ret || !rsa) {
		printf("Failed to read rsa private key\n");
		goto out;
	}

	ret = -1;
	for (i = 0; i < ARRAY_SIZE(akcipher_tv); i++)
		if (!strcmp(akcipher_tv[i].name, "rsa-sign"))
			sign = &akcipher_tv[i];

	if (!sign)
		goto out;

	cipher = akcipher_alloc_ctx("rsa", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = akcipher_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = akcipher_set_priv_key(cipher, (void *)rsa, 0);
	if (ret < 0) {
		printf("Faile to set key\n");
		goto free_req;
	}

	memcpy(src, sign->msg, sign->msg_len);
	akcipher_set_req(req, src, dst, sign->msg_len, akcipher_maxsize(cipher), RSA_PKCS1_PADDING);
	hex_dump("Dump rsa sign src:", src, sign->msg_len);
	ret = akcipher_sign(req);

	TEST_ASSERT(ret == 0, "rsa-sign");
	hex_dump("Dump rsa sign Dst", dst, 128);

	ret = akcipher_verify(req);
	TEST_ASSERT(ret == 0, "rsa-verify");

free_req:
	akcipher_free_req(req);
free_ctx:
	akcipher_free_ctx(cipher);
out:
	return ret;
}

static int test_rsa_multi_encrypt(unsigned long long count)
{
	int ret = 0;

	while (count--)
		ret += test_rsa_pkcs_encrypt(true);

	return ret;
}

static int test_rsa_multi_decrypt(unsigned long long count)
{
	int ret = 0;

	while (count--)
		ret += test_rsa_pkcs_decrypt(true);

	return ret;
}

static const struct skcipher_testvec *find_skcipher_tv(const char *name)
{
	const struct skcipher_testvec *ske = NULL;
	int i;

	for (i = 0; i < ARRAY_SIZE(skcipher_tv); i++)
		if (!strcmp(skcipher_tv[i].name, name))
			ske = &skcipher_tv[i];

	return ske;
}

static int test_ske_unit(void)
{
	const struct skcipher_testvec *ske = NULL;
	struct skcipher_ctx *ctx;
	struct skcipher_req *req;
	unsigned char ptext[16];
	unsigned char ctext[16];
	unsigned char iv[16];
	unsigned char key[16];
	int ret = -1;

	ske = find_skcipher_tv("ske-unit");
	if (!ske)
		goto out;

	ctx = skcipher_alloc_ctx("cbc(aes)", CRYPTO_SYNC);
	if (!ctx)
		goto out;

	req = skcipher_alloc_req(ctx);
	if (!req)
		goto free_ctx;

	memcpy(key, ske->key, ske->key_len);
	skcipher_setkey(ctx, key, ske->key_len);

	memcpy(ptext, ske->ptext, ske->len);
	memcpy(iv, ske->iv, sizeof(ske->iv) - 1);
	skcipher_set_req(req, ptext, ctext, ske->len, iv);

	hex_dump("Dump ske unit src:", ptext, ske->len);
	ret = skcipher_encrypt(req);
	if (!ret)
		hex_dump("Dump ske unit dst:", ctext, 16);

	skcipher_free_req(req);
free_ctx:
	skcipher_free_ctx(ctx);
out:
	return ret;
}

static unsigned int done;
static void test_ske_complete(struct skcipher_req *req, int err)
{
	if (!err)
		hex_dump("Dump async ske dst:", req->dst, req->cryptlen);

	skcipher_free_req(req);
	done = 1;
}

static int test_async_ske_unit(void)
{
	const struct skcipher_testvec *ske = NULL;
	struct skcipher_ctx *ctx;
	struct skcipher_req *req;
	unsigned char ptext[16];
	unsigned char ctext[16];
	unsigned char iv[16];
	unsigned char key[16];
	int ret = -1;

	ske = find_skcipher_tv("ske-unit");
	if (!ske)
		goto out;

	ctx = skcipher_alloc_ctx("cbc(aes)", CRYPTO_ASYNC);
	if (!ctx)
		goto out;

	req = skcipher_alloc_req(ctx);
	if (!req)
		goto free_ctx;

	memcpy(key, ske->key, ske->key_len);
	skcipher_setkey(ctx, key, ske->key_len);

	memcpy(ptext, ske->ptext, ske->len);
	memcpy(iv, ske->iv, sizeof(ske->iv) - 1);
	skcipher_req_set_callback(req, test_ske_complete, NULL);
	skcipher_set_req(req, ptext, ctext, ske->len, iv);

	ret = skcipher_encrypt(req);
	if (ret == -EINPROGRESS) {
		/* Wait for completion */
		while (!done);

		/* Test passed, reset ret to 0 */
		ret = 0;
	} else {
		skcipher_free_req(req);
	}

free_ctx:
	skcipher_free_ctx(ctx);
out:
	return ret;
}

static unsigned int done0, done1;
static void *ske_test0(void *arg)
{
	struct skcipher_ctx *ctx = (struct skcipher_ctx *)arg;
	const struct skcipher_testvec *ske = NULL;
	struct skcipher_req *req;
	unsigned char ptext[16];
	unsigned char ctext[16];
	unsigned char iv[16];
	int i = 100000;
	int ret = 0;

	ske = find_skcipher_tv("ske-unit");
	if (!ske)
		return NULL;

	req = skcipher_alloc_req(ctx);
	if (!req)
		return NULL;

	memcpy(ptext, ske->ptext, ske->len);
	memcpy(iv, ske->iv, sizeof(ske->iv) - 1);
	skcipher_set_req(req, ptext, ctext, ske->len, iv);

	while (i--)
		ret += skcipher_encrypt(req);

	skcipher_free_req(req);
	done0 = 1;
	TEST_ASSERT(ret == 0, "ske thread0 encrypt");
	return NULL;
}

static void *ske_test1(void *arg)
{
	struct skcipher_ctx *ctx = (struct skcipher_ctx *)arg;
	const struct skcipher_testvec *ske = NULL;
	struct skcipher_req *req;
	unsigned char ptext[16];
	unsigned char ctext[16];
	unsigned char iv[16];
	int i = 100000;
	int ret = 0;

	ske = find_skcipher_tv("ske-unit");
	if (!ske)
		return NULL;

	req = skcipher_alloc_req(ctx);
	if (!req)
		return NULL;

	memcpy(ctext, ske->ctext, ske->len);
	memcpy(iv, ske->iv, sizeof(ske->iv) - 1);
	skcipher_set_req(req, ctext, ptext, ske->len, iv);

	while (i--)
		ret += skcipher_decrypt(req);

	skcipher_free_req(req);
	done1 = 1;
	TEST_ASSERT(ret == 0, "ske thread1 decrypt");
	return NULL;
}

static int test_multi_thread_ske_unit(void)
{
	const struct skcipher_testvec *ske = NULL;
	pthread_t thread0, thread1;
	struct skcipher_ctx *ctx;
	unsigned char key[16];
	int ret = -1;

	ske = find_skcipher_tv("ske-unit");
	if (!ske)
		goto out;

	ctx = skcipher_alloc_ctx("cbc(aes)", CRYPTO_SYNC);
	if (!ctx)
		goto out;

	memcpy(key, ske->key, ske->key_len);
	skcipher_setkey(ctx, key, ske->key_len);

	ret = pthread_create(&thread0, NULL, ske_test0, (void *)ctx);
	if (ret)
		done0 = 1;

	ret = pthread_create(&thread1, NULL, ske_test1, (void *)ctx);
	if (ret)
		done1 = 1;

	/* Wait for both two thread completion */
	while (done0 == 0 || done1 == 0);

	skcipher_free_ctx(ctx);
out:
	return ret;
}

static const struct aead_testvec *find_aead_tv(const char *name)
{
	const struct aead_testvec *aead = NULL;
	int i;

	for (i = 0; i < ARRAY_SIZE(aead_tv); i++)
		if (!strcmp(aead_tv[i].name, name))
			aead = &aead_tv[i];

	return aead;
}

static int test_aead_unit(void)
{
	const struct aead_testvec *aead = NULL;
	struct aead_ctx *cipher;
	struct aead_req *req;
	unsigned char iv[100];
	unsigned char src[100];
	unsigned char dst[100];
	int ret = -1;

	aead = find_aead_tv("at");
	if (!aead)
		goto out;

	cipher = aead_alloc_ctx("gcm(aes)", 0);
	if (!cipher)
		goto out;

	req = aead_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = aead_setkey(cipher, aead->key, aead->key_len);
	if (ret < 0)
		goto free_req;

	ret = aead_setauthsize(cipher, aead->c_len - aead->p_len);
	if (ret < 0)
		goto free_req;

	memcpy(src, aead->ass, aead->ass_len);
	memcpy(src + aead->ass_len, aead->ptext, aead->p_len);
	hex_dump("Dump aead src:", src, aead->ass_len + aead->p_len);

	memcpy(iv, aead->iv, sizeof(aead->iv) - 1);
	aead_set_req_ad(req, aead->ass_len);
	aead_set_req(req, src, dst, aead->p_len, iv);

	ret = aead_encrypt(req);
	hex_dump("Dump aead encrypt dst:", dst, aead->c_len);
	TEST_ASSERT(ret == 0, "aead-encrypt-unit");

	memcpy(src, aead->ass, aead->ass_len);
	memcpy(src + aead->ass_len, aead->ctext, aead->c_len);
	hex_dump("Dump aead src:", src, aead->ass_len + aead->c_len);

	memcpy(iv, aead->iv, sizeof(aead->iv) - 1);
	aead_set_req(req, src, dst, aead->c_len, iv);

	ret = aead_decrypt(req);
	hex_dump("Dump aead decrypt dst:", dst, aead->p_len);
	TEST_ASSERT(ret == 0, "aead-decrypt-unit");

free_req:
	aead_free_req(req);
free_ctx:
	aead_free_ctx(cipher);
out:
	return ret;
}

static int test_rng_unit(void)
{
	struct rng_ctx *cipher;
	uint8_t dst[32];
	int ret;

	cipher = rng_alloc_ctx("trng", 0);
	if (!cipher)
		return -1;

	ret = rng_generate(cipher, NULL, 0, dst, 32);
	if (ret >= 0)
		hex_dump("Random:", dst, 32);

	rng_free_ctx(cipher);
	return ret >= 0 ? 0 : ret;
}

static const struct kpp_testvec *find_kpp_tv(const char *name)
{
	const struct kpp_testvec *kpp = NULL;
	int i;

	for (i = 0; i < ARRAY_SIZE(kpp_tv); i++)
		if (!strcmp(kpp_tv[i].name, name))
			kpp = &kpp_tv[i];

	return kpp;
}

static int test_kpp_key_unit(void)
{
	EC_KEY *ecPriv, *ecPub;
	EVP_PKEY *pubKey;
	struct kpp_ctx *cipher;
	struct kpp_req *req;
	unsigned char priv[128];
	unsigned char pub[256];
	FILE *fp;
	int ret = -1;

	pubKey = EVP_PKEY_new();
	ecPriv = EC_KEY_new();
	ecPub = EC_KEY_new();
	fp = fopen("./ecdh_priv_key.pem", "rb");
	if (!fp)
		goto out;

	PEM_read_ECPrivateKey(fp, &ecPriv, NULL, NULL);
	if (!ecPriv)
		goto close_fp;

	cipher = kpp_alloc_ctx("ecdh-nist-p256", CRYPTO_SYNC);
	if (!cipher)
		goto close_fp;

	req = kpp_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	/* private key length is 32 bytes, public key length is 64 bytes */
	kpp_set_req(req, priv, 32, pub, 64);
	ret = kpp_generate_key_pair(req);
	if (ret < 0)
		goto free_req;

	hex_dump("KPP generated private key:", priv, 32);
	hex_dump("KPP generated public key:", pub, 64);

	ret = kpp_set_secret(cipher, (void *)priv, 32);
	if (ret < 0)
		goto free_req;

	ret = kpp_generate_public_key(req);
	if (!ret)
		hex_dump("KPP verified public key:", pub, 64);

free_req:
	kpp_free_req(req);
free_ctx:
	kpp_free_ctx(cipher);
close_fp:
	fclose(fp);
out:
	EVP_PKEY_free(pubKey);
	EC_KEY_free(ecPriv);
	EC_KEY_free(ecPub);
	return ret;
}

static int test_kpp_shared_key(void)
{
	const struct kpp_testvec *kpp;
	struct kpp_ctx *cipher;
	struct kpp_req *req;
	unsigned char priv[128];
	unsigned char pub[256];
	unsigned char shared[128];
	int ret = -1;

	kpp = find_kpp_tv("ecdh-nist-p256");
	if (!kpp)
		goto out;

	memcpy(priv, kpp->secret, 32);
	memcpy(pub, kpp->pp, 64);

	cipher = kpp_alloc_ctx("ecdh-nist-p256", CRYPTO_SYNC);
	if (!cipher)
		goto out;

	req = kpp_alloc_req(cipher);
	if (!req)
		goto free_ctx;

	ret = kpp_set_secret(cipher, (void *)priv, 32);
	if (ret < 0)
		goto free_req;

	/* public key length is 64 bytes, shared key length is 32 bytes */
	kpp_set_req(req, pub, 64, shared, 32);

	ret = kpp_compute_shared_secret(req);
	if (!ret)
		hex_dump("KPP shared key:", shared, 32);

free_req:
	kpp_free_req(req);
free_ctx:
	kpp_free_ctx(cipher);
out:
	return ret;
}

static int test_set_and_get_polling_affinity(void)
{
	unsigned int nprocs, pcpu;
	cpu_set_t cpuset, online_cpuset;
	int ret, i = 100;

	nprocs = get_nprocs_conf();

	while (i--) {
		CPU_ZERO(&cpuset);
		pcpu = rand() % nprocs;
		CPU_SET(pcpu, &cpuset);
		ret = ycc_set_polling_affinity(&cpuset);
		if (ret)
			return ret;

		ret = ycc_get_polling_affinity(&online_cpuset);
		if (ret)
			return ret;

		if (!CPU_EQUAL(&cpuset, &online_cpuset))
			return -1;

		printf("set and get polling affinity successfully, cpu=%u\n", pcpu);
	}

	return 0;
}

static int test_set_and_get_polling_interval(void)
{
	if (!ycc_set_polling_interval(1000) && ycc_get_polling_interval() == 1000)
		return 0;

	return -1;
}

int main(int argc, char *argv[])
{
	int ret;

	ret = ycc_drv_init(1);
	if (ret < 0)
		return ret;

	inited = true;
	ret += TEST_ASSERT(test_set_and_get_polling_affinity() == 0, "polling-affinity");
	ret += TEST_ASSERT(test_set_and_get_polling_interval() == 0, "polling-interval");
	ret += TEST_ASSERT(test_ecdsa_nist_p521() == 0, "ecdsa-nist-p521");
	ret += TEST_ASSERT(test_ecdsa_nist_p256() == 0, "ecdsa-nist-p256");
	ret += TEST_ASSERT(test_rsa_crt_key() == 0, "rsa-crt");
	ret += TEST_ASSERT(test_rsa_pkcs_encrypt(false) == 0, "rsa-pkcs-encrypt");
	ret += TEST_ASSERT(test_rsa_pkcs_decrypt(false) == 0, "rsa-pkcs-decrypt");
	ret += TEST_ASSERT(test_rsa_pke_unit() == 0, "rsa-pke-unit");
	ret += TEST_ASSERT(test_rsa_multi_encrypt(500000) == 0, "rsa-multi-encrypt");
	ret += TEST_ASSERT(test_rsa_multi_decrypt(500000) == 0, "rsa-multi-decrypt");
	ret += TEST_ASSERT(test_rsa_sign() == 0, "rsa-sign-and-verify");
	ret += TEST_ASSERT(test_ske_unit() == 0, "ske-unit");
	ret += TEST_ASSERT(test_async_ske_unit() == 0, "async-ske-unit");
	ret += TEST_ASSERT(test_multi_thread_ske_unit() == 0, "multi-thread-ske-unit");
	ret += TEST_ASSERT(test_aead_unit() == 0, "aead-unit");
	ret += TEST_ASSERT(test_rng_unit() == 0, "rng-unit");
	ret += TEST_ASSERT(test_kpp_key_unit() == 0, "kpp-key-unit");
	ret += TEST_ASSERT(test_kpp_shared_key() == 0, "kpp-shared-key");
	if (ret)
		printf("All %d testcases failed!\n", ret);

	return 0;
}
