// Copyright 2020 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file implements MLIR operation functions for the basic_kernels library.

#include "tfrt/basic_kernels/opdefs/basic_kernels.h"

#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "tfrt/basic_kernels/opdefs/types.h"

using func::FuncOp;

namespace tfrt {
namespace compiler {

namespace {
void createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,
                ArrayRef<Type> types,
                SmallVector<OpAsmParser::Argument> &args) {
  for (auto argAndType : llvm::zip(operands, types)) {
    auto &arg = args.emplace_back();
    arg.ssaName = std::get<0>(argAndType);
    arg.type = std::get<1>(argAndType);
  }
}
}  // namespace

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//

ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
  SymbolRefAttr calleeAttr;
  FunctionType calleeType;
  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
  auto calleeLoc = parser.getNameLoc();
  if (parser.parseAttribute(calleeAttr, "callee", result.attributes) ||
      parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
      parser.parseOptionalAttrDict(result.attributes) ||
      parser.parseColonType(calleeType) ||
      parser.addTypesToList(calleeType.getResults(), result.types) ||
      parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
                             result.operands))
    return failure();

  return success();
}

void CallOp::print(OpAsmPrinter &p) {
  p << " " << (*this)->getAttr("callee") << '(';
  p.printOperands(getOperands());
  p << ')';
  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"callee"});
  p << " : ";
  p.printType(getCalleeType());
}

LogicalResult CallOp::verify() {
  CallOp op = *this;
  // Check that the callee attribute was specified.
  auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
  if (!fnAttr)
    return op.emitOpError("requires a 'callee' symbol reference attribute");
  auto fn = op->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(
      fnAttr.getValue());
  if (!fn)
    return op.emitOpError() << "'" << fnAttr.getValue()
                            << "' does not reference a valid function";

  // Verify that the operand and result types match the callee.
  auto fnType = fn.getFunctionType();
  if (fnType.getNumInputs() != op.getNumOperands())
    return op.emitOpError("incorrect number of operands for callee");

  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
    if (op.getOperand(i).getType() != fnType.getInput(i))
      return op.emitOpError("operand type mismatch");

  if (fnType.getNumResults() != op.getNumResults())
    return op.emitOpError("incorrect number of results for callee");

  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
    if (op.getResult(i).getType() != fnType.getResult(i))
      return op.emitOpError("result type mismatch");

  return success();
}

mlir::FunctionType CallOp::getCalleeType() {
  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}

// Verify that the specified region contains a tfrt.return operation with the
// specified type list and emit an error if not.
template <typename ResultTypeContainer>
static LogicalResult checkTFRTReturn(Operation *op, Region *region,
                                     ResultTypeContainer result_types) {
  assert(std::distance(region->begin(), region->end()) == 1 &&
         "verifier should already check region size");
  auto *block = &region->front();

  if (block->empty() || !isa<ReturnOp>(block->back()))
    return op->emitOpError("expected tfrt.return in body");

  auto returnOp = cast<ReturnOp>(block->back());
  if (!std::equal(returnOp.getOperandTypes().begin(),
                  returnOp.getOperandTypes().end(), result_types.begin(),
                  result_types.end()))
    return returnOp.emitOpError()
           << "operand types don't match '" << op->getName() << "' result";

  return success();
}

//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//

