/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef CPU_X64_JIT_BRGEMM_CONV_HPP
#define CPU_X64_JIT_BRGEMM_CONV_HPP

#include "common/c_types_map.hpp"
#include "common/dnnl_thread.hpp"
#include "common/memory_tracking.hpp"
#include "common/primitive.hpp"
#include "common/utils.hpp"

#include "cpu/cpu_convolution_pd.hpp"
#include "cpu/platform.hpp"

#include "cpu/x64/amx_tile_configure.hpp"
#include "cpu/x64/brgemm/brgemm.hpp"
#include "cpu/x64/cpu_barrier.hpp"
#include "cpu/x64/cpu_reducer.hpp"
#include "cpu/x64/jit_brgemm_conv_comp_pad_kernel.hpp"
#include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp"
#include "cpu/x64/jit_brgemm_conv_utils.hpp"
#include "cpu/x64/jit_brgemm_post_ops.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {

template <cpu_isa_t isa, bool use_inversion = false>
struct brgemm_convolution_fwd_t : public primitive_t {

    struct pd_t : public cpu_convolution_fwd_pd_t {
        pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
                const typename pd_t::hint_class *hint_fwd_pd)
            : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
            , with_sum(false) {}

        ~pd_t() = default;

        // ------- DECLARE_COMMON_PD_t -----
        pd_t *clone() const override {
            auto new_pd = utils::make_unique<pd_t>(*this);
            if (!new_pd->is_initialized()) return nullptr;
            new_pd->brgs_.resize(brgs_sz_);
            for (int i = 0; i < brgs_sz_; i++) {
                new_pd->brgs_[i] = brgs_[i];
                new_pd->bd_masks[i] = bd_masks[i];
            }
            return new_pd.release();
        }

        status_t create_primitive(
                std::pair<std::shared_ptr<primitive_t>, bool> &primitive,
                engine_t *engine,
                const cache_blob_t &cache_blob) const override {
            return primitive_t::create_primitive_common<
                    brgemm_convolution_fwd_t, pd_t>(
                    primitive, this, engine, false, cache_blob);
        }

        const char *name() const override {
            return JIT_IMPL_NAME_HELPER("brgconv:", isa, "");
        }
        // ---------------------------------

        status_t init(engine_t *engine);

        int brgs_sz_;
        std::vector<std::shared_ptr<brgemm_t>> brgs_;
        std::vector<std::shared_ptr<std::vector<char>>> bd_masks;
        bool with_sum;
        jit_brgemm_conv_conf_t jcp_;

        int ic_chunks;
        bool need_postwork;

        // batch sizes info for unrolled kernels
        int bs_c, first_bs;
        std::vector<int> batchsizes;
        int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail,
                bool is_K_tail) const {
            auto bs_idx = jcp_.use_uker ? batchsizes[bs] : 0;
            assert(bs_idx >= 0);
            return (((m * bs_c + bs_idx) * 2
                            + static_cast<int>(do_initialization))
                                   * 2
                           + static_cast<int>(is_N_tail))
                    * 2
                    + static_cast<int>(is_K_tail);
        }

    protected:
        bool arg_scales_ok() const {
            std::vector<int> supported_args
                    = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
            return attr_scales_ok(supported_args);
        }

        bool zero_points_ok() const {
            // Only common zero points are supported -> mask should only be 0
            int mask_src = 0, mask_dst = 0;
            attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
            attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
            return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
                    && mask_src == 0 && mask_dst == 0;
        }
    };

    brgemm_convolution_fwd_t(const pd_t *apd);

    ~brgemm_convolution_fwd_t() = default;

    status_t execute(const exec_ctx_t &ctx) const override;

protected:
    status_t init(engine_t *engine) override;

