Metadata-Version: 2.4
Name: mlx-fisher
Version: 0.1.0
Summary: GPU-accelerated Fisher Information Matrix computation on Apple Silicon via MLX
Author-email: Sheng-Kai Huang <akai@fawstudio.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/akaiHuang/mlx-fisher
Project-URL: Repository, https://github.com/akaiHuang/mlx-fisher
Project-URL: Issues, https://github.com/akaiHuang/mlx-fisher/issues
Keywords: fisher-information,mlx,apple-silicon,gpu,natural-gradient,cosmology
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: MacOS
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Physics
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: mlx>=0.10.0
Requires-Dist: numpy>=1.24.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Dynamic: license-file

# mlx-fisher

GPU-accelerated Fisher Information Matrix computation on Apple Silicon via [MLX](https://github.com/ml-explore/mlx).

**Author**: Sheng-Kai Huang (akai@fawstudio.com)

## Features

- **Fisher Information Matrix** from log-likelihood, model predictions, or samples
- **CMB C_l Fisher matrix** for cosmological parameter estimation
- **KL divergence** (classical) and **quantum relative entropy** D(rho||sigma)
- **Natural gradient descent** optimizer with online Fisher estimation
- All matrix operations (eigendecomposition, matrix multiply, log) on Apple GPU

## Installation

```bash
pip install -e .
```

Requires Python 3.10+ and Apple Silicon (M1/M2/M3/M4).

## Quick Start

```python
import mlx.core as mx
from mlx_fisher import FisherMatrix, kl_divergence, quantum_relative_entropy

# --- Gaussian Fisher matrix from a model ---
x = mx.linspace(-3.0, 3.0, 1000)

def model(theta):
    return theta[0] * x**2 + theta[1] * x + theta[2]

theta0 = mx.array([1.0, -0.5, 2.0])
sigma = mx.ones((1000,)) * 0.5

F = FisherMatrix.from_model(model, theta0, sigma)
print(F.marginal_errors())   # 1-sigma errors on each parameter

# --- KL divergence ---
p = mx.array([0.4, 0.3, 0.2, 0.1])
q = mx.array([0.25, 0.25, 0.25, 0.25])
print(kl_divergence(p, q))   # D_KL(p || q)

# --- Quantum relative entropy ---
d = 4
rho = mx.zeros((d, d)); rho[0, 0] = 1.0       # pure state |0><0|
sigma_dm = mx.eye(d) / d                        # maximally mixed
print(quantum_relative_entropy(rho, sigma_dm))   # = ln(d)
```

## CMB Fisher Matrix

```python
from mlx_fisher import fisher_matrix_cl

def cl_fn(theta):
    """Map cosmological parameters to C_l power spectrum."""
    # Your Boltzmann solver here (e.g., CLASS wrapper)
    ...

theta_fid = mx.array([0.022, 0.12, 0.06, 0.96, 3.04, 67.4])
F = fisher_matrix_cl(cl_fn, theta_fid, f_sky=0.7, l_min=2, l_max=2500)
print(F.marginal_errors())
```

## Natural Gradient Descent

```python
from mlx_fisher import NaturalGradientOptimizer

opt = NaturalGradientOptimizer(lr=1e-2, damping=1e-4)

for step in range(100):
    grad = compute_gradient(theta)
    theta = opt.step(theta, grad, fisher_estimator=compute_fisher)
```

## API Reference

### `FisherMatrix`
- `FisherMatrix.from_model(model_fn, theta, sigma)` -- Gaussian Fisher matrix
- `FisherMatrix.from_loglikelihood(log_lik, theta)` -- from log-likelihood function
- `FisherMatrix.from_samples(log_prob, theta, samples)` -- empirical Fisher
- `.inverse(reg=0.0)` -- covariance matrix (regularised inversion)
- `.marginal_errors(reg=0.0)` -- 1-sigma marginal errors
- `.eigenvalues()` -- eigenvalue spectrum
- `.condition_number()` -- matrix condition number

### `fisher_matrix_cl(cl_fn, theta, f_sky, noise_cl, l_min, l_max)`
CMB power spectrum Fisher matrix with cosmic variance.

### `kl_divergence(p, q)`
Classical KL divergence D_KL(p || q).

### `quantum_relative_entropy(rho, sigma)`
Quantum relative entropy D(rho || sigma) = Tr[rho(ln rho - ln sigma)].

### `NaturalGradientOptimizer(lr, damping, fisher_update_interval, ema_decay)`
Natural gradient descent: theta_new = theta - lr * F^{-1} @ grad.

## Benchmarks (M1 Max)

| Operation | Scale | MLX (ms) | NumPy (ms) | Speedup |
|-----------|-------|----------|------------|---------|
| KL divergence | 1M bins | 0.52 | 5.22 | **10x** |
| KL divergence | 10M bins | 2.23 | 51.40 | **23x** |
| Eigendecomposition | 512x512 | 17.25 | 58.07 | **3.4x** |
| Matrix multiply (Fisher) | 32768x512 | 4.33 | 169.96 | **39x** |

See [benchmark_results.md](benchmark_results.md) for full results.

## License

MIT
