// SPDX-License-Identifier: GPL-2.0
/*
 * Dragonball vsock server
 *
 * Copyright (C) 2019 Alibaba, Inc
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version
 * 2 of the License, or (at your option) any later version.
 *
 * Virtual Sockets is a reliable and convenient communication channel
 * between guest and host/vmm. So add a misc driver to implement an
 * in-kernel vsock server to dispatch Dragonball requests to registered
 * service handlers.
 *
 * Module initialization create a kernel thread, the new thread initialize
 * server socket with AF_VSOCK and default port 0xDB.
 */

#define pr_fmt(fmt) "db-vsock-srv: " fmt

#include <linux/kthread.h>
#include <linux/list.h>
#include <linux/module.h>
#include <linux/net.h>
#include <linux/vm_sockets.h>
#include <net/net_namespace.h>
#include <net/sock.h>
#include <dragonball/vsock_srv.h>
#include <dragonball/dragonball.h>

#define VSOCK_RETRY_TIMES	20
#define VSOCK_RETRY_INTER	500

struct db_conn_info {
	struct work_struct work;
	struct socket *sock;
};

struct db_service_entry {
	char			cmd;
	db_vsock_svc_handler_t	handler;
	struct list_head	list;
};

/* Protects registered command. */
static DEFINE_RWLOCK(db_service_lock);
static LIST_HEAD(db_service_list);

static struct task_struct *db_service_task;
static unsigned int db_server_port = DB_SERVER_PORT;

struct socket *db_create_vsock_listener(unsigned int port)
{
	struct socket *sock;
	int ret = 0;

	union {
		struct sockaddr sa;
		struct sockaddr_vm svm;
	} addr = {
		.svm = {
			.svm_family = AF_VSOCK,
			.svm_port = port,
			.svm_cid = VMADDR_CID_ANY,
		}
	};

	ret = sock_create_kern(&init_net, AF_VSOCK, SOCK_STREAM, 0, &sock);
	if (ret) {
		pr_err("Server vsock create failed, err: %d\n", ret);
		return ERR_PTR(ret);
	}

	ret = sock->ops->bind(sock, &addr.sa, sizeof(addr.svm));
	if (ret) {
		pr_err("Server vsock bind failed, err: %d\n", ret);
		goto err;
	}
	ret = sock->ops->listen(sock, 10);
	if (ret < 0) {
		pr_err("Server vsock listen error: %d\n", ret);
		goto err;
	}

	return sock;
err:
	sock_release(sock);
	return ERR_PTR(ret);
}
EXPORT_SYMBOL_GPL(db_create_vsock_listener);

int db_vsock_sendmsg(struct socket *sock, char *buf, size_t len)
{
	struct kvec vec;
	struct msghdr msgh;

	vec.iov_base = buf;
	vec.iov_len  = len;
	memset(&msgh, 0, sizeof(msgh));

	return kernel_sendmsg(sock, &msgh, &vec, 1, len);
}
EXPORT_SYMBOL_GPL(db_vsock_sendmsg);

/*
 * db_vsock_recvmsg() uses default flag 0 when receiving message, but vebpf
 * need to pass MSG_WAITALL to get enough len data.
 * This patch export flags of recvmsg for db_vsock_recvmsg().
 */
int db_vsock_recvmsg(struct socket *sock, char *buf, size_t len, int flags)
{
	struct kvec vec;
	struct msghdr msgh;

	memset(&vec, 0, sizeof(vec));
	memset(&msgh, 0, sizeof(msgh));
	vec.iov_base = buf;
	vec.iov_len = len;

	return kernel_recvmsg(sock, &msgh, &vec, 1, len, flags);
}
EXPORT_SYMBOL_GPL(db_vsock_recvmsg);

static int db_vsock_recvcmd(struct socket *cli_socket, char *cmd)
{
	int ret;
	char rcv;
	long timeout;
	struct kvec vec;
	struct msghdr msg;

	memset(&vec, 0, sizeof(vec));
	memset(&msg, 0, sizeof(msg));
	vec.iov_base = &rcv;
	vec.iov_len = 1;

	timeout = cli_socket->sk->sk_rcvtimeo;
	cli_socket->sk->sk_rcvtimeo = DB_INIT_TIMEOUT * HZ;
	ret = kernel_recvmsg(cli_socket, &msg, &vec, 1, 1, 0);
	cli_socket->sk->sk_rcvtimeo = timeout;
	*cmd = rcv;

	return ret;
}

/*
 * The workqueue handler for vsock work_struct.
 *
 * Each worker-pool bound to an actual CPU implements concurrency management
 * by hooking into the scheduler. The worker-pool is notified whenever an
 * active worker wakes up or sleeps and keeps track of the number of the
 * currently runnable workers. Generally, work items are not expected to hog
 * a CPU and consume many cycles. That means maintaining just enough concurrency
 * to prevent work processing from stalling should be optimal.
 *
 * So it's OK to sleep in a workqueue handler, it won't cause too many worker
 * threads.
 */
static void db_conn_service(struct work_struct *work)
{
	struct db_conn_info *conn_info =
		container_of(work, struct db_conn_info, work);
	struct db_service_entry *service_entry;
	int len, ret = -1;
	char cmd;

	len = db_vsock_recvcmd(conn_info->sock, &cmd);
	if (len <= 0)
		goto recv_failed;

	read_lock(&db_service_lock);
	list_for_each_entry(service_entry, &db_service_list, list) {
		if (cmd == service_entry->cmd) {
			ret = service_entry->handler(conn_info->sock);
			break;
		}
	}
	read_unlock(&db_service_lock);

recv_failed:
	if (ret) {
		sock_release(conn_info->sock);
		pr_info("Client connection closed, error code: %d\n", ret);
	}
	kfree(conn_info);
}

