// SPDX-License-Identifier: GPL-2.0
/*
 * UDMA: User space dma memory management for Alibaba YCC
 *   (Yitian Crypto Complex) crypto accelerator.
 *
 * Copyright (C) 2020-2022 Alibaba Corporation. All rights reserved.
 * Author: Zelin Deng <zelin.deng@linux.alibaba.com>
 * Author: Guanjun <guanjun@linux.alibaba.com>
 * Author: Jiayu Ni <jiayu.ni@linux.alibaba.com>
 */

#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <signal.h>
#include <stdio.h>
#include <stdbool.h>
#include <stdlib.h>
#include <errno.h>
#include <sys/mman.h>
#include <sys/ioctl.h>
#include <string.h>
#include <pthread.h>
#include <sys/syscall.h>

#include "../utils/utils.h"
#include "udma_ulib.h"

#pragma GCC diagnostic ignored "-Wint-conversion"

#define YCC_UDMA_DEV		"/dev/ycc_udma"

#define QWORD_WIDTH		(8 * sizeof(__u64))
#define QWORD_ALL_ONE		0xFFFFFFFFFFFFFFFFULL

#define RESERVED_UNITS	(DIV_ROUND_UP(sizeof(struct ycc_udma_ctrl), UNIT_SIZE))

#define PAGE_TABLE_ENTRY	(PAGE_SIZE / sizeof(__u64))
#define PAGE_MASK		(~(PAGE_SIZE - 1))

#define MAJOR_VERSION		(0)
#define MINOR_VERSION		(1)

#define QWORD_BYTE_POS(bitmap, offset)	((bitmap) + (offset) / QWORD_WIDTH)
#define QWORD_BIT_POS(offset)		((offset) % QWORD_WIDTH)

struct page_entry {
	__u64 offset : 12;
	__u64 lv0 : 9;
	__u64 lv1 : 9;
	__u64 lv2 : 9;
	__u64 lv3 : 9;
	__u64 lv4 : 9;
};

union page_index {
	__u64 virt_addr;
	struct page_entry pe;
};

struct page_table {
	union {
		__u64 pa;
		struct page_table *pt;
	} next[PAGE_TABLE_ENTRY];
};

struct env_var {
	int cpu_cores;
	struct ycc_uio_info driver_info;
};

struct percpu_slab_pool {
	struct slab_list **slab_list_array;

	/* For every elements of slab_list_array */
	pthread_spinlock_t *lock_array;
	struct env_var *env;
	int *dev_fd;
};

unsigned int udma_debug_mode;

static int fd = -1;
static int g_max_try_num = 10;
static struct page_table g_page_table;
static struct slab_list *g_user_slab_list;

struct ycc_udma_large_info large_mem;

static struct env_var g_env;
static struct percpu_slab_pool g_slab_pool;

/* Used for bit operations */
static const __u64 __bitmask[65] = {
	0x0000000000000000ULL, 0x0000000000000001ULL, 0x0000000000000003ULL,
	0x0000000000000007ULL, 0x000000000000000fULL, 0x000000000000001fULL,
	0x000000000000003fULL, 0x000000000000007fULL, 0x00000000000000ffULL,
	0x00000000000001ffULL, 0x00000000000003ffULL, 0x00000000000007ffULL,
	0x0000000000000fffULL, 0x0000000000001fffULL, 0x0000000000003fffULL,
	0x0000000000007fffULL, 0x000000000000ffffULL, 0x000000000001ffffULL,
	0x000000000003ffffULL, 0x000000000007ffffULL, 0x00000000000fffffULL,
	0x00000000001fffffULL, 0x00000000003fffffULL, 0x00000000007fffffULL,
	0x0000000000ffffffULL, 0x0000000001ffffffULL, 0x0000000003ffffffULL,
	0x0000000007ffffffULL, 0x000000000fffffffULL, 0x000000001fffffffULL,
	0x000000003fffffffULL, 0x000000007fffffffULL, 0x00000000ffffffffULL,
	0x00000001ffffffffULL, 0x00000003ffffffffULL, 0x00000007ffffffffULL,
	0x0000000fffffffffULL, 0x0000001fffffffffULL, 0x0000003fffffffffULL,
	0x0000007fffffffffULL, 0x000000ffffffffffULL, 0x000001ffffffffffULL,
	0x000003ffffffffffULL, 0x000007ffffffffffULL, 0x00000fffffffffffULL,
	0x00001fffffffffffULL, 0x00003fffffffffffULL, 0x00007fffffffffffULL,
	0x0000ffffffffffffULL, 0x0001ffffffffffffULL, 0x0003ffffffffffffULL,
	0x0007ffffffffffffULL, 0x000fffffffffffffULL, 0x001fffffffffffffULL,
	0x003fffffffffffffULL, 0x007fffffffffffffULL, 0x00ffffffffffffffULL,
	0x01ffffffffffffffULL, 0x03ffffffffffffffULL, 0x07ffffffffffffffULL,
	0x0fffffffffffffffULL, 0x1fffffffffffffffULL, 0x3fffffffffffffffULL,
	0x7fffffffffffffffULL, 0xffffffffffffffffULL,
};

