// SPDX-License-Identifier: GPL-2.0-or-later
/*
 *  group_balancer.c - Group Balancer module
 *
 *  Copyright (C) 2021 Cruz Zhao <CruzZhao@linux.alibaba.com>
 */

#include <linux/cgroup-defs.h>
#include <linux/cgroup.h>
#include <linux/ctype.h>
#include <linux/cpumask.h>
#include <linux/cpuset.h>
#include <linux/delay.h>
#include <linux/fs.h>
#include <linux/kthread.h>
#include <linux/kernel_stat.h>
#include <linux/list.h>
#include <linux/module.h>
#include <linux/sched.h>
#include <linux/slab.h>
#include <linux/string.h>
#include <linux/vmstat.h>

static int param_set_period(const char *val, const struct kernel_param *kp)
{
	unsigned int period;
	int err;

	err = kstrtouint(val, 0, &period);
	if (err)
		return err;

	if (period == 0)
		return -EINVAL;

	return 0;
}

static unsigned int scan_period = 10000;
module_param_call(scan_period, param_set_period, NULL, NULL, 0644);
MODULE_PARM_DESC(scan_period, "scan and adjust period, unit ms");

static unsigned int sampling_period = 2000;
module_param_call(sampling_period, param_set_period, NULL, NULL, 0644);
MODULE_PARM_DESC(sampling_period, "cgroup cpu usage sampling interval, unit: ms");

static unsigned int relative_unbalance_watermark = 120;
module_param(relative_unbalance_watermark, uint, 0644);
MODULE_PARM_DESC(relative_unbalance_watermark, "unbalance watermark");

static unsigned int absolute_unbalance_watermark = 5;
module_param(absolute_unbalance_watermark, uint, 0644);
MODULE_PARM_DESC(absolute_unbalance_watermark, "unbalance watermark");

enum {
	NR_TOTAL_PARTS_TITLE,
	NR_TOTAL_PARTS,
	NR_TOTAL_LEVELS_TITLE,
	NR_TOTAL_LEVELS,
	LEVEL,
	CONFIG,
	PARENT_ID,
	CURRENT_ID,
	CPUMASK,
	PART_COMPLETE,
};

struct partition_info {
	u64				cputime;
	u64				cputime_history;
	u64				idle;
	u64				idle_history;
	unsigned int			last_exectime_percent;
	unsigned int			idle_percent;
	unsigned int			cpu_usage_percent;
};

struct partition {
	struct cpumask			part_cpus;
	struct list_head		cgroup_list;
	struct partition_info		*pi;
	struct partition		*parent;
	struct list_head		children;
	struct list_head		siblings;
	struct list_head		level_siblings;
	unsigned int			nr_part_cpus;
	unsigned int			nr_cgroups;
	unsigned int			level;
	unsigned int			id;
};

struct group_balancer_private {
	struct cpumask			soft_cpus_allowed;
	struct list_head		part_cgroup_list;
	struct partition		*part;
	struct cgroup			*cg;
	unsigned long			jstamp;
	u64				cpuacct_cpuusage_last;
	unsigned int			cpu_usage_percent;
};

struct group_balancer {
	struct delayed_work		scan_dwork;
	struct delayed_work		sampling_dwork;
	struct list_head		*gb_list;
	struct partition		*parts;
	struct list_head		*levels;
	struct partition		*min_part;
	struct partition		*max_part;
	u64				min_usage_percent;
	u64				max_usage_percent;
	u64				avg_usage_percent;
	unsigned int			nr_parts;
	unsigned int			nr_levels;
	unsigned int			current_level;
};

static struct group_balancer *gb;
static struct workqueue_struct *group_balancer_wq;

static void exchange_adjacent_list_node(struct list_head *p, struct list_head *q)
{
	list_del_init(p);
	list_add(p, q);
}

static void bubble_sort_level_list_up(struct partition *part)
{
	struct partition *next_part;
	struct list_head *curr, *next;
	struct list_head *current_level_head;

	curr = &part->level_siblings;
	current_level_head = gb->levels + gb->current_level;
	while (curr->next != current_level_head) {
		next_part = list_next_entry(part, level_siblings);
		next = &next_part->level_siblings;
		if (next_part->nr_cgroups >= part->nr_cgroups)
			break;
		//exchange curr and next
		exchange_adjacent_list_node(curr, next);
	}
}

