#!/bin/bash
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

# ================================================================================================ #
#   Globals
# ================================================================================================ #

script_name=$(basename "$0")
cpu_bind=1
gpu_bind=1
use_xpu_smi=0
show_help=0
prefix_command=""
local_id=0
local_size=0

which_sycl_ls=1
which_clinfo=1
which_xpu_smi=1

# ================================================================================================ #
#   Functions
# ================================================================================================ #

usage()
{
    cat << EOF
Usage: $script_name [-hcgn]

Options:
    -h, --help                 Show this message
    -c, --disable-cpu-bind     Don't perform CPU binding
    -g, --disable-gpu-bind     Don't perform GPU binding
    -n, --disable-bind         Don't perform CPU or GPU binding
    -x, --enable-xpu-smi       Use xpu-smi for CPU/GPU affinity checks
EOF
}

# Parse long options
#   Return value indicates if option consumed an argument
parse_long_opt()
{
    local arg="$1"
    local next="$2"

    case "$arg" in
        --help)
            show_help=1
            return 0
            ;;
        --disable-cpu-bind)
            cpu_bind=0
            return 0
            ;;
        --disable-gpu-bind)
            gpu_bind=0
            return 0
            ;;
        --disable-bind)
            cpu_bind=0
            gpu_bind=0
            return 0
            ;;
        --enable-xpu-smi)
            use_xpu_smi=1
            return 0
            ;;
        *)
            echo "Error: Unknown option '$arg'" >&2
            usage
            exit 1
            ;;
    esac
}

parse_short_opt()
{
    local arg="$1"

    i=0
    while [ $i -lt ${#arg} ]; do
        char="${arg:$i:1}"

        case "$char" in
            h)
                show_help=1
                ;;
            c)
                cpu_bind=0
                ;;
            g)
                gpu_bind=0
                ;;
            n)
                cpu_bind=0
                gpu_bind=0
                ;;
            x)
                use_xpu_smi=1
                ;;
            *)
                echo "Error: Unknown option '$arg'" >&2
                usage
                exit 1
                ;;
        esac
        i=$((i + 1))
    done
}

cpu_binding()
{
    # Note: currently assumes round-robin core numbering
    local numactl_output=$(numactl -H)
    local cpu_numa_nodes=$(echo "$numactl_output" | grep -Po "^node [0-9]+ cpus: [0-9 ]+" | wc -l)
    local numa_split=$((local_size / cpu_numa_nodes))
    local numa_remainder=$((local_size % cpu_numa_nodes))

    if [ $cpu_numa_nodes -eq 0 ]; then
        echo -e "\033[33mWARNING: Could not process NUMA information. Consider running with '--disable-cpu-bind' and using external binding mechanisms to ensure best performance.\033[0m" >&2
        return
    fi

    if [ $cpu_numa_nodes -gt 2 ]; then
        echo -e "\033[33mWARNING: More than 2 NUMA nodes detected. Consider running with '--disable-cpu-bind' and using external binding mechanisms to ensure best performance.\033[0m" >&2
        return
    fi

    declare -a node_cpus
    declare -a node_core_counts
    for idx in $(seq 0 $((cpu_numa_nodes - 1))); do
        local cpus=$(echo "$numactl_output" | grep -Po "node $idx cpus: \K[0-9 ]+")
        node_cpus[$idx]="$cpus"
        node_core_counts[$idx]=$(echo $cpus | wc -w)
    done

    # Split procs across socket(s)
    local prev=0
    for idx in $(seq 0 $((cpu_numa_nodes - 1))); do
        local add=0
        if [ $numa_remainder -gt 0 ]; then
            local add=1
            local numa_remainder=$((numa_remainder - 1))
        fi
        if [ $local_id -lt $((prev + numa_split + add)) ]; then
            # local_numa_id is used in GPU binding as well
            local_numa_id=$idx
            local plow=$prev
            local phigh=$((prev + numa_split + add))
            break
        fi
        local prev=$((prev + numa_split + add))
    done

    # Split cores across procs
    # procs with affinity to $local_numa_id are [$plow, $phigh)
    # skip first proc on each socket (reserved for OS)
    # skip last proc on each socket (reserved for proxy thread)
    # TODO: assumes maximum 2 sockets
    local num_skip=2
    local cores_per_proc=$(((node_core_counts[local_numa_id] - num_skip) / (phigh - plow) / 2))
    local core_low=$(echo ${node_cpus[$local_numa_id]} | cut -d ' ' -f $((cores_per_proc * (local_id - plow) + 2)))
    local core_high=$((core_low + cores_per_proc - 1))
    local core_assignment="$core_low-$core_high"

    prefix_command="numactl --all -C ${core_assignment}"
}

