/*******************************************************************************
 * Copyright 2022 Intel Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
#include "relu_backprop.hpp"
#include <compiler/ir/graph/fusible_op.hpp>

namespace sc {
namespace ops {

relu_backprop_op::relu_backprop_op(const std::vector<graph_tensor_ptr> &ins,
        const std::vector<graph_tensor_ptr> &outs, const any_map_t &attrs) {
    COMPILE_ASSERT(ins.size() == 2, "Wrong op input size.\n");
    info_.inputs_ = ins;
    if (outs.empty()) {
        info_.outputs_.emplace_back(
                std::make_shared<graph_tensor>(this, ins[0]->details_));
    } else {
        info_.outputs_ = outs;
    }
    attrs_ = attrs;
    op_name_ = "relu_backprop";
}

void relu_backprop_op::get_graph_impl(std::shared_ptr<sc_graph_t> &graph) {
    // create new input logical tensors
    std::vector<graph_tensor_ptr> inputs, outputs;
    inputs = remake_logical_tensors(info_.inputs_);
    outputs = remake_logical_tensors(info_.outputs_);

    // input
    graph->make_input(inputs);

    // if "use_dst" is true, inputs0 is the result of forward, which is
    // relu(x). otherwise, inputs0 is the src of forward

    sc_op_ptr select_one, mul;
    if (attrs_.get_or_else("use_dst", true)) {
        select_one = graph->make("select_one", {inputs[0]}, {}, {});
        mul = graph->make(
                "mul", {inputs[1], select_one->get_outputs()[0]}, {}, {});
    } else {
        select_one = graph->make("select_one", {inputs[0]}, {}, {});
        mul = graph->make(
                "mul", {inputs[1], select_one->get_outputs()[0]}, {}, {});
    }

    // output
    graph->make_output(mul->get_outputs());
}

void relu_backprop_op::query_format(context_ptr ctx,
        std::vector<std::vector<format_stride_pair>> &supported_ins,
        std::vector<std::vector<format_stride_pair>> &supported_outs) {}

} // namespace ops

OP_REGISTER(ops::relu_backprop_op, relu_backprop)
} // namespace sc