static void bubble_sort_level_list_down(struct partition *part)
{
	struct partition *prev_part;
	struct list_head *curr, *prev;
	struct list_head *current_level_head;

	curr = &part->level_siblings;
	current_level_head = gb->levels + gb->current_level;
	while (curr->prev != current_level_head) {
		prev_part = list_prev_entry(part, level_siblings);
		prev = &prev_part->level_siblings;
		if (prev_part->nr_cgroups <= part->nr_cgroups)
			break;
		//exchange curr and prev
		exchange_adjacent_list_node(prev, curr);
	}
}

static void update_partition_info(void)
{
	struct partition *part;
	int sum_usage_percent = 0;
	struct list_head *current_level_head;

	gb->max_usage_percent = 0;
	gb->min_usage_percent = 0xffffffffffffffff;	//u64 max
	current_level_head = gb->levels + gb->current_level;
	list_for_each_entry(part, current_level_head, level_siblings) {
		int cpu;
		u64 idle_curr, cputime_curr;
		struct partition_info *pip = part->pi;

		/*
		 * cputime_curr		-- historical cputime
		 * idle_curr		-- historical idle time
		 * pip->cputime		-- cputime growth since last update
		 * pip->idle		-- idle time growth since last update
		 * pip->cputime_history	-- historical cputime on last update
		 * pip->idle_history	-- historical idle time on last update
		 */
		idle_curr = cputime_curr = 0;
		for_each_cpu(cpu, &part->part_cpus) {
			u64 *cstat = kcpustat_cpu(cpu).cpustat;

			idle_curr += cstat[CPUTIME_IDLE];
			cputime_curr +=
				cstat[CPUTIME_USER] + cstat[CPUTIME_NICE] +
				cstat[CPUTIME_SYSTEM] + cstat[CPUTIME_IDLE] +
				cstat[CPUTIME_IOWAIT] + cstat[CPUTIME_IRQ] +
				cstat[CPUTIME_SOFTIRQ] + cstat[CPUTIME_STEAL] +
				cstat[CPUTIME_GUEST] + cstat[CPUTIME_GUEST_NICE];
		}

		pip->cputime = cputime_curr - pip->cputime_history;

		if (pip->cputime) {
			pip->cputime_history = cputime_curr;
			pip->idle = idle_curr - pip->idle_history;
		} else {
			/*
			 * If the cpu of the part is idle since last update, the cputime won't
			 * change, so we need to add the time to idle_curr and cputime_curr, and
			 * then update:
			 *
			 * pip->cputime_history and pip->idle_history.
			 * cputime_curr = pip->cputime_history + scan_period
			 * idle_curr = pip->idle_history + scan_period
			 * pip->cputime_history = cputime_curr
			 * pip->idle_history = idle_curr
			 */
			pip->cputime = pip->idle = scan_period * NSEC_PER_MSEC * part->nr_part_cpus;
			pip->cputime_history += pip->cputime;
			idle_curr += pip->cputime;
		}

		pip->idle_history = idle_curr;
		pip->idle_percent = pip->idle * 100 / (pip->cputime + 1);
		pip->cpu_usage_percent = 100 - pip->idle_percent;

		if (pip->cpu_usage_percent > gb->max_usage_percent) {
			gb->max_usage_percent = pip->cpu_usage_percent;
			gb->max_part = part;
		}
		if (pip->cpu_usage_percent < gb->min_usage_percent) {
			gb->min_usage_percent = pip->cpu_usage_percent;
			gb->min_part = part;
		}

		sum_usage_percent += pip->cpu_usage_percent;
	}
	if (gb->nr_parts)
		gb->avg_usage_percent = sum_usage_percent / gb->nr_parts;
}

