Metadata-Version: 2.4
Name: normix
Version: 0.2.0
Summary: JAX package for Generalized Hyperbolic distributions as exponential families
Project-URL: Homepage, https://github.com/xshi19/normix
Project-URL: Repository, https://github.com/xshi19/normix
Project-URL: Documentation, https://xshi19.github.io/normix/
Project-URL: Bug Tracker, https://github.com/xshi19/normix/issues
Author: normix developers
License: MIT
License-File: LICENSE
Keywords: distributions,exponential-family,fat-tail,generalized-hyperbolic,normal-inverse-gamma,normal-inverse-gaussian,probability,statistics,variance-gamma
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.12
Requires-Dist: equinox>=0.13
Requires-Dist: jax>=0.4.38
Requires-Dist: jaxopt>=0.8
Requires-Dist: numpy>=2.0.0
Requires-Dist: scipy>=1.14.0
Provides-Extra: cuda12
Requires-Dist: jax[cuda12]>=0.4.38; extra == 'cuda12'
Provides-Extra: dev
Requires-Dist: ipykernel>=6.29; extra == 'dev'
Requires-Dist: matplotlib>=3.9; extra == 'dev'
Requires-Dist: pytest>=8.0; extra == 'dev'
Provides-Extra: docs
Requires-Dist: nbsphinx>=0.9; extra == 'docs'
Requires-Dist: pandoc>=2.4; extra == 'docs'
Requires-Dist: sphinx-rtd-theme>=3.0; extra == 'docs'
Requires-Dist: sphinx>=9.0; extra == 'docs'
Provides-Extra: plotting
Requires-Dist: matplotlib>=3.10.0; extra == 'plotting'
Provides-Extra: scripts
Requires-Dist: lxml>=6.0.2; extra == 'scripts'
Requires-Dist: yfinance>=1.1.0; extra == 'scripts'
Description-Content-Type: text/markdown

# normix

JAX package for Generalized Hyperbolic distributions as exponential families.

Built on [Equinox](https://docs.kidger.site/equinox/) with Float64 precision throughout.

## Installation

```bash
pip install normix
```

Install optional plotting helpers with:

```bash
pip install "normix[plotting]"
```

For local development:

```bash
uv sync
pip install -e .
```

## Quick Start

```python
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter

# Fit GH distribution to data via EM
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 3))

model = GeneralizedHyperbolic.from_classical(
    mu=jnp.zeros(3), gamma=jnp.zeros(3),
    sigma=jnp.eye(3), p=-0.5, a=2.0, b=1.0,
)
result = BatchEMFitter(max_iter=100).fit(model, X)

# Evaluate log-density (batched via vmap)
log_p = jax.vmap(result.model.log_prob)(X)   # shape (1000,)
```

## Distributions

### Univariate (exponential family)

| Class | Parameters | Description |
|---|---|---|
| `Gamma` | `alpha`, `beta` | Shape α > 0, rate β > 0 |
| `InverseGamma` | `alpha`, `beta` | Shape α > 0, rate β > 0 |
| `InverseGaussian` | `mu`, `lam` | Mean μ > 0, shape λ > 0 |
| `GIG` / `GeneralizedInverseGaussian` | `p`, `a`, `b` | Generalized Inverse Gaussian |

### Multivariate

| Class | Parameters | Description |
|---|---|---|
| `MultivariateNormal` | `mu`, `L_Sigma` | Mean μ, Cholesky L_Sigma of Σ |

### Normal Variance-Mean Mixtures (marginal)

| Class | Subordinator | Parameters |
|---|---|---|
| `VarianceGamma` | Gamma | `mu`, `gamma`, `L_Sigma`, `alpha`, `beta` |
| `NormalInverseGamma` | InverseGamma | `mu`, `gamma`, `L_Sigma`, `alpha`, `beta` |
| `NormalInverseGaussian` | InverseGaussian | `mu`, `gamma`, `L_Sigma`, `mu_ig`, `lam` |
| `GeneralizedHyperbolic` | GIG | `mu`, `gamma`, `L_Sigma`, `p`, `a`, `b` |

### Joint distributions

The `Joint*` classes (e.g. `JointGeneralizedHyperbolic`) model the full joint $f(x,y)$ where Y is the mixing variable. They are exponential families and are used internally for the EM E-step.

## Exponential Family API

All univariate and joint distributions subclass `ExponentialFamily(eqx.Module)`:

```python
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import Gamma

X = jnp.array([1.0, 1.5, 2.0, 2.5])
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))

# Log-density (single observation)
dist.log_prob(jnp.array(1.5))

# Three parametrizations
theta = dist.natural_params()       # natural parameters θ
eta   = dist.expectation_params()   # expectation parameters η = E[t(X)]
I     = dist.fisher_information()   # Fisher information I(θ) = ∇²ψ(θ)

# Constructors
dist2 = Gamma.from_natural(theta)
dist3 = Gamma.from_expectation(eta)
dist4 = Gamma.fit_mle(X)           # η̂ = mean t(xᵢ)
```

## EM Algorithm

```python
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter

d = 3
X = ...  # (n, d) data array

# Initialise from classical parameters
model = GeneralizedHyperbolic.from_classical(
    mu=jnp.zeros(d), gamma=jnp.zeros(d), sigma=jnp.eye(d),
    p=-0.5, a=2.0, b=1.0,
)

# Fit with hybrid CPU/JAX backend for maximum speed
fitter = BatchEMFitter(max_iter=200, tol=1e-6,
                       e_step_backend='cpu', m_step_backend='cpu')
result = fitter.fit(model, X)
fitted = result.model
```

## Bessel Functions

```python
import jax

jax.config.update("jax_enable_x64", True)

from normix import log_kv        # or: from normix.utils.bessel import log_kv

# JIT-able, differentiable (backend='jax', default)
log_kv(v=0.5, z=2.0)

# Fast CPU path for EM hot path (not JIT-able)
log_kv(v=0.5, z=2.0, backend='cpu')
```

## Package Layout

```
normix/
├── exponential_family.py         # ExponentialFamily base class
├── distributions/                # All distribution implementations
│   ├── gamma.py
│   ├── inverse_gamma.py
│   ├── inverse_gaussian.py
│   ├── generalized_inverse_gaussian.py
│   ├── normal.py
│   ├── variance_gamma.py
│   ├── normal_inverse_gamma.py
│   ├── normal_inverse_gaussian.py
│   └── generalized_hyperbolic.py
├── mixtures/                     # Joint and marginal base classes
├── fitting/em.py                 # BatchEMFitter, EMResult
└── utils/
    ├── bessel.py                 # log_kv with custom JVP
    ├── constants.py              # Shared numerical constants
    ├── plotting.py               # Notebook helpers
    └── validation.py             # EM validation helpers
```

## Development

```bash
uv run pytest tests/              # run tests
uv run jupyter lab                # notebooks
make -C docs html                 # build docs
```

## References

- Barndorff-Nielsen, O. E. (1977). Exponentially decreasing distributions for the logarithm of particle size.
- Eberlein, E., & Keller, U. (1995). Hyperbolic distributions in finance.

## License

MIT — see [LICENSE](LICENSE).
