cmake_minimum_required(VERSION 3.18)

# ---------------------------------------------------------------------------
# Detect CUDA availability BEFORE declaring the project languages.
# On macOS and systems without nvcc, we build a CPU-only extension.
# ---------------------------------------------------------------------------
include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
  project(vikshep_core LANGUAGES CXX CUDA)
  set(VIKSHEP_HAS_CUDA ON)
  message(STATUS "[Vikshep] CUDA detected — building GPU+CPU extension")
else()
  project(vikshep_core LANGUAGES CXX)
  set(VIKSHEP_HAS_CUDA OFF)
  message(STATUS "[Vikshep] No CUDA — building CPU-only extension")
endif()

find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
find_package(pybind11 CONFIG REQUIRED)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# ---------------------------------------------------------------------------
# Target: _vikshep_core — Python extension module (PyPI wheel)
# ---------------------------------------------------------------------------
if(VIKSHEP_HAS_CUDA)
  # ── GPU build: compile .cu files with nvcc ──────────────────────────────
  set(CMAKE_CUDA_STANDARD 17)

  if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
    set(CMAKE_CUDA_ARCHITECTURES 80)
  endif()

  find_package(CUDAToolkit REQUIRED)

  set(SOURCES
    cpp/wst_bindings.cu
    cpp/memory_staging.cu
  )

  pybind11_add_module(_vikshep_core MODULE ${SOURCES})

  target_include_directories(_vikshep_core PRIVATE cpp/)
  target_link_libraries(_vikshep_core PRIVATE
    CUDA::cudart
    CUDA::cufft
  )

  target_compile_options(_vikshep_core PRIVATE
    $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr -O3>
    $<$<COMPILE_LANGUAGE:CXX>:-O3>
  )

else()
  # ── CPU-only build: compile the .cu as plain C++ ────────────────────────
  # Rename to .cpp so the C++ compiler handles it. We use a configure step.
  set(CPU_BINDING_SRC "${CMAKE_CURRENT_BINARY_DIR}/wst_bindings_cpu.cpp")
  configure_file(cpp/wst_bindings.cu "${CPU_BINDING_SRC}" COPYONLY)

  pybind11_add_module(_vikshep_core MODULE "${CPU_BINDING_SRC}")

  target_include_directories(_vikshep_core PRIVATE cpp/)

  # Define VIKSHEP_CPU_ONLY so the source can skip CUDA code paths
  target_compile_definitions(_vikshep_core PRIVATE VIKSHEP_CPU_ONLY)

  target_compile_options(_vikshep_core PRIVATE -O3)
endif()

# Install into the vikshep package directory so scikit-build-core
# places the .so/.pyd next to __init__.py in the wheel.
install(TARGETS _vikshep_core DESTINATION vikshep)

# ---------------------------------------------------------------------------
# Target 2: omni_wst_bridge — Shared library for the Rust Phase 2 orchestrator
#
# Built with -fPIC (mandatory for .so). The bridge header (wst_bridge.h)
# references "rust/cxx.h" which is generated by the Rust build.rs when this
# .so is compiled from within the `omni-wst-sys` crate. For standalone builds
# we provide a minimal stub header via OMNI_CXX_INCLUDE_DIR.
# ---------------------------------------------------------------------------
option(BUILD_CXX_BRIDGE "Build the Rust cxx FFI bridge shared library" OFF)

if(BUILD_CXX_BRIDGE AND VIKSHEP_HAS_CUDA)
  set(BRIDGE_SOURCES
    cpp/wst_bridge.cu
    cpp/memory_staging.cu
  )

  add_library(omni_wst_bridge SHARED ${BRIDGE_SOURCES})

  # Allow callers to supply the cxx-generated rust/cxx.h directory
  if(DEFINED OMNI_CXX_INCLUDE_DIR)
    target_include_directories(omni_wst_bridge PRIVATE ${OMNI_CXX_INCLUDE_DIR})
  endif()

  target_include_directories(omni_wst_bridge PRIVATE cpp/)

  target_link_libraries(omni_wst_bridge PRIVATE
    CUDA::cudart
    CUDA::cufft
  )

  target_compile_options(omni_wst_bridge PRIVATE
    $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr -O3 --compiler-options -fPIC>
    $<$<COMPILE_LANGUAGE:CXX>:-O3 -fPIC>
  )

  set_target_properties(omni_wst_bridge PROPERTIES
    CUDA_SEPARABLE_COMPILATION ON
    POSITION_INDEPENDENT_CODE ON
  )

  install(TARGETS omni_wst_bridge
    LIBRARY DESTINATION lib
    ARCHIVE DESTINATION lib
  )
  install(FILES cpp/wst_bridge.h DESTINATION include)
endif()
