Metadata-Version: 2.4
Name: metalog-jax
Version: 1.0.1
Summary: A JAX implementation of the Metalog distribution for flexible probability modeling.
Author: Travis Jefferies
Maintainer: Travis Jefferies
License-Expression: MIT
Keywords: jax,metalog,probability,distribution,quantile,statistics
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Typing :: Typed
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: chex>=0.1.91
Requires-Dist: flax>=0.12.0
Requires-Dist: jax>=0.8.0
Requires-Dist: plotly>=6.4.0
Dynamic: license-file

<p align="center">
  <img src="https://raw.githubusercontent.com/tjefferies/metalog_jax/main/docs/source/_static/logo.svg" alt="metalog-jax" width="400">
</p>

<h1 align="center">metalog-jax</h1>

<p align="center">
  <strong>GPU-accelerated metalog distributions for modern probabilistic modeling</strong>
</p>

<p align="center">
  <a href="https://pypi.org/project/metalog-jax/"><img src="https://img.shields.io/pypi/v/metalog-jax?style=flat-square&color=blue" alt="PyPI"></a>
  <a href="https://pypi.org/project/metalog-jax/"><img src="https://img.shields.io/pypi/pyversions/metalog-jax?style=flat-square" alt="Python"></a>
  <a href="https://github.com/astral-sh/uv"><img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json&style=flat-square" alt="uv"></a>
  <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-green.svg?style=flat-square" alt="License"></a>
</p>

<p align="center">
  <a href="https://github.com/tjefferies/metalog_jax/actions/workflows/quality-gate.yml"><img src="https://github.com/tjefferies/metalog_jax/actions/workflows/quality-gate.yml/badge.svg" alt="CI"></a>
  <a href="https://codecov.io/gh/tjefferies/metalog_jax"><img src="https://codecov.io/gh/tjefferies/metalog_jax/branch/main/graph/badge.svg" alt="Coverage"></a>
  <a href="https://tjefferies.github.io/metalog_jax/"><img src="https://img.shields.io/badge/docs-GitHub%20Pages-blue?style=flat-square" alt="Documentation"></a>
</p>

<p align="center">
  <a href="https://github.com/astral-sh/ruff"><img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json&style=flat-square" alt="Ruff"></a>
  <a href="https://github.com/astral-sh/ty"><img src="https://img.shields.io/badge/type%20checked-ty-blue?style=flat-square" alt="ty"></a>
  <a href="https://radon.readthedocs.io/"><img src="https://img.shields.io/badge/maintainability-A-brightgreen?style=flat-square" alt="Maintainability"></a>
</p>

<p align="center">
  <a href="https://semgrep.dev/"><img src="https://img.shields.io/badge/SAST-Semgrep-purple?style=flat-square" alt="Semgrep"></a>
  <a href="https://bandit.readthedocs.io/"><img src="https://img.shields.io/badge/security-Bandit-yellow?style=flat-square" alt="Bandit"></a>
  <a href="https://github.com/anchore/grype"><img src="https://img.shields.io/badge/vulnerabilities-Grype-orange?style=flat-square" alt="Grype"></a>
  <a href="https://cyclonedx.org/"><img src="https://img.shields.io/badge/SBOM-CycloneDX-lightgrey?style=flat-square" alt="SBOM"></a>
</p>

<p align="center">
  <a href="#installation">Installation</a> &bull;
  <a href="#quick-start">Quick Start</a> &bull;
  <a href="#features">Features</a> &bull;
  <a href="#documentation">Documentation</a> &bull;
  <a href="#citation">Citation</a>
</p>

---

# metalog-jax

## The Problem

Traditional probability distributions often fail to model real-world data accurately. You find yourself:

- Trying multiple distributions, hoping one fits
- Truncating distributions to enforce bounds
- Using mixture models that are hard to interpret
- Writing custom likelihood functions for edge cases

**What if one distribution family could fit virtually any continuous data?**

## The Solution

