/* Copyright (c) 2023 Intel Corporation

Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "itex/core/compiler/xla/service/gpu/gpu_compiler.h"

#include <level_zero/ze_api.h>
#include <stdlib.h>

#include <atomic>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "itex/core/compiler/mlir/hlo/transforms/itex_gpu_passes.h"
#include "itex/core/compiler/mlir/utils/name_utils.h"
#include "itex/core/compiler/mlir/xla/hlo_utils.h"
#include "itex/core/compiler/mlir/xla/type_to_shape.h"
#include "itex/core/compiler/xla/protobuf_util.h"
#include "itex/core/compiler/xla/service/algebraic_simplifier.h"
#include "itex/core/compiler/xla/service/all_gather_broadcast_reorder.h"
#include "itex/core/compiler/xla/service/all_gather_combiner.h"
#include "itex/core/compiler/xla/service/all_reduce_combiner.h"
#include "itex/core/compiler/xla/service/all_reduce_contiguous.h"
#include "itex/core/compiler/xla/service/all_reduce_folder.h"
#include "itex/core/compiler/xla/service/all_reduce_reassociate.h"
#include "itex/core/compiler/xla/service/all_to_all_decomposer.h"
#include "itex/core/compiler/xla/service/async_collective_creator.h"
#include "itex/core/compiler/xla/service/batchnorm_expander.h"
#include "itex/core/compiler/xla/service/bfloat16_normalization.h"
#include "itex/core/compiler/xla/service/bitcast_dtypes_expander.h"
#include "itex/core/compiler/xla/service/buffer_assignment.h"
#include "itex/core/compiler/xla/service/call_inliner.h"
#include "itex/core/compiler/xla/service/collectives_schedule_linearizer.h"
#include "itex/core/compiler/xla/service/comparison_expander.h"
#include "itex/core/compiler/xla/service/conditional_canonicalizer.h"
#include "itex/core/compiler/xla/service/conditional_simplifier.h"
#include "itex/core/compiler/xla/service/convert_mover.h"
#include "itex/core/compiler/xla/service/convolution_4d_expander.h"
#include "itex/core/compiler/xla/service/copy_insertion.h"
#include "itex/core/compiler/xla/service/dot_decomposer.h"
#include "itex/core/compiler/xla/service/dot_merger.h"
#include "itex/core/compiler/xla/service/dump.h"
#include "itex/core/compiler/xla/service/dynamic_dimension_simplifier.h"
#include "itex/core/compiler/xla/service/dynamic_index_splitter.h"
#include "itex/core/compiler/xla/service/dynamic_padder.h"
#include "itex/core/compiler/xla/service/eigh_expander.h"
#include "itex/core/compiler/xla/service/flatten_call_graph.h"
#include "itex/core/compiler/xla/service/gather_expander.h"
#include "itex/core/compiler/xla/service/gather_simplifier.h"
#include "itex/core/compiler/xla/service/gpu/alias_passthrough_params.h"
#include "itex/core/compiler/xla/service/gpu/all_reduce_blueconnect.h"
#include "itex/core/compiler/xla/service/gpu/fusion_merger.h"
#include "itex/core/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
#include "itex/core/compiler/xla/service/gpu/gemm_rewriter.h"
#include "itex/core/compiler/xla/service/gpu/gpu_constants.h"
#include "itex/core/compiler/xla/service/gpu/gpu_conv_rewriter.h"
#include "itex/core/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include "itex/core/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "itex/core/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
#include "itex/core/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
#include "itex/core/compiler/xla/service/gpu/gpu_scatter_expander.h"
#include "itex/core/compiler/xla/service/gpu/horizontal_input_fusion.h"
#include "itex/core/compiler/xla/service/gpu/horizontal_loop_fusion.h"
#include "itex/core/compiler/xla/service/gpu/instruction_fusion.h"
#include "itex/core/compiler/xla/service/gpu/ir_emission_utils.h"
#include "itex/core/compiler/xla/service/gpu/ir_emitter_context.h"
#include "itex/core/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "itex/core/compiler/xla/service/gpu/kernel_thunk.h"
#include "itex/core/compiler/xla/service/gpu/launch_dimensions.h"
#include "itex/core/compiler/xla/service/gpu/multi_output_fusion.h"
#include "itex/core/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
#include "itex/core/compiler/xla/service/gpu/reduction_dimension_grouper.h"
#include "itex/core/compiler/xla/service/gpu/reduction_layout_normalizer.h"
#include "itex/core/compiler/xla/service/gpu/reduction_splitter.h"
#include "itex/core/compiler/xla/service/gpu/runtime_intrinsics.h"
#include "itex/core/compiler/xla/service/gpu/scatter_slice_simplifier.h"
#include "itex/core/compiler/xla/service/gpu/stream_assignment.h"
#include "itex/core/compiler/xla/service/gpu/stream_executor_util.h"
#include "itex/core/compiler/xla/service/gpu/target_constants.h"
#include "itex/core/compiler/xla/service/gpu/thunk_schedule.h"
#include "itex/core/compiler/xla/service/gpu/tree_reduction_rewriter.h"
#include "itex/core/compiler/xla/service/gpu/variadic_op_splitter.h"
#include "itex/core/compiler/xla/service/hlo_computation.h"
#include "itex/core/compiler/xla/service/hlo_constant_folding.h"
#include "itex/core/compiler/xla/service/hlo_cse.h"
#include "itex/core/compiler/xla/service/hlo_dataflow_analysis.h"
#include "itex/core/compiler/xla/service/hlo_dce.h"
#include "itex/core/compiler/xla/service/hlo_instruction.h"
#include "itex/core/compiler/xla/service/hlo_instructions.h"
#include "itex/core/compiler/xla/service/hlo_parser.h"
#include "itex/core/compiler/xla/service/hlo_pass_fix.h"
#include "itex/core/compiler/xla/service/hlo_proto_util.h"
#include "itex/core/compiler/xla/service/hlo_verifier.h"
#include "itex/core/compiler/xla/service/llvm_ir/llvm_util.h"
#include "itex/core/compiler/xla/service/logistic_expander.h"
#include "itex/core/compiler/xla/service/loop_schedule_linearizer.h"
#include "itex/core/compiler/xla/service/operand_upcaster.h"
#include "itex/core/compiler/xla/service/optimization_barrier_expander.h"
#include "itex/core/compiler/xla/service/qr_expander.h"
#include "itex/core/compiler/xla/service/real_imag_expander.h"
#include "itex/core/compiler/xla/service/reduce_scatter_combiner.h"
#include "itex/core/compiler/xla/service/reshape_mover.h"
#include "itex/core/compiler/xla/service/result_caster.h"
#include "itex/core/compiler/xla/service/rng_bit_generator_expander.h"
#include "itex/core/compiler/xla/service/rng_expander.h"
#include "itex/core/compiler/xla/service/scatter_simplifier.h"
#include "itex/core/compiler/xla/service/sharding_propagation.h"
#include "itex/core/compiler/xla/service/sharding_remover.h"
#include "itex/core/compiler/xla/service/simplify_fp_conversions.h"
#include "itex/core/compiler/xla/service/sort_simplifier.h"
#include "itex/core/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.h"
#include "itex/core/compiler/xla/service/stable_sort_expander.h"
#include "itex/core/compiler/xla/service/transpose_folding.h"
#include "itex/core/compiler/xla/service/tuple_simplifier.h"
#include "itex/core/compiler/xla/service/while_loop_constant_sinking.h"
#include "itex/core/compiler/xla/service/while_loop_simplifier.h"
#include "itex/core/compiler/xla/service/while_loop_trip_count_annotator.h"
#include "itex/core/compiler/xla/service/zero_sized_hlo_elimination.h"
#include "itex/core/utils/casts.h"
#include "itex/core/utils/env.h"
#include "itex/core/utils/env_var.h"
#include "itex/core/utils/logging.h"
#include "itex/core/utils/path.h"
#include "itex/core/utils/regexp.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/SplitModule.h"
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "mlir-hlo/Transforms/passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"         // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h"        // from @llvm-project
#include "mlir/Dialect/GPU/Transforms/Passes.h"  // from @llvm-project
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"                          // from @llvm-project
#include "mlir/InitAllDialects.h"                        // from @llvm-project
#include "mlir/Pass/PassManager.h"                       // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
#include "mlir/Transforms/LocationSnapshot.h"            // from @llvm-project
#include "mlir/Transforms/Passes.h"                      // from @llvm-project

namespace itex_xla {
namespace gpu {

constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";

class GpuBfloat16Support : public BFloat16Support {
 public:
  explicit GpuBfloat16Support(bool supports_matrix_multiplication)
      : supports_matrix_multiplication_(supports_matrix_multiplication) {}

  bool SupportsBF16Operand(const HloInstruction& hlo,
                           int64_t operand_index) const override {
    return BFloat16Support::SupportsBF16Operand(hlo, operand_index) ||
           IsSupported(hlo);
  }

  // Returns whether the backend supports BF16 output for the HLO instruction.
  bool SupportsBF16Output(const HloInstruction& hlo) const override {
    return BFloat16Support::SupportsBF16Output(hlo) || IsSupported(hlo);
  }

 private:
  bool IsSupported(const HloInstruction& hlo) const {
    switch (hlo.opcode()) {
      case HloOpcode::kAllGather:
      case HloOpcode::kAllReduce:
      case HloOpcode::kAllReduceStart:
      case HloOpcode::kAllReduceDone:
      case HloOpcode::kAllToAll:
      case HloOpcode::kCollectivePermute:
      case HloOpcode::kReduceScatter:
      // Data movement only ops.
      case HloOpcode::kBroadcast:
      case HloOpcode::kConcatenate:
      case HloOpcode::kCopy:
      case HloOpcode::kDynamicSlice:
      case HloOpcode::kDynamicUpdateSlice:
      case HloOpcode::kGather:
      case HloOpcode::kPad:
      case HloOpcode::kReshape:
      case HloOpcode::kReverse:
      case HloOpcode::kScatter:
      case HloOpcode::kSelect:
      case HloOpcode::kSelectAndScatter:
      case HloOpcode::kSlice:
      case HloOpcode::kTranspose:
      // Other special ops.
      case HloOpcode::kBitcast:
        return true;
      case HloOpcode::kConvolution:
        return IsConvBF16Supported();
      default:
        return supports_matrix_multiplication_ &&
               gpu::IsMatrixMultiplication(hlo);
    }
  }

  bool IsConvBF16Supported() const { return true; }

  bool supports_matrix_multiplication_;
};

namespace {
int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
  if (shape.is_static() || shape.IsTuple()) {
    return ShapeUtil::ByteSizeOf(shape, pointer_size);
  }
  // Each dynamic dimension size is represented as a S32.
  int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size();
  return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
}

bool ConvIsLowerable(HloInstruction* conv) {
  return conv_matchers::CanImplementAsGpuForwardConv(conv) ||
         std::get<0>(conv_matchers::MatchBackwardFilter(conv)) ||
         std::get<0>(conv_matchers::MatchBackwardInput(conv));
}

bool EnableGpuMlirLowering();

}  // end anonymous namespace

using OwnedThunkSchedule = GpuExecutable::OwnedThunkSchedule;
using OwnedBefBuffer = GpuExecutable::OwnedBefBuffer;
using Dim3D = LaunchDimensions::Dim3D;

#define L0_SAFE_CALL(call)                      \
  {                                             \
    ze_result_t status = (call);                \
    if (status != 0) {                          \
      ITEX_LOG(FATAL) << "L0 error " << status; \
    }                                           \
  }

namespace {
std::vector<ze_device_handle_t> GetDeviceList() {
  uint32_t driver_count = 0;
  L0_SAFE_CALL(zeDriverGet(&driver_count, nullptr));
  std::vector<ze_driver_handle_t> driver_list(driver_count);
  L0_SAFE_CALL(zeDriverGet(&driver_count, driver_list.data()));

  std::vector<ze_device_handle_t> device_list_total;
  for (auto driver : driver_list) {
    uint32_t device_count = 0;
    L0_SAFE_CALL(zeDeviceGet(driver, &device_count, nullptr));
    std::vector<ze_device_handle_t> device_list(device_count);
    L0_SAFE_CALL(zeDeviceGet(driver, &device_count, device_list.data()));

    for (auto device : device_list) {
      ze_device_properties_t props{
          ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES,
      };
      L0_SAFE_CALL(zeDeviceGetProperties(device, &props));
      if (props.type == ZE_DEVICE_TYPE_GPU) {
        device_list_total.push_back(device);
      }
    }
  }
  return device_list_total;
}
}  // namespace
GpuDeviceInfo GetGpuDeviceInfo() {
  auto device = GetDeviceList()[0];
  ze_device_properties_t prop{
      ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES,
  };
  L0_SAFE_CALL(zeDeviceGetProperties(device, &prop));
  ze_device_compute_properties_t c_prop{
      ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES,
  };
  L0_SAFE_CALL(zeDeviceGetComputeProperties(device, &c_prop));

  int64_t xla_subgroup_size = 32;
  TF_ABORT_IF_ERROR(
      itex::ReadInt64FromEnvVar("XLA_SUBGROUP_SIZE", 32, &xla_subgroup_size));

  GpuDeviceInfo gpu_device_info;
  gpu_device_info.threads_per_block_limit = c_prop.maxTotalGroupSize;
  gpu_device_info.threads_per_core_limit =
      c_prop.maxTotalGroupSize / prop.numThreadsPerEU;
  gpu_device_info.threads_per_warp = xla_subgroup_size;
  gpu_device_info.block_dim_limit_x = c_prop.maxGroupCountX;
  gpu_device_info.block_dim_limit_y = c_prop.maxGroupCountY;
  gpu_device_info.block_dim_limit_z = c_prop.maxGroupCountZ;
  gpu_device_info.core_count = prop.numSlices * prop.numSubslicesPerSlice;
  gpu_device_info.shared_memory_per_block = c_prop.maxSharedLocalMemory;

  ITEX_VLOG(3) << "threads_per_block_limit: "
               << gpu_device_info.threads_per_block_limit;
  ITEX_VLOG(3) << "threads_per_core_limit: "
               << gpu_device_info.threads_per_core_limit;
  ITEX_VLOG(3) << "threads_per_warp: " << gpu_device_info.threads_per_warp;
  ITEX_VLOG(3) << "block_dim_limit_x: " << gpu_device_info.block_dim_limit_x;
  ITEX_VLOG(3) << "block_dim_limit_y: " << gpu_device_info.block_dim_limit_y;
  ITEX_VLOG(3) << "block_dim_limit_z: " << gpu_device_info.block_dim_limit_z;
  ITEX_VLOG(3) << "core_count: " << gpu_device_info.core_count;
  ITEX_VLOG(3) << "shared_memory_per_block: "
               << gpu_device_info.shared_memory_per_block;
  return gpu_device_info;
}
#undef L0_SAFE_CALL

GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
                         const char* target_triple, const char* data_layout)
    : platform_id_(platform_id),
      target_triple_(target_triple),
      data_layout_(data_layout),
      pointer_size_(llvm::DataLayout(data_layout_)
                        .getPointerSize(0 /* default address space */)) {}

