/**
 * Copyright (c) 2024 OceanBase
 * OceanBase CE is licensed under Mulan PubL v2.
 * You can use this software according to the terms and conditions of the Mulan PubL v2.
 * You may obtain a copy of Mulan PubL v2 at:
 *          http://license.coscl.org.cn/MulanPubL-2.0
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PubL v2 for more details.
 */

#ifndef OCEANBASE_LIB_OB_VECTOR_DIV_H_
#define OCEANBASE_LIB_OB_VECTOR_DIV_H_

#include "lib/utility/ob_print_utils.h"
#include "lib/oblog/ob_log.h"
#include "lib/ob_define.h"
#include "common/object/ob_obj_compare.h"
#include "ob_vector_op_common.h"

namespace oceanbase
{
namespace common
{
struct ObVectorDiv
{
  // a = a + f
  static int calc(float *a, float f, const int64_t len);
};

#define SIMD_DIV(type, set, load, div, store) \
  type divisor_simd = set(f);                 \
  for (int i = 0; i < dim; i += batch) {      \
    type values = load(&a[i]);                \
    type tmp_res = div(values, divisor_simd); \
    store(&a[i], tmp_res);                    \
  }

OB_INLINE static int vector_div_normal(float *a, float f, const int64_t len)
{
  int ret = OB_SUCCESS;
  if (f < 1e-10) {
    ret = OB_DIVISION_BY_ZERO;
    LIB_LOG(WARN, "f should > 0", K(ret), K(f));
  } else {
    for (int64_t i = 0; i < len; ++i) {
      a[i] /= f;
    }
  }
  return ret;
}

// for others
OB_DECLARE_DEFAULT_CODE(inline static int vector_div(float *a, float f, const int64_t len) {
  return vector_div_normal(a, f, len);
})

OB_DECLARE_SSE_AND_AVX_CODE(
    inline static int vector_div_simd4_avx128(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      const int64_t batch = 128 / (sizeof(float) * 8);  // 4
      int64_t dim = len >> 2 << 2;
      SIMD_DIV(__m128, _mm_set1_ps, _mm_loadu_ps, _mm_div_ps, _mm_storeu_ps);
      return ret;
    }

    inline static int vector_div_simd4_avx128_extra(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      int64_t dim = len >> 2 << 2;
      if (OB_FAIL(vector_div_simd4_avx128(a, f, dim))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret), K(len), K(dim));
      } else if (0 < len - dim && OB_FAIL(vector_div_normal(a + dim, f + dim, len - dim))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret), K(len), K(dim));
      }
      return ret;
    })

// for sse,sse2,sse3,ssse3,sse4,popcnt
OB_DECLARE_SSE42_SPECIFIC_CODE(inline static int vector_div(float *a, float f, const int64_t len) {
  int ret = OB_SUCCESS;
  if (f < 1e-10) {
    ret = OB_DIVISION_BY_ZERO;
    LIB_LOG(WARN, "f should > 0", K(ret), K(f));
  } else if (4 < len) {
    if (OB_FAIL(vector_div_simd4_avx128_extra(a, f, len))) {
      LIB_LOG(WARN, "failed to cal div", K(ret));
    }
  } else if (OB_FAIL(vector_div_normal(a, f, len))) {
    LIB_LOG(WARN, "failed to cal div", K(ret));
  }
  return ret;
})

OB_DECLARE_AVX_ALL_CODE(
    inline static int vector_div_simd8_avx256(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      const int64_t batch = 256 / (sizeof(float) * 8);  // 8
      int64_t dim = len >> 3 << 3;
      SIMD_DIV(__m256, _mm256_set1_ps, _mm256_loadu_ps, _mm256_div_ps, _mm256_storeu_ps);
      return ret;
    }

    inline static int vector_div_simd8_avx256_extra(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      int64_t dim = len >> 3 << 3;
      if (OB_FAIL(vector_div_simd8_avx256(a, f, dim))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret), K(len), K(dim));
      } else if (0 < len - dim
                 && OB_FAIL(vector_div_simd4_avx128_extra(a + dim, f + dim, len - dim))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret), K(len), K(dim));
      }
      return ret;
    })

// for avx2
OB_DECLARE_AVX_AND_AVX2_CODE(inline static int vector_div(float *a, float f, const int64_t len) {
  int ret = OB_SUCCESS;
  if (f < 1e-10) {
    ret = OB_DIVISION_BY_ZERO;
    LIB_LOG(WARN, "f should > 0", K(ret), K(f));
  } else if (8 < len) {
    if (OB_FAIL(vector_div_simd8_avx256_extra(a, f, len))) {
      LIB_LOG(WARN, "failed to cal div", K(ret));
    }
  } else if (4 < len) {
    if (OB_FAIL(vector_div_simd4_avx128_extra(a, f, len))) {
      LIB_LOG(WARN, "failed to cal div", K(ret));
    }
  } else if (OB_FAIL(vector_div_normal(a, f, len))) {
    LIB_LOG(WARN, "failed to cal div", K(ret));
  }
  return ret;
})

// for avx512f,avx512bw,avx512vl
OB_DECLARE_AVX512_SPECIFIC_CODE(
    OB_INLINE static int vector_div_simd16_avx512(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      const int64_t batch = 512 / (sizeof(float) * 8);  // 16
      int64_t dim = len >> 4 << 4;
      SIMD_DIV(__m512, _mm512_set1_ps, _mm512_loadu_ps, _mm512_div_ps, _mm512_storeu_ps);
      return ret;
    }

    OB_INLINE static int vector_div_simd16_avx512_extra(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      int64_t dim = len >> 4 << 4;
      if (OB_FAIL(vector_div_simd16_avx512(a, f, dim))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret), K(len), K(dim));
      } else if (0 < len - dim
                 && OB_FAIL(vector_div_simd8_avx256_extra(a + dim, f + dim, len - dim))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret), K(len), K(dim));
      }
      return ret;
    }

    inline static int vector_div(float *a, float f, const int64_t len) {
      int ret = OB_SUCCESS;
      if (f < 1e-10) {
        ret = OB_DIVISION_BY_ZERO;
        LIB_LOG(WARN, "f should > 0", K(ret), K(f));
      } else if (16 < len) {
        if (OB_FAIL(vector_div_simd16_avx512_extra(a, f, len))) {
          LIB_LOG(WARN, "failed to cal vector div", K(ret));
        }
      } else if (8 < len) {
        if (OB_FAIL(vector_div_simd8_avx256_extra(a, f, len))) {
          LIB_LOG(WARN, "failed to cal vector div", K(ret));
        }
      } else if (4 < len) {
        if (OB_FAIL(vector_div_simd4_avx128_extra(a, f, len))) {
          LIB_LOG(WARN, "failed to cal vector div", K(ret));
        }
      } else if (OB_FAIL(vector_div_normal(a, f, len))) {
        LIB_LOG(WARN, "failed to cal vector div", K(ret));
      }
      return ret;
    })

}  // namespace common
}  // namespace oceanbase
#endif