static int db_create_cli_conn(struct socket *sock)
{
	struct db_conn_info *conn;

	conn = kmalloc(sizeof(*conn), GFP_KERNEL);
	if (!conn)
		return -ENOMEM;

	conn->sock = sock;
	INIT_WORK(&conn->work, db_conn_service);
	schedule_work(&conn->work);

	return 0;
}

static int db_vsock_server(void *data)
{
	struct socket *sock;
	int err, retry;

	/*
	 * In kangaroo scenario, both module and built-in vsock should be
	 * supported, but request_module will return error when vsock is built-in.
	 * So we use IS_BUILTIN macro to check this.
	 */
	if (!IS_BUILTIN(CONFIG_VIRTIO_VSOCKETS)) {
		/*
		 * Previously dragonball vsock server would retry creating vsock listener
		 * if vsock module has not been inserted. We change this logic by using
		 * request_module to waiting for vsock module. And we also have to change
		 * the retry logic to 20 times with 500 ms each interval in order to work
		 * on AMD server.
		 */
		for (retry = VSOCK_RETRY_TIMES; retry > 0; retry--) {
			err = request_module("vmw_vsock_virtio_transport");
			if (err == 0)
				break;
			pr_err("Not ready to load vsock module, due to error %d.\n", err);
			msleep(VSOCK_RETRY_INTER);
		}

		if (retry == 0) {
			pr_err("Vsock module not loaded, exceeds max retry times.\n");
			return err;
		}
	}
	sock = db_create_vsock_listener(db_server_port);
	if (IS_ERR(sock)) {
		err = PTR_ERR(sock);
		pr_err("Init server err: %d\n", err);
		return err;
	}

	while (!kthread_should_stop()) {
		struct socket *conn;

		conn = sock_alloc();
		if (!conn)
			return -ENOMEM;

		conn->type = sock->type;
		conn->ops  = sock->ops;

		/* 0:propotal 1:kernel */
		err = sock->ops->accept(sock, conn, 0, 1);
		if (err < 0) {
			pr_err("Server accept err: %d\n", err);
			sock_release(conn);
			continue;
		}

		err = db_create_cli_conn(conn);
		if (err)
			pr_err("Create client connetion err: %d\n", err);
	}

	return 0;
}

static int db_create_service(void)
{
	struct task_struct *service;
	int rc = 0;

	service = kthread_create(db_vsock_server, NULL, "db-vsock-srv");
	if (IS_ERR(service)) {
		rc = PTR_ERR(service);
		pr_err("Server task create failed, err: %d\n", rc);
	} else {
		db_service_task = service;
		wake_up_process(service);
	}
	return rc;
}

/*
 * Introduce a kernel parameter to configure vsock server listening
 * port.
 * Add following context to kernel boot param to configure port number:
 *    "dragonball_vsock_srv.port=@port_number"
 */
static int db_vsock_srv_cmdline_set(const char *device,
				    const struct kernel_param *kp)
{
	unsigned int port = 0;
	int processed, consumed = 0;

	/* Get "@<port>" */
	processed = sscanf(device, "@%u%n", &port, &consumed);
	if (processed < 1 || device[consumed] || port == 0 || port > 1024) {
		pr_err("Using @<port> format and port range (0, 1024].\n");
		return -EINVAL;
	}

	db_server_port = port;
	return 0;
}

static const struct kernel_param_ops db_vsock_srv_cmdline_param_ops = {
	.set = db_vsock_srv_cmdline_set,
};

device_param_cb(port, &db_vsock_srv_cmdline_param_ops, NULL, 0400);

/*
 * Added register/unregister support for dragonball vsock server.
 *
 * After the vmm client connected to the server, client should send a
 * character to switch the server working mode. Server module/driver
 * needs register the process function with the character to enable
 * the server support.
 */
int register_db_vsock_service(const char cmd, db_vsock_svc_handler_t handler)
{
	int rc = -EEXIST;
	struct db_service_entry *service_entry;

	write_lock(&db_service_lock);
	list_for_each_entry(service_entry, &db_service_list, list) {
		if (cmd == service_entry->cmd) {
			rc = -EEXIST;
			goto out;
		}
	}

	service_entry = kzalloc(sizeof(*service_entry), GFP_KERNEL);
	if (!service_entry) {
		rc = -ENOMEM;
		goto out;
	}
	service_entry->cmd = cmd;
	service_entry->handler = handler;
	list_add_tail(&service_entry->list, &db_service_list);
	rc = 0;
out:
	write_unlock(&db_service_lock);
	return rc;
}
EXPORT_SYMBOL_GPL(register_db_vsock_service);

int unregister_db_vsock_service(const char cmd)
{
	int rc = -EEXIST;
	struct db_service_entry *service_entry, *n;

	write_lock(&db_service_lock);
	list_for_each_entry_safe(service_entry, n, &db_service_list, list) {
		if (cmd == service_entry->cmd) {
			list_del(&service_entry->list);
			rc = 0;
			break;
		}
	}
	write_unlock(&db_service_lock);

	return rc;
}
EXPORT_SYMBOL_GPL(unregister_db_vsock_service);

static int __init db_vsock_srv_init(void)
{
	if (has_dragonball_feature(DB_FEAT_VSOCKSRV))
		return db_create_service();

	return 0;
}

late_initcall(db_vsock_srv_init);

MODULE_AUTHOR("Alibaba, Inc.");
MODULE_DESCRIPTION("Dragonball vsock server");
MODULE_LICENSE("GPL v2");
