/*******************************************************************************
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <algorithm>
#include "gpu/ocl/ref_reorder.hpp"

#include "common/utils.hpp"
#include "gpu/ocl/ocl_stream.hpp"
#include "gpu/ocl/ocl_utils.hpp"
namespace dnnl {
namespace impl {
namespace gpu {
namespace ocl {

using namespace dnnl::impl::memory_tracking::names;

#define allowed(x) ((getenv_int(x, 1) != 0))

status_t ref_reorder_t::pd_t::init_conf(engine_t *engine) {
    using namespace format_tag;

    const memory_desc_wrapper src_mdw(src_md());
    const memory_desc_wrapper dst_mdw(dst_md());

    conf.src_md_info = memory_desc_info_t::create(src_mdw);
    conf.dst_md_info = memory_desc_info_t::create(dst_mdw);

    status_t status = status::success;

    const auto &padded_dims = dst_mdw.padded_dims();
    const auto &oscales = attr()->output_scales_;
    const auto &zp = attr()->zero_points_;
    conf.scale_mask = oscales.mask_;
    conf.scales_num = get_attr_oscales_count(oscales.mask_, dst_mdw);
    const bool has_alpha = oscales.defined() && conf.scales_num == 1;
    conf.scale_quant = !oscales.has_default_values() && !has_alpha;
    conf.with_sum_ab = ((has_alpha && alpha() != 1.f) || beta() != 0.f);
    conf.with_sum_a = conf.with_sum_ab && beta() == 0.f;
    conf.has_padding = !src_mdw.is_dense() || !dst_mdw.is_dense();
    conf.with_src_zp = !zp.has_default_values(DNNL_ARG_SRC);
    conf.with_dst_zp = !zp.has_default_values(DNNL_ARG_DST);
    conf.common_src_zp = conf.with_src_zp && zp.defined(DNNL_ARG_SRC)
            ? *zp.get(DNNL_ARG_SRC)
            : 0;
    conf.common_dst_zp = conf.with_dst_zp && zp.defined(DNNL_ARG_DST)
            ? *zp.get(DNNL_ARG_DST)
            : 0;
    conf.ndims = src_mdw.ndims();
    conf.nelems = utils::array_product(padded_dims, conf.ndims);

    conf.sub_group_size = 1;

    if (conf.nelems == 0) return status::success;

    auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);

    dim_t blocks[MAX_NDIMS] = {1, 1, 1, 1, 1, 1};

    conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_);

    blocks[2] = blocks[3] = blocks[4] = blocks[5] = 0;

    for (int i = 0; i < MAX_NDIMS; ++i) {
        auto dim_str = utils::format("D%d", i);
        if (i < dst_mdw.ndims()) {
            int dim = padded_dims[i];
            // if needed to align vectorized dim with vector size, pad that dim again
            conf.dispatch.define_dim(dim_str, i, dim, blocks[i]);
        } else {
            conf.dispatch.define_dim(dim_str, 1);
        }
    }

    conf.dispatch.generate();
    return status;
}

status_t ref_reorder_t::pd_t::init_kernel_ctx(
        compute::kernel_ctx_t &kernel_ctx) const {
    using namespace format_tag;

    const memory_desc_wrapper src_mdw(src_md());
    const memory_desc_wrapper dst_mdw(dst_md());

    if (conf.nelems == 0) return status::success;

    kernel_ctx.define_int("NDIMS", conf.ndims);
    kernel_ctx.add_option("-cl-std=CL2.0");

    if (conf.with_sum_a)
        kernel_ctx.define_int("WITH_SUM_A", 1);
    else if (conf.with_sum_ab)
        kernel_ctx.define_int("WITH_SUM_AB", 1);

    if (conf.scale_quant) {
        kernel_ctx.define_int("SCALE_QUANT", 1);
        kernel_ctx.define_int("SCALE_MASK", conf.scale_mask);
    }

    kernel_ctx.define_int("WITH_SRC_ZPOINTS", conf.with_src_zp);
    kernel_ctx.define_int(
            "RUNTIME_SRC_ZPOINTS", conf.with_src_zp && conf.common_src_zp == 0);
    kernel_ctx.define_int("WITH_DST_ZPOINTS", conf.with_dst_zp);
    kernel_ctx.define_int(
            "RUNTIME_DST_ZPOINTS", conf.with_dst_zp && conf.common_dst_zp == 0);

    def_dispatch(kernel_ctx, conf.dispatch);

    kernel_ctx.define_int("REF_REORDER", 1);

    kernel_ctx.define_int("PAD_FILL_ZERO", conf.has_padding);

    def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
    def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");

    kernel_ctx.print_options();
    return status::success;
}

void ref_reorder_t::pd_t::init_scratchpad() {
    if (conf.scales_num > 0) {
        auto scratchpad = scratchpad_registry().registrar();
        scratchpad.book(memory_tracking::names::key_reorder_scales,
                conf.scales_num, sizeof(float), OCL_BUFFER_ALIGNMENT);
    }
}

status_t ref_reorder_t::execute(const exec_ctx_t &ctx) const {

    status_t status = status::success;

    auto &src = CTX_IN_STORAGE(DNNL_ARG_FROM);
    auto &dst = CTX_OUT_STORAGE(DNNL_ARG_TO);
    CHECK(status);

    const auto &conf = pd()->conf;
    if (conf.nelems == 0) return status::success;

    float alpha = pd()->alpha();
    float beta = pd()->beta();

    compute::kernel_arg_list_t arg_list;
    arg_list.set(0, src);
    arg_list.set(1, dst);
    arg_list.set(2, alpha);
    arg_list.set(3, beta);

    std::shared_ptr<memory_storage_t> scales;
    if (conf.scale_quant) {
        if (pd()->attr()->output_scales_.defined()) {
            scales = ctx.get_scratchpad_grantor().get_memory_storage(
                    key_reorder_scales);

            void *tmp_ptr = nullptr;
            status = scales->map_data(&tmp_ptr, ctx.stream(),
                    sizeof(float) * pd()->attr()->output_scales_.count_);
            if (status != status::success) return status;
            utils::array_copy((float *)tmp_ptr,
                    pd()->attr()->output_scales_.scales_,
                    pd()->attr()->output_scales_.count_);
            status = scales->unmap_data(tmp_ptr, ctx.stream());
            if (status != status::success) return status;
            arg_list.set(4, *scales);
        } else {
            auto &runtime_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES);
            arg_list.set(4, runtime_scales);
        }
    } else {
        arg_list.set(4, memory_storage_t::empty_storage());
    }

    if (conf.with_src_zp && conf.common_src_zp == 0) {
        auto &zps = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
        arg_list.set(5, zps);
    } else
        arg_list.set(5, conf.common_src_zp);

    if (conf.with_dst_zp && conf.common_dst_zp == 0) {
        auto &zps = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
        arg_list.set(6, zps);
    } else
        arg_list.set(6, conf.common_dst_zp);

    auto nd_range = conf.dispatch.nd_range();

    status = parallel_for(ctx, nd_range, kernel_, arg_list);

    return status;
}

} // namespace ocl
} // namespace gpu
} // namespace impl
} // namespace dnnl
