Metadata-Version: 2.4
Name: ksd-metric
Version: 0.1.0
Summary: Kernel Stein Discrepancy
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.13
Requires-Dist: jaxlib>=0.4.13
Requires-Dist: jaxtyping>=0.2.19
Requires-Dist: sphinx>=8.1.3
Requires-Dist: sphinx-autobuild>=2024.10.3
Requires-Dist: sphinx-autodoc-typehints>=3.0.1
Requires-Dist: sphinx-rtd-theme>=3.0.2
Provides-Extra: docs
Requires-Dist: sphinx; extra == "docs"
Requires-Dist: sphinx_rtd_theme; extra == "docs"
Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
Requires-Dist: sphinx-markdown-builder; extra == "docs"
Dynamic: license-file

# KSD-Metric
Kernel Stein Discrepancy

## Example
```{python}
import jax
from jax import numpy as jnp
from jax.scipy.stats import multivariate_normal

from ksd_metric.kernel import KernelJax
from ksd_metric.stein import KernelSteinDiscrepancyJax
from ksd_metric.target import TargetDistributionJax
from ksd_metric.utils import JaxKernelFunction

# Example usage of KernelSteinDiscrepancyJax with a multivariate normal distribution
# and the inverse multiquadric kernel function.
key = jax.random.PRNGKey(42)
dim = 2
mean = jnp.zeros(dim)
cov = jnp.eye(dim)
N = 10_000
x = jax.random.multivariate_normal(key, mean, cov, shape=(N,))

# Define the target distribution and kernel
log_target_pdf = lambda x: multivariate_normal.logpdf(x, mean=mean, cov=cov)
target = TargetDistributionJax(log_target_pdf=log_target_pdf)
kernel = KernelJax(lambda x, y: JaxKernelFunction.imq(x, y, jnp.eye(dim)))
ksd = KernelSteinDiscrepancyJax(target=target, kernel=kernel)

# Compute the kernel Stein discrepancy
print(ksd.kernel_stein_discrepancy(x))
```