// Runs optimization passes on the given HLO module.
Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module) {
  const DebugOptions& debug_options = hlo_module->config().debug_options();

  AlgebraicSimplifierOptions layout_insensitive_algsimp_opts({},
                                                             ConvIsLowerable);
  // "slow" minmax means we propagate nan.
  layout_insensitive_algsimp_opts.set_minmax_propagate_nan(
      !debug_options.xla_gpu_enable_fast_min_max());
  layout_insensitive_algsimp_opts.set_enable_conv_operand_swap(false);

  if (hlo_module->config().use_spmd_partitioning()) {
    HloPassPipeline spmd_pipeline("spmd-partitioner");
    const int64_t num_partitions = hlo_module->config().num_partitions();
    if (num_partitions > 1) {
      // Run some IR cleanup passes before running the SPMD partitioning
      // passes.
      spmd_pipeline.AddInvariantChecker<HloVerifier>(
          /*layout_sensitive=*/false,
          /*allow_mixed_precision=*/false);
      spmd_pipeline.AddPass<CallInliner>();
      spmd_pipeline.AddPass<ZeroSizedHloElimination>();
      spmd_pipeline.AddPass<ConditionalCanonicalizer>();

      HloPassPipeline& spmd_simplify =
          spmd_pipeline.AddPass<HloPassFix<HloPassPipeline>>("spmd-simplify");

      spmd_simplify.AddPass<AlgebraicSimplifier>(
          layout_insensitive_algsimp_opts);

      spmd_simplify.AddPass<SortSimplifier>();
      spmd_simplify.AddPass<TupleSimplifier>();
      spmd_simplify.AddPass<ScatterExpander>(
          ScatterExpander::kEliminateSimpleScatters);
      spmd_simplify.AddPass<GatherExpander>(
          GatherExpander::kEliminateSimpleGathers);
      spmd_simplify.AddPass<WhileLoopConstantSinking>();
      spmd_simplify.AddPass<WhileLoopSimplifier>();

      spmd_simplify.AddPass<ReshapeMover>();
      spmd_simplify.AddPass<HloConstantFolding>();
      spmd_simplify.AddPass<ConditionalSimplifier>();
      spmd_simplify.AddPass<HloDCE>();

      spmd_pipeline.AddPass<ShardingPropagation>(/*is_spmd=*/true);
      spmd_pipeline.AddPass<spmd::StatefulRngSpmdPartitioner>(
          num_partitions, hlo_module->config().replica_count());
    } else {
      // Remove redundant sharding ops when partition_count == 1.
      spmd_pipeline.AddPass<ShardingRemover>();
      spmd_pipeline.AddPass<HloDCE>();
    }
    TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status());
  }

  {
    HloPassPipeline pipeline("optimization");
    pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
                                              /*allow_mixed_precision=*/false);
    pipeline.AddPass<AllToAllDecomposer>();

    OpExpanderPass::PatternExtraFilter upcaster_filter =
        [&](const HloInstruction* instr) {
          return !gpu::IsMatrixMultiplication(*instr);
        };

    pipeline.AddPass<OperandUpcaster>(upcaster_filter);
    pipeline.AddPass<ResultCaster>(upcaster_filter);

    // Expand random number generation.
    pipeline.AddPass<RngExpander>();
    pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);

    // Comparison total order expander
    pipeline.AddPass<ComparisonExpander>();

    // Remove zero-sized HLO from the input so that other passes don't have to
    // handle it.
    pipeline.AddPass<ZeroSizedHloElimination>();

    if (debug_options.xla_gpu_deterministic_ops()) {
      // Scatter is nondeterministic, so eliminate all Scatters.
      pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
    } else {
      // Only Scatters unsupported on XLA:GPU are eliminated.
      pipeline.AddPass<GpuScatterExpander>();
    }
    // TODO(phawkins): replace QR and Eigh decompositions with calls to
    // cuSOLVER.
    pipeline.AddPass<QrExpander>();
    pipeline.AddPass<EighExpander>();

    pipeline.AddPass<DynamicIndexSplitter>();

    // TODO(b/64094172): make Call work on GPU instead of inlining.
    pipeline.AddPass<CallInliner>();

    pipeline.AddPass<DotDecomposer>();

    pipeline.AddPass<Convolution4DExpander>();

    // Expand the sort op to support stable sorting if required.
    pipeline.AddPass<StableSortExpander>();

    GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/true);
    pipeline.AddPass<BFloat16Normalization>(&bf16);

    // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization.
    if (debug_options.xla_gpu_simplify_all_fp_conversions())
      pipeline.AddPass<SimplifyFPConversions>();

    pipeline.AddPass<BatchNormExpander>(
        /*rewrite_training_op=*/true,
        /*rewrite_inference_op=*/true,
        /*rewrite_grad_op=*/true);

    pipeline.AddPass<LogisticExpander>(
        /*expansion_type=*/LogisticExpansionType::kExp);
    pipeline.AddPass<ConditionalCanonicalizer>();
    pipeline.AddPass<DynamicDimensionSimplifier>();

    DynamicPadderOptions dynamic_padder_options;
    switch (hlo_module->config().debug_options().xla_gpu_shape_checks()) {
      case DebugOptions::IGNORE:
        dynamic_padder_options.shape_check_mode =
            DynamicDimensionInference::ShapeCheckMode::kIgnore;
        break;
      case DebugOptions::RUNTIME: {
        dynamic_padder_options.shape_check_mode =
            DynamicDimensionInference::ShapeCheckMode::kRuntime;
        dynamic_padder_options.assertion_generator = [&](HloInstruction* inst) {
          auto created = Cast<HloCustomCallInstruction>(
              inst->parent()->AddInstruction(HloInstruction::CreateCustomCall(
                  ShapeUtil::MakeTokenShape(), {inst},
                  kXlaGpuAssertCustomCallTag,
                  "Buffers have different size at runtime",
                  API_VERSION_STATUS_RETURNING)));
          created->set_custom_call_has_side_effect(true);
        };
        break;
      }
      case DebugOptions::COMPILE_TIME:
        dynamic_padder_options.shape_check_mode =
            DynamicDimensionInference::ShapeCheckMode::kCompileTime;
        break;
      default:
        ITEX_LOG(FATAL) << "Unreachable";
    }
    pipeline.AddPass<DynamicPadder>(dynamic_padder_options);

    // Build simplification pipeline.  The passes in here are run to a fixed
    // point.
    [&, &pipeline =
            pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
      pipeline.AddInvariantCheckerDebug<HloVerifier>(
          /*layout_sensitive=*/false,
          /*allow_mixed_precision=*/false);

      // BatchNormExpander can create zero-sized ops, so zero-sized HLO
      // elimination has to come after that pass.
      pipeline.AddPass<ZeroSizedHloElimination>();

      pipeline.AddPass<GatherSimplifier>();
      pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
      pipeline.AddPass<ScatterSimplifier>();
      pipeline.AddPass<ScatterExpander>(
          ScatterExpander::kEliminateSimpleScatters);
      pipeline.AddPass<ScatterSliceSimplifier>();

      pipeline.AddPass<AlgebraicSimplifier>(layout_insensitive_algsimp_opts);
      pipeline.AddPass<BitcastDtypesExpander>();
      // AlgebraicSimplifier may add contracting dimensions to a dot.
      pipeline.AddPass<DotDecomposer>();
      // Only merge "smallish" dots.  This threshold was not set carefully, but
      // so far we know that 1mb is too small.
      pipeline.AddPass<DotMerger>(/*max_size_to_merge=*/int64_t{16} << 20);
      pipeline.AddPass<SortSimplifier>();
      pipeline.AddPass<TupleSimplifier>();
      pipeline.AddPass<WhileLoopConstantSinking>();
      pipeline.AddPass<WhileLoopSimplifier>();

      // TODO(b/134075051): Re-enable after b/134075051 is fixed.
      // pipeline.AddPass<SliceSinker>();

      pipeline.AddPass<ReshapeMover>();
      pipeline.AddPass<HloConstantFolding>();
      pipeline.AddPass<ConditionalSimplifier>();
      pipeline.AddPass<RealImagExpander>();

      pipeline.AddPass<TransposeFolding>();
      pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
      pipeline.AddPass<HloDCE>();
    }();

    // ConvertMover and ReshapeMover fight with each other: ConvertMover wants
    // to move some converts down the graph, but ReshapeMover wants to move them
    // up the graph.  As a compromise, let ReshapeMover run to a fixed point,
    // and then run ConvertMover + algsimp to a fixed point.
    [&, &pipeline =
            pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification-2")] {
      pipeline.AddPass<ConvertMover>();
      pipeline.AddPass<AlgebraicSimplifier>(layout_insensitive_algsimp_opts);
    }();

    // Run WhileLoopTripCountAnnotator at the end of the simplification
    // pipeline, before layout assignment and fusion.  This pass does some
    // pattern-matching on while bodies/conditions, and this is where the HLO is
    // "nicest".
    //
    // It's important that we don't make semantic changes (e.g. unrolling) to
    // any `while` loops after this point, because otherwise the trip-count
    // annotations added by this pass may not be correct after the
    // modifications.
    pipeline.AddPass<WhileLoopTripCountAnnotator>();
    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
  }

  // Optimize collectives generated by SPMD partitioning. Enable these passes
  // otherwise as well so that all collectives can get these optimizations.
  {
    HloPassPipeline collectives_pipeline("collective-optimizations");
    collectives_pipeline.AddPass<AllReduceFolder>();
    collectives_pipeline.AddPass<ReduceScatterCreator>();
    collectives_pipeline.AddPass<AllReduceReassociate>();

    // Run algebraic simplifier to reshape(broadcast) into a broadcast when
    // the reshape is just adding a unit dimension. This will help with the
    // AllGatherBroadcastReorder pass.
    collectives_pipeline.AddPass<AlgebraicSimplifier>(
        layout_insensitive_algsimp_opts);

    collectives_pipeline.AddPass<AllGatherBroadcastReorder>();
    TF_RETURN_IF_ERROR(collectives_pipeline.Run(hlo_module).status());
  }
  // Run target-specific HLO optimization passes for convolution
  // canonicalization.
  TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(hlo_module));

  if (!EnableGpuMlirLowering()) {
    {
      // Run layout assignment in a separate pipeline from
      // "post-layout-assignment" because we want everything after layout
      // assignment to have a layout-sensitive invariant-checker, but
      // HloPassPipeline also runs its invariant checker before any passes are
      // run, meaning, the pipeline that contains layout assignment cannot
      // contain a layout-sensitive verifier!
      HloPassPipeline pipeline("layout assignment");
      // Layout assignment uses alias analysis, which requires the call graph to
      // be flattened.
      pipeline.AddPass<FlattenCallGraph>();
      ChannelLayoutConstraints layout_constraints;
      pipeline.AddPass<GpuLayoutAssignment>(
          hlo_module->mutable_entry_computation_layout(), &layout_constraints);
      TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
    }

    // Run target-specific HLO optimization passes after layout assignment.
    TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module));
  }

  {
    HloPassFix<HloPassPipeline> fusion("fusion");
    // We try to split variadic ops with many parameters into several such ops
    // to avoid exceeding the parameter space.
    fusion.AddPass<VariadicOpSplitter>();
    fusion.AddInvariantCheckerDebug<HloVerifier>(
        /*layout_sensitive=*/true,
        /*allow_mixed_precision=*/false,
        LayoutAssignment::InstructionCanChangeLayout);
    fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
    fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
    fusion.AddPass<FusionMerger>();
    fusion.AddPass<GpuMultiOutputFusion>();
    fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
                           /*only_fusion_computations=*/true);
    fusion.AddPass<HloDCE>();
    TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
  }

  {
    HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
    horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
    horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
    // FusionBitcastLift must be after InstructionFusion, as it undoes
    // part of it.
    // horizontal_fusion.AddPass<FusionBitcastLift>();
    horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
                                      /*only_fusion_computations=*/true);
    horizontal_fusion.AddPass<HloDCE>();
    TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
  }

  {
    HloPassPipeline pipeline("post-fusion optimization");
    pipeline.AddPass<AllGatherCombiner>(
        /*combine_threshold_in_bytes=*/1024 * 1024 * 1024,
        /*combine_threshold_count=*/256);
    pipeline.AddPass<AllReduceCombiner>(
        debug_options.xla_gpu_all_reduce_combine_threshold_bytes(),
        /*combine_threshold_count=*/256);
    pipeline.AddPass<ReduceScatterCombiner>(
        /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
        /*combine_threshold_count=*/256);

    if (!EnableGpuMlirLowering() &&
        debug_options.xla_gpu_all_reduce_contiguous()) {
      pipeline.AddPass<AllReduceContiguous>();
    }

    int32_t blueconnect_num_devices_per_host =
        debug_options.xla_gpu_all_reduce_blueconnect_num_devices_per_host();
    if (blueconnect_num_devices_per_host > 0) {
      pipeline.AddPass<AllReduceBlueConnect>(blueconnect_num_devices_per_host);
    }

    if (debug_options.xla_gpu_enable_async_all_reduce()) {
      AsyncCollectiveCreator::CollectiveCreatorConfig config;
      config.convert_all_reduce = [](const HloInstruction*) { return true; };
      pipeline.AddPass<AsyncCollectiveCreator>(std::move(config));
    }

    pipeline.AddPass<CollectivesScheduleLinearizer>();

    AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts;
    options.set_is_layout_sensitive(true);
    pipeline.AddPass<AlgebraicSimplifier>(options);
    pipeline.AddPass<OptimizationBarrierExpander>();
    pipeline.AddPass<TupleSimplifier>();

    TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
  }

  return Status::OK();
}

