cmake_minimum_required(VERSION 3.22)
project(pydftb_torch_cpp LANGUAGES CXX)

if(NOT CMAKE_CONFIGURATION_TYPES AND NOT CMAKE_BUILD_TYPE)
  set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
endif()

if(NOT DEFINED DFTB_TORCH_FORCE_PERFORMANCE_FLAGS)
  set(DFTB_TORCH_FORCE_PERFORMANCE_FLAGS ON CACHE BOOL "Force high-performance compiler flags for local validation builds")
endif()
if(NOT DEFINED DFTB_TORCH_ENABLE_IPO)
  set(DFTB_TORCH_ENABLE_IPO ON CACHE BOOL "Enable interprocedural optimization when the compiler supports it")
endif()
if(NOT DEFINED DFTB_TORCH_ENABLE_NATIVE_OPT)
  set(DFTB_TORCH_ENABLE_NATIVE_OPT OFF CACHE BOOL "Enable CPU-native code generation for non-portable local builds")
endif()
if(NOT DEFINED DFTB_TORCH_ENABLE_FAST_MATH)
  set(DFTB_TORCH_ENABLE_FAST_MATH OFF CACHE BOOL "Enable unsafe fast-math flags; disabled by default for numerical reproducibility")
endif()

function(dftb_torch_apply_performance_options target_name)
  if(DFTB_TORCH_FORCE_PERFORMANCE_FLAGS)
    if(MSVC)
      target_compile_options(${target_name} PRIVATE /O2 /DNDEBUG)
    else()
      target_compile_options(${target_name} PRIVATE -O3)
      target_compile_definitions(${target_name} PRIVATE NDEBUG)
    endif()
  endif()

  if(DFTB_TORCH_ENABLE_NATIVE_OPT)
    if(MSVC)
      target_compile_options(${target_name} PRIVATE /arch:AVX2)
    else()
      target_compile_options(${target_name} PRIVATE -march=native -mtune=native)
    endif()
  endif()

  if(DFTB_TORCH_ENABLE_FAST_MATH)
    if(MSVC)
      target_compile_options(${target_name} PRIVATE /fp:fast)
    else()
      target_compile_options(${target_name} PRIVATE -ffast-math -fno-math-errno)
    endif()
  endif()
endfunction()

if(NOT DEFINED DFTB_TORCH_PYBIND11_EXTRAS)
  if(DFTB_TORCH_FORCE_PERFORMANCE_FLAGS OR CMAKE_BUILD_TYPE MATCHES "^(Release|RelWithDebInfo|MinSizeRel)$")
    set(DFTB_TORCH_PYBIND11_EXTRAS ON CACHE BOOL "Enable pybind11 LTO/strip extras for the extension module" FORCE)
  else()
    set(DFTB_TORCH_PYBIND11_EXTRAS OFF CACHE BOOL "Enable pybind11 LTO/strip extras for the extension module" FORCE)
  endif()
endif()

find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

# Make editable/source-tree builds robust when pybind11 was installed from pip
# and CMAKE_PREFIX_PATH was not preconfigured by the frontend.
if(NOT pybind11_DIR)
  execute_process(
    COMMAND ${Python_EXECUTABLE} -c "import pybind11; print(pybind11.get_cmake_dir())"
    OUTPUT_VARIABLE _PYBIND11_CMAKE_DIR
    OUTPUT_STRIP_TRAILING_WHITESPACE
    ERROR_QUIET
  )
  if(_PYBIND11_CMAKE_DIR)
    set(pybind11_DIR "${_PYBIND11_CMAKE_DIR}" CACHE PATH "pybind11 CMake package directory" FORCE)
  endif()
endif()
find_package(pybind11 CONFIG REQUIRED)

# Prefer the torch package imported by this Python interpreter.  Mixing an
# external LIBTORCH_HOME with pip torch can produce RPATH cycles and ABI/library
# mismatches, especially for torch::Tensor pybind11 casters.
if(NOT DEFINED DFTB_TORCH_USE_PYTHON_TORCH)
  set(DFTB_TORCH_USE_PYTHON_TORCH ON CACHE BOOL "Use Python torch package to locate LibTorch")
endif()

