Metadata-Version: 2.4
Name: diffuse-jax
Version: 0.1.0
Summary: A research-oriented package for diffusion-based generative modeling with modular components
Author-email: Jacopo Iollo <jacopo.iollo@inria.fr>
License: Apache-2.0
Project-URL: Homepage, https://github.com/jcopo/diffuse
Project-URL: Repository, https://github.com/jcopo/diffuse
Project-URL: Documentation, https://diffuse.readthedocs.io
Project-URL: Bug Tracker, https://github.com/jcopo/diffuse/issues
Keywords: diffusion,generative-models,machine-learning,jax,research
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.31
Requires-Dist: jaxlib>=0.4.31
Requires-Dist: optax>=0.2.2
Requires-Dist: matplotlib>=3.8.2
Requires-Dist: numpy>=1.24.0
Requires-Dist: einops>=0.8.0
Requires-Dist: tqdm>=4.66.1
Requires-Dist: jax-tqdm>=0.1.2
Requires-Dist: flax>=0.8.5
Requires-Dist: chex>=0.1.86
Requires-Dist: jaxtyping>=0.2.24
Provides-Extra: dev
Requires-Dist: pytest>=6.0; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx; extra == "docs"
Requires-Dist: sphinx-rtd-theme; extra == "docs"
Dynamic: license-file
Dynamic: requires-python

# Diffuse

<p align="center">
  <img src="docs/_static/logo.png" alt="Denoising Process" style="width: 36%;">
</p>

A Python package designed for research in diffusion-based generative modeling with modular components that can be easily swapped and combined for experimentation.

## Quick Start

```python
from diffuse.diffusion.sde import SDE, LinearSchedule
from diffuse.timer import VpTimer, HeunTimer
from diffuse.integrator import EulerIntegrator, DDIMIntegrator
from diffuse.denoisers.cond import DPSDenoiser

# Define SDE with noise schedule
sde = SDE(beta=LinearSchedule(b_min=0.1, b_max=20.0, T=1.0))
n_steps = 100

# Choose timer
timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0)
# timer = HeunTimer(n_steps=n_steps, rho=7.0, sigma_min=0.002, sigma_max=1.0)

# Timer-aware integrator
#integrator = EulerIntegrator(sde=sde, timer=timer)
integrator = DDIMIntegrator(sde=sde, timer=timer)

# DPS with timer
dps = DPSDenoiser(
    sde=sde,
    score=score_fn,
    integrator=integrator,
    forward_model=forward_model
)

# Generate conditional samples
state, trajectory = dps.generate(key, measurement_state, n_steps, n_samples=10)

# Single step
next_state = dps.step(rng_key, state, measurement_state) # x_t -> x_{t-1}
```

## Features

- **Flexible Noising process**: Support for various noise schedules and diffusion processes
- **Timer-aware integration**: Advanced timing schemes for improved sampling
- **Conditional sampling**: DPS (Diffusion Posterior Sampling) and other conditional methods
- **Modular design**: Mix and match denoisers, integrators, timers, and forward models
- **Research-focused**: Built for experimentation with new diffusion techniques
- **Examples**: MNIST, Gaussian mixtures, and other applications

## Installation

```bash
pip install -e .
```

## Examples

See the `examples/` directory for implementations including:
- MNIST digit generation
- Gaussian mixture modeling
- Conditional sampling demonstrations
