# Seamstress Dockerfile (CPU/GPU)
#
# Build args:
#   BASE_IMAGE  - base image to use (e.g., nvidia/cuda:12.8.0-devel-ubuntu22.04)
#   USE_CUDA    - "1" to install CUDA extras (torch + jax-cuda + flash-attn), "0" for CPU-only
#
# Notes:
# - GPU builds require a CUDA *devel* image with nvcc.
# - flash-attn requires --no-build-isolation and build tools (hatchling/editables).

ARG BASE_IMAGE=ubuntu:22.04
FROM ${BASE_IMAGE}

SHELL ["/bin/bash", "-o", "pipefail", "-c"]

ARG USE_CUDA=0
ARG RUN_GIT_LFS=0
# Limit build parallelism to avoid RAM spikes during flash-attn compile.
ARG MAX_JOBS=8
ENV DEBIAN_FRONTEND=noninteractive

WORKDIR /workspace

# System dependencies (mirrors README Linux/Ubuntu setup)
RUN apt-get update && apt-get install -y \
    python3.11 python3.11-venv python3.11-dev \
    build-essential git git-lfs curl zsh gh \
    libegl1 libgles2 libgl1-mesa-dev ffmpeg libosmesa6 libosmesa6-dev \
    libglfw3 libglfw3-dev libglew-dev \
    && rm -rf /var/lib/apt/lists/*

# Install uv
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.cargo/bin:/root/.local/bin:${PATH}"
ENV CUDA_HOME="/usr/local/cuda"
ENV VIRTUAL_ENV="/opt/venv"
ENV PATH="/opt/venv/bin:${PATH}"
ENV XDG_RUNTIME_DIR="/tmp/seamstress-xdg-runtime"
ENV MAX_JOBS="${MAX_JOBS}"
ENV CMAKE_BUILD_PARALLEL_LEVEL="${MAX_JOBS}"

# Install starship prompt
RUN curl -LsSf https://starship.rs/install.sh | sh -s -- -y
RUN echo 'eval "$(starship init bash)"' >> /root/.bashrc \
 && echo 'eval "$(starship init zsh)"' >> /root/.zshrc

# Copy repo
COPY . /workspace
ENV STARSHIP_CONFIG="/workspace/starship.toml"


# Create venv outside the bind mount so it isn't hidden by docker-compose
RUN python3.11 -m venv /opt/venv

# Ensure pip is available in the venv
RUN /opt/venv/bin/python -m ensurepip --upgrade && /opt/venv/bin/python -m pip --version

# Headless rendering runtime dir (wgpu expects XDG_RUNTIME_DIR to exist)
RUN mkdir -p /tmp/seamstress-xdg-runtime && chmod 700 /tmp/seamstress-xdg-runtime

# Build tools for editable installs with no-build-isolation
RUN /opt/venv/bin/python -m pip install hatchling editables setuptools wheel packaging psutil ninja

# Sync dependencies (stage torch first so flash-attn can import it during build)
# Limit build parallelism to avoid OOM on lower-RAM hosts (per flash-attn docs)
RUN mkdir -p /workspace/build-logs && \
    if [ "${USE_CUDA}" = "1" ]; then \
      uv sync --active --extra torch --extra jax-cuda \
        --extra-index-url https://download.pytorch.org/whl/cu128 \
        2>&1 | tee /workspace/build-logs/uv_sync_torch_jax.log ; \
      /opt/venv/bin/python -m pip install psutil packaging ninja setuptools wheel \
        2>&1 | tee /workspace/build-logs/uv_pip_build_deps.log ; \
      /opt/venv/bin/python -m pip install flash-attn --no-build-isolation \
        2>&1 | tee /workspace/build-logs/pip_flash_attn.log ; \
    else \
      uv sync --active --extra torch \
        2>&1 | tee /workspace/build-logs/uv_sync_torch.log ; \
    fi

# Fail the image build immediately if the learning stack is not present.
RUN /opt/venv/bin/python - <<'PY'
import pytorch_lightning
import torch
print("torch", torch.__version__)
print("pytorch_lightning", pytorch_lightning.__version__)
PY

CMD ["bash"]