// Modifies the given HLO module so that it will be accepted by IrEmitter.
// Unlike optimization passes, the passes are necessary for correctness.
Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
  // In some cases, we have to place the result of an instruction in a temporary
  // buffer. For instance, the buffer that holds an external parameter is
  // assumed immutable at this point, and should not be reused for output
  // (b/27180329). Therefore, in that case, we set the output to be a copy of
  // the parameter.
  HloPassPipeline pipeline("GPU-ir-emit-prepare");
  pipeline.AddInvariantCheckerDebug<HloVerifier>(
      /*layout_sensitive=*/true,
      /*allow_mixed_precision=*/false,
      LayoutAssignment::InstructionCanChangeLayout);

  // Copy insertion should be performed immediately before IR emission to avoid
  // inserting unnecessary copies (later pass adds an instruction which
  // materializes the value) or missing a necessary copy (later pass removes an
  // instruction which materializes a value). DCE must be run immediately before
  // (and sometime after) copy insertion, to avoid dead code from interfering
  // with the rewrites.
  pipeline.AddPass<HloDCE>();
  if (hlo_module->config().alias_passthrough_params()) {
    pipeline.AddPass<AliasPassthroughParams>();
  }
  pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
  pipeline.AddPass<CopyInsertion>(GetCanShareBuffer());
  pipeline.AddPass<GpuSanitizeConstantNames>();
  return pipeline.Run(hlo_module).status();
}

