// SPDX-License-Identifier: GPL-2.0
#include <stdint.h>
#include <stddef.h>
#include <arm_sve.h>

#include "utils.h"

static inline void mov_1vector(uint8_t *dst, const uint8_t *src)
{
	__uint128_t *dst128 = (__uint128_t *) dst;
	const __uint128_t *src128 = (const __uint128_t *) src;
	*dst128 = *src128;
}

static inline void mov_2vector(uint8_t *dst, const uint8_t *src)
{
	__uint128_t *dst128 = (__uint128_t *) dst;
	const __uint128_t *src128 = (const __uint128_t *) src;
	const __uint128_t x0 = src128[0], x1 = src128[1];

	dst128[0] = x0;
	dst128[1] = x1;
}

static inline void mov_8vector(uint8_t *dst, const uint8_t *src, svbool_t m)
{
	svuint8_t src1, src2, src3, src4, src5, src6, src7, src8;

	src1 = svld1_vnum_u8 (m, src, 0);
	src2 = svld1_vnum_u8 (m, src, 1);
	svst1_vnum_u8 (m, dst, 0, src1);
	svst1_vnum_u8 (m, dst, 1, src2);

	src3 = svld1_vnum_u8 (m, src, 2);
	src4 = svld1_vnum_u8 (m, src, 3);
	svst1_vnum_u8 (m, dst, 2, src3);
	svst1_vnum_u8 (m, dst, 3, src4);

	src5 = svld1_vnum_u8 (m, src, 4);
	src6 = svld1_vnum_u8 (m, src, 5);
	svst1_vnum_u8 (m, dst, 4, src5);
	svst1_vnum_u8 (m, dst, 5, src6);

	src7 = svld1_vnum_u8 (m, src, 6);
	src8 = svld1_vnum_u8 (m, src, 7);
	svst1_vnum_u8 (m, dst, 6, src7);
	svst1_vnum_u8 (m, dst, 7, src8);
}

static inline svbool_t u8xn_elt_mask(int64_t a, int64_t b)
{
	return svwhilelt_b8(a, b);
}

static inline void memcpy_ge8vector(uint8_t *dst, const uint8_t *src, size_t n)
{
	int64_t eno = (int64_t)svcntb();
	svbool_t m;

	m = u8xn_elt_mask(0, eno);

	while (n >= eno * 8) {
		mov_8vector(dst, src, m);
		dst += eno * 8;
		src += eno * 8;
		n -= eno * 8;
	}

	for (int64_t i = 0; i < n; i += eno) {
		m = u8xn_elt_mask(i, n);
		svst1_u8(m, dst + i, svld1_u8(m, src + i));
	}
}

void *memcpy_sve(void *dst, const void *src, size_t n)
{
	size_t dstofss;
	void *ret = dst;

	if (n < 128) {
		if (n < 16) {
			if (n & 0x08) {
				*(uint64_t *)dst = *(const uint64_t *)src;
				*(uint64_t *)((uint8_t *)dst - 8 + n) =
					*(const uint64_t *)((uint8_t *)src - 8 + n);
			} else if (n & 0x04) {
				*(uint32_t *)dst = *(const uint32_t *)src;
				*(uint32_t *)((uint8_t *)dst - 4 + n) =
					*(const uint32_t *)((uint8_t *)src - 4 + n);
			} else if (n & 0x02) {
				*(uint16_t *)dst = *(const uint16_t *)src;
				*(uint16_t *)((uint8_t *)dst - 2 + n) =
					*(const uint16_t *)((uint8_t *)src - 2 + n);
			} else if (n & 0x01) {
				*(uint8_t *)dst = *(uint8_t *)src;
			}
			return ret;
		}

		if (n <= 32) {
			mov_1vector((uint8_t *)dst, (const uint8_t *)src);
			mov_1vector((uint8_t *)dst - 16 + n, (const uint8_t *)src - 16 + n);
			return ret;
		}

		if (n <= 64) {
			mov_2vector((uint8_t *)dst, (const uint8_t *)src);
			mov_2vector((uint8_t *)dst - 32 + n, (const uint8_t *)src - 32 + n);
			return ret;
		}

		svbool_t m;
		int64_t eno = (int64_t)svcntb();

		for (int64_t i = 0; i < n; i += eno) {
			m = u8xn_elt_mask(i, n);
			svst1_u8(m, dst + i, svld1_u8(m, src + i));
		}

		return ret;
	}

	dstofss = (uint64_t)src & 0x0F;
	if (dstofss > 0) {
		dstofss = 16 - dstofss;
		n -= dstofss;
		mov_1vector((uint8_t *) dst, (const uint8_t *) src);
		src = (const uint8_t *) src + dstofss;
		dst = (uint8_t *) dst + dstofss;
	}

	memcpy_ge8vector((uint8_t *)dst, (const uint8_t *)src, n);
	return dst;
}