static u64 cgroup_cpuusage_sampling(struct cgroup *cg)
{
	struct group_balancer_private *gb_priv;
	int cpu;
	u64 cpuacct_cpuusage_curr;

	if (!cg) {
		pr_err("sampling error: cgroup not found\n");
		return 0;
	}
	gb_priv = cg->gb_private;
	if (!gb_priv) {
		pr_err("sampling error: cgroup gb_priv not found\n");
		return 0;
	}
	cpuacct_cpuusage_curr = 0;
	for_each_cpu(cpu, &gb_priv->soft_cpus_allowed) {
		cpuacct_cpuusage_curr += get_cpuacct_cpuusage(cg, cpu);
	}
	gb_priv->cpuacct_cpuusage_last = cpuacct_cpuusage_curr;
	gb_priv->jstamp = jiffies;
	return cpuacct_cpuusage_curr;
}

static void update_cgroup_cpuusage(void)
{
	struct cgroup *cg;
	unsigned long period_in_jiffies;

	list_for_each_entry(cg, gb->gb_list, gb_list) {
		struct group_balancer_private *gb_priv = cg->gb_private;
		u64 cpuacct_cpuusage_curr, cpuacct_cpuusage_last;
		u64 cpuacct_cpuusage;

		period_in_jiffies = jiffies - gb_priv->jstamp;
		cpuacct_cpuusage_last = gb_priv->cpuacct_cpuusage_last;
		cpuacct_cpuusage_curr = cgroup_cpuusage_sampling(cg);
		cpuacct_cpuusage = cpuacct_cpuusage_curr -
				   cpuacct_cpuusage_last;
		gb_priv->cpu_usage_percent = cpuacct_cpuusage * 100 /
			period_in_jiffies;
	}
}

static void migrate_cgroup(unsigned int target)
{
	struct partition *max_part, *min_part;
	struct group_balancer_private *gb_priv, *n;

	max_part = gb->max_part;
	min_part = gb->min_part;
	list_for_each_entry_safe(gb_priv, n, &max_part->cgroup_list, part_cgroup_list) {
		struct cgroup *cg = gb_priv->cg;

		if (gb_priv->cpu_usage_percent != 0 &&
		    gb_priv->cpu_usage_percent < target) {
			struct cpumask cpus_allowed;

			list_del_init(&gb_priv->part_cgroup_list);
			max_part->nr_cgroups -= 1;
			bubble_sort_level_list_down(max_part);

			get_cpus_allowed(&cpus_allowed, cg);
			if (cpumask_empty(&cpus_allowed))
				cpumask_copy(&gb_priv->soft_cpus_allowed,
					     &min_part->part_cpus);
			else
				cpumask_and(&gb_priv->soft_cpus_allowed,
					    &min_part->part_cpus,
					    &cpus_allowed);
			//call soft bind interface
			set_soft_cpus_allowed(&gb_priv->soft_cpus_allowed, cg);
			list_add_tail(&gb_priv->part_cgroup_list,
				      &min_part->cgroup_list);
			min_part->nr_cgroups += 1;
			bubble_sort_level_list_up(min_part);

			target -= gb_priv->cpu_usage_percent;
		}
	}
}

static void group_balance(void)
{
	unsigned int target;
	struct partition *max_part, *min_part;

	max_part = gb->max_part;
	min_part = gb->min_part;

	target = (gb->max_usage_percent * max_part->nr_part_cpus -
		  gb->min_usage_percent * min_part->nr_part_cpus) / 2;
	migrate_cgroup(target);
}

static void group_balancer_workfn(struct work_struct *work)
{
	struct cgroup *cg;

	update_partition_info();
	if (gb->max_usage_percent * 100 >
	    gb->avg_usage_percent * relative_unbalance_watermark &&
	    gb->max_usage_percent > gb->avg_usage_percent +
	    absolute_unbalance_watermark) {
		lock_gb_list();
		list_for_each_entry(cg, gb->gb_list, gb_list) {
			cgroup_cpuusage_sampling(cg);
		}

		msleep(sampling_period);
		update_cgroup_cpuusage();
		group_balance();
		unlock_gb_list();

	}

	queue_delayed_work(group_balancer_wq, &gb->scan_dwork, msecs_to_jiffies(scan_period));
}

