// SPDX-License-Identifier: GPL-2.0
/*
 * vsock service for device management.
 *
 * Copyright (C) 2019 Alibaba, Inc
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version
 * 2 of the License, or (at your option) any later version.
 *
 */

#define pr_fmt(fmt) "db-dev-mgr: " fmt

#include <linux/kthread.h>
#include <linux/module.h>
#include <linux/platform_device.h>
#include <linux/slab.h>
#include <linux/virtio_mmio.h>
#include <linux/cpu.h>
#include <linux/cpumask.h>
#include <linux/cpuhotplug.h>
#include <asm/cpu.h>
#include <dragonball/vsock_srv.h>
#include <dragonball/device_manager.h>
#include <dragonball/dragonball.h>
#include <linux/percpu.h>
#include <linux/device.h>
#include <asm/numa.h>

/*
 * Following designs are adopted to simplify implementation:
 * 1) fix size messages with padding to ease receiving logic.
 * 2) binary encoding instead of string encoding because it's on the same host.
 * 3) synchronous communication in ping-pong mode, one in-fly request at most.
 * 4) do not support module unloading
 */

/* These definitions are synchronized with dragonball */
#define DEVMGR_MSG_SIZE			0x400
#define DEVMGR_CMD_BYTE			'd'
#define DEVMGR_MAGIC_VERSION		0x444D0100 /* 'DM' + Version 1.0 */
#define MAGIC_IOPORT_BASE		0xdbdb
#define MAGIC_IOPORT_CPU_HOTPLUG	(MAGIC_IOPORT_BASE + 1)

/* Type of request and reply messages. */
enum devmgr_msg_type {
	DEVMGR_ADD_CPU			= 0xdb000000,
	DEVMGR_DEL_CPU			= 0xdb100000,
	DEVMGR_ADD_MEM			= 0xdb200000,
	DEVMGR_DEL_MEM			= 0xdb300000,
	DEVMGR_ADD_MMIO			= 0xdb400000,
	DEVMGR_DEL_MMIO			= 0xdb500000,
	DEVMGR_ADD_PCI			= 0xdb600000,
	DEVMGR_DEL_PCI			= 0xdb700000,

	DEVMGR_CMD_OK			= 0xdbe00000,
	DEVMGR_CMD_ERR			= 0xdbf00000,
};

struct devmgr_msg_header {
	uint32_t	magic_version;
	uint32_t	msg_size;
	uint32_t	msg_type;
	uint32_t	msg_flags;
};

struct devmgr_req {
	struct devmgr_msg_header msg_header;
	union {
		char	pad[DEVMGR_MSG_SIZE - sizeof(struct devmgr_msg_header)];
#if defined(CONFIG_DRAGONBALL_HOTPLUG_VIRTIO_MMIO)
		struct {
			uint64_t mmio_base;
			uint64_t mmio_size;
			uint32_t mmio_irq;
		} add_mmio_dev;
#endif
#if defined(CONFIG_DRAGONBALL_HOTPLUG_CPU)
		struct {
			uint8_t count;
#ifdef CONFIG_X86_64
			uint8_t apic_ver;
			uint8_t apic_ids[256];
#endif
		} cpu_dev_info;
#endif
	} msg_load;
};

struct cpu_dev_reply_info {
#if defined(CONFIG_X86_64)
	uint32_t apic_index;
#elif defined(CONFIG_ARM64)
	uint32_t cpu_id;
#endif
};

struct devmgr_reply {
	struct devmgr_msg_header msg_header;
	union {
		char	pad[DEVMGR_MSG_SIZE - sizeof(struct devmgr_msg_header)];
#if defined(CONFIG_DRAGONBALL_HOTPLUG_VIRTIO_MMIO)
		struct {
		} add_mmio_dev;
#endif
#if defined(CONFIG_DRAGONBALL_HOTPLUG_CPU)
	struct cpu_dev_reply_info cpu_dev_info;
#endif
	} msg_load;
};

struct task_res {
	struct task_struct	*task;
	struct socket		*sock;
	struct devmgr_req	req;
	struct devmgr_reply	reply;
};

typedef int (*action_route_t) (struct devmgr_req *req,
			       struct devmgr_reply *rep);