Status GpuCompiler::OptimizeHloPostLayoutAssignment(HloModule* hlo_module) {
  const DebugOptions& debug_options = hlo_module->config().debug_options();
  HloPassPipeline pipeline("post-layout_assignment");
  pipeline.AddInvariantCheckerDebug<HloVerifier>(
      true, false, LayoutAssignment::InstructionCanChangeLayout);

  pipeline.AddPass<ReductionDegenerateDimRemover>();
  pipeline.AddPass<ReductionLayoutNormalizer>();
  pipeline.AddPass<ReductionDimensionGrouper>();
  pipeline.AddPass<HloPassFix<ReductionSplitter>>();
  pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>();

  // The LayoutAssignment pass may leave behind kCopy instructions which are
  // duplicate or NOPs, so remove them with algebraic simplification and CSE.
  AlgebraicSimplifierOptions options;
  options.set_is_layout_sensitive(true);
  options.set_enable_conv_operand_swap(false);
  // "slow" minmax means we propagate nan.
  options.set_minmax_propagate_nan(
      !hlo_module->config().debug_options().xla_gpu_enable_fast_min_max());
  pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);

  // GemmRewriter assumes that all transposes are folded into gemms, but,
  // since commit 7d529df, this is not always true at this point.
  // Therefore, rerun transpose folding.
  pipeline.AddPass<TransposeFolding>(
      [](const HloInstruction& dot,
         const TransposeFolding::OperandIndices& candidate_operands) {
        return IsMatrixMultiplication(dot) ? candidate_operands
                                           : TransposeFolding::OperandIndices{};
      },
      TransposeFolding::NeverFoldTranspose);

  // Rewrite GEMMs into custom calls.
  pipeline.AddPass<GemmRewriter>();

  // Rewrite GEMMs with broadcasted inputs as strided GEMMs.
  // pipeline.AddPass<GemmBroadcastFoldingRewriter>();

  // Run conversion again, to catch those matrix multiplications which were not
  // rewritten into cuBLAS calls.
  GpuBfloat16Support bf16(false);
  pipeline.AddPass<BFloat16Normalization>(&bf16);

  // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization.
  if (debug_options.xla_gpu_simplify_all_fp_conversions())
    pipeline.AddPass<SimplifyFPConversions>();
  /*
  // Choose the fastest algorithm for each conv.
  //
  // We pick the algorithm before fusion so we can generate better HLO. After
  // GpuConvRewriter, our convolutions are CustomCalls which return a
  // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
  // scratch:
  //
  //   customcall = (f32[...], f32[0])
  //   return gte(customcall, 0)
  //
  // The algorithm picker then chooses the best algorithm, and potentially
  // increases the scratch space.  It replaces customcall with new_tuple,
  // giving us the following:
  //
  //   new_customcall = (f32[...], f32[N])
  //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
  //   return gte(new_tuple, 0)
  //
  // The new tuple and gte instructions then be simplified away, because
  // nobody is expected to use the scratch value.
  //
  // However, if we were to run GpuConvAlgorithmPicker after fusion
  // the gte(customcall, 0) would probably already be into a fusion node.  We
  // can't simplify across HloComputation boundaries, so in this case we
  // wouldn't be able to simplify away the new_tuple bits.
  pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
  */
  // Clean up new_tuple described above.
  pipeline.AddPass<TupleSimplifier>();

  pipeline.AddPass<HloCSE>(true);
  TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());

  return Status::OK();
}

StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
    const CompileOptions& options) {
  // DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
  // We dump the post-optimization HLO in RunBackend so no need to dump it here.
  XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
  TF_RETURN_IF_ERROR(OptimizeHloModule(module.get()));
  TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));

  return std::move(module);
}

using OutputInfoMap =
    absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
static Status GetMlirAllocationInfo(mlir::func::FuncOp func,
                                    std::vector<BufferAllocation>* allocations,
                                    OutputInfoMap* output_info,
                                    Shape* output_shape,
                                    EntryFunctionAttributes* entry_func_attrs);

namespace {
// Removes all globals from the given module that are both uninitialized and
// have no uses within that module.
void RemoveUnusedAndUninitializedGlobals(
    llvm::Module* llvm_module,
    const std::vector<GpuExecutable::ConstantInfo>& constants) {
  for (const auto& info : constants) {
    // Empty content means the constant is initialized in the LLVM IR, so we
    // must not remove it.
    if (!info.content.empty()) {
      llvm::GlobalVariable* global =
          llvm_module->getGlobalVariable(info.symbol_name);
      ITEX_CHECK(global != nullptr);
      if (global->use_empty()) {
        global->eraseFromParent();
      }
    }
  }
}

bool EnableGpuMlirLowering() {
  bool enable_gpu_mlir_lowering = false;
  TF_ABORT_IF_ERROR(itex::ReadBoolFromEnvVar("ENABLE_GPU_MLIR_LOWERING", false,
                                             &enable_gpu_mlir_lowering));
  return enable_gpu_mlir_lowering;
}

// Force using GPU MLIR pipeline if ENABLE_GPU_MLIR_LOWERING is ON.
bool DisableGpuMlirFallBack() {
  bool disable_gpu_mlir_fallback = false;
  TF_ABORT_IF_ERROR(itex::ReadBoolFromEnvVar("DISABLE_GPU_MLIR_FALLBACK", false,
                                             &disable_gpu_mlir_fallback));
  return disable_gpu_mlir_fallback;
}
}  // namespace

