Metadata-Version: 2.3
Name: kernax
Version: 0.2.0
Summary: Regularized Stein thinning using JAX
Requires-Dist: blackjax>=1.2.5,<2.0.0
Requires-Dist: jax>=0.6.2,<1.0.0
Requires-Dist: numpy>=2.2.6,<3.0.0
Requires-Dist: scipy>=1.15.3,<2.0.0
Requires-Dist: matplotlib>=3.10.5 ; extra == 'plot'
Requires-Python: >=3.10
Provides-Extra: plot
Description-Content-Type: text/markdown

# Kernax: kernel-based MCMC post-processing algorithms

Small package that implements kernel-based post-processing and subsampling algorithms. Three main algortihms are provided:

* *The vanilla Stein thinning algorithm proposed by M. Riabiz et al. in [Optimal thinning of MCMC output](https://academic.oup.com/jrsssb/article-abstract/84/4/1059/7073269)
* *The regularized Stein thinning algorithm proposed by by C. Benard et al. in [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9a8eb202c060b7d81f5889631cbcd47e-Abstract-Conference.html).
* *The greedy maximum mean discrepancy subsampling algorithm (see, e.g., [Optimal quantisation of probability measures using maximum mean discrepancy](http://proceedings.mlr.press/v130/teymur21a.html)).

## Quick start

Here's an example of usage of the Stein thinning algorithm applied to a Gaussian sample:

```python
import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning

rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))

def logprob_fn(x):
    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)

lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)
```

If you want to apply the regularized method, you need a few more lines:

```python
from kernax import laplace_log_p_softplus
from kernax import RegularizedSteinThinning

log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)
```

## Documentation

Documentation is available at [readthedocs](https://kernax.readthedocs.io/en/latest/?kernax=latest).

## Contributing

This code is not meant to be an evolving library. However, feel free to create issues and merge requests.

## Install guide

### As a user

A python wheel is available on [PyPi](https://pypi.org/project/kernax/). You can install kernax in your python environment as follows:

```console
pip install kernax
```

### As a developper

We recommand using [uv](https://docs.astral.sh/uv/getting-started/installation/). First clone this repository then simply run
```bash
uv sync
```

This will create a virtual environment that you can use for developping in kernax. If you're not familiar with `uv`, have a look at their [Getting started](https://docs.astral.sh/uv/getting-started/) section.

## Paper experiments

This code implements the regularized Stein thinning algorithm introduced in the paper [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9a8eb202c060b7d81f5889631cbcd47e-Abstract-Conference.html).

Please consider citing the paper when using this library:
```bibtex
@article{benard2023kernel,
  title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
  author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
  journal={arXiv preprint arXiv:2301.13528},
  year={2023}
}
```

All the numerical experiments presented in the [paper](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9a8eb202c060b7d81f5889631cbcd47e-Abstract-Conference.html) can be reproduced with the scripts made available in the example folder.

In particular:

* Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py

* Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:
    * Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim
    * Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim
    * Bayesian logistic regression: example/logistic_regression.py

* Two additional scripts are also available to reproduce figures shown in the supplementary material:
    * Figure 2: example/mog_weight_weights.py
    * Figure 6: example/mog4_mcmc_lambda