if(DFTB_TORCH_USE_PYTHON_TORCH)
  execute_process(
    COMMAND ${Python_EXECUTABLE} -c "import sys, torch, pathlib; print(sys.prefix); print(pathlib.Path(torch.__file__).parent / 'share' / 'cmake' / 'Torch'); print(pathlib.Path(torch.__file__).parent / 'lib')"
    OUTPUT_VARIABLE _PY_TORCH_INFO
    OUTPUT_STRIP_TRAILING_WHITESPACE
    ERROR_QUIET
  )
  string(REPLACE "\n" ";" _PY_TORCH_INFO_LIST "${_PY_TORCH_INFO}")
  list(LENGTH _PY_TORCH_INFO_LIST _PY_TORCH_INFO_LEN)
  if(_PY_TORCH_INFO_LEN LESS 3)
    message(FATAL_ERROR "DFTB_TORCH_USE_PYTHON_TORCH=ON but Python torch could not be queried")
  endif()
  list(GET _PY_TORCH_INFO_LIST 0 _PYTHON_PREFIX)
  list(GET _PY_TORCH_INFO_LIST 1 _TORCH_CMAKE_DIR)
  list(GET _PY_TORCH_INFO_LIST 2 TORCH_PYTHON_LIBDIR)
  list(PREPEND CMAKE_PREFIX_PATH "${_PYTHON_PREFIX}" "${_TORCH_CMAKE_DIR}")
  set(Torch_DIR "${_TORCH_CMAKE_DIR}" CACHE PATH "Torch CMake package directory" FORCE)
  if(DEFINED ENV{LIBTORCH_HOME})
    list(APPEND CMAKE_IGNORE_PREFIX_PATH "$ENV{LIBTORCH_HOME}")
  endif()
  find_package(Torch REQUIRED CONFIG PATHS "${_TORCH_CMAKE_DIR}" NO_DEFAULT_PATH)
else()
  find_package(Torch REQUIRED)
endif()

find_package(ZLIB REQUIRED)

if(DFTB_TORCH_USE_PYTHON_TORCH)
  find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_PYTHON_LIBDIR}" NO_DEFAULT_PATH)
else()
  find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)
endif()
if(NOT TORCH_PYTHON_LIBRARY)
  execute_process(
    COMMAND ${Python_EXECUTABLE} -c "import torch, pathlib; print(pathlib.Path(torch.__file__).parent / 'lib')"
    OUTPUT_VARIABLE TORCH_PYTHON_LIBDIR
    OUTPUT_STRIP_TRAILING_WHITESPACE
  )
  find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_PYTHON_LIBDIR}" NO_DEFAULT_PATH)
endif()
if(NOT TORCH_PYTHON_LIBRARY)
  message(FATAL_ERROR "Could not locate libtorch_python.so required for torch::Tensor pybind11 casters")
endif()
message(STATUS "Using Torch_DIR: ${Torch_DIR}")
message(STATUS "Using torch_python: ${TORCH_PYTHON_LIBRARY}")

set(DFTB_CORE_SOURCES
  src/core/basis.cpp src/core/skf.cpp src/core/slater_koster.cpp
  src/core/hamiltonian.cpp src/core/eigensolver.cpp src/core/occupations.cpp
  src/core/repulsion.cpp src/core/calculator.cpp src/core/calculator_compute.cpp src/core/calculator_cp_periodic.cpp src/core/cp_response.cpp
  src/core/calculator_molecular_derivatives.cpp
  src/core/calculator_periodic_closed_form.cpp
  src/core/charges.cpp src/core/gamma.cpp src/core/scc.cpp src/core/periodic.cpp src/core/multipole.cpp src/core/frechet.cpp
  src/lab/hessian.cpp
)
add_library(dftb_torch_core STATIC ${DFTB_CORE_SOURCES})
set_target_properties(dftb_torch_core PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON)
if(DFTB_TORCH_ENABLE_IPO)
  include(CheckIPOSupported)
  check_ipo_supported(RESULT DFTB_TORCH_IPO_SUPPORTED OUTPUT DFTB_TORCH_IPO_MESSAGE)
  if(DFTB_TORCH_IPO_SUPPORTED)
    set_property(TARGET dftb_torch_core PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
  endif()
endif()
target_include_directories(dftb_torch_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include ${TORCH_INCLUDE_DIRS})
target_link_libraries(dftb_torch_core PUBLIC ${TORCH_LIBRARIES} ZLIB::ZLIB)
if(MSVC)
  target_compile_options(dftb_torch_core PRIVATE /W4)
else()
  target_compile_options(dftb_torch_core PRIVATE
    -Wall -Wextra -Wpedantic
    -Wno-unused-function
    -Wno-unused-parameter
    -Wno-unused-variable
    -Wno-unused-but-set-variable
  )
endif()
dftb_torch_apply_performance_options(dftb_torch_core)
if(DFTB_TORCH_PYBIND11_EXTRAS)
  pybind11_add_module(_dftb_torch src/bindings/_module.cpp src/bindings/kernel_bindings.cpp)
else()
  # Keep default editable/CI builds debuggable and avoid pybind11's Release LTO/strip
  # extras unless explicitly requested.  Torch-heavy translation units already
  # dominate build time, so LTO is a poor default for development validation.
  pybind11_add_module(_dftb_torch NO_EXTRAS src/bindings/_module.cpp src/bindings/kernel_bindings.cpp)
endif()
set_target_properties(_dftb_torch PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON)
target_include_directories(_dftb_torch PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include ${TORCH_INCLUDE_DIRS})
target_link_libraries(_dftb_torch PRIVATE dftb_torch_core ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
dftb_torch_apply_performance_options(_dftb_torch)
if(DFTB_TORCH_ENABLE_IPO AND DFTB_TORCH_IPO_SUPPORTED)
  set_property(TARGET _dftb_torch PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
endif()
install(TARGETS _dftb_torch LIBRARY DESTINATION pydftb_torch)