static inline unsigned int get_key(void *ptr)
{
	/*
	 * Note: Use bits 21-32 of virtual address as hash key.
	 * Considering using physical address instead.
	 */
	return ((unsigned long)ptr >> 21) & HASH_BUCKETS_MASK;
}

/*
 * Establish page table.
 */
static void *next_level(struct page_table **pt)
{
	struct page_table *old_ptr = *pt, *new_ptr;

	/* page_table has been established */
	if (old_ptr != NULL)
		return old_ptr;

	new_ptr = mmap(NULL,
		       sizeof(struct page_table),  /* PAGE_SIZE actually */
		       PROT_READ|PROT_WRITE,
		       MAP_PRIVATE|MAP_ANONYMOUS,
		       -1,
		       0);

	if (unlikely(new_ptr == MAP_FAILED))
		return NULL;

	if (!__sync_bool_compare_and_swap(pt, NULL, new_ptr))
		munmap(new_ptr, sizeof(struct page_table));

	return *pt;
}

/*
 * Store the mapping of virtual and physical address to page table.
 */
static void store_addr(struct page_table *pt,
		       __u64 virt_addr,
		       __u64 phys_addr)
{
	union page_index idx;

	idx.virt_addr = virt_addr;
	pt = next_level(&pt->next[idx.pe.lv4].pt);
	if (!pt)
		return;

	pt = next_level(&pt->next[idx.pe.lv3].pt);
	if (!pt)
		return;

	pt = next_level(&pt->next[idx.pe.lv2].pt);
	if (!pt)
		return;

	pt = next_level(&pt->next[idx.pe.lv1].pt);
	if (!pt)
		return;

	pt->next[idx.pe.lv0].pa = phys_addr;
}

/*
 * Load the mapping of virtual and physical address from page table.
 */
static __u64 load_addr(struct page_table *pt, __u64 virt_addr)
{
	union page_index idx;
	__u64 phys_addr = 0;

	idx.virt_addr = virt_addr;
	pt = pt->next[idx.pe.lv4].pt;
	if (!pt)
		return 0;

	pt = pt->next[idx.pe.lv3].pt;
	if (!pt)
		return 0;

	pt = pt->next[idx.pe.lv2].pt;
	if (!pt)
		return 0;

	pt = pt->next[idx.pe.lv1].pt;
	if (!pt)
		return 0;

	phys_addr = pt->next[idx.pe.lv0].pa;
	return (phys_addr & PAGE_MASK) | idx.pe.offset;
}

/*
 * Set bits of bitmap from offset pos to pos+len.
 */
