FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime

# Install pip dependencies
RUN pip install --no-cache-dir \
                'azureml-mlflow==1.39.0.post1' \
                'mlflow-skinny==1.26.1' \
                'func_to_script==0.1.0' \
                'matplotlib==3.5.3' \
                'torchmetrics==0.11.4'

