Metadata-Version: 2.4
Name: jaxht
Version: 0.1.0
Summary: JAX-native spherical harmonic transforms (BK-regime first): GPU-capable, differentiable, dependency-controlled.
Project-URL: Homepage, https://github.com/jrcheshire/jht
Project-URL: Repository, https://github.com/jrcheshire/jht
Project-URL: Issues, https://github.com/jrcheshire/jht/issues
Project-URL: Changelog, https://github.com/jrcheshire/jht/blob/main/CHANGELOG.md
Author-email: James Cheshire <cheshire@caltech.edu>
Maintainer-email: James Cheshire <cheshire@caltech.edu>
License-Expression: MIT
License-File: LICENSE
Keywords: automatic-differentiation,cosmic-microwave-background,cosmology,gpu,healpix,jax,nufft,spherical-harmonic-transform
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Topic :: Scientific/Engineering :: Astronomy
Classifier: Topic :: Scientific/Engineering :: Physics
Classifier: Typing :: Typed
Requires-Python: >=3.11
Requires-Dist: jax>=0.4.30
Requires-Dist: numpy>=1.24
Provides-Extra: dev
Requires-Dist: mypy; extra == 'dev'
Requires-Dist: pytest; extra == 'dev'
Requires-Dist: ruff; extra == 'dev'
Provides-Extra: validation
Requires-Dist: ducc0; extra == 'validation'
Requires-Dist: healpy; extra == 'validation'
Description-Content-Type: text/markdown

# jht — JAX Harmonic Transforms

JAX-native spherical harmonic transforms (map ↔ aₗₘ): **GPU-capable**, **fully
differentiable**, and **dependency-controlled** (pure JAX + numpy at runtime — no
compiled C++ extension, no heavyweight third-party SHT library). Scoped to the
BICEP/Keck regime — **spin-0 and spin-2** on the **HEALPix RING** pixelization,
ℓ_max ≲ 1000, nside ≤ ~2048 — but written cleanly so it can serve as a general
transform dependency.

It exists to serve the GPU / differentiable tier of analysis that a CPU-only C++
transform (ducc0) structurally cannot, while *owning the numerics*. See
[`docs/motivation.md`](docs/motivation.md) for the full decision record.

## Status (2026-06-10)

Phases 0–4 **complete and validated** (190 tests pass + 8 GPU-gated skips,
CPU/float64):

- **On-grid transforms** — spin-0 & spin-2 synthesis (`aₗₘ→map`) and the exact
  adjoint `Sᵀ`, validated to machine precision vs healpy **and** ducc0; spin-2
  inverse at the HEALPix floor with **no s2fft-style structural defect**.
- **Accuracy** — jht's own ring quadrature weights + Jacobi iteration reach
  ~1e-13 on band-limited maps (matches `healpy.map2alm(use_weights=True)`); see
  [`docs/accuracy.md`](docs/accuracy.md).
- **Partial-sky** — masked pseudo-aₗₘ, a cut-sky CG deconvolution, and a masked
  Wiener filter / constrained realization (the MUSE inner solve); see
  [`docs/masked.md`](docs/masked.md).
- **Off-grid (NUFFT)** — `synthesis_general` / `adjoint_synthesis_general`
  evaluate a band-limited field at **arbitrary pointings** (spin 0–3), alm- **and**
  pointing-differentiable. The JAX-native replacement for ducc0's
  `sht.experimental.synthesis_general` (on-grid SHT + this NUFFT = the full ducc0
  surface bk-jax depends on); see [`docs/offgrid.md`](docs/offgrid.md).
- **Differentiability** — native JAX autodiff (`jacfwd ≡ jacrev`, tight adjoint
  identity), plus a convention-clean real-DOF layer `jht.diff`; see
  [`docs/design.md`](docs/design.md) §Differentiability.
- **GPU** — pure JAX, so the transforms run on CUDA with no code change. Measured on
  Cannon A100/V100 (fp64): GPU==CPU parity ~1e-13 across the BK regime **including
  nside=2048**, forward synthesis 14–60× the CPU. See **Performance** below and
  [`docs/gpu.md`](docs/gpu.md).

## Install