struct CompileModuleResults {
  std::unique_ptr<llvm::Module> llvm_module;
  std::unique_ptr<BufferAssignment> buffer_assignment;
  std::vector<BufferAllocation> allocations;
  absl::variant<OwnedThunkSchedule, OwnedBefBuffer> thunks_or_bef;
  EntryFunctionAttributes entry_func_attrs;
  std::vector<GpuExecutable::ConstantInfo> constants;
  OutputInfoMap output_info;
  Shape output_shape;
  std::string module_name;
  std::vector<std::vector<uint8_t>> spv_binary_vec;
  std::vector<std::string> launch_func_seqs;
  std::vector<std::pair<Dim3D, Dim3D>> launch_func_dims;
  bool mlir_compiler_success;
};

// The order of `thunk_sequence` corresponds to
// `hlo_schedule->ThunkLaunchOrder()`.
static Status CompileModuleToLlvmIrImpl(
    HloModule* hlo_module, llvm::LLVMContext* llvm_context,
    const std::string& target_triple, const std::string& data_layout,
    const std::string& platform_name, GpuDeviceInfo gpu_device_info,
    const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
    int pointer_size, CompileModuleResults* results) {
  results->llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
  results->llvm_module->setTargetTriple(target_triple);
  results->llvm_module->setDataLayout(data_layout);

  std::unique_ptr<StreamAssignment> stream_assignment =
      AssignStreams(*hlo_module);
  TF_ASSIGN_OR_RETURN(
      std::unique_ptr<GpuHloSchedule> hlo_schedule,
      GpuHloSchedule::Build(hlo_module, *stream_assignment, pointer_size));

  auto buffer_size_bytes_function =
      [pointer_size](const BufferValue& buffer_value) -> int64_t {
    return GetSizeOfShape(buffer_value.shape(), pointer_size);
  };

  TF_ASSIGN_OR_RETURN(
      results->buffer_assignment,
      BufferAssigner::Run(
          hlo_module, hlo_schedule->ConsumeHloOrdering(),
          buffer_size_bytes_function,
          /*color_alignment=*/
          [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
          /*allocate_buffers_for_constants=*/true,
          /*colorer=*/BufferAssigner::DefaultColorer(),
          /*must_not_live_out=*/{}, can_share_buffer_function));

  ITEX_VLOG(1) << "Buffer Assignment Stats "
               << results->buffer_assignment->GetStats().ToString();
  DumpHloModuleIfEnabled(*hlo_module, *results->buffer_assignment,
                         absl::StrCat("gpu_after_optimizations"));

  mlir::MLIRContext init_mlir_context;
  mlir::MLIRContext* mlir_context = &init_mlir_context;
  mlir_context->loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
                            mlir::arith::ArithDialect, mlir::func::FuncDialect,
                            mlir::lmhlo_gpu::LmhloGpuDialect>();
  mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
      mlir::ModuleOp::create(mlir::Builder(mlir_context).getUnknownLoc());

  TF_RETURN_IF_ERROR(
      HloToLhloModule(*results->buffer_assignment, *hlo_module, *mlir_module));

  results->module_name = mlir::GetNameFromLoc(mlir_module->getLoc());

  if (DumpingEnabledForHloModule(*hlo_module)) {
    DumpToFileInDirOrStdout(*hlo_module, "lmhlo", mlir_module.get());
  }

  results->mlir_compiler_success = EnableGpuMlirLowering();
  if (EnableGpuMlirLowering()) {
    mlir::PassManager pm(mlir_context);

    bool print_after_pass = false;
    TF_ABORT_IF_ERROR(itex::ReadBoolFromEnvVar("ENABLE_MLIR_VERBOSE", false,
                                               &print_after_pass));
    if (print_after_pass) {
      pm.getContext()->disableMultithreading();
      pm.enableIRPrinting(
          /*shouldPrintBeforePass*/
          nullptr, /*shouldPrintAfterPass*/
          [](mlir::Pass*, mlir::Operation*) { return true; },
          /*printModuleScope=*/true,
          /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure=*/false);
    }

    pm.addPass(mlir::createGpuFusionRewritePass());
    if (failed(pm.run(mlir_module.get()))) {
      printf("!!! Failed to run gpu-fusion-rewrite pass\n");
      results->mlir_compiler_success = false;
      if (DisableGpuMlirFallBack())
        return InternalError("Failed to run gpu-fusion-rewrite pass");
    } else {
      if (mlir_module.get().getOps<mlir::gpu::GPUModuleOp>().empty()) {
        results->mlir_compiler_success = false;
        if (DisableGpuMlirFallBack())
          return InternalError("Failed to generate GPU Modules");
      }
      bool validSpvBin = false;
      for (auto gpu_module :
           mlir_module.get().getOps<mlir::gpu::GPUModuleOp>()) {
        results->module_name = gpu_module.getName().str();
        auto spv_attr = gpu_module->getAttrOfType<mlir::StringAttr>(
            mlir::gpu::getDefaultGpuBinaryAnnotation());
        if (spv_attr) {
          validSpvBin = true;
        }
        if (!validSpvBin) {
          if (DisableGpuMlirFallBack())
            return InternalError("SPIRV kernel binary is empty.");
          results->mlir_compiler_success = false;
          break;
        } else {
          auto spv_str = spv_attr.getValue().str();
          llvm::SmallVector<char> spv_data(spv_str.c_str(),
                                           spv_str.c_str() + spv_str.size());
          std::vector<uint8_t> spv_binary(
              reinterpret_cast<uint8_t*>(spv_data.data()),
              reinterpret_cast<uint8_t*>(spv_data.data() + spv_data.size()));
          results->spv_binary_vec.emplace_back(std::move(spv_binary));
          // get kernel names with order
          mlir_module->walk([&](mlir::gpu::LaunchFuncOp launch_func_op) {
            using mlir::getConstantIntValue;
            auto grids = launch_func_op.getGridSizeOperandValues();
            auto blocks = launch_func_op.getBlockSizeOperandValues();
            // NOLINTNEXTLINE
            results->launch_func_dims.emplace_back(std::make_pair<Dim3D, Dim3D>(
                {getConstantIntValue(grids.x).value(),
                 getConstantIntValue(grids.y).value(),
                 getConstantIntValue(grids.z).value()},
                {getConstantIntValue(blocks.x).value(),
                 getConstantIntValue(blocks.y).value(),
                 getConstantIntValue(blocks.z).value()}));

            results->launch_func_seqs.emplace_back(
                "__spv__" + launch_func_op.getKernelModuleName().str());
          });
          break;
        }
      }
    }

    // Restore mlir_context to avoid conflict in LLVM path.
    mlir_module = std::move(
        mlir::ModuleOp::create(mlir::Builder(mlir_context).getUnknownLoc()));

    TF_RETURN_IF_ERROR(HloToLhloModule(*results->buffer_assignment, *hlo_module,
                                       *mlir_module));
  }

  auto entry_function = mlir::cast<mlir::func::FuncOp>(
      mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));

  TF_RETURN_IF_ERROR(GetMlirAllocationInfo(
      entry_function, &results->allocations, &results->output_info,
      &results->output_shape, &results->entry_func_attrs));

  IrEmitterContext ir_emitter_context(
      /*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr, platform_name,
      gpu_device_info, mlir_context, results->llvm_module.get(),
      results->mlir_compiler_success);

  ir_emitter_context.set_allocations(results->allocations);

  TF_ASSIGN_OR_RETURN(
      auto ir_emitter,
      IrEmitterUnnested::Create(hlo_module->config(), &ir_emitter_context));

  {
    XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");

    TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.getBody()));

    bool supports_runtime_managed_constants =
        // TODO(b/218527186): Implement this feature for BEF as well.
        !IsBefThunkEnabled(hlo_module->config()) &&
        hlo_module->config().debug_options().xla_gpu_enable_shared_constants();
    if (supports_runtime_managed_constants) {
      // Remove these globals from the generated code to indicate that XLA is
      // responsible for allocating and initializing them.
      RemoveUnusedAndUninitializedGlobals(ir_emitter_context.llvm_module(),
                                          ir_emitter_context.constants());
    }

    results->constants = std::move(ir_emitter_context.constants());
  }

  results->thunks_or_bef =
      absl::make_unique<ThunkSchedule>(ir_emitter->ConsumeThunkSequence());
  if (results->mlir_compiler_success && EnableGpuMlirLowering()) {
    // set KernelThunk names
    if (absl::holds_alternative<OwnedThunkSchedule>(results->thunks_or_bef)) {
      auto& thunks = absl::get<OwnedThunkSchedule>(results->thunks_or_bef);

      int kernel_idx = 0;
      for (const std::unique_ptr<Thunk>& thunk : thunks->TotalOrder()) {
        KernelThunk* kernel_thunk = dynamic_cast<KernelThunk*>(thunk.get());

        if (nullptr != kernel_thunk &&
            kernel_thunk->kernel_name() == "spv_dummy_kernel") {
          kernel_thunk->set_kernel_name(
              results->launch_func_seqs.at(kernel_idx));
          auto [grids, blocks] = results->launch_func_dims.at(kernel_idx);
          kernel_thunk->set_launch_dims(LaunchDimensions(grids, blocks));
          ++kernel_idx;
        }
      }
    }
  }
  return Status::OK();
}

static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info,
                                  void* context) {
  std::string error_string;
  llvm::raw_string_ostream string_printer(error_string);
  llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer);
  diag_info.print(diagnostic_printer);

  ITEX_VLOG(5) << error_string;
}

StatusOr<std::pair<std::string, std::vector<uint8_t>>>
GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
                                   std::unique_ptr<llvm::Module> llvm_module,
                                   const HloModule* debug_module) {
  using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;

  const auto compile_single_module =
      [this, &module_config, debug_module](
          llvm::Module* llvm_module,
          absl::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
    {
      XLA_SCOPED_LOGGING_TIMER(
          "GpuCompiler::RunBackend - Running LLVM verifier");

      llvm_module->getContext().setDiagnosticHandlerCallBack(
          NullDiagnosticHandler, nullptr);

      std::string err;
      llvm::raw_string_ostream err_stream(err);

      // verifyModule() returns true if the module is broken.
      TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
          << "Invalid LLVM IR before optimizations:\n"
          << err_stream.str()
          << "\nThis probably indicates a bug in the HLO -> LLVM IR "
             "lowering. Rerun with --xla_dump_to to get the IR"
          << (debug_module
                  ? absl::StrCat(" and looks for files with name containing: *",
                                 FilenameFor(*debug_module, "", ""), "*")
                  : ".");
    }
    StatusOr<std::pair<std::string, std::vector<uint8_t>>> result =
        CompileTargetBinary(module_config, llvm_module, debug_module);

    if (!result.ok()) {
      return result;
    }

    const bool should_dump =
        DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
                                   module_config.debug_options());

    if (should_dump) {
      if (debug_module) {
        if (shard_number.has_value()) {
          llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
                                   /*optimized=*/true,
                                   std::to_string(*shard_number));
        } else {
          llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
                                   /*optimized=*/true);
        }
      } else {
        ITEX_LOG(ERROR)
            << "Dumping is not implemented since the file name cannot be "
               "inferred. Please implement (potentially MLIR) module -> "
               "filename heuristic.";
      }
    }

    if (user_post_optimization_hook_) {
      user_post_optimization_hook_(*llvm_module);
    }

    // Write PTX to IR dump directory, if IR dumping was requested.
    if (should_dump) {
      absl::string_view ptx = result->first;
      if (debug_module) {
        if (shard_number.has_value()) {
          DumpToFileInDirOrStdout(*debug_module, "",
                                  std::to_string(*shard_number) + ".ptx", ptx);
        } else {
          DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
        }
      } else {
        ITEX_LOG(ERROR)
            << "Dumping is not implemented since the file name cannot be "
               "inferred. Please implement (potentially MLIR) module -> "
               "filename heuristic.";
      }
    }

    return result;
  };

  // TODO(ITEX): add threadpool
  return compile_single_module(llvm_module.get(),
                               /*shard_number=*/absl::nullopt);
}

StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
    const CompileOptions& options) {
  XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
  // std::string slow_compilation_msg =
  //     absl::StrCat("Compiling module ", module->name());
  // auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);

  llvm::LLVMContext llvm_context;

  GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo();

  // if (module->config().hlo_profiling_enabled() || ITEX_VLOG_IS_ON(1)) {
  //   HloCostAnalysis::Options options{ShapeSizeBytesFunction()};
  //   options.set_bytes_per_second(
  //       stream_exec->GetDeviceDescription().memory_bandwidth());
  //   GpuHloCostAnalysis cost_analysis(options);
  //   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
  //   ITEX_VLOG(1) << "HLO memory read+written: "
  //           << itex::strings::HumanReadableNumBytes(
  //                  cost_analysis.bytes_accessed());
  //   if (module->config().hlo_profiling_enabled()) {
  //     ITEX_LOG(ERROR) << "--xla_hlo_profile for GPU is unsupported.";
  //   }
  // }

  CompileModuleResults compile_module_results;
  TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
      module.get(), &llvm_context, target_triple_, data_layout_,
      stream_exec->platform()->Name(), gpu_device_info, GetCanShareBuffer(),
      pointer_size_, &compile_module_results));

  if (user_pre_optimization_hook_) {
    user_pre_optimization_hook_(*compile_module_results.llvm_module);
  }
  std::string ir_module_string_before_opt;
  const bool embed_ir_in_executable =
      module->config().debug_options().xla_embed_ir_in_executable();
  if (embed_ir_in_executable) {
    ir_module_string_before_opt =
        llvm_ir::DumpModuleToString(*compile_module_results.llvm_module);
  }

  llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module,
                           /*optimized=*/false);

  using BackendCompileResult = std::pair<std::string, std::vector<uint8_t>>;
  BackendCompileResult backend_result;
  if (EnableGpuMlirLowering() &&
      (DisableGpuMlirFallBack() ||
       compile_module_results.mlir_compiler_success)) {
    ITEX_VLOG(1) << "SPIRV kernel is passed to GpuExecutable.";
    backend_result.first = compile_module_results.module_name;
    // TODO(ITEX): Assume there is only 1 kernel in each mlir module. A method
    // is needed to combine multiple spirv kernels to a single binary in
    // ThunkSequence order.
    backend_result.second = std::move(compile_module_results.spv_binary_vec[0]);
  } else {
    ITEX_VLOG(1) << "Build kernel via LLVM kernel compilation.";
    TF_ASSIGN_OR_RETURN(
        backend_result,
        CompileToTargetBinary(module->config(),
                              std::move(compile_module_results.llvm_module),
                              module.get()));
    if (DumpingEnabledForHloModule(*module) &&
        absl::holds_alternative<OwnedThunkSchedule>(
            compile_module_results.thunks_or_bef)) {
      const ThunkSchedule& thunk_schedule =
          *absl::get<OwnedThunkSchedule>(compile_module_results.thunks_or_bef);
      DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
                              thunk_schedule.ToString());
    }
  }

  auto buffer_assignment_proto = std::make_unique<BufferAssignmentProto>(
      compile_module_results.buffer_assignment->ToProto());

  // Make it shared to be captured in the following lambda.
  std::shared_ptr<const BufferAssignment> buffer_assignment(
      std::move(compile_module_results.buffer_assignment));

  TF_ASSIGN_OR_RETURN(
      auto gpu_executable,
      GpuExecutable::Create(
          {std::move(backend_result.first), std::move(backend_result.second),
           std::move(compile_module_results.thunks_or_bef),
           compile_module_results.entry_func_attrs,
           std::move(compile_module_results.constants),
           std::move(compile_module_results.output_info),
           compile_module_results.module_name,
           compile_module_results.output_shape,
           std::move(compile_module_results.allocations),
           std::move(buffer_assignment_proto),
           [buffer_assignment] { return buffer_assignment->ToVerboseString(); },
           std::move(module)}));
  if (embed_ir_in_executable) {
    ITEX_DCHECK_NE("", ir_module_string_before_opt);
    gpu_executable->set_ir_module_string(ir_module_string_before_opt);
  }

  // Dump computation proto state and buffer assignment for debug and test, if
  // dump or embed_ir_in_executable is enabled.
  // if (embed_ir_in_executable ||
  //     DumpingEnabledForHloModule(gpu_executable->module())) {
  //   auto hlo_proto = absl::make_unique<HloProto>();
  //   if (hlo_proto_) {
  //     *hlo_proto = *hlo_proto_;
  //   } else {
  //     *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto();
  //   }
  //   *hlo_proto->mutable_buffer_assignment() = buffer_assignment->ToProto();
  //   gpu_executable->set_hlo_proto(std::move(hlo_proto));
  // }
  // gpu_executable->set_debug_info(buffer_assignment->GetStats().ToString());
  return static_cast<std::unique_ptr<Executable>>(std::move(gpu_executable));
}

HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
  // Capture just the pointer size, not the entire GpuCompiler object.
  return [pointer_size = pointer_size_](const Shape& shape) {
    return GetSizeOfShape(shape, pointer_size);
  };
}

// Analyze the function signature to reconstruct a vector of BufferAllocation
// objects, as well as other output information.
//
// This function also serves as a half-baked verifier for function arg
// attributes, since a full verifier doens't exist yet.
static Status GetMlirAllocationInfo(mlir::func::FuncOp func,
                                    std::vector<BufferAllocation>* allocations,
                                    OutputInfoMap* output_info,
                                    Shape* output_shape,
                                    EntryFunctionAttributes* entry_func_attrs) {
  ITEX_CHECK(allocations->empty());
  allocations->reserve(func.getNumArguments());

  std::vector<int64_t> buffer_sizes;
  for (int i = 0; i < func.getNumArguments(); i++) {
    mlir::BlockArgument arg = func.getArgument(i);

    TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
    mlir::ShapedType type = arg.getType().cast<mlir::ShapedType>();
    TF_ASSIGN_OR_RETURN(auto element_type_bytes,
                        GetElementTypeBytes(type.getElementType()));
    size_t size = type.getNumElements() * element_type_bytes;
    buffer_sizes.push_back(size);
  }

  for (int i = 0; i < func.getNumArguments(); i++) {
    for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
      TF_RET_CHECK(attr.getName() == "lmhlo.params" ||
                   attr.getName() == "lmhlo.param_shape_index" ||
                   attr.getName() == "lmhlo.constant_name" ||
                   attr.getName() == "lmhlo.must_alias" ||
                   attr.getName() == "lmhlo.output_index");
    }
  }

  // Encode buffer parameter metadata in a proto for persisting, because BEF
  // doesn't persist function attributes.
  for (int i = 0; i < func.getNumArguments(); i++) {
    auto buffer = entry_func_attrs->add_buffers();
    if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
      buffer->set_lmhlo_params_present(true);
      buffer->set_lmhlo_params(param_attr.cast<mlir::IntegerAttr>().getInt());
    }
    if (auto shape_index_attr = func.getArgAttr(i, "lmhlo.param_shape_index")) {
      auto param_shape_index = buffer->mutable_lmhlo_param_shape_index();
      for (const llvm::APInt& element :
           shape_index_attr.cast<mlir::DenseIntElementsAttr>()) {
        param_shape_index->add_indices(element.getSExtValue());
      }
    }
    if (auto constant_name_attr = func.getArgAttr(i, "lmhlo.constant_name")) {
      buffer->set_lmhlo_constant_name(
          constant_name_attr.cast<mlir::StringAttr>().str());
    }
    if (func.getArgAttr(i, "lmhlo.must_alias")) {
      buffer->set_lmhlo_must_alias(true);
    }
    if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
      auto output_index = buffer->mutable_lmhlo_output_index();
      for (const llvm::APInt& element :
           output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
        output_index->add_indices(element.getSExtValue());
      }
    }
  }
  entry_func_attrs->set_result_xla_shape(
      func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
          .getValue()
          .str());

  return GpuExecutable::SetUpMlirAllocation(func, buffer_sizes, allocations,
                                            output_info, output_shape);
}

}  // namespace gpu
}  // namespace itex_xla
