cmake_minimum_required(VERSION 3.16)

# XTransformers - High-Performance Kernel for Xoron Multimodal Model
# Supports: NVIDIA CUDA, AMD ROCm, Intel oneAPI, Apple Metal, CPU (AVX2/AVX512/AMX)

project(xtransformers VERSION 1.0.0 LANGUAGES CXX)

# Options for hardware backends
option(XT_USE_CUDA "Enable NVIDIA CUDA support" OFF)
option(XT_USE_ROCM "Enable AMD ROCm support" OFF)
option(XT_USE_ONEAPI "Enable Intel oneAPI support" OFF)
option(XT_USE_METAL "Enable Apple Metal support" OFF)
option(XT_USE_TRITON "Enable Triton JIT kernels" ON)
option(XT_BUILD_TESTS "Build unit tests" ON)
option(XT_BUILD_EXAMPLES "Build examples" ON)

# CPU instruction set options (auto-detected by default)
option(XT_NATIVE "Enable -march=native" OFF)
option(XT_AVX2 "Enable AVX2" OFF)
option(XT_AVX512 "Enable AVX512" OFF)
option(XT_AVX512_VNNI "Enable AVX512-VNNI" OFF)
option(XT_AVX512_BF16 "Enable AVX512-BF16" OFF)
option(XT_AVX512_VBMI "Enable AVX512-VBMI" OFF)
option(XT_AMX "Enable Intel AMX" OFF)
option(XT_ARM_NEON "Enable ARM NEON" OFF)
option(XT_ARM_SVE "Enable ARM SVE" OFF)

# GGUF and quantization support
option(XT_GGUF_SUPPORT "Enable GGUF model format support" ON)
option(XT_EXPERT_OFFLOAD "Enable expert offloading to CPU" ON)
option(XT_KV_CACHE_COMPRESSION "Enable KV cache compression (MLA)" ON)

# JIT and runtime dispatch
option(XT_ENABLE_JIT "Enable JIT compilation for kernels" ON)
option(XT_RUNTIME_DISPATCH "Enable runtime CPU/GPU feature dispatch" ON)

# LTO support
option(XT_ENABLE_LTO "Enable Link Time Optimization" OFF)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# Header-only fmt to avoid linking issues
add_compile_definitions(FMT_HEADER_ONLY)

# Auto-detect CPU features
if(NOT XT_NATIVE)
    include(cmake/DetectCPU.cmake)
endif()

# Optimization flags
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")

# Architecture specific flags
message(STATUS "System Processor: ${CMAKE_SYSTEM_PROCESSOR}")
set(ARCH_FLAGS "")

if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")
    message(STATUS "ARM architecture detected")
    if(XT_ARM_NEON)
        set(ARCH_FLAGS "${ARCH_FLAGS} -march=armv8-a+simd")
    endif()
    if(XT_ARM_SVE)
        set(ARCH_FLAGS "${ARCH_FLAGS} -march=armv8.2-a+sve")
    endif()
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64)$")
    message(STATUS "x86_64 architecture detected")
    if(XT_NATIVE)
        set(ARCH_FLAGS "-march=native")
    else()
        if(XT_AVX2)
            set(ARCH_FLAGS "${ARCH_FLAGS} -mavx2 -mfma -mf16c")
        endif()
        if(XT_AVX512)
            set(ARCH_FLAGS "${ARCH_FLAGS} -mavx512f -mavx512bw -mavx512dq -mavx512vl")
        endif()
        if(XT_AVX512_VNNI)
            set(ARCH_FLAGS "${ARCH_FLAGS} -mavx512vnni")
        endif()
        if(XT_AVX512_BF16)
            set(ARCH_FLAGS "${ARCH_FLAGS} -mavx512bf16")
        endif()
        if(XT_AVX512_VBMI)
            set(ARCH_FLAGS "${ARCH_FLAGS} -mavx512vbmi -mavx512vbmi2")
        endif()
        if(XT_AMX)
            set(ARCH_FLAGS "${ARCH_FLAGS} -mamx-tile -mamx-int8 -mamx-bf16")
        endif()
    endif()
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_FLAGS}")
message(STATUS "Architecture flags: ${ARCH_FLAGS}")

# Find required packages
find_package(OpenMP REQUIRED)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# pybind11 (vendored or system)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11/CMakeLists.txt")
    add_subdirectory(third_party/pybind11)
else()
    find_package(pybind11 REQUIRED)
endif()

# HWLOC for NUMA awareness
find_package(PkgConfig)
if(PKG_CONFIG_FOUND)
    pkg_search_module(HWLOC IMPORTED_TARGET hwloc)
endif()

# NUMA library
find_library(NUMA_LIBRARY NAMES numa)

