Metadata-Version: 2.4
Name: jrmt
Version: 0.1.0
Summary: JAXed Random Matrix Theory
Project-URL: Homepage, https://github.com/emaballarin/jrmt
Author-email: Emanuele Ballarin <emanuele@ballarin.cc>
License: MIT
License-File: LICENSE
Keywords: Differentiable Programming,JAX,Machine Learning,Random Matrices,Random Matrix Theory
Classifier: Development Status :: 3 - Alpha
Classifier: Environment :: Console
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.10
Requires-Dist: jax>=0.4
Requires-Dist: jaxtyping>=0.2
Description-Content-Type: text/markdown

# `jrmt` — _JAXed_ Random Matrix Theory

A JAX port of [`thrmt`](https://github.com/emaballarin/thrmt). Pure
functional samplers for the classical real / complex / quantum random
matrix ensembles, built on `jax.numpy` and `jaxtyping`.

## Parity with `thrmt`

`jrmt` mirrors `thrmt`'s public surface 1:1 — all 25 names from
`thrmt`'s `__init__` (16 ensembles + 9 historical aliases) are
available under `jrmt`. The single intentional API break is described
below.

| Migration of …                                 | `thrmt` (PyTorch)                           | `jrmt` (JAX)                                                                                          |
| ---------------------------------------------- | ------------------------------------------- | ----------------------------------------------------------------------------------------------------- |
| PRNG                                           | implicit, global `torch.manual_seed(...)`   | explicit, **`key` is the first positional argument** of every random function                         |
| Batching                                       | `batch_shape=(B,)` kwarg returning `(B, …)` | `jax.vmap(lambda k: fn(k, …))(jax.random.split(root_key, B))` — pure-functional                       |
| Device                                         | `device=...` kwarg                          | dropped — use `jax.device_put` or JAX's implicit placement                                            |
| dtype                                          | `torch.{cdouble, cfloat, double, float}`    | `jnp.{complex128, complex64, float64, float32, bfloat16, float16}`                                    |
| Shape types                                    | runtime checks at the wrapper               | `jaxtyping` annotations: `Float[Array, "n n"]`, `Complex[Array, "n n"]`, `PRNGKeyArray`               |
| Output of `random_obs_csu`'s `evdist` callback | `evdist(*batch_shape, size, dtype, device)` | **`evdist(key, size, dtype)`** — the only documented public-API break (forced by JAX's PRNG plumbing) |

For `random_rho_pure`, `thrmt`'s `bo_einsum` keyword is gone — the
JAX implementation always uses `jnp.einsum` for the outer product.

## Install

```bash
pip install -e .
```

Requires Python ≥ 3.10, `jax`, and `jaxtyping`. Enable 64-bit
precision before importing anything that uses `complex128`:

```python
from jax import config
config.update("jax_enable_x64", True)
```

## Use

```python
import jax
import jax.numpy as jnp
import jrmt

key = jax.random.key(0)

# Single draw
h = jrmt.random_gue(key, size=8, sigma=1.0)
assert jnp.allclose(h, jnp.conjugate(h.T))

# Batched draw via vmap
keys = jax.random.split(key, 1000)
batch = jax.vmap(lambda k: jrmt.random_rho_hs(k, size=8))(keys)
assert batch.shape == (1000, 8, 8)
```

## Tests

```bash
pytest                 # fast invariants + cross-library moment-matching
pytest -m slow         # asymptotic-law KS tests (~30 s)
```

Cross-library moment-matching (`tests/test_cross_lib.py`) needs both
`thrmt` and `torch` installed in the environment; the other suites do
not.