static void set_bitmap(__u64 *bitmap, unsigned int pos, unsigned int len)
{
	unsigned int start_pos, stop_pos;
	__u64 *start, *stop;
	int i;

	start     = QWORD_BYTE_POS(bitmap, pos);
	stop      = QWORD_BYTE_POS(bitmap, pos + len);
	start_pos = QWORD_BIT_POS(pos);
	stop_pos  = QWORD_BIT_POS(pos + len);

	if (start != stop) {
		/* Head and tail */
		*start |= ~((1UL << start_pos) - 1);
		*stop |= (1UL << stop_pos) - 1;

		/* Remaining bits between head and tail */
		i = 1;
		while (start + i != stop) {
			*(start + i) = ~0UL;
			i++;
		}
	} else {
		if (start_pos == 0)
			*start |= (1UL << stop_pos) - 1;
		else
			*start |= ((1UL << stop_pos) - 1) ^ ((1UL << start_pos) - 1);
	}
}

/*
 * Clear bits of bitmap from offset pos to pos+len.
 */
static void clear_bitmap(__u64 *bitmap, unsigned int pos, unsigned int len)
{
	unsigned int start_pos, stop_pos;
	__u64 *start, *stop;
	int i;

	start     = QWORD_BYTE_POS(bitmap, pos);
	stop      = QWORD_BYTE_POS(bitmap, pos + len);
	start_pos = QWORD_BIT_POS(pos);
	stop_pos  = QWORD_BIT_POS(pos + len);

	if (start != stop) {
		/* Head and tail */
		*start &= ~(__bitmask[QWORD_WIDTH - start_pos] << start_pos);
		*stop &= ~__bitmask[stop_pos];

		/* Remaining bits between head and tail */
		i = 1;
		while (start + i != stop) {
			*(start + i) = 0UL;
			i++;
		}
	} else {
		*start &= ~(__bitmask[len] << start_pos);
	}
}

/*
 * Get size through helper array 'sizes' from the structure ycc_udma_info.
 */
static inline size_t get_sizes(__u16 *sizes, unsigned int pos)
{
	return sizes[pos];
}

/*
 * Set size through helper array 'sizes' from the structure ycc_udma_info.
 */
static inline void set_sizes(__u16 *sizes, unsigned int pos, size_t size)
{
	sizes[pos] = size;
}

/*
 * Reads a 64-bit window from a BITMAP_LEN * 64 bits.
 *
 * @map: address to the given bitmap memory region
 * @window_pos: the bit offset of the 64-bit value; must be a multiple of 64
 *
 * Returns the 64-bit value located at the @window_pos bit offset within
 * the @map memory region.
 * Each bit represents a 1KB block in the 2 MB buffer.
 */
static __u64 bitmap_read(__u64 *map, size_t window_pos)
{
	__u64 quad_word_window = 0ULL;
	__u64 next_quad_word = 0ULL;
	size_t quad_word_pos = 0;
	size_t bit_pos = 0;

	quad_word_pos = window_pos / QWORD_WIDTH;

	if (quad_word_pos >= BITMAP_LEN)
		return QWORD_ALL_ONE;

	bit_pos = QWORD_BIT_POS(window_pos);
	quad_word_window = map[quad_word_pos];

	if (!bit_pos)
		return quad_word_window;

	/*
	 * It is safe to read the next quad word because
	 * there is always a barrier at the end.
	 */
	next_quad_word = map[quad_word_pos + 1];
	quad_word_window >>= bit_pos;
	next_quad_word <<= QWORD_WIDTH - bit_pos;
	quad_word_window |= next_quad_word;
	return quad_word_window;
}

/*
 * @bitmap_window: a 64-bit bitmap window
 *
 * Returns the number of contiguous 0s from least significant bit position
 */
static inline int32_t mem_ctzll(__u64 bitmap_window)
{
	if (bitmap_window) {
#ifdef __GNUC__
		return __builtin_ctzll(bitmap_window);
#else
#error "Undefined built-in function"
#endif
	}
	return QWORD_WIDTH;
}