static void group_balancer_init_handler(struct cgroup *cg)
{
	struct group_balancer_private *gb_priv = cg->gb_private;
	struct partition *part;
	struct cpumask cpus_allowed;
	struct list_head *current_level_head;

	if (!gb_priv)
		gb_priv = kzalloc(sizeof(struct group_balancer_private), GFP_KERNEL);

	INIT_LIST_HEAD(&gb_priv->part_cgroup_list);
	current_level_head = gb->levels + gb->current_level;
	part = list_first_entry(current_level_head, struct partition, level_siblings);
	gb_priv->part = part;

	get_cpus_allowed(&cpus_allowed, cg);

	if (cpumask_empty(&cpus_allowed))
		cpumask_copy(&gb_priv->soft_cpus_allowed, &part->part_cpus);
	else
		cpumask_and(&gb_priv->soft_cpus_allowed, &part->part_cpus,
			    &cpus_allowed);
	set_soft_cpus_allowed(&gb_priv->soft_cpus_allowed, cg);
	list_add_tail(&gb_priv->part_cgroup_list, &part->cgroup_list);
	part->nr_cgroups += 1;
	bubble_sort_level_list_up(part);
	cg->gb_private = gb_priv;
	gb_priv->cg = cg;

}

static void group_balancer_exit_handler(struct cgroup *cg)
{
	struct group_balancer_private *gb_priv = cg->gb_private;
	struct partition *part;
	struct cpumask tmp_cpumask;
	struct cpumask cpus_allowed;

	if (!gb_priv)
		return;

	part = gb_priv->part;
	get_cpus_allowed(&cpus_allowed, cg);
	if (cpumask_empty(&cpus_allowed)) {
		cpumask_copy(&tmp_cpumask, cpu_online_mask);
		set_soft_cpus_allowed(&tmp_cpumask, cg);
	} else
		set_soft_cpus_allowed(&cpus_allowed, cg);
	list_del_init(&gb_priv->part_cgroup_list);
	part->nr_cgroups -= 1;
	bubble_sort_level_list_down(part);
	kfree(gb_priv);
	cg->gb_private = NULL;
}

struct group_balancer_callback cb = {
	.init = group_balancer_init_handler,
	.exit = group_balancer_exit_handler
};


static bool isspaceline(char *buf, unsigned int length)
{
	loff_t pos;

	for (pos = 0; pos < length; pos++) {
		if (!isspace(buf[pos]))
			return false;
	}
	return true;
}

static int check_nr_parts_title(char *buf)
{
	const char *string_title = "[nr_parts]";

	return strncmp(buf, string_title, 10);
}

static int check_nr_levels_title(char *buf)
{
	const char *string_title = "[nr_levels]";

	return strncmp(buf, string_title, 11);
}

static int parse_level_line(char *buf, unsigned int line_length, unsigned int *level)
{
	int ret = 0;
	loff_t start_pos = 6, end_pos;
	char *string_level_title = "[level";
	char string_level[5];
	char *start_buf;
	unsigned int length;

	ret = strncmp(buf, string_level_title, 6);
	if (ret) {
		pr_err("level title wrong\n");
		goto out;
	}
	while (start_pos < line_length && isspace(buf[start_pos]))
		start_pos++;
	end_pos = start_pos;
	start_buf = buf + start_pos;
	while (end_pos < line_length && buf[end_pos] != ']')
		end_pos++;
	length = end_pos - start_pos;
	strncpy(string_level, start_buf, length);
	string_level[length] = '\0';
	ret = kstrtouint(string_level, 0, level);
	if (ret) {
		pr_err("level wrong\n");
		goto out;
	}
	if (*level >= gb->nr_levels) {
		pr_err("level shouldn't be bigger than nr_levels\n");
		ret = -EINVAL;
		goto out;
	}
out:
	return ret;
}


