#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
// !!! This is a file automatically generated by hipify!!!
#pragma once

#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

#include <c10/hip/HIPCachingAllocator.h>
#include <c10/hip/HIPException.h>
#include <c10/hip/HIPFunctions.h>
#include <c10/hip/HIPStream.h>

#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/PyInterpreter.h>
#include <hip/hip_runtime_api.h>
#include <cstdint>
#include <optional>

namespace c10::hip::impl {

struct HIPGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  static constexpr DeviceType static_type = DeviceType::HIP;

  HIPGuardImpl() = default;
  explicit HIPGuardImpl(DeviceType t) {
    TORCH_CHECK(
        t == DeviceType::HIP,
        "HIPGuardImpl initialized with non-HIP DeviceType: ",
        t);
  }
  DeviceType type() const override {
    return DeviceType::HIP;
  }
  Device exchangeDevice(Device d) const override {
    TORCH_CHECK(d.is_hip(), "Expected a HIP device, but got ", d);
    auto old_device_index = c10::hip::ExchangeDevice(d.index());
    return Device(DeviceType::HIP, old_device_index);
  }
  Device getDevice() const override {
    DeviceIndex device = 0;
    C10_HIP_CHECK(c10::hip::GetDevice(&device));
    return Device(DeviceType::HIP, device);
  }
  std::optional<Device> uncheckedGetDevice() const noexcept {
    DeviceIndex device{-1};
    const auto err = C10_HIP_ERROR_HANDLED(c10::hip::GetDevice(&device));
    C10_HIP_CHECK_WARN(err);
    if (err != hipSuccess) {
      return std::nullopt;
    }
    return Device(DeviceType::HIP, device);
  }
  void setDevice(Device d) const override {
    TORCH_CHECK(d.is_hip(), "Expected a HIP device, but got ", d);
    C10_HIP_CHECK(c10::hip::SetDevice(d.index()));
  }
  void uncheckedSetDevice(Device d) const noexcept override {
    C10_HIP_CHECK_WARN(c10::hip::MaybeSetDevice(d.index()));
  }
  Stream getStream(Device d) const override {
    return getCurrentHIPStream(d.index()).unwrap();
  }
  Stream getDefaultStream(Device d) const override {
    return getDefaultHIPStream(d.index());
  }
  Stream getNewStream(Device d, int priority = 0) const override {
    return getStreamFromPool(priority, d.index());
  }
  Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
      const override {
    return getStreamFromPool(isHighPriority, d.index());
  }
  // NB: These do NOT set the current device
  Stream exchangeStream(Stream s) const override {
    HIPStream cs(s);
    auto old_stream = getCurrentHIPStream(s.device().index());
    setCurrentHIPStream(cs);
    return old_stream.unwrap();
  }
  DeviceIndex deviceCount() const noexcept override {
    return device_count();
  }

