#!/usr/bin/env bash

set -euo pipefail

info() {
  DIRENV_LOG_FORMAT="envrc: %s" log_status "$@"
}

if nvidia_smi="$(type -p "nvidia-smi")" && [[ -n "$nvidia_smi" ]]; then
  cuda_major_version=$("$nvidia_smi" --version | awk 'BEGIN { FS = "[[:space:]]*:[[:space:]]*"}; /^CUDA Version/ {print int($2) };')
  if [[ -n "$cuda_major_version" ]]; then
    info "Detected CUDA version: $cuda_major_version"
    export CUDA_VERSION="$cuda_major_version"
  fi
else
  info "Did not detect nvidia-smi, using non-GPU environment"
fi

if xhost_cmd="$(type -p "xhost")" && [[ -n "$xhost_cmd" ]]; then
  "$xhost_cmd" +si:localuser:"$(whoami)" >&/dev/null && {
    info "Display present, setting XLA_PYTHON_CLIENT_MEM_FRACTION to 80%"
    export XLA_PYTHON_CLIENT_MEM_FRACTION=".80"
  }
fi

export UV_TORCH_BACKEND="auto"

uv run --dev just dev
source ./.venv/bin/activate
watch_file pyproject.toml
