# Use pytorch/pytorch:2.9.1-cuda12.6-cudnn9-runtime as the base image
FROM pytorch/pytorch:2.9.1-cuda12.6-cudnn9-runtime

# Install git and wget
RUN apt-get update && apt-get install -y git gcc g++ wget && apt-get clean

# Set the device on which the model should load e.g., "cpu", "cuda:0", etc.
ENV JAILBREAK_CHECK_DEVICE=cuda:0

# Predownload embedding-based jailbreak detection models, set environment variable for path
WORKDIR /models
RUN wget https://huggingface.co/nvidia/NemoGuard-JailbreakDetect/resolve/main/snowflake.onnx
ENV EMBEDDING_CLASSIFIER_PATH=/models

# Set working directory
WORKDIR /app

# Copy the source code
COPY requirements.txt .

# Upgrade pip and install the minimal set of requirements for jailbreak detection Server
RUN pip install -r requirements.txt

COPY . .

# Predownload the GPT2 model.
RUN python -c "from transformers import GPT2LMHeadModel, GPT2TokenizerFast; GPT2LMHeadModel.from_pretrained('gpt2-large'); GPT2TokenizerFast.from_pretrained('gpt2-large');"

# Expose a port for the server
EXPOSE 1337

# Start the server as the default command
ENTRYPOINT ["python", "/app/server.py"]
CMD ["--port=1337"]
