#!/bin/bash
# ============================================================================
# NVCC Wrapper for Air.rs — v1.1.5
#
# Fixes applied transparently to every nvcc invocation:
#   1. -fPIC         — required for shared-library (.so) builds on Linux
#   2. -arch=sm_XX   — compile for the *actual* installed GPU ISA, not the
#                      ancient sm_52 default that NVCC uses without this flag.
#                      Detected at build-time via nvidia-smi.
#   3. -O3           — maximum compiler optimization
#   4. --use_fast_math — hardware-accelerated transcendentals (exp, tanh, …)
# ============================================================================

set -euo pipefail

ARGS=()
HAS_PIC=0
HAS_ARCH=0
HAS_OPT=0
HAS_FAST_MATH=0

for arg in "$@"; do
    ARGS+=("$arg")
    case "$arg" in
        -fPIC|-fpic) HAS_PIC=1 ;;
        -arch=*|-gencode*) HAS_ARCH=1 ;;
        -O*) HAS_OPT=1 ;;
        --use_fast_math) HAS_FAST_MATH=1 ;;
    esac
done

# ── Find the real nvcc (avoid recursion into this wrapper) ────────────────────
REAL_NVCC=""
while IFS= read -r candidate; do
    # Skip this wrapper
    case "$candidate" in
        */Air.rs/scripts/nvcc) continue ;;
    esac
    REAL_NVCC="$candidate"
    break
done < <(which -a nvcc 2>/dev/null || true)

# Fallback: check common CUDA install paths
if [[ -z "$REAL_NVCC" ]]; then
    for p in \
        "${CUDA_HOME:-}/bin/nvcc" \
        "/usr/local/cuda/bin/nvcc" \
        "/usr/bin/nvcc"
    do
        [[ -x "$p" ]] && { REAL_NVCC="$p"; break; }
    done
fi

[[ -z "$REAL_NVCC" ]] && { echo "nvcc wrapper: real nvcc not found" >&2; exit 1; }

# ── 1. Position Independent Code ─────────────────────────────────────────────
if [[ "$OSTYPE" == "linux-gnu"* && $HAS_PIC -eq 0 ]]; then
    ARGS+=("-Xcompiler" "-fPIC")
fi

# ── 2. GPU Architecture Targeting ────────────────────────────────────────────
# Only inject -arch if the caller hasn't already set one.
# This is the key optimization: compiling for sm_89 (Ada Lovelace), sm_90
# (Hopper), or sm_100 (Blackwell) instead of the NVCC default (sm_52) gives
# access to newer ISA features, better register alloc, and HW-specific paths.
if [[ $HAS_ARCH -eq 0 ]]; then
    # Strategy 1: honour env override (set by build_air.sh/ps1 or CI)
    if [[ -n "${NVCC_ARCH:-}" ]]; then
        ARGS+=("-arch=${NVCC_ARCH}")
    else
        # Strategy 2: query nvidia-smi for the first GPU's compute capability
        if COMPUTE_CAP=$(nvidia-smi --query-gpu=compute_cap \
                         --format=csv,noheader 2>/dev/null | head -1 | tr -d '.'); then
            if [[ -n "$COMPUTE_CAP" && "$COMPUTE_CAP" =~ ^[0-9]+$ ]]; then
                ARGS+=("-arch=sm_${COMPUTE_CAP}")
            fi
        fi
        # If nvidia-smi failed (CI/Docker without GPU), skip — let NVCC
        # decide; the build will still succeed, just without ISA optimizations.
    fi
fi

# ── 3. Optimization flags ─────────────────────────────────────────────────────
[[ $HAS_OPT       -eq 0 ]] && ARGS+=("-O3")
[[ $HAS_FAST_MATH -eq 0 ]] && ARGS+=("--use_fast_math")

exec "$REAL_NVCC" "${ARGS[@]}"
