# Core JAX ecosystem - diffusion model computation
jax==0.4.31
jaxlib==0.4.31
flax==0.8.5
optax==0.2.2
chex==0.1.86
jaxtyping==0.2.24

# Array operations and linear algebra
numpy==1.26.4
einops==0.8.0

# Visualization and plotting
matplotlib==3.8.2

# Progress bars and utilities
tqdm==4.66.1
jax-tqdm==0.1.2

# Testing framework
pytest

# Development and debugging
pdbp

# Experiment tracking (optional - for hyperparameter sweeps)
wandb

# Documentation dependencies (optional)
# sphinx
# myst-parser
# sphinx-rtd-theme