static int parse_part_line(char *buf, unsigned int line_length, unsigned int level)
{
	int ret = 0;
	char *start_buf;
	loff_t start_pos, end_pos;
	char string_current_id[5];
	char string_parent_id[5];
	char string_cpus_allowed[256];
	unsigned int current_id;
	int parent_id;
	struct cpumask cpus_allowed;
	unsigned int stage = CURRENT_ID;
	unsigned int length;
	struct partition *current_part, *parent_part;

	start_pos = end_pos = 0;
	current_part = parent_part = NULL;
	for (end_pos = 0; end_pos < line_length; end_pos++) {
		if (isspace(buf[end_pos]) || buf[end_pos] == '\0') {
			start_buf = buf + start_pos;
			length = end_pos - start_pos;
			if (length < 0) {
				ret = -EINVAL;
				pr_err("Internal error: line length < 0\n");
				goto out;
			}
			switch (stage) {
			case CURRENT_ID:
				strncpy(string_current_id, start_buf, length);
				string_current_id[length] = '\0';
				ret = kstrtouint(string_current_id, 0, &current_id);
				if (ret) {
					pr_err("current id wrong\n");
					goto out;
				}
				if (current_id < 0 || current_id >= gb->nr_parts) {
					pr_err("current id out of range\n");
					ret = -EINVAL;
					goto out;
				}
				current_part = gb->parts + current_id;
				current_part->pi = kzalloc(sizeof(struct partition_info),
							   GFP_KERNEL);
				INIT_LIST_HEAD(&current_part->children);
				INIT_LIST_HEAD(&current_part->cgroup_list);
				current_part->level = level;
				current_part->id = current_id;
				list_add(&current_part->level_siblings, gb->levels + level);
				memset(string_current_id, 0, length);
				stage = PARENT_ID;
				break;
			case PARENT_ID:
				strncpy(string_parent_id, start_buf, length);
				string_parent_id[length] = '\0';
				ret = kstrtoint(string_parent_id, 0, &parent_id);
				if (ret) {
					pr_err("parent id wrong\n");
					goto out;
				}
				if (parent_id == -1)
					goto break_parent_id;
				if (parent_id < -1 || parent_id >= gb->nr_parts) {
					pr_err("parent id out of range\n");
					ret = -EINVAL;
					goto out;
				}
				parent_part = gb->parts + parent_id;
				if (!current_part) {
					pr_err("internal error: current_part not found\n");
					ret = -EINVAL;
					goto out;
				}
				current_part->parent = parent_part;
				list_add(&current_part->siblings, &parent_part->children);
break_parent_id:
				memset(string_parent_id, 0, length);
				stage = CPUMASK;
				break;
			case CPUMASK:
				strncpy(string_cpus_allowed, start_buf, length);
				string_cpus_allowed[length] = '\0';
				ret = cpulist_parse(string_cpus_allowed, &cpus_allowed);
				if (ret) {
					pr_err("cpumaks wrong\n");
					goto out;
				}
				if (cpumask_empty(&cpus_allowed)) {
					pr_err("the cpu range of part %d shouldn't be empty",
					       current_id);
					ret = -EINVAL;
					goto out;
				}
				if (!cpumask_subset(&cpus_allowed, cpu_online_mask)) {
					pr_err("the cpu range of part %d out of range",
					       current_id);
					ret = -EINVAL;
					goto out;
				}
				cpumask_copy(&current_part->part_cpus, &cpus_allowed);
				current_part->nr_part_cpus =
					cpumask_weight(&current_part->part_cpus);
				if (parent_part &&
				    !cpumask_subset(&cpus_allowed, &parent_part->part_cpus)) {
					pr_err("The cpumask of part %d isn't subset of that of part %d\n",
						current_id, parent_id);
					ret = -EINVAL;
					goto out;
				}
				memset(string_cpus_allowed, 0, length);
				stage = PART_COMPLETE;
				break;
			default:
				break;
			}
			start_pos = end_pos + 1;
			while (start_pos < line_length && isspace(buf[start_pos]))
				start_pos++;
			end_pos = start_pos - 1;
		}
	}

out:
	return ret;
}