private:
    struct S_t {
        char a[AMX_PALETTE_SIZE];
    };

    //  brgemm convolution execution context
    struct brgemm_exec_ctx_t {
        brgemm_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd)
            : src(CTX_IN_MEM(const char *, DNNL_ARG_SRC))
            , weights(CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS))
            , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS))
            , dst(CTX_OUT_MEM(char *, DNNL_ARG_DST))
            , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args(
                      pd->attr()->post_ops_, ctx)) {}
        const char *const __restrict src;
        const char *const __restrict weights;
        const char *const __restrict bias;
        char *const __restrict dst;
        const std::vector<const void *> post_ops_binary_rhs_arg_vec;
    };

    struct brgemm_thread_ctx_t;

    static int get_ker_po_idx(int m, bool do_postwork, bool is_N_tail) {
        return (m * 2 + static_cast<int>(do_postwork)) * 2
                + static_cast<int>(is_N_tail);
    }

    static int get_inp_size(
            int max_src_size, int dst_size, int k, int stride, int dilate) {
        const auto res = nstl::min(max_src_size,
                calculate_end_padding(0, dst_size, 0, stride,
                        calculate_extended_filter_size(k, dilate)));
        return res;
    }

    int maybe_invert(int k, int K) const {
        return use_inversion ? K - 1 - k : k;
    };
    void get_kw_range(
            int ow, int &kw_s, int &kw_full_s, int &kw_full_e, int &kw_e) const;
    void get_ow_range(int ow, int kw, int &ow_s, int &ow_e) const;

    void ker_base(brgemm_thread_ctx_t &btc) const;
    void ker_trans(brgemm_thread_ctx_t &btc, char *inp_buffer) const;
    void ker_vpad(brgemm_thread_ctx_t &btc) const;

    void perform_outwork(char *dst_base, char *dst, char *c_buffer,
            const char *bias_w, int od, int oh, int ow, int g_oc,
            bool is_oc_tail, int ker_ow_s, int ker_ow_f, int kd_l, int kh_l,
            const void *post_ops_binary_rhs_arg_vec, const float *oscales,
            int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
            int32_t *s8s8_compensation, bool maybe_do_init, bool do_postwork,
            bool do_post_comp, const float *dst_scales) const;

    void call_brgemm_kernel(brgemm_thread_ctx_t &btc, int brg_idx,
            int batch_size, char *ptr_C, char *ptr_D, const char *bias_w,
            int g_oc, bool do_postops, const void *binary_post_ops_rhs,
            int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
            int32_t *s8s8_comp, bool do_only_comp) const;

    void maybe_conv_inp(int ithr, const char *__restrict src,
            char *__restrict inp_buffer, uint8_t *__restrict inp_buffer_mask,
            int g, int n, int icc, int odb, int ohb, int owb, int last_g,
            int last_n, int last_icc, int last_odb, int last_ohb,
            int last_owb) const;

    status_t add_po_kernel(brgemm_t *bcfg, int ker_idx, bool is_init);
    void add_po_kernels(int i_N, int init_bcast_dim, int po_bcast_dim);
    status_t add_brg_kernel(int bs, int M, int i_N, int i_K, int i_init);

    status_t cal_compensation(const char *__restrict weights,
            int32_t *src_zp_buffer, int32_t *s8s8_comp_buffer) const;
    int get_comp_ker_idx(const int kd_b, const int kd_e, const int kh_b,
            const int kh_e, const int kw_b, const int kw_e) const;
    int get_comp_offset(const int g, const int ocb, const int ow,
            const int kd_b, const int kd_e, const int kh_b, const int kh_e,
            const int kw_b, const int kw_e) const;
    const pd_t *pd() const {
        return static_cast<const pd_t *>(primitive_t::pd().get());
    }

    std::vector<std::unique_ptr<brgemm_kernel_t>> brg_kernels_;
    std::vector<std::unique_ptr<jit_brgemm_kernel_post_ops<isa>>> kernels_po_;
    std::unique_ptr<jit_avx512_core_brgemm_conv_trans_kernel::
                    jit_avx512_core_brgemm_conv_trans_kernel_t>
            copy_to_pbuffer_;
    std::unique_ptr<jit_avx512_core_brgemm_conv_comp_pad_kernel::
                    jit_avx512_core_brgemm_conv_comp_pad_kernel_t>
            comp_vpad_pbuffer_;
    std::vector<S_t> brg_kernel_palettes_;

    size_t acc_dsz, bia_dsz, src_dsz, wei_dsz, dst_dsz;

    const memory_desc_wrapper bias_d;

    // pre - calculated values
    std::vector<dim_t> owb_kw_top_vpads;
    std::vector<dim_t> owb_kw_bottom_vpads;
    std::vector<dim_t> kd_bs, kd_es, kh_bs, kh_es, kw_bs, kw_es;

    int KD, KH, KW, EXT_KD, EXT_KH, EXT_KW, KS, KD_BLOCK, KH_BLOCK, KW_BLOCK,
            KD_BLOCK_PAD, KH_BLOCK_PAD, ID, IH, IW, IDP, IHP, IWP, OD, OH, OW,
            SD, SH, SW, FP, TP, LP, DD, DH, DW;
    dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz;
    dim_t wei_g_stride, wei_ic_stride, wei_ocb_stride;
    dim_t wei_kw_stride, wei_kh_stride, wei_kd_stride;
    dim_t pbuf_w_sz, pbuf_h_sz, pbuf_d_sz;
    dim_t ker_vpad_sz, comp_ocb_sz, comp_ker_sz, comp_kw_sz;

    bool need_compensation;
    bool is_amx;
};

} // namespace x64
} // namespace cpu
} // namespace impl
} // namespace dnnl

#endif

// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