#if defined(CONFIG_DRAGONBALL_HOTPLUG_VIRTIO_MMIO)
static int get_dev_resource(struct devmgr_req *req, struct resource *res)
{
	uint64_t base = req->msg_load.add_mmio_dev.mmio_base;
	uint64_t size = req->msg_load.add_mmio_dev.mmio_size;
	uint32_t irq  = req->msg_load.add_mmio_dev.mmio_irq;
	uint32_t virq;

	if (req->msg_header.msg_size != sizeof(req->msg_load.add_mmio_dev))
		return -EINVAL;

	res[0].flags = IORESOURCE_MEM;
	res[0].start = base;
	res[0].end   = base + size - 1;
	res[1].flags = IORESOURCE_IRQ;
	virq = get_device_virq(irq);
	if (!virq)
		return -EINVAL;
	res[1].start = res[1].end = virq;

	/*
	 * detect the irq sharing mode:
	 *
	 * If more devices are inserted to dragonball, the number of legacy irqs
	 * requested by these devices may exceed the number supported by the platform,
	 * so it is important to let some devices sharing the same irq.
	 *
	 * Therefore, when determining whether the parameters of a hot-plugged virtio
	 * device are legal or not, it is necessary to determine whether irq sharing
	 * is enabled. If irq sharing is enabled, multi devices with the same irq is
	 * legal.
	 *
	 */
	if (irq == SHARED_IRQ_NO)
		res[1].flags |= IORESOURCE_IRQ_SHAREABLE;

	return 0;
}
#endif

static void _fill_msg_header(struct devmgr_msg_header *msg, uint32_t msg_size,
			     uint32_t msg_type, uint32_t msg_flags)
{
	msg->magic_version = DEVMGR_MAGIC_VERSION;
	msg->msg_size      = msg_size;
	msg->msg_type      = msg_type;
	msg->msg_flags     = msg_flags;
}

/*
 * TODO: At present, cpu hotplug has not been fully ported, so there is no
 * documentation for this feature. After the complete porting, add
 * documentation and delete this comment.
 */
#if defined(CONFIG_DRAGONBALL_HOTPLUG_CPU)
#if defined(CONFIG_X86_64)
static int get_cpu_id(int apic_id)
{
	int i;

	for (i = 0; i < num_processors; i++) {
		if (cpu_physical_id(i) == apic_id)
			return i;
	}
	return -1;
}

/**
 * Return the first failed hotplug index of the apic_ids to dragonball.
 * If it is not equal to the count of all hotplug needed vcpus,
 * we will rollback the vcpus from apics_ids[0] to apic_ids[i-1] in dragonball.
 */
static void cpu_event_notification(
	uint8_t apic_ids_index,
	int ret,
	uint32_t action_type,
	struct devmgr_reply *rep)
{
	pr_info("cpu event notification: apic ids index %d\n", apic_ids_index);
	rep->msg_load.cpu_dev_info.apic_index = apic_ids_index;
	if (ret == 0) {
		_fill_msg_header(&rep->msg_header,
		sizeof(struct cpu_dev_reply_info), DEVMGR_CMD_OK, action_type);
		outw(1 << 8 | apic_ids_index, MAGIC_IOPORT_CPU_HOTPLUG);
	} else {
		_fill_msg_header(&rep->msg_header,
		sizeof(struct cpu_dev_reply_info), DEVMGR_CMD_ERR, action_type);
		outw(0 << 8 | apic_ids_index, MAGIC_IOPORT_CPU_HOTPLUG);
	}
}
#elif defined(CONFIG_ARM64)
/**
 * Return the first failed hotplug index of the cpu_id to dragonball.
 * If hotplugging/hotunplugging succeeds, it will equals to the
 * expected cpu count.
 */
static void cpu_event_notification(
	uint8_t cpu_id,
	int ret,
	uint32_t action_type,
	struct devmgr_reply *rep)
{
	pr_info("cpu event notification: cpu_id %d\n", cpu_id);
	rep->msg_load.cpu_dev_info.cpu_id = cpu_id;
	if (ret == 0) {
		_fill_msg_header(&rep->msg_header,
		sizeof(struct cpu_dev_reply_info), DEVMGR_CMD_OK, action_type);
	} else {
		_fill_msg_header(&rep->msg_header,
		sizeof(struct cpu_dev_reply_info), DEVMGR_CMD_ERR, action_type);
	}
}
#endif
#endif

/*
 * Added virtio-mmio hotplug support.
 *
 * Vmm add device by:
 *  1. Connect vsock server port(default 0xDB)
 *  2. Send character 'd' to switch to system command mode
 *  3. Send add request with msg_header and virtio-mmio device info
 *
 * Vmm remove device by:
 *  1. Connect vsock server port(default 0xDB) if needed
 *  2. Send character 'd' to switch to system command mode if needed
 *  3. Send del request with msg_header and virtio-mmio device info
 */