static int part_from_file(void)
{
	int ret = 0, i;
	struct file *fp;
	int file_size;
	char *buf, *start_buf;
	unsigned int length;
	char line[512];
	unsigned int level;
	unsigned int stage = NR_TOTAL_PARTS_TITLE;
	loff_t start_pos, end_pos;
	struct partition *part;

	fp = filp_open("/etc/group_balancer.cfg", O_RDONLY, 0644);
	if (IS_ERR(fp)) {
		ret = PTR_ERR(fp);
		pr_err("/etc/group_balancer.cfg open failed,err = %d\n", ret);
		return ret;
	}

	file_size = fp->f_inode->i_size;
	buf = kzalloc(file_size * sizeof(char) + 1, GFP_KERNEL);
	if (!buf) {
		pr_err("Not enough memory for file buf\n");
		ret = -ENOMEM;
		goto out;
	}
	start_pos = 0;
	end_pos = 0;
	kernel_read(fp, buf, file_size + 1, &start_pos);

	start_pos = 0;
	while (isspace(buf[start_pos]))
		start_pos++;
	for (end_pos = start_pos; end_pos < file_size; end_pos++) {
		if (buf[end_pos] == '\n') {
			start_buf = buf + start_pos;
			length = end_pos - start_pos;
			if (length == 0)
				goto next_line;
			else if (length < 0) {
				pr_err("Internal error: line length < 0\n");
				ret = -EINVAL;
				goto out;
			}
			strncpy(line, start_buf, length);
			if (isspaceline(line, length))
				goto next_line;
			line[length] = '\0';
			length++;

			switch (stage) {
			case NR_TOTAL_PARTS_TITLE:
				ret = check_nr_parts_title(line);
				if (ret) {
					pr_err("nr_parts title wrong\n");
					goto out;
				}
				stage = NR_TOTAL_PARTS;
				break;
			case NR_TOTAL_PARTS:
				ret = kstrtouint(line, 0, &gb->nr_parts);
				if (ret) {
					pr_err("nr_parts wrong\n");
					goto out;
				}
				gb->parts = kzalloc(sizeof(struct partition) * gb->nr_parts,
						    GFP_KERNEL);
				stage = NR_TOTAL_LEVELS_TITLE;
				break;
			case NR_TOTAL_LEVELS_TITLE:
				ret = check_nr_levels_title(line);
				if (ret) {
					pr_err("nr_levels title wrong\n");
					goto free_parts;
				}
				stage = NR_TOTAL_LEVELS;
				break;
			case NR_TOTAL_LEVELS:
				ret = kstrtouint(line, 0, &gb->nr_levels);
				if (ret) {
					pr_err("nr_levels wrong\n");
					goto free_parts;
				}
				if (gb->nr_levels == 0) {
					pr_err("nr_levels should be bigger than zero\n");
					ret = -EINVAL;
					goto free_parts;
				}
				gb->levels = kzalloc(sizeof(struct list_head) * gb->nr_levels,
						     GFP_KERNEL);
				for (i = 0; i < gb->nr_levels; i++)
					INIT_LIST_HEAD(gb->levels + i);
				stage = CONFIG;
				break;
			case CONFIG:
				if (line[0] == '[')
					ret = parse_level_line(line, length, &level);
				else
					ret = parse_part_line(line, length, level);
				if (ret)
					goto free_levels;
				break;
			default:
				ret = -EINVAL;
				goto free_levels;
			}
			memset(line, 0, length);
next_line:
			start_pos = end_pos + 1;
			while (start_pos < file_size && isspace(buf[start_pos]))
				start_pos++;
			end_pos = start_pos - 1;
		}
	}

	goto out;

free_levels:
	kfree(gb->levels);
	gb->levels = NULL;
free_parts:
	for (i = 0; i < gb->nr_parts; i++) {
		part = gb->parts + i;
		kfree(part->pi);
		part->pi = NULL;
	}
	kfree(gb->parts);
	gb->parts = NULL;
out:
	kfree(buf);
	filp_close(fp, NULL);

	return ret;
}

