Metadata-Version: 2.4
Name: dl4bi
Version: 0.1.3
Summary: Deep Learning for Bayesian Inference (DL4BI)
Requires-Python: <4.0,>=3.11
Requires-Dist: dl4bi-sps<1,>=0.1.2
Requires-Dist: einops<1,>=0.8
Requires-Dist: flax<0.13,>=0.12
Requires-Dist: hydra-core<2,>=1.3
Requires-Dist: jraph==0.0.6.dev0
Requires-Dist: matplotlib<4,>=3.10
Requires-Dist: numpy<3,>=2.2
Requires-Dist: optax<0.3,>=0.2
Requires-Dist: orbax-checkpoint<1,>=0.11
Requires-Dist: scikit-learn<2,>=1.7
Requires-Dist: scipy<2,>=1.15
Requires-Dist: scoringrules<1,>=0.7
Requires-Dist: tqdm<5,>=4.67
Requires-Dist: wandb<1,>=0.22
Provides-Extra: benchmarks
Requires-Dist: arviz<1,>=0.20; extra == 'benchmarks'
Requires-Dist: cdsapi<1,>=0.7; extra == 'benchmarks'
Requires-Dist: datasets<4,>=3.5; extra == 'benchmarks'
Requires-Dist: geopandas<2.0.0,>=1.1.3; extra == 'benchmarks'
Requires-Dist: numpyro<1,>=0.14; extra == 'benchmarks'
Requires-Dist: pandas<3,>=2.3; extra == 'benchmarks'
Requires-Dist: pillow<12,>=11.1; extra == 'benchmarks'
Requires-Dist: rapidfuzz<4.0.0,>=3.14.3; extra == 'benchmarks'
Requires-Dist: requests<3.0.0,>=2.33.0; extra == 'benchmarks'
Requires-Dist: seaborn<0.14.0,>=0.13.2; extra == 'benchmarks'
Requires-Dist: shapely<3.0.0,>=2.1.2; extra == 'benchmarks'
Requires-Dist: tensorflow-datasets<5.0.0,>=4.9.9; extra == 'benchmarks'
Requires-Dist: tensorflow<3.0.0,>=2.21.0; extra == 'benchmarks'
Requires-Dist: tiktoken<1,>=0.8; extra == 'benchmarks'
Requires-Dist: ucimlrepo<1,>=0.0.7; extra == 'benchmarks'
Requires-Dist: xarray<2026,>=2025.9; extra == 'benchmarks'
Provides-Extra: cpu
Requires-Dist: jax<0.9,>=0.8.1; extra == 'cpu'
Provides-Extra: cuda12
Requires-Dist: jax[cuda12]<0.9,>=0.8.1; extra == 'cuda12'
Provides-Extra: cuda13
Requires-Dist: jax[cuda13]<0.9,>=0.8.1; extra == 'cuda13'
Description-Content-Type: text/markdown

# Deep Learning for Bayesian Inference (dl4bi)

## Install
Install with the appropriate command. If JAX isn't installed already, we recommend using one of the `dl4bi[<jax-version>]` installs.
```bash
pip install dl4bi # dl4bi
pip install dl4bi[cpu] # dl4bi + jax for CPU
pip install dl4bi[cuda12] # dl4bi + jax for CUDA-12
pip install dl4bi[cuda13] # dl4bi + jax for CUDA-13
pip install dl4bi[benchmarks,cpu] # benchmark deps + jax for CPU
pip install dl4bi[benchmarks,cuda12] # benchmark deps + jax for CUDA-12
pip install dl4bi[benchmarks,cuda13] # benchmark deps + jax for CUDA-13
```

The `benchmarks` extra installs the additional packages used by the benchmark scripts, especially under `benchmarks/meta_learning` and `benchmarks/vae`.

## View Documentation (Locally)
```bash
git clone git@github.com:MLGlobalHealth/dl4bi.git
cd dl4bi
uv run --with pdoc pdoc --docformat google --math dl4bi
```

## Cite
If you're using this package or some of its code, please cite the relevant paper(s):
```bibtex
@inproceedings{jenson2026scalable,
  title = {Scalable Spatiotemporal Inference with Biased Scan Attention Transformer Neural Processes},
  author = {Jenson, Daniel and Navott, Jhonathan and Grynfelder, Piotr and Zhang, Mengyan and Sharma, Makkunda and Semenova, Elizaveta and Flaxman, Seth},
  booktitle = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics},
  series = {Proceedings of Machine Learning Research},
  volume = {300},
  address = {Tangier, Morocco},
  year = {2026},
  publisher = {PMLR}
}

@inproceedings{navott2026deeprv,
  title = {{DeepRV}: Accelerating Spatiotemporal Inference with Pre-trained Neural Priors},
  author = {Navott, Jhonathan and Jenson, Daniel and Flaxman, Seth and Semenova, Elizaveta},
  booktitle = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics},
  series = {Proceedings of Machine Learning Research},
  volume = {300},
  address = {Tangier, Morocco},
  year = {2026},
  publisher = {PMLR}
}
```
Benchmarks are available for BSA-TNP [here](https://github.com/MLGlobalHealth/dl4bi/tree/main/benchmarks/meta_learning) and for DeepRV [here](https://github.com/MLGlobalHealth/dl4bi/tree/main/benchmarks/vae).

## Development Setup
- Install `uv`: `curl -LsSf https://astral.sh/uv/install.sh | sh`
- Clone the repository and `cd` into it: `git clone git@github.com:MLGlobalHealth/dl4bi.git && cd dl4bi`
- Install the latest Python with `uv`: `uv python install`
- Sync the project environment:
    - CPU JAX: `uv sync --extra cpu`
    - CUDA 12 JAX: `uv sync --extra cuda12`
    - CUDA 13 JAX: `uv sync --extra cuda13`
    - Benchmark deps + CPU JAX: `uv sync --extra benchmarks --extra cpu`
    - Benchmark deps + CUDA 12 JAX: `uv sync --extra benchmarks --extra cuda12`
    - Benchmark deps + CUDA 13 JAX: `uv sync --extra benchmarks --extra cuda13`
- `uv sync` creates `.venv`, installs the project in editable mode, includes the default `dev` dependency group, and picks a Python interpreter compatible with the project's `requires-python`
- Before making changes, install the shared development hooks: `uv run pre-commit install --install-hooks`
- Verify the hook setup once per clone with: `uv run pre-commit run --all-files`
- Keep the hooks installed for local development; commits on `main` run `pytest -q tests` through the shared `pre-commit` setup
- Run project commands through `uv`, e.g. `uv run pytest` or `uv run python gp.py`
- If you want to activate the virtualenv directly, use `source .venv/bin/activate`

## Build and Publish to PyPI
Create a local `.env` file with the publish tokens:
```env
TEST_PYPI_TOKEN=pypi-...
PYPI_TOKEN=pypi-...
```

Run the release helper from a clean `main` checkout:
```bash
uv run python scripts/release.py .env "AISTATS 2026"
```

The helper bumps the patch version, commits and tags `v<version> <message>`,
rebuilds `dist/`, publishes to TestPyPI and PyPI, pushes `main` and the tag,
and smoke-tests the published install targets.