static void *mem_alloc(struct ycc_udma_ctrl *udma_ctrl,
		       size_t size,
		       size_t align)
{
	unsigned int blocks_required = 0, first_block = 0;
	size_t window_pos = 0, blocks_found = 0;
	size_t width_ones = 0, width = 0;
	__u64 bitmap_window = 0ULL;
	__u64 *bitmap = NULL;
	void *retval = NULL;

	if (!udma_ctrl || !size) {
		udma_err("invalid udma_ctrl or size provided!\n");
		return NULL;
	}

	bitmap = udma_ctrl->bitmap;
	blocks_required = DIV_ROUND_UP(size, UNIT_SIZE);
	window_pos = 0;
	first_block = window_pos;

	do {
		/* Read 64-bit bitmap value from window_pos */
		bitmap_window = bitmap_read(bitmap, window_pos);

		/* Find number of contiguous 0s from right */
		width = mem_ctzll(bitmap_window);

		/*
		 * Increment number of blocks found with number of
		 * contiguous 0s in bitmap window.
		 */
		blocks_found += width;
		if (blocks_found >= blocks_required) {
			/* Calculate return address from virtual address and first block number */
			retval = (void *)(udma_ctrl) + first_block * UNIT_SIZE;
			if (first_block + blocks_required > BITMAP_LEN * QWORD_WIDTH)
				return NULL;

			set_bitmap(bitmap, first_block, blocks_required);
			set_sizes(udma_ctrl->sizes, first_block, blocks_required);
		} else {
			/*
			 * Didn't find suitable free space yet. Check if bitmap_window has at
			 * least a 1 value bit.
			 */
			if (bitmap_window) {
				/*
				 * Bit field of 0s is not contiguous, clear blocks_found adjust
				 * first_block and window_pos find width of contiguous 1 bits
				 * and move window position will read next 64-bit wide window
				 * from bitmap.
				 */
				bitmap_window >>= (width + 1);
				width_ones = mem_ctzll(~bitmap_window);
				blocks_found = 0;
				window_pos += width + 1 + width_ones;
				if (align && window_pos % align)
					window_pos += align - window_pos % align;

				first_block = window_pos;
				break;
			} else {
				/*
				 * Bit field of 0s is contiguous, but suitable area
				 * not found yet, move window_pos an search more 0s.
				 */
				window_pos += width;
			}
		}
	} while (window_pos < BITMAP_LEN * QWORD_WIDTH);

	return retval;
}

/*
 * The file description of misc device should be accessed mutual exclusively, so
 * just open once and store it in a global variable 'fd'.
 */
static int ycc_udma_open(void)
{
	if (fd != -1)
		return -EBUSY;

	fd = open(YCC_UDMA_DEV, O_RDWR);
	if (fd < 0)
		return errno;

	return 0;
}

static bool percpu_pool_ready(struct percpu_slab_pool *pool)
{
	return !(!pool ||
		 !pool->slab_list_array ||
		 !pool->lock_array ||
		 pool->dev_fd < 0);
}

static struct ycc_udma_ctrl *lookup_percpu_pool(size_t size,
						void **block,
						const size_t align,
						int cpu)
{
	struct ycc_udma_ctrl *pcurr = NULL;
	int try_num = 0;

	if (unlikely(!percpu_pool_ready(&g_slab_pool))) {
		udma_warn("percpu pool is not ready.\n");
		return NULL;
	}

	for (pcurr = g_slab_pool.slab_list_array[cpu]->head;
			pcurr;
			pcurr = pcurr->next_user_percpu) {
		pthread_spin_lock((pthread_spinlock_t *)(pcurr->slab_lock));
		*block = mem_alloc(pcurr, size, align);
		pthread_spin_unlock((pthread_spinlock_t *)(pcurr->slab_lock));
		if (*block)
			return pcurr;

		if (++try_num >= g_max_try_num)
			break;
	}

	return NULL;
}

/*
 * In case memory block has been allocated and linked
 * in cache list, just find out suitable slab directly.
 */
static struct ycc_udma_ctrl *find_slab(size_t size,
				       void **virt_addr,
				       const size_t align,
				       int cpu)
{
	struct ycc_udma_ctrl *slab = NULL;

	slab = lookup_percpu_pool(size, virt_addr, align, cpu);
	udma_dbg("find slab %p\n", slab);
	return slab;
}