# CUDA backend
if(XT_USE_CUDA)
    enable_language(CUDA)
    find_package(CUDAToolkit REQUIRED)
    set(CMAKE_CUDA_STANDARD 17)
    set(CMAKE_CUDA_STANDARD_REQUIRED ON)
    
    # CUDA architectures (support wide range of GPUs)
    if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
        set(CMAKE_CUDA_ARCHITECTURES "70;75;80;86;89;90" CACHE STRING "CUDA architectures")
    endif()
    message(STATUS "CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
    add_compile_definitions(XT_CUDA_ENABLED)
endif()

# ROCm backend
if(XT_USE_ROCM)
    list(APPEND CMAKE_PREFIX_PATH $ENV{ROCM_PATH})
    find_package(hip REQUIRED)
    find_package(rocblas REQUIRED)
    add_compile_definitions(XT_ROCM_ENABLED USE_HIP=1)
endif()

# Intel oneAPI backend
if(XT_USE_ONEAPI)
    find_package(IntelDPCPP QUIET)
    if(IntelDPCPP_FOUND)
        add_compile_definitions(XT_ONEAPI_ENABLED)
    else()
        message(WARNING "Intel oneAPI requested but not found")
    endif()
endif()

# Collect source files
file(GLOB_RECURSE CPU_SOURCES 
    "xtransformers_kernel/cpu_backend/*.cpp"
    "xtransformers_kernel/cpu_backend/*.h"
)

file(GLOB_RECURSE OPERATOR_SOURCES
    "xtransformers_kernel/operators/*.cpp"
    "xtransformers_kernel/operators/*.hpp"
)

file(GLOB_RECURSE PYTHON_SOURCES
    "xtransformers_kernel/python/*.py"
)

set(ALL_SOURCES
    ${CPU_SOURCES}
    ${OPERATOR_SOURCES}
    xtransformers_kernel/ext_bindings.cpp
)

# CUDA sources
if(XT_USE_CUDA)
    file(GLOB_RECURSE CUDA_SOURCES
        "xtransformers_kernel/cuda/*.cu"
        "xtransformers_kernel/cuda/*.cpp"
    )
    list(APPEND ALL_SOURCES ${CUDA_SOURCES})
endif()

# Create the Python extension module
if(XT_ENABLE_LTO)
    set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON)
    pybind11_add_module(xtransformers_ext MODULE THIN_LTO ${ALL_SOURCES})
else()
    pybind11_add_module(xtransformers_ext MODULE ${ALL_SOURCES})
endif()

# Include directories
target_include_directories(xtransformers_ext PRIVATE
    ${CMAKE_CURRENT_SOURCE_DIR}/xtransformers_kernel
    ${CMAKE_CURRENT_SOURCE_DIR}/third_party
)

# Link libraries
target_link_libraries(xtransformers_ext PRIVATE OpenMP::OpenMP_CXX)

if(HWLOC_FOUND)
    target_link_libraries(xtransformers_ext PRIVATE PkgConfig::HWLOC)
endif()

if(NUMA_LIBRARY)
    target_link_libraries(xtransformers_ext PRIVATE ${NUMA_LIBRARY})
endif()

if(XT_USE_CUDA)
    target_link_libraries(xtransformers_ext PRIVATE CUDA::cudart)
    if(TARGET CUDA::cublas)
        target_link_libraries(xtransformers_ext PRIVATE CUDA::cublas)
    endif()
endif()

if(XT_USE_ROCM)
    target_link_libraries(xtransformers_ext PRIVATE hip::host rocblas)
endif()

# Define compile-time feature flags
if(XT_GGUF_SUPPORT)
    target_compile_definitions(xtransformers_ext PRIVATE XT_GGUF_SUPPORT)
endif()
if(XT_EXPERT_OFFLOAD)
    target_compile_definitions(xtransformers_ext PRIVATE XT_EXPERT_OFFLOAD)
endif()
if(XT_KV_CACHE_COMPRESSION)
    target_compile_definitions(xtransformers_ext PRIVATE XT_KV_CACHE_COMPRESSION)
endif()
if(XT_ENABLE_JIT)
    target_compile_definitions(xtransformers_ext PRIVATE XT_ENABLE_JIT)
endif()
if(XT_RUNTIME_DISPATCH)
    target_compile_definitions(xtransformers_ext PRIVATE XT_RUNTIME_DISPATCH)
endif()

# Tests
if(XT_BUILD_TESTS)
    enable_testing()
    add_subdirectory(tests)
endif()

# Install rules
install(TARGETS xtransformers_ext
    LIBRARY DESTINATION xtransformers
)

message(STATUS "")
message(STATUS "========================================")
message(STATUS "XTransformers Build Configuration")
message(STATUS "========================================")
message(STATUS "CUDA:        ${XT_USE_CUDA}")
message(STATUS "ROCm:        ${XT_USE_ROCM}")
message(STATUS "oneAPI:      ${XT_USE_ONEAPI}")
message(STATUS "Metal:       ${XT_USE_METAL}")
message(STATUS "Triton:      ${XT_USE_TRITON}")
message(STATUS "AVX2:        ${XT_AVX2}")
message(STATUS "AVX512:      ${XT_AVX512}")
message(STATUS "AMX:         ${XT_AMX}")
message(STATUS "GGUF:        ${XT_GGUF_SUPPORT}")
message(STATUS "Expert Offload: ${XT_EXPERT_OFFLOAD}")
message(STATUS "JIT:         ${XT_ENABLE_JIT}")
message(STATUS "Runtime Dispatch: ${XT_RUNTIME_DISPATCH}")
message(STATUS "========================================")
message(STATUS "")