#if defined(CONFIG_DRAGONBALL_HOTPLUG_VIRTIO_MMIO)
static int add_mmio_dev(struct devmgr_req *req,
			struct devmgr_reply *rep)
{
	int ret;
	struct resource res[2] = {};
	struct devmgr_msg_header *rep_mh = &rep->msg_header;

	ret = get_dev_resource(req, res);
	if (ret)
		return ret;

	ret = virtio_mmio_add_device(res, ARRAY_SIZE(res));
	if (!ret)
		_fill_msg_header(rep_mh, 0, DEVMGR_CMD_OK, 0);
	return ret;
}

static int del_mmio_dev(struct devmgr_req *req,
			struct devmgr_reply *rep)
{
	int ret;
	struct resource res[2] = {};
	struct devmgr_msg_header *rep_mh = &rep->msg_header;

	ret = get_dev_resource(req, res);
	if (ret)
		return ret;

	ret = virtio_mmio_del_device(res, ARRAY_SIZE(res));
	if (!ret)
		_fill_msg_header(rep_mh, 0, DEVMGR_CMD_OK, 0);
	return ret;
}
#endif


#if defined(CONFIG_DRAGONBALL_HOTPLUG_CPU)
#if defined(CONFIG_X86_64)
static int db_add_cpu(int apic_id, uint8_t apic_ver)
{
	int cpu_id, node_id;
	int ret;
	struct device *cpu_dev;

	pr_info("adding vcpu apic_id %d\n", apic_id);

	/**
	 * Get the mutex lock for hotplug and cpu update and cpu write lock.
	 * So that other threads won't influence the hotplug process.
	 */
	lock_device_hotplug();
	cpu_maps_update_begin();
	cpu_hotplug_begin();

	cpu_id = generic_processor_info(apic_id, apic_ver);
	if (cpu_id < 0) {
		pr_err("cpu (apic id %d) can't be added, generic processor info failed\n", apic_id);
		ret = -EINVAL;
		goto rollback_generic_cpu;
	}

	/* update numa mapping for hot-plugged cpus. */
	node_id = numa_cpu_node(cpu_id);
	if (node_id != NUMA_NO_NODE)
		numa_set_node(cpu_id, node_id);

	ret = arch_register_cpu(cpu_id);
	if (ret) {
		pr_err("cpu %d cannot be added, register cpu failed %d\n", cpu_id, ret);
		goto rollback_register_cpu;
	}

	cpu_hotplug_done();
	cpu_maps_update_done();
	unlock_device_hotplug();

	cpu_dev = get_cpu_device(cpu_id);
	if (!cpu_dev) {
		pr_err("Cannot get cpu device with cpu_id: %d\n", cpu_id);
		goto rollback_cpu_up;
	}
	ret = cpu_device_up(cpu_dev);
	if (ret) {
		pr_err("cpu %d cannot be added, cpu up failed: %d\n", cpu_id, ret);
		goto rollback_cpu_up;
	}
	cpu_dev->offline = false;
	return ret;

rollback_cpu_up:
	arch_unregister_cpu(cpu_id);
	set_cpu_present(cpu_id, false);
	per_cpu(x86_cpu_to_apicid, cpu_id) = -1;
	num_processors--;
	return ret;

rollback_register_cpu:
	set_cpu_present(cpu_id, false);
	per_cpu(x86_cpu_to_apicid, cpu_id) = -1;
	num_processors--;
rollback_generic_cpu:
	cpu_hotplug_done();
	cpu_maps_update_done();
	unlock_device_hotplug();
	return ret;
}

static int db_del_cpu(int apic_id)
{
	int cpu_id = get_cpu_id(apic_id);
	int ret;
	struct device *cpu_dev;

	if (cpu_id == 0) {
		pr_err("cannot del bootstrap processor.\n");
		return -EINVAL;
	}

	pr_info("deleting vcpu %d\n", cpu_id);

	lock_device_hotplug();

	cpu_dev = get_cpu_device(cpu_id);
	if (!cpu_dev) {
		pr_err("Cannot get cpu device with cpu_id: %d\n", cpu_id);
		goto error;
	}
	ret = cpu_device_down(cpu_dev);
	if (ret) {
		pr_err("del vcpu failed, err: %d\n", ret);
		goto error;
	}

	cpu_maps_update_begin();
	cpu_hotplug_begin();

	arch_unregister_cpu(cpu_id);
	set_cpu_present(cpu_id, false);
	per_cpu(x86_cpu_to_apicid, cpu_id) = -1;
	num_processors--;

	cpu_hotplug_done();
	cpu_maps_update_done();

error:
	unlock_device_hotplug();
	return ret;
}