/*
 * No suitable slab found. We should do the real action
 * to allocate memory block which is mem_type.
 */
static struct ycc_udma_ctrl *real_alloc(int fd,
					size_t size,
					int node,
					enum slab_type mem_type)
{
	struct ycc_udma_info *mem_info = NULL, params = {0};
	struct ycc_udma_ctrl *mem_ctrl;
	unsigned int blocks_required = 0, slab_offset;
	int ret, bits;

	params.node = node;			/* node is actually unused */
	params.type = mem_type;
	params.size = MEMBLOCK_SIZE_NORMAL; /* size here doesn't matter */

	ret = ioctl(fd, YCC_IOC_MEM_ALLOC, &params);
	if (unlikely(ret < 0))
		return ERR_PTR(ret);

	mem_info = mmap(NULL,
			params.size,
			PROT_READ|PROT_WRITE,
			MAP_SHARED,
			fd,
			params.dma_addr);
	if (unlikely(mem_info == MAP_FAILED)) {
		ret = ioctl(fd, YCC_IOC_MEM_FREE, &params);
		if (unlikely(ret < 0))
			return ERR_PTR(ret);
	}

	mem_info->virt_addr = mem_info;
	mem_ctrl = (struct ycc_udma_ctrl *)mem_info;

	pthread_spin_init((pthread_spinlock_t *)(mem_ctrl->slab_lock), 0);

	/* Set bits for metadata and allocated memory */
	blocks_required = DIV_ROUND_UP(size, UNIT_SIZE);
	bits = RESERVED_UNITS + blocks_required;
	/*
	 * We don't hold the slab_lock here because this
	 * slab is not in percpu pool link list now.
	 */
	set_bitmap(mem_ctrl->bitmap, 0, bits);

	/* Only set sizes for allocated memory */
	set_sizes(mem_ctrl->sizes, RESERVED_UNITS, blocks_required);

	atomic_inc(&mem_ctrl->alloc_cnts);

	/* Store all map table, then pagetable can remove lock */
	for (slab_offset = 0; slab_offset < MEMBLOCK_SIZE_NORMAL; slab_offset += PAGE_SIZE)
		store_addr(&g_page_table,
			   mem_info->virt_addr + slab_offset,
			   mem_info->dma_addr + slab_offset);

	return mem_ctrl;
}

/*
 * malloc api used by user-space applications.
 * Return virtual address just like traditional malloc.
 */
void *ycc_udma_malloc(size_t size)
{
	struct ycc_udma_ctrl *mem_ctrl = NULL;
	size_t phys_align_unit, phys_alignment_byte = 64;
	enum slab_type mem_type = NORMAL_type;
	int ret = -1, node = 0, cpu;
	void *virt_addr = NULL;

	if (fd < 0) {
		udma_warn("fd=%d is unexpected\n", fd);
		return ERR_PTR(-EFAULT);
	}

	if (!size) {
		udma_warn("size=%lu is unexpected\n", size);
		return ERR_PTR(-EFAULT);
	}

	if (size > (MEMBLOCK_SIZE_NORMAL - RESERVED_UNITS * UNIT_SIZE)) {
		udma_warn("size=%lu, not support > 32K size\n", size);
		return ERR_PTR(-ENOMEM);
	}

	udma_dbg("lib size %lu\n", size);
	ret = syscall(SYS_getcpu, &cpu, NULL, NULL);
	if (ret)
		cpu = 0;

	phys_align_unit = phys_alignment_byte / UNIT_SIZE;
	mem_ctrl = find_slab(size, &virt_addr, phys_align_unit, cpu);
	if (!mem_ctrl) {
		udma_dbg("enter real alloc path...\n");
		mem_ctrl = real_alloc(fd, size, node, mem_type);
		if (IS_ERR_OR_NULL(mem_ctrl)) {
			udma_warn("failed to get system memory, %s\n",
				  strerror(PTR_ERR(mem_ctrl)));
			virt_addr = mem_ctrl;
		}

		pthread_spin_lock(&(g_slab_pool.lock_array[cpu]));
		ADD_ELEMENT_TO_HEAD_LIST(mem_ctrl,
			(g_slab_pool.slab_list_array[cpu])->head,
			(g_slab_pool.slab_list_array[cpu])->tail,
			_user_percpu);
		pthread_spin_unlock(&(g_slab_pool.lock_array[cpu]));

		/* skip metadata bytes */
		virt_addr = ((struct ycc_udma_info *)mem_ctrl)->virt_addr +
			    RESERVED_UNITS * UNIT_SIZE;
	}

	return virt_addr;
}