  // Event-related functions
  void createEvent(hipEvent_t* hip_event, const EventFlag flag) const {
    // Maps PyTorch's Event::Flag to HIP flag
    auto hip_flag = hipEventDefault;
    switch (flag) {
      case EventFlag::PYTORCH_DEFAULT:
        hip_flag = hipEventDisableTiming;
        break;
      case EventFlag::BACKEND_DEFAULT:
        hip_flag = hipEventDefault;
        break;
      default:
        TORCH_CHECK(false, "HIP event received unknown flag");
    }

    C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_creation(
          c10::kHIP, reinterpret_cast<uintptr_t>(hip_event));
    }
  }

  void destroyEvent(void* event, const DeviceIndex device_index)
      const noexcept override {
    if (!event)
      return;
    auto hip_event = static_cast<hipEvent_t>(event);
    DeviceIndex orig_device{-1};
    C10_HIP_CHECK_WARN(c10::hip::GetDevice(&orig_device));
    C10_HIP_CHECK_WARN(c10::hip::SetDevice(device_index));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_deletion(
          c10::kHIP, reinterpret_cast<uintptr_t>(hip_event));
    }
    C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
    C10_HIP_CHECK_WARN(c10::hip::SetDevice(orig_device));
  }

  void record(
      void** event,
      const Stream& stream,
      const DeviceIndex device_index,
      const EventFlag flag) const override {
    TORCH_CHECK(
        device_index == -1 || device_index == stream.device_index(),
        "Event device index ",
        device_index,
        " does not match recording stream's device index ",
        stream.device_index(),
        ".");

    hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
    HIPStream hip_stream{stream};

    // Moves to stream's device to record
    const auto orig_device = getDevice();
    setDevice(stream.device());

    // Creates the event (lazily)
    if (!hip_event)
      createEvent(&hip_event, flag);
    C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
    // Makes the void* point to the (possibly just allocated) HIP event
    *event = hip_event;
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_record(
          c10::kHIP,
          reinterpret_cast<uintptr_t>(hip_event),
          reinterpret_cast<uintptr_t>(hip_stream.stream()));
    }

    // Resets device
    setDevice(orig_device);
  }

  void block(void* event, const Stream& stream) const override {
    if (!event)
      return;
    hipEvent_t hip_event = static_cast<hipEvent_t>(event);
    HIPStream hip_stream{stream};
    const auto orig_device = getDevice();
    setDevice(stream.device());
    C10_HIP_CHECK(hipStreamWaitEvent(
        hip_stream,
        hip_event,
        /*flags (must be zero)=*/0));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_wait(
          c10::kHIP,
          reinterpret_cast<uintptr_t>(hip_event),
          reinterpret_cast<uintptr_t>(hip_stream.stream()));
    }
    setDevice(orig_device);
  }

  // May be called from any device
  bool queryEvent(void* event) const override {
    if (!event)
      return true;
    hipEvent_t hip_event = static_cast<hipEvent_t>(event);
    // Note: hipEventQuery can be safely called from any device
    const hipError_t err = C10_HIP_ERROR_HANDLED(hipEventQuery(hip_event));
    if (err != hipErrorNotReady) {
      C10_HIP_CHECK(err);
    } else {
      // ignore and clear the error if not ready
      (void)hipGetLastError();
    }
    return (err == hipSuccess);
  }

  // Stream-related functions
  bool queryStream(const Stream& stream) const override {
    HIPStream hip_stream{stream};
    return hip_stream.query();
  }

  void synchronizeStream(const Stream& stream) const override {
    HIPStream hip_stream{stream};
    hip_stream.synchronize();
  }

  void synchronizeEvent(void* event) const override {
    if (!event)
      return;
    hipEvent_t hip_event = static_cast<hipEvent_t>(event);
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_synchronization(
          c10::kHIP, reinterpret_cast<uintptr_t>(hip_event));
    }
    // Note: hipEventSynchronize can be safely called from any device
    C10_HIP_CHECK(hipEventSynchronize(hip_event));
  }

  // Note: synchronizeDevice can be safely called from any device
  void synchronizeDevice(const c10::DeviceIndex device_index) const override {
    DeviceIndex orig_device{-1};
    C10_HIP_CHECK(c10::hip::GetDevice(&orig_device));
    C10_HIP_CHECK(c10::hip::SetDevice(device_index));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_device_synchronization(c10::kHIP);
    }
    C10_HIP_CHECK(hipDeviceSynchronize());
    C10_HIP_CHECK(c10::hip::SetDevice(orig_device));
  }

  void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
      const override {
    HIPStream hip_stream{stream};
    HIPCachingAllocator::recordStream(data_ptr, hip_stream);
  }

  double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
      const override {
    TORCH_CHECK(
        event1 && event2,
        "Both events must be recorded before calculating elapsed time.");
    // Even though hipEventElapsedTime can be safely called from any device, if
    // the current device is not initialized, it will create a new cuda context,
    // which will consume a lot of memory.
    DeviceIndex orig_device{-1};
    C10_HIP_CHECK(c10::hip::GetDevice(&orig_device));
    C10_HIP_CHECK(c10::hip::SetDevice(device_index));
    hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
    hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
    float time_ms = 0;
    // raise hipErrorNotReady if either event is recorded but not yet completed
    C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
    C10_HIP_CHECK(c10::hip::SetDevice(orig_device));
    return static_cast<double>(time_ms);
  }
};

} // namespace c10::hip::impl

#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
