Metadata-Version: 2.4
Name: diffuse-jax
Version: 0.1.1
Summary: A research-oriented package for diffusion-based generative modeling with modular components
Author-email: Jacopo Iollo <jacopo.iollo@inria.fr>
License: Apache-2.0
Project-URL: Homepage, https://github.com/jcopo/diffuse
Project-URL: Repository, https://github.com/jcopo/diffuse
Project-URL: Documentation, https://diffuse.readthedocs.io
Project-URL: Bug Tracker, https://github.com/jcopo/diffuse/issues
Keywords: diffusion,generative-models,machine-learning,jax,research
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.12
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.31
Requires-Dist: jaxlib>=0.4.31
Requires-Dist: optax>=0.2.2
Requires-Dist: matplotlib>=3.8.2
Requires-Dist: numpy>=1.24.0
Requires-Dist: einops>=0.8.0
Requires-Dist: tqdm>=4.66.1
Requires-Dist: jax-tqdm>=0.1.2
Requires-Dist: flax>=0.12.0
Requires-Dist: chex>=0.1.86
Requires-Dist: jaxtyping>=0.2.24
Requires-Dist: Pillow>=10.0.0
Requires-Dist: huggingface-hub>=0.20.0
Requires-Dist: myst-nb>=1.3.0
Provides-Extra: dev
Requires-Dist: pytest>=6.0; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Requires-Dist: pytest-xdist; extra == "dev"
Requires-Dist: nbmake>=1.5.0; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx>=7.1.0; extra == "docs"
Requires-Dist: sphinx-book-theme>=1.0.0; extra == "docs"
Requires-Dist: sphinx-autodoc-typehints>=1.24.0; extra == "docs"
Requires-Dist: sphinx-copybutton>=0.5.2; extra == "docs"
Requires-Dist: sphinx-design>=0.5.0; extra == "docs"
Requires-Dist: sphinx-togglebutton>=0.3.0; extra == "docs"
Requires-Dist: myst-parser>=2.0.0; extra == "docs"
Requires-Dist: myst-nb>=1.0.0; extra == "docs"
Requires-Dist: linkify-it-py>=2.0.0; extra == "docs"
Requires-Dist: ipython>=8.0.0; extra == "docs"
Requires-Dist: sphinxcontrib-tikz>=0.4.0; extra == "docs"
Requires-Dist: sphinxcontrib-bibtex>=2.6.0; extra == "docs"
Requires-Dist: sphinxext-opengraph>=0.8.2; extra == "docs"
Provides-Extra: flux
Requires-Dist: sentencepiece>=0.1.99; extra == "flux"
Requires-Dist: tokenizers>=0.15.0; extra == "flux"
Requires-Dist: orbax-checkpoint>=0.5.0; extra == "flux"
Dynamic: license-file
Dynamic: requires-python

# Diffuse

<p align="center">
  <img src="docs/_static/logo.png" alt="Denoising Process" style="width: 36%;">
</p>

A Python package designed for research in diffusion-based generative modeling with modular components that can be easily swapped and combined for experimentation.

## Quick Start

### Unconditional Generation

```python
import jax
from diffuse import Flow, Predictor, Denoiser
from diffuse.integrators import EulerIntegrator
from diffuse.timer import VpTimer

# Define flow matching model
flow = Flow(tf=1.0)

# Create predictor with your neural network
predictor = Predictor(
    model=flow,
    network=network_fn,
    prediction_type="velocity",  # or "noise", "sample"
)

# Setup timer and integrator
timer = VpTimer(n_steps=100, eps=0.001, tf=1.0)
integrator = EulerIntegrator(model=flow, timer=timer)

# Create denoiser
denoiser = Denoiser(
    integrator=integrator,
    model=flow,
    predictor=predictor,
    x0_shape=(height, width, channels),
)

# Generate samples
key = jax.random.PRNGKey(42)
state, trajectory = denoiser.generate(
    rng_key=key,
    n_steps=100,
    n_particles=10,
    keep_history=False,
)

# Single denoising step
next_state = denoiser.step(rng_key, state)  # x_t -> x_{t-1}
```

### Conditional Generation with DPS

```python
import jax
from diffuse import Flow, Predictor
from diffuse.integrators import EulerIntegrator
from diffuse.denoisers import DPSDenoiser
from diffuse.timer import VpTimer

# Define flow matching model
flow = Flow(tf=1.0)

# Create predictor
predictor = Predictor(
    model=flow,
    network=network_fn,
    prediction_type="velocity",
)

# Setup timer and integrator
timer = VpTimer(n_steps=100, eps=0.001, tf=1.0)
integrator = EulerIntegrator(model=flow, timer=timer)

# DPS denoiser for conditional sampling
dps = DPSDenoiser(
    integrator=integrator,
    model=flow,
    predictor=predictor,
    forward_model=forward_model,
    x0_shape=(height, width, channels),
)

# Generate conditional samples
key = jax.random.PRNGKey(42)
state, trajectory = dps.generate(
    rng_key=key,
    measurement_state=measurement_state,
    n_steps=100,
    n_particles=10,
)

# Single conditional step
next_state = dps.step(rng_key, state, measurement_state)  # x_t -> x_{t-1}
```

## Features

- **Flow Matching & Diffusion**: Support for both flow-based models and SDE-based diffusion processes
- **Flexible Prediction Types**: Velocity, noise, and sample prediction for different model architectures
- **Timer-aware Integration**: Advanced timing schemes (VpTimer, HeunTimer, FluxTimer) for improved sampling
- **Multiple Integrators**: EulerIntegrator, DDIMIntegrator, DPM++, Heun, and more
- **Conditional Sampling**: DPS (Diffusion Posterior Sampling) for inverse problems and conditioning
- **Modular Design**: Mix and match models, predictors, denoisers, integrators, and timers
- **JAX-Powered**: Efficient computation with automatic differentiation and JIT compilation
- **Research-Focused**: Built for experimentation with new diffusion and flow matching techniques
- **Examples**: MNIST, Gaussian mixtures, text-to-image generation, and more

## Installation

```bash
pip install diffuse-jax
```
or with uv

```bash
uv add diffuse-jax
```

## Examples

See the `examples/` directory for implementations including:
- MNIST digit generation
- Gaussian mixture modeling
- Conditional sampling demonstrations

## Citation

If you use Diffuse in your research, please cite the library:

```bibtex
@software{diffuse2024,
  title = {Diffuse: A modular diffusion model library},
  author = {Iollo, J., Oudoumanessah G.},
  year = {2025},
  url = {https://github.com/jcopo/diffuse}
}
```