static void alloc_large_mem(void)
{
	struct ycc_udma_large_info params = {0};
	int ret;

	if (large_mem.virt_addr) {
		udma_warn("large mem is already allocated."
			  "large_mem virt_addr:%p\n", large_mem.virt_addr);
		return;
	}

	ret = ioctl(fd, YCC_IOC_LARGE_MEM_ALLOC, &params);
	if (unlikely(ret < 0)) {
		udma_warn("YCC_IOC_LARGE_MEM_ALLOC failed\n");
		return;
	}

	large_mem.virt_addr = mmap(NULL,
				   LARGE_ALLOC_SIZE,
				   PROT_READ | PROT_WRITE,
				   MAP_SHARED,
				   fd,
				   params.dma_addr);
	if (unlikely(large_mem.virt_addr == MAP_FAILED)) {
		udma_warn("mmap large mem failed!\n");
		large_mem.virt_addr = NULL;
	}

	if (large_mem.virt_addr)
		large_mem.dma_addr = params.dma_addr;
}

/*
 * Malloc 4M memory for special use.
 * NOTE: Only can be called once.
 */
void *ycc_udma_large_malloc(void)
{
	if (fd < 0) {
		udma_warn("udma: fd=%d is unexpected\n", fd);
		return ERR_PTR(-EFAULT);
	}

	alloc_large_mem();
	if (!large_mem.virt_addr)
		udma_warn("alloc large mem failed!\n");

	return large_mem.virt_addr;
}

static void free_large_mem(int fd)
{
	int ret;

	if (!large_mem.virt_addr)
		return;

	ret = ioctl(fd, YCC_IOC_LARGE_MEM_FREE, NULL);
	if (unlikely(ret < 0)) {
		udma_warn("YCC_IOC_LARGE_MEM_FREE failed\n");
		return;
	}

	large_mem.virt_addr = NULL;
}

/*
 * Free 4M large memory
 */
void ycc_udma_large_free(void)
{
	if (fd < 0) {
		udma_warn("udma: fd=%d is unexpected\n", fd);
		return;
	}

	free_large_mem(fd);
	if (large_mem.virt_addr)
		udma_warn("large memory is not NULL after free.\n");
}

/*
 * virt_to_phys api used by user_space applications.
 * Return 0 means there may be some errors.
 */
unsigned long virt_to_phys(void *ptr)
{
	__u64 pa_from_pgt;

	if (!ptr)
		return 0;

	pa_from_pgt = load_addr(&g_page_table, (__u64)ptr);

	return pa_from_pgt;
}

unsigned long virt_to_phys_large(void *ptr)
{
	signed long delta;

	if (!ptr)
		return 0;

	if (!large_mem.virt_addr)
		return 0;

	delta = (unsigned long)ptr - (unsigned long)large_mem.virt_addr;
	if (delta < 0 || delta >= LARGE_ALLOC_SIZE)
		return 0;

	return large_mem.dma_addr + delta;
}

static void mem_free(struct ycc_udma_ctrl *slab, void *ptr)
{
	unsigned int pos = (ptr - (void *)slab) / UNIT_SIZE;
	size_t blocks = get_sizes(slab->sizes, pos);

	set_sizes(slab->sizes, pos, 0);
	clear_bitmap(slab->bitmap, pos, blocks);
	atomic_dec(&slab->alloc_cnts);
}

struct ycc_udma_ctrl *get_slab_from_pgt(void *ptr)
{
	__u64 ptr_pa = 0;

