// Copyright 2021 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Implementation of the hipfft API forwarding calls to symbols dynamically
// loaded from the real library.
#include "tfrt/gpu/wrapper/hipfft_stub.h"

#include "symbol_loader.h"

// Memoizes load of the .so for this ROCm library.
static void *LoadSymbol(const char *symbol_name) {
  static SymbolLoader loader("libhipfft.so");
  return loader.GetAddressOfSymbol(symbol_name);
}

template <typename Func>
static Func *GetFunctionPointer(const char *symbol_name, Func *func = nullptr) {
  return reinterpret_cast<Func *>(LoadSymbol(symbol_name));
}

// Calls function 'symbol_name' in shared library with 'args'.
// TODO(gkg): Change to 'auto Func' when C++17 is allowed.
template <typename Func, Func *, typename... Args>
static hipfftResult DynamicCall(const char *symbol_name, Args &&...args) {
  static auto func_ptr = GetFunctionPointer<Func>(symbol_name);
  if (!func_ptr) return HIPFFT_INTERNAL_ERROR;
  return func_ptr(std::forward<Args>(args)...);
}

extern "C" {
#include "hipfft_stub.cc.inc"
}