**metalog-jax** implements the [Metalog distribution](http://metalogdistributions.com/)&mdash;a revolutionary approach to probability modeling introduced by Tom Keelin (2016). Metalogs are a continuous, semi-parametric family that can represent virtually any probability distribution through quantile-based fitting.

```python
import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.utils import DEFAULT_Y

# Your data
data = jnp.array([2.1, 3.5, 4.2, 5.8, 6.1, 7.3, 8.9, 12.4, 15.2, 18.7])

# Create validated input data
input_data = MetalogInputData.from_values(data, DEFAULT_Y, precomputed_quantiles=False)

# Configure metalog parameters
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Fit the metalog distribution
metalog = fit(input_data, params)

# Use it like any scipy distribution
median = metalog.ppf(jnp.array([0.5]))   # Quantile function
density = metalog.pdf(jnp.array([0.5]))  # Probability density
cumulative = metalog.cdf(10.0)           # CDF at x=10
```

## Why metalog-jax?

| Feature | metalog-jax | [pymetalog](https://github.com/tjefferies/pymetalog) |
|---------|:-----------:|:-----------:|
| GPU acceleration | Yes | No |
| Automatic differentiation | Yes | No |
| JIT compilation | Yes | No |
| Bounded distribution support | Yes | Yes |
| Multiple regression methods | Yes (OLS, LASSO) | OLS only |
| Hyperparameter grid search | Yes (vectorized) | No |
| Serialization (save/load) | Yes | No |
| Active development | Yes | No |

## Installation

### Using pip

```bash
pip install metalog-jax
```

### Using uv (recommended)

```bash
uv add metalog-jax
```

### From source

```bash
git clone https://github.com/tjefferies/metalog_jax.git
cd metalog_jax
make install    # Install all dependencies
```

### Requirements

- Python >= 3.11
- JAX >= 0.8.0
- Flax >= 0.12.0
- Chex >= 0.1.91
- Plotly >= 6.4.0

## Quick Start

### Fitting a Distribution

```python
import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.utils import DEFAULT_Y

# Sample data (e.g., response times in microseconds)
response_times = jnp.array([
    45, 52, 58, 61, 67, 72, 78, 85, 93, 102,
    115, 128, 145, 167, 195, 238, 312, 456, 623, 892
])

# Create validated input data
data = MetalogInputData.from_values(response_times, DEFAULT_Y, precomputed_quantiles=False)

# Configure the metalog (bounded below by 0)
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Fit the distribution
metalog = fit(data, params)

# Analyze your distribution
print(f"Median response time: {float(metalog.ppf(jnp.array([0.5]))[0]):.1f} ms")
print(f"95th percentile: {float(metalog.ppf(jnp.array([0.95]))[0]):.1f} ms")
print(f"99th percentile: {float(metalog.ppf(jnp.array([0.99]))[0]):.1f} ms")
```

### Working with Bounded Data

Metalog natively supports four boundedness types:

```python
from metalog_jax.base import MetalogParameters, MetalogBoundedness, MetalogFitMethod

# Unbounded: support on (-inf, +inf)
# Example: temperature anomalies, returns
params = MetalogParameters(
    boundedness=MetalogBoundedness.UNBOUNDED,
    lower_bound=0.0,  # ignored
    upper_bound=0.0,  # ignored
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Lower-bounded: support on (lower_bound, +inf)
# Example: response times, prices, distances
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,  # ignored
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Upper-bounded: support on (-inf, upper_bound)
# Example: time until deadline, remaining capacity
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_UPPER_BOUND,
    lower_bound=0.0,  # ignored
    upper_bound=100.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Fully bounded: support on (lower_bound, upper_bound)
# Example: percentages, probabilities, test scores
params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=100.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)
```

### Regularized Fitting

For noisy data or when using many terms, LASSO regularization improves stability:

```python
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.regression import LassoParameters

# LASSO (L1 regularization for sparse coefficients)
lasso_params = LassoParameters(
    lam=0.01,              # L1 regularization strength
    learning_rate=1e-3,
    num_iters=1000,
    tol=1e-6,
    momentum=0.9,
)
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,
    method=MetalogFitMethod.Lasso,
    num_terms=6,
)
metalog = fit(data, params, regression_hyperparams=lasso_params)
```

### SPT Metalog (3-Term Analytical Fitting)

For rapid approximation with minimal data, use the Symmetric Percentile Triplet method. SPT metalog computes coefficients analytically from just three quantiles and **validates feasibility upfront**—ensuring the resulting distribution has a valid (non-negative) PDF before returning. This fail-fast behavior prevents downstream errors from infeasible fits.

```python
import jax.numpy as jnp
from metalog_jax.base import MetalogBoundedness, SPTMetalogParameters
from metalog_jax.metalog import fit_spt_metalog

# Sample data (e.g., response times in microseconds)
response_times = jnp.array([
    45, 52, 58, 61, 67, 72, 78, 85, 93, 102,
    115, 128, 145, 167, 195, 238, 312, 456, 623, 892
])

# Only needs 3 quantiles: alpha, median, 1-alpha
# Use STRICTLY_LOWER_BOUND for non-negative data like response times
spt_params = SPTMetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    alpha=0.1,  # Uses 10th, 50th, and 90th percentiles
    lower_bound=0.0,
    upper_bound=0.0,
)

spt_metalog = fit_spt_metalog(response_times, spt_params)
print(f"SPT Median: {float(spt_metalog.ppf(jnp.array([0.5]))[0]):.1f}")
```

## Features

### Full Probability Distribution API

```python
import jax.numpy as jnp

# Quantile function (inverse CDF)
quantiles = metalog.ppf(jnp.array([0.1, 0.25, 0.5, 0.75, 0.9]))

# Probability density function
pdf_values = metalog.pdf(jnp.array([0.2, 0.4, 0.6, 0.8]))

# Log probability density (numerically stable)
log_pdf = metalog.logpdf(jnp.array([0.2, 0.4, 0.6, 0.8]))

# Cumulative distribution function
cdf_values = metalog.cdf(jnp.array([50.0, 100.0, 150.0]))

# Survival function (1 - CDF)
sf_values = metalog.sf(jnp.array([50.0, 100.0, 150.0]))

# Inverse survival function
isf_values = metalog.isf(jnp.array([0.1, 0.05, 0.01]))

# Summary statistics (properties, not methods)
mean = metalog.mean
variance = metalog.var
std_dev = metalog.std
mode = metalog.mode
median = metalog.median
```

### Random Sampling

Two PRNG backends are supported for random variate generation:

```python
from metalog_jax.base import MetalogRandomVariableParameters
from metalog_jax.utils import JaxUniformDistributionParameters, HDRPRNGParameters

# JAX-based random sampling (standard approach)
rv_params = MetalogRandomVariableParameters(
    prng_params=JaxUniformDistributionParameters(seed=42),
    size=10000,
)
samples = metalog.rvs(rv_params)

# HDR PRNG for reproducible Monte Carlo simulations
# Multi-dimensional, counter-based PRNG ideal for parallel simulations
hdr_params = HDRPRNGParameters(
    trial=1,       # Simulation trial/iteration
    variable=0,    # Random variable identifier
    entity=0,      # Entity being simulated
    time=0,        # Time step
    agent=0,       # Agent/actor identifier
)
rv_params = MetalogRandomVariableParameters(
    prng_params=hdr_params,
    size=10000,
)
samples = metalog.rvs(rv_params)
```

### Interactive Visualization

```python
from metalog_jax.base import MetalogPlotOptions

# Plot PDF
fig = metalog.plot(MetalogPlotOptions.PDF)
fig.show()

# Plot CDF
fig = metalog.plot(MetalogPlotOptions.CDF)
fig.show()

# Plot Survival Function
fig = metalog.plot(MetalogPlotOptions.SF)
fig.show()
```

### Serialization

```python
from pathlib import Path
from metalog_jax.metalog import Metalog

# Save to JSON
metalog.save(Path("my_distribution.json"))

# Load from JSON
loaded = Metalog.load(Path("my_distribution.json"))

# String serialization
json_str = metalog.dumps()
loaded2 = Metalog.loads(json_str)
assert metalog == loaded == loaded2
```

### Unified Grid Search with `fit_grid`

The `fit_grid` function provides a unified interface for hyperparameter optimization,
automatically detecting which axes to search based on inputs.

#### Grid Search over L1 Penalties

```python
import jax.numpy as jnp
import numpy as np
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.grid_search import fit_grid, find_best_config, extract_metalog
from metalog_jax.utils import DEFAULT_Y

# Generate sample data
np.random.seed(42)
samples = jnp.array(np.random.beta(2, 5, 100))

# Create validated input data
data = MetalogInputData.from_values(samples, DEFAULT_Y, precomputed_quantiles=False)

params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=1.0,
    method=MetalogFitMethod.Lasso,
    num_terms=5,
)

# Grid search over L1 penalties
l1_penalties = jnp.array([0.0, 0.001, 0.01, 0.1])
result = fit_grid(data.x, data.y, params, l1_penalties=l1_penalties)

# Find the best configuration
best_idx, best_ks = find_best_config(result.ks_dist)
print(f"Best L1 penalty: {float(l1_penalties[int(best_idx)]):.4f}")
print(f"Best KS distance: {float(best_ks):.4f}")

# Extract the best metalog for use
best_metalog = extract_metalog(result, int(best_idx))
print(f"Median: {float(best_metalog.ppf(jnp.array([0.5]))[0]):.4f}")
```

#### 2D Grid Search (L1 x num_terms)

```python
import jax.numpy as jnp
import numpy as np
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.grid_search import fit_grid, find_best_config, extract_metalog
from metalog_jax.utils import DEFAULT_Y

# Generate sample data
np.random.seed(42)
samples = jnp.array(np.random.beta(2, 5, 100))

# Create validated input data
data = MetalogInputData.from_values(samples, DEFAULT_Y, precomputed_quantiles=False)

params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=1.0,
    method=MetalogFitMethod.Lasso,
    num_terms=5,
)

# 2D grid search over L1 penalties and num_terms
l1_penalties = jnp.array([0.0, 0.01, 0.1])
num_terms_list = [5, 7, 9]

result = fit_grid(
    data.x, data.y, params,
    l1_penalties=l1_penalties,
    num_terms=num_terms_list
)

# Result shape is (len(l1_penalties), len(num_terms_list))
print(f"Grid shape: {result.ks_dist.shape}")  # (3, 3)

# Find the best configuration
best_idx, best_ks = find_best_config(result.ks_dist)
best_l1_idx, best_terms_idx = int(best_idx[0]), int(best_idx[1])
print(f"Best L1 penalty: {float(l1_penalties[best_l1_idx]):.4f}")
print(f"Best num_terms: {num_terms_list[best_terms_idx]}")
print(f"Best KS distance: {float(best_ks):.4f}")

# Extract the best metalog for use
best_metalog = extract_metalog(result, best_l1_idx, best_terms_idx)
print(f"Median: {float(best_metalog.ppf(jnp.array([0.5]))[0]):.4f}")
```

#### Batch Multiple Datasets with Full 3D Grid

```python
import jax.numpy as jnp
import numpy as np
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.grid_search import fit_grid, find_best_config, extract_metalog
from metalog_jax.utils import DEFAULT_Y

# Create 3 different datasets with different distributions
np.random.seed(42)
samples1 = jnp.array(np.random.beta(2, 5, 100))  # left-skewed
samples2 = jnp.array(np.random.beta(5, 2, 100))  # right-skewed
samples3 = jnp.array(np.random.beta(2, 2, 100))  # symmetric

data1 = MetalogInputData.from_values(samples1, DEFAULT_Y, precomputed_quantiles=False)
data2 = MetalogInputData.from_values(samples2, DEFAULT_Y, precomputed_quantiles=False)
data3 = MetalogInputData.from_values(samples3, DEFAULT_Y, precomputed_quantiles=False)

# Stack datasets for batch processing
batched_x = jnp.stack([data1.x, data2.x, data3.x])
batched_y = jnp.stack([data1.y, data2.y, data3.y])

params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=1.0,
    method=MetalogFitMethod.Lasso,
    num_terms=5,
)

l1_penalties = jnp.array([0.0, 0.01])
num_terms_list = [5, 7]

result = fit_grid(
    batched_x, batched_y, params,
    l1_penalties=l1_penalties,
    num_terms=num_terms_list
)

# Result shape is (n_datasets, len(l1_penalties), len(num_terms_list))
print(f"Grid shape: {result.ks_dist.shape}")  # (3, 2, 2)

# Find best configuration for each dataset and extract the metalog
dataset_names = ["left-skewed", "right-skewed", "symmetric"]
for d, name in enumerate(dataset_names):
    best_idx, best_ks = find_best_config(result.ks_dist[d])
    best_l1_idx, best_terms_idx = int(best_idx[0]), int(best_idx[1])

    # Extract the best metalog for this dataset
    best_metalog = extract_metalog(result, d, best_l1_idx, best_terms_idx)

    print(f"Dataset '{name}': L1={float(l1_penalties[best_l1_idx]):.4f}, "
          f"terms={num_terms_list[best_terms_idx]}, KS={float(best_ks):.4f}, "
          f"median={float(best_metalog.ppf(jnp.array([0.5]))[0]):.4f}")
```

The function handles all 8 combinations of axes automatically:
- Single/batched datasets
- With/without L1 penalty grid
- With/without num_terms grid

### JAX Transformations

Full compatibility with JAX's transformation primitives:

```python
import jax

# Automatic differentiation
def quantile_at(prob):
    return metalog.ppf(prob)

gradient_fn = jax.grad(quantile_at)
gradient = gradient_fn(0.5)
```

## Choosing the Number of Terms

| Terms | Use Case | Data Requirements |
|-------|----------|-------------------|
| 2 | Simple symmetric distributions | 6+ observations |
| 3-4 | Moderate skewness | 9-12+ observations |
| 5-6 | Heavy tails, asymmetry | 15-18+ observations |
| 7-10 | Complex multimodal shapes | 21-30+ observations |
| 10+ | Highly irregular distributions | 30+ observations |

**Rule of thumb**: Use at least 3x observations per term.

## Architecture

```
metalog_jax/
├── base/                    # Core abstractions
│   ├── core.py             # MetalogBase class with distribution methods
│   ├── data.py             # Input data validation and containers
│   ├── enums.py            # MetalogBoundedness, MetalogFitMethod
│   └── parameters.py       # Configuration dataclasses
├── regression/              # Fitting algorithms
│   ├── base.py             # RegressionModel, RegularizedParameters
│   ├── ols.py              # Ordinary Least Squares
│   └── lasso.py            # LASSO (L1 regularization)
├── metalog.py              # Metalog, SPTMetalog, fit, fit_spt_metalog
├── grid_search.py          # Unified fit_grid for hyperparameter optimization
└── utils.py                # HDRPRNG, KS distance, DEFAULT_Y, helpers
```

## Comparison with Standard Distributions

Metalog can approximate any continuous distribution. Here's how it compares fitting various scipy distributions:

| Distribution | 5-term KS Distance | 7-term KS Distance |
|--------------|-------------------|-------------------|
| Normal | < 0.001 | < 0.0001 |
| Log-Normal | < 0.002 | < 0.0005 |
| Gamma | < 0.003 | < 0.001 |
| Beta | < 0.002 | < 0.0005 |
| Weibull | < 0.003 | < 0.001 |
| Chi-Square | < 0.004 | < 0.001 |
| Student's t | < 0.003 | < 0.001 |

*KS Distance: Kolmogorov-Smirnov distance (lower is better)*

## Documentation

- **[Online Documentation](https://tjefferies.github.io/metalog_jax/)** - Full documentation on GitHub Pages
- **[Getting Started Guide](https://tjefferies.github.io/metalog_jax/getting-started.html)** - Installation and basic usage
- **[API Reference](https://tjefferies.github.io/metalog_jax/api/modules.html)** - Complete API documentation

### Tutorials

Interactive notebooks are available in two formats:

**Jupyter Notebooks** (pre-executed, viewable in browser):
- [Basic Usage](https://github.com/tjefferies/metalog_jax/blob/main/docs/source/basic_usage.ipynb) - Core API and distribution methods
- [Grid Search](https://github.com/tjefferies/metalog_jax/blob/main/docs/source/fitting_grids.ipynb) - Hyperparameter optimization with `fit_grid`

**Marimo Notebooks** (interactive, run locally):
- `marimo run examples/basic_usage.py`
- `marimo run examples/fitting_grids.py`

## Contributing

We welcome contributions! Please follow these guidelines to ensure a smooth review process.

### Before You Start

1. **Clone** repo locally
2. **Create** a feature branch: `git checkout -b feature/amazing-feature`
3. **Install** dependencies: `make install`

### Development Workflow

This project uses a Makefile that mirrors all CI/CD checks. Run `make help` to see all available targets.

**Write and test your changes:**

```bash
make test-quick    # Fast iteration (no coverage)
make test          # Full test suite with coverage
```

**Check code quality before committing:**

```bash
make format        # Auto-format code
make lint          # Check for issues
make typecheck     # Verify types
```

**Build documentation** (if you modified docstrings):

```bash
make docs          # Build once
make docs-live     # Live-reload during development
```

**Before pushing**, run the full quality gate to catch CI failures early:

```bash
make quality-gate
```

This runs all 8 checks that CI will run: formatting, linting, type checking, complexity metrics, tests, license compliance, and security scans.

### Submitting Changes

1. **Commit** with a clear message: `git commit -m 'Add amazing feature'`
2. **Push** to your branch: `git push origin feature/amazing-feature`
3. **Open** a Pull Request with a description of your changes

> **Note:** Pull Requests with failing CI will not be reviewed. Run `make quality-gate` locally first.

### Development Setup

```bash
git clone https://github.com/tjefferies/metalog_jax.git
cd metalog_jax
make install    # Install all dependencies
make test       # Run tests with coverage
make docs       # Build documentation
make help       # Show all available Make targets
```

### Code Style

- Follow existing code conventions in the repository
- Use type hints for all function signatures
- Write comprehensive docstrings following Google style
- Keep functions focused and single-purpose
- Prefer immutable data structures (Flax dataclasses)

## Citation

If you use metalog-jax in your research, please cite:

```bibtex
@software{metalog_jax,
  author = {Jefferies, Travis},
  title = {metalog-jax: GPU-accelerated metalog distributions for JAX},
  year = {2026},
  url = {https://github.com/tjefferies/metalog_jax}
}
```

And the original metalog paper:

```bibtex
@article{keelin2016metalog,
  author = {Keelin, Thomas W.},
  title = {The Metalog Distributions},
  journal = {Decision Analysis},
  volume = {13},
  number = {4},
  pages = {243-277},
  year = {2016},
  doi = {10.1287/deca.2016.0338}
}
```

## References

### Metalog Distributions

- Keelin, T. W. (2016). [The Metalog Distributions](https://doi.org/10.1287/deca.2016.0338). *Decision Analysis*, 13(4), 243-277.
- [Metalog Distributions Website](http://metalogdistributions.com/) — Official resource by Tom Keelin

### Regression Methods

- Hastie, T., Tibshirani, R., & Friedman, J. (2009). *The Elements of Statistical Learning: Data Mining, Inference, and Prediction* (2nd ed.). Springer. Chapter 3: Linear Methods for Regression.
- Tibshirani, R. (1996). [Regression Shrinkage and Selection via the Lasso](https://doi.org/10.1111/j.2517-6161.1996.tb02080.x). *Journal of the Royal Statistical Society: Series B (Methodological)*, 58(1), 267-288.

### Optimization Algorithms

- Beck, A., & Teboulle, M. (2009). [A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems](https://doi.org/10.1137/080716542). *SIAM Journal on Imaging Sciences*, 2(1), 183-202.
- Parikh, N., & Boyd, S. (2014). [Proximal Algorithms](https://doi.org/10.1561/2400000003). *Foundations and Trends in Optimization*, 1(3), 127-239.

### Statistical Methods

- Kolmogorov, A. N. (1933). Sulla determinazione empirica di una legge di distribuzione. *Giornale dell'Istituto Italiano degli Attuari*, 4, 83-91.
- Smirnov, N. V. (1948). Table for estimating the goodness of fit of empirical distributions. *Annals of Mathematical Statistics*, 19(2), 279-281.

### Random Number Generation

- Hubbard, D. W. (2019). [A Multi-dimensional, Counter-based Pseudo Random Number Generator as a Standard for Monte Carlo Simulations](https://doi.org/10.1109/WSC40007.2019.9004789). *Proceedings of the Winter Simulation Conference*, IEEE Press, 3064-3073.

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

---

<p align="center">
  <sub>Built with JAX by <a href="https://github.com/tjefferies">Travis Jefferies</a></sub>
</p>