Standard env is [pixi](https://pixi.sh):

```bash
pixi install          # CPU env (osx-arm64 / linux-64)
pixi run test         # the gated suite
```

GPU (CUDA, linux-64 — see [`docs/gpu.md`](docs/gpu.md)):

```bash
pixi run -e gpu python scripts/gpu_check.py   # on an NVIDIA box
```

As a dependency in another project (runtime deps are just `jax` + `numpy`):

```bash
pip install jaxht        # once released on PyPI — then `import jht`
# or track the repo directly:
pip install "jaxht @ git+https://github.com/jrcheshire/jht.git"
```

## Quick start

float64 is **opt-in per entry point** — enable it *before creating any array*
(library code never touches the global config):

```python
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jht

nside, lmax, spin = 256, 512, 0
m = jht.synthesis(alm, nside, lmax, spin=spin)          # aₗₘ -> map
a = jht.map2alm(m, nside, lmax, spin=spin, niter=3)     # map -> aₗₘ (weighted + iterated)
cl = jht.bandpower(a, lmax, spin=spin)                  # angular auto-power C_ℓ
```

`spin=2` takes/returns `(E, B)` aₗₘ of shape `(2, …)` and `(Q, U)` maps of shape
`(2, npix)`. `jht.adjoint_synthesis` is the **exact unweighted transpose** `Sᵀ`
(the operator seam / VJP), distinct from `map2alm` (the approximate inverse). For
gradient-based work use the real-DOF layer `jht.synthesis_real` /
`jht.analysis_real` (plain ℝⁿ→ℝᵐ, no complex-conjugate convention subtlety).

## Conventions

healpy m-major triangular aₗₘ packing, orthonormal Yₗₘ with the Condon–Shortley
phase, HEALPix-internal (COSMO) polarization — verified against healpy 1.19.0 and
ducc0 0.41.0. Pinned in [`docs/design.md`](docs/design.md).

## Accuracy tiers (the contract)

jht targets the **GPU/differentiable tier** where the HEALPix ~1e-3 sampling
floor is acceptable; weights + iteration close it to ~1e-13 on band-limited
inputs. It is **not** a drop-in for ducc's purity-critical (~1e-4 E→B-leakage)
production path. Tolerances are a-priori and gate-driven, never relaxed without
sign-off. Residual mismatches are logged in
[`DISCREPANCIES.md`](DISCREPANCIES.md).

## Performance

Pure JAX runs unchanged on CUDA. Measured on Cannon A100 (incl. a 20 GB MIG) / V100, fp64:

- **GPU==CPU parity ~1e-13** across the BK regime, **including nside=2048** (synthesis and `map2alm`).
- **Forward synthesis 14–60×** the 8-core CPU; fp64/fp32 ≈ 2.2×.
- **Off-grid forward** ~0.5–0.9 s at ℓ_max=1000 (independent of the number of points; recursion-bound), with the pointing gradient ~1× a forward.
- **nside=2048** compiles and runs on GPU — a ~20 GB slice holds synthesis + `map2alm`; the one-time compile is multi-minute (jit-cached).

The recurring GPU lesson: fp64/complex scatters are catastrophic on GPU, so jht packs and assembles via **gathers**. CPU perf model + memory in [`docs/performance.md`](docs/performance.md); GPU detail in [`docs/gpu.md`](docs/gpu.md).

## Using jht as a dependency

jht is standalone and consumer-agnostic. The operator/grad seam a downstream
needs (e.g. to use jht *in place of ducc0*) — and the accuracy boundary — are
documented in [`docs/consumers.md`](docs/consumers.md). Any backend-selection
wiring lives in the consumer, not here.

## Docs

- [`docs/design.md`](docs/design.md) — technical design, conventions, the crux, differentiability.
- [`docs/accuracy.md`](docs/accuracy.md) — the accuracy contract + ring-weight algorithm.
- [`docs/masked.md`](docs/masked.md) — partial-sky estimators.
- [`docs/performance.md`](docs/performance.md) — CPU perf model + memory.
- [`docs/gpu.md`](docs/gpu.md) — the GPU env, the x64 requirement, the harness.
- [`docs/consumers.md`](docs/consumers.md) — the downstream-dependency seam.
- [`docs/motivation.md`](docs/motivation.md) — why jht exists.
- [`ROADMAP.md`](ROADMAP.md) — phased plan + gates.