LogicalResult IfOp::verify() {
  IfOp op = *this;
  // Verify that the operands match the bb arguments.  The ODS verifier already
  // checked the first argument to be present and i1.
  auto *then_block = &op.getThenRegion().front();
  if (op.getNumOperands() - 1 != then_block->getNumArguments())
    return op.emitOpError("incorrect number of arguments to 'then' block");
  auto *else_block = &op.getElseRegion().front();
  if (op.getNumOperands() - 1 != else_block->getNumArguments())
    return op.emitOpError("incorrect number of arguments to 'else' block");

  for (unsigned i = 0, e = op.getNumOperands() - 1; i != e; ++i)
    if (op.getOperand(i + 1).getType() !=
            then_block->getArgument(i).getType() ||
        op.getOperand(i + 1).getType() != else_block->getArgument(i).getType())
      return op.emitOpError("operand/argument type mismatch");

  if (failed(checkTFRTReturn(op, &op.getThenRegion(), op.getResultTypes())))
    return failure();

  return checkTFRTReturn(op, &op.getElseRegion(), op.getResultTypes());
}

ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
  if (parser.parseOperandList(operands)) return failure();

  if (succeeded(parser.parseOptionalKeyword("attributes"))) {
    if (parser.parseOptionalAttrDict(result.attributes)) return failure();
  }

  FunctionType types;
  llvm::SMLoc type_loc = parser.getCurrentLocation();
  if (parser.parseColonType(types) ||
      parser.addTypesToList(types.getResults(), result.types))
    return failure();

  if (operands.empty())
    return parser.emitError(parser.getCurrentLocation(), "expected condition");

  auto body_operands = llvm::ArrayRef(operands).drop_front();
  auto body_types = types.getInputs();
  auto i1_type = IntegerType::get(result.getContext(), 1);
  if (parser.resolveOperand(operands[0], i1_type, result.operands) ||
      parser.resolveOperands(body_operands, types.getInputs(), type_loc,
                             result.operands))
    return failure();

  SmallVector<OpAsmParser::Argument> body_args;
  createArgs(body_operands, body_types, body_args);
  // Parse the body region.
  Region *then_region = result.addRegion();
  if (parser.parseRegion(*then_region, body_args,
                         /*enableNameShadowing=*/true))
    return failure();

  Region *else_region = result.addRegion();
  if (succeeded(parser.parseOptionalKeyword("else"))) {
    if (parser.parseRegion(*else_region, body_args,
                           /*enableNameShadowing=*/true))
      return failure();
  } else {
    // While the else region is syntactically optional, it is structurally
    // required in the IR and by the op kernel implementation.  Fill in the
    // default implementation.
    if (!types.getResults().empty())
      return parser.emitError(
          parser.getCurrentLocation(),
          "expected 'else' in 'tfrt.if' with result values");

    mlir::OpBuilder builder(result.getContext());
    auto *block = builder.createBlock(else_region);
    block->addArguments(
        body_types, SmallVector<Location>(body_types.size(), result.location));
    builder.create<ReturnOp>(result.location);
  }

  return success();
}

void IfOp::print(OpAsmPrinter &p) {
  p << " ";
  p.printOperands(getOperands());
  if (!(*this)->getAttrs().empty()) {
    p.printOptionalAttrDict((*this)->getAttrs());
  }
  p << " : (";
  interleaveComma(llvm::drop_begin(getOperandTypes(), 1), p);
  p << ") -> (";
  interleaveComma(getResultTypes(), p);
  p << ") ";

  // Reuse the argument names provided to the op for the bbarg names within
  // the region.
  auto arg_name_values = llvm::drop_begin(getOperands(), 1);
  p.shadowRegionArgs(getThenRegion(), arg_name_values);
  p << ' ';
  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false);
  p << " else ";
  p.shadowRegionArgs(getElseRegion(), arg_name_values);
  p << ' ';
  p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// CondOp
//===----------------------------------------------------------------------===//