gpu_binding()
{
    # Reset Level-zero/SYCL environment
    unset ZE_AFFINITY_MASK
    unset ONEAPI_DEVICE_SELECTOR
    unset SYCL_DEVICE_FILTER

    local root_count=0
    local sub_count=0

    # Determine how to set ZE_AFFINITY_MASK
    if [ $which_sycl_ls -eq 0 ]; then
        local root_count=$(ONEAPI_DEVICE_SELECTOR=level_zero:* sycl-ls 2>/dev/null | wc -l)
        local sub_count=$(ONEAPI_DEVICE_SELECTOR=level_zero:*.* sycl-ls 2>/dev/null | wc -l)
    elif [ $which_clinfo -eq 0 ]; then
        local platforms=$(clinfo -l | grep -i Platform | wc -l)

        for ((platform=0; platform < $platforms; platform++)); do
            if [ "$(clinfo -d ${platform}:0 --prop DEVICE_TYPE 2>/dev/null | grep -Po GPU | wc -l)" -gt 0 ]; then
                break
            fi
        done

        local root_count=$(clinfo --prop DEVICE_TYPE 2>/dev/null | grep -Po GPU | wc -l)
        local sub_count=$(clinfo -d ${platform}:0 --prop MAX_SUB_DEVICES 2>/dev/null | awk '{print $2}')
    else
        # Can't use xpu-smi here because we can't detect the format for ZE_AFFINITY_MASK
        echo -e "\033[33mWARNING: Could not detect devices on system. Ensure either clinfo or sycl-ls is available.\033[0m" >&2
        return
    fi

    if [ "$cpu_bind" -eq 0 ]; then
        # Get the local numa node in case cpu binding is disabled
        local local_numa_id=$(numactl --show | grep -Po "(?<=^nodebind: ).*" | awk '{print $NF}')
    fi

    declare -a all_devs
    if [ $sub_count -eq 0 ]; then
        for root_dev in $(seq 0 $((root_count - 1))); do
            all_devs+=("${root_dev}")
        done
    else
        # Current sub_count is total subdevices in the node - fix it to be per device
        local sub_count=$((sub_count / root_count))
        for root_dev in $(seq 0 $((root_count - 1))); do
            for sub_dev in $(seq 0 $((sub_count - 1))); do
                all_devs+=("${root_dev}.${sub_dev}")
            done
        done
    fi

    local idx=0
    local leftover_devs_per_numa=$(((${#all_devs[@]} - local_size ) / 2))

    # Use xpu-smi to match GPU and CPU affinity
    if [ $which_xpu_smi -eq 0 ]; then
        # xpu-smi will always provide hierarchy info regarding root and sub devices, thus
        # xpu_root_count may not equal root_count (i.e. when ZE_FLAT_DEVICE_HIERARCHY=FLAT)
        local xpu_root_count=$(xpu-smi discovery --dump 1 2>/dev/null | grep -P [0-9]+ | wc -l)
        local xpu_sub_count=$(xpu-smi discovery -d 0 2>/dev/null | grep -Po "(?<=Number of Tiles: )[0-9]+")

        # Check for homogeneity
        for root_dev in $(seq 10000 $((xpu_root_count - 1))); do
            local sub_count=$(xpu-smi discovery -d ${root_dev} 2>/dev/null | grep -Po "(?<=Number of Tiles: )[0-9]+")
            if [ $sub_count -ne $xpu_sub_count ]; then
                echo -e "\033[33mWARNING: Detected different accelerator architectures. Consider running with '--disable-gpu-bind' and using external binding mechanisms to ensure best performance.\033[0m" >&2
                return
            fi
        done

        local xpu_total_dev=$((xpu_root_count * xpu_sub_count))

        declare -a s0_devs
        declare -a s1_devs
        for root_dev in $(seq 0 $((xpu_root_count - 1))); do
            local dev_cpu_list=$(xpu-smi topology -d ${root_dev} 2>/dev/null | grep -Po "(?<=Local CPU List: )[0-9,-]+")
            local dev_numa=$(numactl --all -C ${dev_cpu_list} --show | grep -Po "(?<=^nodebind: ).*" | awk '{print $NF}')
            for sub_dev in $(seq 0 $((xpu_sub_count - 1))); do
                if [ $dev_numa -eq 0 ]; then
                    s0_devs+=("${all_devs[${idx}]}")
                else
                    s1_devs+=("${all_devs[${idx}]}")
                fi
                local idx=$((idx + 1))
            done
        done

        local s0_count=${#s0_devs[@]}
        local idx_offset=$((s0_count - leftover_devs_per_numa))
        if [ $local_id -lt $idx_offset ]; then
            local mask=${s0_devs[${local_id}]}
        else
            local mask=${s1_devs[$((local_id - idx_offset))]}
        fi
    else
        local s0_count=$((${#all_devs[@]} / 2))
        local idx_offset=$((s0_count - leftover_devs_per_numa))
        if [ $local_id -lt $idx_offset ]; then
            local mask=${all_devs[${local_id}]}
        else
            local mask=${all_devs[$((s0_count + local_id - idx_offset))]}
        fi
    fi

    export ZE_AFFINITY_MASK=$mask
    export ONEAPI_DEVICE_SELECTOR="level_zero:0"
}

gpu_env()
{
    export NEOReadDebugKeys=1
    export ZE_ENABLE_PCI_ID_DEVICE_ORDER=1

    # Determine which utilities are available for device detection and topology
    which clinfo > /dev/null 2>&1
    which_clinfo=$?
    which sycl-ls > /dev/null 2>&1
    which_sycl_ls=$?
    if [ $use_xpu_smi -eq 1 ]; then
        which xpu-smi > /dev/null 2>&1
        which_xpu_smi=$?
    else
        which_xpu_smi=1
    fi

    local output=""

    if [ $which_sycl_ls -eq 0 ]; then
        local output=$(ONEAPI_DEVICE_SELECTOR=level_zero:* sycl-ls 2>/dev/null)
    elif [ $which_clinfo -eq 0 ]; then
        local platforms=$(clinfo -l | grep -i Platform | wc -l)
        for ((platform=0; platform < $platforms; platform++)); do
            if [ "$(clinfo -d ${platform}:0 --prop DEVICE_TYPE 2>/dev/null | grep -Po GPU | wc -l)" -gt 0 ]; then
                local output=$(clinfo -d ${platform}:0 --prop DEVICE_NAME 2>/dev/null)
                break
            fi
        done
    elif [ $which_xpu_smi -eq 0 ]; then
        # Property ID 2 is Device Name
        local output=$(xpu-smi discovery --dump 2 2>/dev/null)
    fi

    if [ $(echo $output | grep -Po "Intel.*Arc.*B[0-9]+ Graphics" | wc -l) -gt 0 ]; then
        # Intel(R) Arc(TM) B-Series GPU Family
        # Necessary for GPU IPC
        export RenderCompressedBuffersEnabled=0
    elif [ $(echo $output | grep -Po "Intel.*Data.*Center.*GPU" | wc -l) -gt 0 ]; then
        # Intel(R) Data Center GPU Max Series
        # No support for Implicit Scaling
        export EnableImplicitScaling=0
    fi
}

# ================================================================================================ #
#   Main script
# ================================================================================================ #

# Parse input arguments
while [ $# -gt 0 ]; do
    case "$1" in
        # Long options
        --*)
            if parse_long_opt "$1" "$2"; then
                shift 1
            else
                shift 2
            fi
            ;;

        # Short options
        -*)
            parse_short_opt "${1#-}"
            shift
            ;;

        # End of options marker
        --)
            shift
            break
            ;;

        # Non-option args
        *)
            break
            ;;
    esac
done

if [ "$show_help" -eq 1 ]; then
    usage
    exit 0
fi

# Detect process manager
parent_process=$(ps -p $PPID -o comm=)

if [ $parent_process  = "slurmstepd" ]; then
    local_id=$SLURM_LOCALID
    local_size=$(echo $SLURM_STEP_TASKS_PER_NODE | awk -F "(" '{print $1}')
elif [ $parent_process = "hydra_pmi_proxy" ]; then
    local_id=$MPI_LOCALRANKID
    local_size=$MPI_LOCALNRANKS
elif [ $parent_process = "palsd" ]; then
    local_id=$PALS_LOCAL_RANKID
    local_size=$PALS_LOCAL_SIZE
else
    if [ "$cpu_bind" -eq 1 ] || [ "$gpu_bind" -eq 1 ]; then
        echo -e "\033[33mWARNING: Process not launched with a supported process manager.\033[0m" >&2
    fi
fi

# Perform binding (if applicable)
if [ "$cpu_bind" -eq 1 ]; then
    cpu_binding
fi

# Even if gpu binding is disabled, some environment variables may need to be set for specific GPUs
gpu_env

if [ "$gpu_bind" -eq 1 ]; then
    gpu_binding
fi

# Invoke the main program
${prefix_command} $*