static int part_by_numa(void)
{
	int nr_online_nodes;
	int nid, i;
	int ret = 0;
	struct partition *part;

	nr_online_nodes = num_online_nodes();
	gb->parts = kzalloc(sizeof(struct partition) * (1 + nr_online_nodes), GFP_KERNEL);
	if (!gb->parts) {
		pr_err("Alloc struct partition failed\n");
		ret = -ENOMEM;
		goto out;
	}
	gb->nr_parts = nr_online_nodes + 1;

	gb->levels = kzalloc(sizeof(struct list_head) * 2, GFP_KERNEL);
	INIT_LIST_HEAD(gb->levels);
	INIT_LIST_HEAD(gb->levels + 1);
	gb->nr_levels = 2;
// level 0: all online cpus
	part = gb->parts;
	cpumask_copy(&part->part_cpus, cpu_online_mask);
	part->nr_part_cpus = num_online_cpus();
	INIT_LIST_HEAD(&part->children);
	INIT_LIST_HEAD(&part->cgroup_list);
	part->pi = kzalloc(sizeof(struct partition_info), GFP_KERNEL);
	if (!part->pi) {
		pr_err("Alloc struct partition_info failed\n");
		ret = -ENOMEM;
		goto free_levels;
	}
	part->id = 0;
	part->parent = NULL;
	part->level = 0;
	list_add_tail(&part->level_siblings, gb->levels);
// level 1: part by online numa node
	for_each_online_node(nid) {
		part++;
		cpumask_copy(&part->part_cpus, cpumask_of_node(nid));
		part->nr_part_cpus = cpumask_weight(&part->part_cpus);
		INIT_LIST_HEAD(&part->children);
		INIT_LIST_HEAD(&part->cgroup_list);
		part->pi = kzalloc(sizeof(struct partition_info), GFP_KERNEL);
		if (!part->pi) {
			pr_err("Alloc partition info failed\n");
			ret = -ENOMEM;
			goto free_levels;
		}
		part->id = nid + 1;
		part->parent = gb->parts;
		part->level = 1;
		list_add_tail(&part->siblings, &gb->parts->children);
		list_add_tail(&part->level_siblings, gb->levels + 1);
	}

	return 0;
free_levels:
	kfree(gb->levels);
	gb->levels = NULL;
	for (i = 0; i < gb->nr_parts; i++) {
		part = gb->parts + i;
		kfree(part->pi);
		part->pi = NULL;
	}
	kfree(gb->parts);
	gb->parts = NULL;
out:
	return ret;
}

static int __init group_balancer_init(void)
{
	int ret;

	gb = kzalloc(sizeof(struct group_balancer), GFP_KERNEL);

	group_balancer_wq = create_workqueue("group_balancer");
	if (!group_balancer_wq) {
		pr_err("Create workqueue failed\n");
		ret = -ENOMEM;
		goto free_gb;
	}

	ret = part_from_file();
	if (ret) {
		pr_err("part from /etc/group_balancer.cfg failed, now part by numa node\n");
		ret = part_by_numa();
		if (ret)
			goto destroy_wq;
	}
	gb->current_level = gb->nr_levels - 1;

	if (register_group_balancer_callback(&cb, &gb->gb_list) != 0) {
		pr_err("Register group balancer callback failed\n");
		ret = -EINVAL;
		goto destroy_wq;
	}

	INIT_DELAYED_WORK(&gb->scan_dwork, group_balancer_workfn);
	queue_delayed_work(group_balancer_wq, &gb->scan_dwork, msecs_to_jiffies(scan_period));
	pr_notice(KBUILD_MODNAME ": Initialization Done\n");
	return 0;

destroy_wq:
	destroy_workqueue(group_balancer_wq);
free_gb:
	kfree(gb);
	gb = NULL;
	return ret;
}

static void __exit group_balancer_exit(void)
{
	int i;
	struct partition *part;

	for (i = 0; i < gb->nr_parts; i++) {
		part = gb->parts + i;
		kfree(part->pi);
		part->pi = NULL;
	}
	kfree(gb->parts);
	gb->parts = NULL;
	kfree(gb->levels);
	gb->levels = NULL;
	cancel_delayed_work_sync(&gb->scan_dwork);
	unregister_group_balancer_callback(&cb);
	destroy_workqueue(group_balancer_wq);
	kfree(gb);
	gb = NULL;
	pr_notice(KBUILD_MODNAME ": Exit\n");
}

module_init(group_balancer_init);
module_exit(group_balancer_exit);

MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Cruz Zhao <cruzzhao@linux.alibaba.com>");