static int add_cpu_dev(struct devmgr_req *req,
			struct devmgr_reply *rep)
{
	int ret;
	uint8_t i;
	int apic_id;

	uint8_t count = req->msg_load.cpu_dev_info.count;
	uint8_t apic_ver = req->msg_load.cpu_dev_info.apic_ver;
	uint8_t *apic_ids = req->msg_load.cpu_dev_info.apic_ids;

	pr_info("add vcpu number: %d\n", count);

	for (i = 0; i < count; ++i) {
		apic_id = apic_ids[i];
		if (get_cpu_id(apic_id) != -1) {
			pr_err("cpu cannot be added: apci_id %d is already been used.\n", apic_id);
			ret = -EINVAL;
			cpu_event_notification(0, ret, DEVMGR_ADD_CPU, rep);
			return ret;
		}
	}

	for (i = 0; i < count; ++i) {
		apic_id = apic_ids[i];
		ret = db_add_cpu(apic_id, apic_ver);
		if (ret != 0)
			break;
	}

	cpu_event_notification(i, ret, DEVMGR_ADD_CPU, rep);
	return ret;
}

static int del_cpu_dev(struct devmgr_req *req,
			struct devmgr_reply *rep)
{
	int ret;
	uint8_t i;
	int cpu_id;

	uint8_t count = req->msg_load.cpu_dev_info.count;
	uint8_t *apic_ids = req->msg_load.cpu_dev_info.apic_ids;

	pr_info("del vcpu number : %d\n", count);

	if (count >= num_processors) {
		pr_err("cpu del parameter check error: cannot remove all vcpus\n");
		ret = -EINVAL;
		cpu_event_notification(0, ret, DEVMGR_DEL_CPU, rep);
		return ret;
	}

	for (i = 0; i < count; ++i) {
		cpu_id = get_cpu_id(apic_ids[i]);
		if (!cpu_possible(cpu_id)) {
			pr_err("cpu %d cannot be deleted: cpu not possible\n", cpu_id);
			ret = -EINVAL;
			cpu_event_notification(0, ret, DEVMGR_DEL_CPU, rep);
			return ret;
		}
	}

	for (i = 0; i < count; ++i) {
		ret = db_del_cpu(apic_ids[i]);
		if (ret != 0)
			break;
	}

	cpu_event_notification(i, ret, DEVMGR_DEL_CPU, rep);
	return ret;
}
#elif defined(CONFIG_ARM64)
static int add_cpu_dev(struct devmgr_req *req, struct devmgr_reply *rep)
{
	int i, ret = 0;
	unsigned int cpu_id, nr_online_cpus;
	uint8_t count = req->msg_load.cpu_dev_info.count;

	nr_online_cpus = num_online_cpus();

	pr_info("Dragonball device manager add vcpus!");
	pr_info("Current vcpu number: %d, Add vcpu number: %d\n",
		nr_online_cpus, count);

	for (i = 0; i < count; ++i) {
		cpu_id = nr_online_cpus + i;
		ret = add_cpu(cpu_id);
		if (ret != 0)
			break;
	}

	cpu_event_notification(i, ret, DEVMGR_ADD_CPU, rep);
	return ret;
}

static int del_cpu_dev(struct devmgr_req *req, struct devmgr_reply *rep)
{
	int i, ret = 0;
	unsigned int cpu_id, nr_online_cpus;
	uint8_t count = req->msg_load.cpu_dev_info.count;

	nr_online_cpus = num_online_cpus();

	pr_info("Dragonball device manager remove vcpus!");
	pr_info("Current vcpu number: %d, Delete vcpu number: %d\n",
		nr_online_cpus, count);

	if (count >= nr_online_cpus) {
		pr_err("cpu del parameter check error: cannot remove all vcpus\n");
		ret = -EINVAL;
		cpu_event_notification(0, ret, DEVMGR_DEL_CPU, rep);
		return ret;
	}

	for (i = 0; i < count; ++i) {
		cpu_id = nr_online_cpus - i - 1;
		ret = remove_cpu(cpu_id);
		if (ret != 0)
			break;
	}

	cpu_event_notification(i, ret, DEVMGR_DEL_CPU, rep);
	return ret;
}
#endif
#endif