LogicalResult CondOp::verify() {
  CondOp op = *this;
  // Check that the true/false function attributes are specified.
  auto trueFnAttr = op->getAttrOfType<FlatSymbolRefAttr>("a_true_fn");
  if (!trueFnAttr)
    return op.emitOpError("requires a 'a_true_fn' symbol reference attribute");

  auto falseFnAttr = op->getAttrOfType<FlatSymbolRefAttr>("b_false_fn");
  if (!falseFnAttr)
    return op.emitOpError("requires a 'a_false_fn' symbol reference attribute");

  auto trueFn = op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
      trueFnAttr.getValue());
  if (!trueFn)
    return op.emitOpError() << "'" << trueFnAttr.getValue()
                            << "' does not reference a valid function";

  auto falseFn = op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
      falseFnAttr.getValue());
  if (!falseFn)
    return op.emitOpError() << "'" << falseFnAttr.getValue()
                            << "' does not reference a valid function";

  // Verify that the operand and result types match the true/false function.
  auto trueFnType = trueFn.getFunctionType();
  if (trueFnType.getNumInputs() != op.getNumOperands() - 1)
    return op.emitOpError("incorrect number of operands for true function");

  auto falseFnType = falseFn.getFunctionType();
  if (falseFnType.getNumInputs() != op.getNumOperands() - 1)
    return op.emitOpError("incorrect number of operands for false function");

  for (unsigned i = 0, e = trueFnType.getNumInputs(); i != e; ++i) {
    if (op.getOperand(i + 1).getType() != trueFnType.getInput(i))
      return op.emitOpError("operand type mismatch for true function");

    if (op.getOperand(i + 1).getType() != falseFnType.getInput(i))
      return op.emitOpError("operand type mismatch for false function");
  }

  if (trueFnType.getNumResults() != op.getNumResults())
    return op.emitOpError("incorrect number of results for true function");

  if (falseFnType.getNumResults() != op.getNumResults())
    return op.emitOpError("incorrect number of results for false function");

  for (unsigned i = 0, e = trueFnType.getNumResults(); i != e; ++i) {
    if (op.getResult(i).getType() != trueFnType.getResult(i))
      return op.emitOpError("result type mismatch for true function");

    if (op.getResult(i).getType() != falseFnType.getResult(i))
      return op.emitOpError("result type mismatch for false function");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// RepeatI32Op
//===----------------------------------------------------------------------===//

ParseResult RepeatI32Op::parse(OpAsmParser &parser, OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
  if (parser.parseOperandList(operands)) return failure();

  if (succeeded(parser.parseOptionalKeyword("attributes"))) {
    if (parser.parseOptionalAttrDict(result.attributes)) return failure();
  }

  SmallVector<Type, 4> types;
  llvm::SMLoc type_loc = parser.getCurrentLocation();
  if (parser.parseOptionalColonTypeList(types) ||
      parser.addTypesToList(types, result.types))
    return failure();

  if (operands.empty())
    return parser.emitError(parser.getCurrentLocation(), "expected trip count");

  auto loop_operands = llvm::ArrayRef(operands).drop_front();
  auto i32_type = IntegerType::get(result.getContext(), 32);

  if (parser.resolveOperand(operands[0], i32_type, result.operands) ||
      parser.resolveOperands(loop_operands, types, type_loc, result.operands))
    return failure();

  // Parse the body region.
  SmallVector<OpAsmParser::Argument> loop_args;
  createArgs(loop_operands, types, loop_args);
  Region *body = result.addRegion();
  return parser.parseRegion(*body, loop_args,
                            /*enableNameShadowing=*/true);
}

void RepeatI32Op::print(OpAsmPrinter &p) {
  p << " ";
  p.printOperands(getOperands());
  if (!(*this)->getAttrs().empty()) {
    p.printOptionalAttrDict((*this)->getAttrs());
  }
  if (getNumOperands() > 1) {
    p << " : ";
    interleaveComma(llvm::drop_begin(getOperandTypes(), 1), p);
  }

  // Reuse the argument names provided to the op for the bbarg names within
  // the region.
  SmallVector<Value, 4> arg_name_values(llvm::drop_begin(getOperands(), 1));
  p.shadowRegionArgs(getRegion(), arg_name_values);
  p << ' ';
  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}

LogicalResult RepeatI32Op::verify() {
  RepeatI32Op op = *this;
  // Verify that the operand and result types match.
  if (op.getNumResults() != op.getNumOperands() - 1)
    return op.emitOpError("incorrect number of operands");

  for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
    if (op.getOperand(i + 1).getType() != op.getResult(i).getType())
      return op.emitOpError("operand/result type mismatch");

  return checkTFRTReturn(op, &op.getRegion(), op.getResultTypes());
}

//===----------------------------------------------------------------------===//
// ParallelForI32Op
//===----------------------------------------------------------------------===//

// Parse tfrt.parallel_for.i32 operation.
//
// Expected format:
//
//   %ch = tfrt.parallel_for.i32 %start to %end fixed %block_size,
//                              %loop_arg0 : !my.type {
//     ... parallel block compute function ...
//     tfrt.return ... : !tfrt.chain
//   }
ParseResult ParallelForI32Op::parse(OpAsmParser &parser,
                                    OperationState &result) {
  OpAsmParser::UnresolvedOperand start;
  OpAsmParser::UnresolvedOperand end;
  OpAsmParser::UnresolvedOperand block_size;

  // Parse parallel for bounds: %start to %end fixed %block_size
  if (parser.parseOperand(start) || parser.parseKeyword("to") ||
      parser.parseOperand(end) || parser.parseKeyword("fixed") ||
      parser.parseOperand(block_size)) {
    return failure();
  }

  // Parse additional parallel for operands.
  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
  if (succeeded(parser.parseOptionalComma())) {
    if (parser.parseOperandList(operands)) return failure();
  }

  // Parse types for additional operands.
  SmallVector<Type, 4> types;
  llvm::SMLoc type_loc = parser.getCurrentLocation();
  if (parser.parseOptionalColonTypeList(types)) return failure();

  // Resolve parsed parallel for bounds operands ...
  auto i32_type = IntegerType::get(result.getContext(), 32);
  if (parser.resolveOperand(start, i32_type, result.operands) ||
      parser.resolveOperand(end, i32_type, result.operands) ||
      parser.resolveOperand(block_size, i32_type, result.operands)) {
    return failure();
  }

  // ... and additional body operands.
  if (parser.resolveOperands(operands, types, type_loc, result.operands))
    return failure();

  // Parallel for returns chain when all parallel blocks are completed.
  auto chain_type = compiler::ChainType::get(result.getContext());
  if (parser.addTypesToList(chain_type, result.types)) return failure();

  // Parallel for body operands and types.
  SmallVector<OpAsmParser::UnresolvedOperand, 6> body_operands = {start, end};
  for (auto &operand : operands) body_operands.push_back(operand);
  SmallVector<Type, 6> body_operands_types = {i32_type, i32_type};
  for (auto &type : types) body_operands_types.push_back(type);

  SmallVector<OpAsmParser::Argument> body_args;
  createArgs(body_operands, body_operands_types, body_args);
  Region *body = result.addRegion();
  return parser.parseRegion(*body, body_args,
                            /*enableNameShadowing=*/true);
}

void ParallelForI32Op::print(OpAsmPrinter &p) {
  p << " ";

  p.printOperand(getOperand(0));
  p << " to ";
  p.printOperand(getOperand(1));
  p << " fixed ";
  p.printOperand(getOperand(2));

  if (getNumOperands() > 3) {
    p << ", ";
    p.printOperands(llvm::drop_begin(getOperands(), 3));
    p << " : ";
    interleaveComma(llvm::drop_begin(getOperandTypes(), 3), p);
  }

  // Reuse the argument names provided to the op for the bbarg names within
  // the region (except block_size argument).
  SmallVector<Value, 4> arg_name_values(getOperands());
  arg_name_values.erase(arg_name_values.begin() + 2);

  p.shadowRegionArgs(getRegion(), arg_name_values);
  p << ' ';
  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}

LogicalResult ParallelForI32Op::verify() {
  ParallelForI32Op op = *this;
  auto *block = &op.getRegion().front();
  if (block->empty() || !isa<ReturnOp>(block->back()))
    return op.emitOpError("expected tfrt.return in body");

  // Synchronous parallel region can have a return op without operands.
  auto return_op = cast<ReturnOp>(block->back());
  if (return_op.getNumOperands() == 0) return success();

  // Otherwise parallel region must return a chain (same result type as
  // tfrt.parallel_for itself).
  return checkTFRTReturn(op, &op.getRegion(), op.getResultTypes());
}

//===----------------------------------------------------------------------===//
// ParallelCallI32Op
//===----------------------------------------------------------------------===//

// Parse tfrt.parallel_call.i32 operation.
//
// Expected format:
//
//   %ch = tfrt.parallel_call.i32 %start to %end fixed %block_size
//         @callee(%loop_arg0) : !my.type
ParseResult ParallelCallI32Op::parse(OpAsmParser &parser,
                                     OperationState &result) {
  OpAsmParser::UnresolvedOperand start;
  OpAsmParser::UnresolvedOperand end;
  OpAsmParser::UnresolvedOperand block_size;

  // Parse parallel for bounds: %start to %end fixed %block_size
  if (parser.parseOperand(start) || parser.parseKeyword("to") ||
      parser.parseOperand(end) || parser.parseKeyword("fixed") ||
      parser.parseOperand(block_size)) {
    return failure();
  }

  // Parse callee attribute.
  SymbolRefAttr callee_attr;
  if (parser.parseAttribute(callee_attr, "callee", result.attributes))
    return failure();

  // Parse additional parallel call operands.
  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
  if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
    return failure();

  // Parse types for additional operands.
  SmallVector<Type, 4> types;
  llvm::SMLoc type_loc = parser.getCurrentLocation();
  if (parser.parseOptionalColonTypeList(types)) return failure();

  // Resolve parsed parallel call bounds operands ...
  auto i32_type = IntegerType::get(result.getContext(), 32);
  if (parser.resolveOperand(start, i32_type, result.operands) ||
      parser.resolveOperand(end, i32_type, result.operands) ||
      parser.resolveOperand(block_size, i32_type, result.operands)) {
    return failure();
  }

  // ... and additional body operands.
  if (parser.resolveOperands(operands, types, type_loc, result.operands))
    return failure();

  // Parallel for returns chain when all parallel blocks are completed.
  auto chain_type = compiler::ChainType::get(result.getContext());
  if (parser.addTypesToList(chain_type, result.types)) return failure();

  return success();
}

void ParallelCallI32Op::print(OpAsmPrinter &p) {
  p << " ";

  p.printOperand(getOperand(0));
  p << " to ";
  p.printOperand(getOperand(1));
  p << " fixed ";
  p.printOperand(getOperand(2));
  p << " ";

  p << (*this)->getAttr("callee");
  p << '(';
  p.printOperands(llvm::drop_begin(getOperands(), 3));
  p << ')';

  if (getNumOperands() > 3) {
    p << " : ";
    interleaveComma(llvm::drop_begin(getOperandTypes(), 3), p);
  }
}

LogicalResult ParallelCallI32Op::verify() {
  ParallelCallI32Op op = *this;
  // Check that the callee attribute was specified.
  auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
  if (!fnAttr)
    return op.emitOpError("requires a 'callee' symbol reference attribute");
  auto fn =
      op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
  if (!fn)
    return op.emitOpError() << "'" << fnAttr.getValue()
                            << "' does not reference a valid function";

  // Verify that the operand and result types match the callee.
  auto fnType = fn.getFunctionType();

  // Callee must take start and end indices followed by parallel call operands.
  if (fnType.getNumInputs() != op.getNumOperands() - 1)
    return op.emitOpError("incorrect number of callee operands");

  auto i32_type = IntegerType::get(op.getContext(), 32);
  for (unsigned i = 0; i != 2; ++i) {
    if (fnType.getInput(i) != i32_type)
      return op.emitOpError("callee must take stard and end indices first");
  }

  for (unsigned i = 2, e = fnType.getNumInputs(); i != e; ++i) {
    if (op.getOperand(i + 1).getType() != fnType.getInput(i))
      return op.emitOpError("operand type mismatch");
  }

  // Callee must have empty results for synchronous body function, or a single
  // chain for an asynchronous body function.
  if (fnType.getNumResults() > 1)
    return op.emitOpError("invalid callee result type");

  if (fnType.getNumResults() == 1) {
    auto chain_type = compiler::ChainType::get(op.getContext());

    if (fnType.getResult(0) != chain_type)
      return op.emitOpError("async callee must return a chain");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//

ParseResult ReturnOp::parse(OpAsmParser &parser, OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
  SmallVector<Type, 2> types;
  llvm::SMLoc loc = parser.getCurrentLocation();
  return failure(parser.parseOperandList(opInfo) ||
                 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
                 parser.resolveOperands(opInfo, types, loc, result.operands));
}

void ReturnOp::print(OpAsmPrinter &p) {
  if (getNumOperands() > 0) {
    p << ' ';
    p.printOperands(getOperands());
    p << " : ";
    interleaveComma(getOperandTypes(), p);
  }
}

LogicalResult ReturnOp::verify() {
  ReturnOp op = *this;
  // The parent is often a 'func' but not always.
  auto function = dyn_cast<FuncOp>(op->getParentOp());

  // We allow tfrt.return in arbitrary control flow structures.
  if (!function) return success();

  // The operand number and types must match the function signature.
  auto results = function.getFunctionType().getResults();
  if (op.getNumOperands() != results.size())
    return op.emitOpError("has ")
           << op.getNumOperands()
           << " operands, but enclosing function returns " << results.size();

  for (unsigned i = 0, e = results.size(); i != e; ++i)
    if (op->getOperand(i).getType() != results[i])
      return op.emitError()
             << "type of return operand " << i << " ("
             << op.getOperand(i).getType()
             << ") doesn't match function result type (" << results[i] << ")";

  return success();
}

}  // namespace compiler
}  // namespace tfrt

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "tfrt/basic_kernels/opdefs/basic_kernels_opdefs.cpp.inc"