	ptr_pa = load_addr(&g_page_table, (__u64)ptr);
	if (!ptr_pa)
		return NULL;
	/*
	 * Two conditions must be satisfied.
	 * 1. PA of the slab must be 2M aligned.
	 * 2. Memory mapping is linear mapping.
	 */
	return (struct ycc_udma_ctrl *)((__u64)ptr - (ptr_pa & MASK_SIZE_2M));
}

/*
 * free api used by user_space applications.
 */
void ycc_udma_free(void *ptr)
{
	struct ycc_udma_ctrl *slab_pgt = NULL;

	if (!ptr) {
		udma_err("unable to free NULL pointer\n");
		return;
	}

	slab_pgt = get_slab_from_pgt(ptr);
	if (!slab_pgt) {
		udma_err("unable to find slab from pgt\n");
		return;
	}

	pthread_spin_lock((pthread_spinlock_t *)(slab_pgt->slab_lock));
	mem_free(slab_pgt, ptr);
	pthread_spin_unlock((pthread_spinlock_t *)(slab_pgt->slab_lock));
}

static void free_page_level(struct page_table *level, const int iter)
{
	struct page_table *tmp = NULL;
	int i;

	if (!iter)
		return;

	for (i = 0; i < PAGE_TABLE_ENTRY; ++i) {
		tmp = level->next[i].pt;
		if (tmp) {
			free_page_level(tmp, iter - 1);
			munmap(tmp, sizeof(struct page_table));
		}
	}
}

static void free_page_table(struct page_table *pt)
{
	free_page_level(pt, PAGE_TABLE_LEVEL - 1);
	memset(pt, 0, sizeof(struct page_table));
}

static int is_match_version(struct ycc_uio_info *uio_info)
{
	if (!uio_info)
		return 0;

	if (uio_info->major > MAJOR_VERSION)
		return 0;

	if (uio_info->major == MAJOR_VERSION && uio_info->minor > MINOR_VERSION)
		return 0;

	return 1;
}

static int init_env_var(struct env_var *env)
{
	struct ycc_uio_info *driver_info = &(env->driver_info);
	int ret;

	env->cpu_cores = sysconf(_SC_NPROCESSORS_ONLN);

	ret = ioctl(fd, YCC_IOC_GET_UIO_INFO, driver_info);
	if (ret) {
		udma_warn("get version failed.\n");
		return ret;
	}

	if (!is_match_version(driver_info)) {
		udma_warn("version %d.%d, but kernel version %d.%d\n"
			  "Please update udma lib.\n",
			  MAJOR_VERSION, MINOR_VERSION,
			  driver_info->major, driver_info->minor);
		return -1;
	}
	return 0;
}

/*
 * Just alloc memory for spinlock and slab_list_array, but have not
 * alloc memory for every slab_list element.
 */
static int init_percpu_slab_pool(struct percpu_slab_pool *pool,
				struct env_var *env)
{
	int ret, core_cnt;

	if (!pool || !env)
		return -EINVAL;

	pool->env = env;
	pool->dev_fd = fd;

	if (!env->cpu_cores)
		return 0;

	pool->slab_list_array = calloc(env->cpu_cores, sizeof(struct slab_list *));
	if (!pool->slab_list_array)
		return -ENOMEM;

	for (core_cnt = 0; core_cnt < env->cpu_cores; ++core_cnt) {
		pool->slab_list_array[core_cnt] = calloc(1, sizeof(struct slab_list));
		if (!(pool->slab_list_array[core_cnt])) {
			ret = -ENOMEM;
			goto free_core_pool;
		}
	}

	pool->lock_array = malloc(sizeof(pthread_spinlock_t) * (env->cpu_cores));
	if (!pool->lock_array) {
		ret = -ENOMEM;
		goto free_core_pool;
	}
	for (core_cnt = 0; core_cnt < env->cpu_cores; ++core_cnt) {
		ret = pthread_spin_init(pool->lock_array + core_cnt, 0);
		if (ret) {
			ret = -EIO;
			goto free_lock_array;
		}
	}