static struct {
	enum devmgr_msg_type cmd;
	action_route_t fn;
} opt_map[] = {
#if defined(CONFIG_DRAGONBALL_HOTPLUG_VIRTIO_MMIO)
	{DEVMGR_ADD_MMIO, add_mmio_dev},
	{DEVMGR_DEL_MMIO, del_mmio_dev},
#endif
#ifdef CONFIG_DRAGONBALL_HOTPLUG_CPU
	{DEVMGR_ADD_CPU, add_cpu_dev},
	{DEVMGR_DEL_CPU, del_cpu_dev},
#endif
};

static action_route_t get_action(struct devmgr_req *req)
{
	int i;
	action_route_t action = NULL;
	int size_opt = ARRAY_SIZE(opt_map);

	for (i = 0; i < size_opt; i++) {
		if (opt_map[i].cmd == req->msg_header.msg_type) {
			action = opt_map[i].fn;
			break;
		}
	}
	return action;
}

static void db_devmgr_process(struct devmgr_req *req,
			      struct devmgr_reply *rep)
{
	int err;
	action_route_t action;
	struct devmgr_msg_header *req_mh = &req->msg_header;
	struct devmgr_msg_header *rep_mh = &rep->msg_header;

	if (req_mh->magic_version != DEVMGR_MAGIC_VERSION) {
		_fill_msg_header(rep_mh, 0, DEVMGR_CMD_ERR, 0);
		return;
	}

	action = get_action(req);
	if (action == NULL) {
		pr_err("%s: Not found valid command\n", __func__);
		_fill_msg_header(rep_mh, 0, DEVMGR_CMD_ERR, 0);
		return;
	}

	err = action(req, rep);
	if (err) {
		pr_err("%s: Command run failed, err: %d\n", __func__, err);
		_fill_msg_header(rep_mh, 0, DEVMGR_CMD_ERR, 0);
		return;
	}
}

static int db_devmgr_server(void *data)
{
	struct task_res *res = (struct task_res *)data;
	struct devmgr_msg_header *rep_mh = &res->reply.msg_header;
	int len;

	_fill_msg_header(rep_mh, 0, DEVMGR_CMD_OK, 0);
	len = db_vsock_sendmsg(res->sock, (char *)&res->reply, DEVMGR_MSG_SIZE);
	if (len <= 0) {
		pr_err("%s: Server send message failed, err: %d\n", __func__, len);
		sock_release(res->sock);
		kfree(res);
		return len;
	}

	while (!kthread_should_stop()) {
		len = db_vsock_recvmsg(res->sock, (char *)&res->req,
				       DEVMGR_MSG_SIZE, 0);
		if (len <= 0)
			break;

		/* The result(OK or Error) will fill into res->reply field */
		db_devmgr_process(&res->req, &res->reply);

		len = db_vsock_sendmsg(res->sock, (char *)&res->reply,
				       DEVMGR_MSG_SIZE);
		if (len <= 0)
			break;
	}

	/* TODO: check who shutdown the socket, receiving or sending. */
	sock_release(res->sock);
	kfree(res);
	return 0;
}

static int db_devmgr_handler(struct socket *sock)
{
	struct task_res *res;
	struct task_struct *conn_task;

	/* TODO: ensure singleton, only one server exists */
	res = kzalloc(sizeof(*res), GFP_KERNEL);
	if (!res)
		return -ENOMEM;

	res->sock = sock;
	conn_task = kthread_create(db_devmgr_server, res, "db_dev_mgr");
	if (IS_ERR(conn_task)) {
		pr_err("%s: Client process thread create failed, err: %d\n",
		       __func__, (int)PTR_ERR(conn_task));
		goto failed;
	} else {
		res->task = conn_task;
		wake_up_process(conn_task);
	}

	return 0;
failed:
	kfree(res);
	return PTR_ERR(conn_task);
}

static int __init db_device_manager_init(void)
{
	if (has_dragonball_feature(DB_FEAT_DEVMGR))
		return register_db_vsock_service(DEVMGR_CMD_BYTE, db_devmgr_handler);

	return 0;
}

late_initcall(db_device_manager_init);

MODULE_AUTHOR("Alibaba, Inc.");
MODULE_DESCRIPTION("Dragonball Device Manager");
MODULE_LICENSE("GPL v2");