	return 0;

free_lock_array:
	free((void *)pool->lock_array);
	core_cnt = env->cpu_cores;
free_core_pool:
	for (core_cnt--; core_cnt >= 0; core_cnt--)
		free(pool->slab_list_array[core_cnt]);

	free(pool->slab_list_array);
	return ret;
}

/*
 * Really allocate memory for every slab
 * in slab list pool.
 */
#define RESERVE_PERCPU_SLAB_NR	(1)

static int prepare_slab_pool(struct percpu_slab_pool *pool, int reserve_nr)
{
	struct ycc_udma_ctrl *mem_ctrl = NULL;
	int slab_cnt, core_cnt;

	if (!pool || !pool->env || !pool->dev_fd)
		return -EINVAL;

	for (core_cnt = 0; core_cnt < pool->env->cpu_cores; ++core_cnt) {
		for (slab_cnt = 0; slab_cnt < reserve_nr; ++slab_cnt) {
			/*
			 * Here we invoke real_alloc without any lock
			 * because it is in initial flow at this stage
			 * which is execute by a single thread.
			 * But ycc_udma_init function will not cover the
			 * condition that multi-process or multi-thread
			 * invokes ycc_udma_init.
			 */
			mem_ctrl = real_alloc(pool->dev_fd, 0, 0, NORMAL_type);
			if (IS_ERR_OR_NULL(mem_ctrl)) {
				udma_warn("failed to get system memory, %s\n",
						strerror(PTR_ERR(mem_ctrl)));
				goto free_slab_memory;
			}

			ADD_ELEMENT_TO_HEAD_LIST(mem_ctrl,
						(pool->slab_list_array[core_cnt])->head,
						(pool->slab_list_array[core_cnt])->tail,
						_user_percpu);
		}
	}
	return 0;

free_slab_memory:
	/*
	 * TODO: Add free action in current function, but now we
	 * just return error then flow will fallback to close_fd
	 * tag, memory will release in kernel space.
	 */
	return -ENOMEM;
}

static void free_percpu_pool(struct percpu_slab_pool *pool)
{
	int core_cnt;

	if (!pool || !pool->env || !pool->env->cpu_cores)
		return;

	for (core_cnt = 0; core_cnt < pool->env->cpu_cores; ++core_cnt) {
		if (!pool->slab_list_array[core_cnt])
			break;

		free(pool->slab_list_array[core_cnt]);
		pool->slab_list_array[core_cnt] = NULL;
	}

	free(pool->slab_list_array);
	free((void *)pool->lock_array);
	memset(pool, 0, sizeof(struct percpu_slab_pool));
}

/*
 * Init global variables and environment.
 */
int ycc_udma_init(void)
{
	int ret;

	ret = ycc_udma_open();
	if (ret) {
		udma_err("failed to open /dev/ycc_udma...\n");
		goto out;
	}

	/* init cpu cores */
	ret = init_env_var(&g_env);

	/* Init page table */
	free_page_table(&g_page_table);

	ret = init_percpu_slab_pool(&g_slab_pool, &g_env);
	if (ret)
		goto close_fd;

	ret = prepare_slab_pool(&g_slab_pool, RESERVE_PERCPU_SLAB_NR);
	if (ret)
		goto free_slab_pool;

	g_user_slab_list = malloc(sizeof(struct slab_list));
	if (!g_user_slab_list) {
		udma_warn("failed to malloc memory for slab_list\n");
		ret = -ENOMEM;
		goto close_fd;
	}

	memset(g_user_slab_list, 0, sizeof(struct slab_list));

	return 0;

free_slab_pool: /* TODO:  don't free slab pool now */
close_fd:
	close(fd);
out:
	return ret;
}

void ycc_udma_exit(void)
{
	if (g_user_slab_list) {
		free(g_user_slab_list);
		g_user_slab_list = NULL;
	}

	if (fd != -1) {
		close(fd);
		fd = -1;
	}

	free_percpu_pool(&g_slab_pool);
}
