Metadata-Version: 2.2
Name: eigh-cuda128
Version: 0.4.1
Summary: Differentiable eigenvalue decomposition with JAX — CUDA 12.8 (GPU) build
Keywords: jax,eigenvalue,linear-algebra,gpu,cuda,autodiff
Author: Igor Sokolov
Maintainer: Igor Sokolov
License: Apache-2.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: POSIX :: Linux
Classifier: Environment :: GPU :: NVIDIA CUDA :: 12
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Mathematics
Project-URL: Homepage, https://github.com/Brogis1/eigh
Project-URL: Repository, https://github.com/Brogis1/eigh
Project-URL: Original Project, https://github.com/fishjojo/pyscfad
Requires-Python: >=3.10
Requires-Dist: jax>=0.5.0
Requires-Dist: jaxlib>=0.5.0
Requires-Dist: numpy>=1.20.0
Requires-Dist: nvidia-cuda-runtime-cu12
Requires-Dist: nvidia-cusolver-cu12
Requires-Dist: nvidia-cublas-cu12
Requires-Dist: nvidia-cusparse-cu12
Requires-Dist: nvidia-cuda-nvrtc-cu12
Requires-Dist: nvidia-nvjitlink-cu12
Requires-Dist: jax[cuda12]>=0.5.0
Provides-Extra: test
Requires-Dist: pytest>=7.0.0; extra == "test"
Requires-Dist: scipy>=1.7.0; extra == "test"
Description-Content-Type: text/markdown

# Differentiable Generalized Eigenvalue Decomposition

[![Tests](https://github.com/Brogis1/eigh/actions/workflows/tests.yml/badge.svg)](https://github.com/Brogis1/eigh/actions/workflows/tests.yml)

<img src="https://raw.githubusercontent.com/Brogis1/eigh/main/img/eig.png?v=4" alt="Eigh Logo" width="400">

Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from [pyscfad](https://github.com/fishjojo/pyscfad).

CPU and GPU wheels on PyPI for Linux and macOS (Apple Silicon), Python 3.10–3.13,
JAX 0.5–0.10+. See [Installation](#installation--quick-start) and
[Compatibility](#compatibility).

> ### New
> - Core code rewritten to be able to run on older cluster with JAX 0.4.XX for instance (most likely on GPU clusters).
> - Builds for CUDA but I recommend just building from source, fast and easy with this package (it will work for your specific JAX version).

## Features
- **Generalized Problems**: `A @ V = B @ V @ diag(W)`, etc.
- **JAX Integrated**: Full support for `jit`, `vmap`, `grad`, and `jvp`.
- **High Performance**: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
- **Precision**: `float32/64` and `complex64/128`.
- **Degeneracy Handling**: Configurable `deg_thresh` for stable gradients.

## Installation & Quick Start

### CPU

```bash
pip install eigh
```

Prebuilt **CPU-only** wheels — Linux (x86_64) and macOS (Apple Silicon),
Python 3.10–3.13, JAX 0.5+.

### GPU - Build from source (Recommended)

Make sure first that you have JAX installed that runs fine on your GPU.

Build from source for a `jaxlib` / CUDA / glibc combination the wheels don't
cover. The main case is an environment pinned to e.g., **`jaxlib` 0.4.29** (the
prebuilt wheels require `jaxlib` ≥ 0.5 — the FFI binary ABI changed at 0.5, so a
0.5 wheel can't run on 0.4.x and vice-versa). The source builds against
**whatever `jaxlib` is in your env** (0.4.29 or 0.5–0.10+), CPU or GPU:

```bash
git clone https://github.com/Brogis1/eigh && cd eigh
pip install "scikit-build-core>=0.8" "nanobind>=1.0.0" cmake ninja
pip install . --no-build-isolation --no-deps
```

- `--no-build-isolation` compiles against the `jaxlib` already in your env.
- `--no-deps` keeps your pinned `jax`/`jaxlib` (essential on 0.4.29 — otherwise
  pip would upgrade it to ≥0.5).
- For **GPU**, have `nvcc` on `PATH` (`module load cuda/12.x`); look for
  `CUDA support enabled` in the build log (`CUDA not found` ⇒ `nvcc` not on
  `PATH`). Plain `jaxlib` (no CUDA) yields a CPU-only build.

Full details and the *why* — FFI ABI, pinned-`jaxlib` clusters, nvcc paths, GPU
verification — are in [docs/TECHNICAL_NOTES.md](docs/TECHNICAL_NOTES.md).


### GPU (CUDA 12, Linux x86_64)

You can try this and may get lucky if it happens that JAX and other libraries match. I strongly recommend to build from source (see GPU - Build from source). Pick the package matching your cluster's CUDA version:

```bash
pip install eigh-cuda120   # CUDA 12.0+ (works through 12.8+); the safe default
pip install eigh-cuda128   # CUDA 12.8+ (newer toolchain / glibc 2.34)
```

Both bundle the cuSOLVER kernel + NVIDIA CUDA runtime libs; `import eigh`
auto-detects the GPU. They are separate packages from this same repo — `import
eigh` is identical. See [Compatibility](#compatibility).


### Usage Example
```python
import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh, eigh_gen

jax.config.update("jax_enable_x64", True)
# Eigenvalue problem
A = jnp.array([[2., 1.], [1., 2.]])
B = jnp.array([[1., 1], [0.5, 1.]])
w1, v1 = eigh(A)
w2, v2 = eigh_gen(A, B)

# With gradients
grad1 = jax.grad(lambda A: eigh(A)[0].sum())(A)
grad2 = jax.grad(lambda A: eigh_gen(A, B)[0].sum())(A)
print("Eigenvalues:", w1, w2)
print("Eigenvectors:", v1, v2)
print("Gradients computed:", grad1.shape, grad2.shape)
```

## Benchmarks
Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in `src/jax/`. See [benchmarks/suite/](benchmarks/suite/) for the scripts.

<p align="center">
  <img src="https://raw.githubusercontent.com/Brogis1/eigh/main/benchmarks/suite/figs/scaling_fwd.png" alt="Forward-pass scaling" width="45%">
  <img src="https://raw.githubusercontent.com/Brogis1/eigh/main/benchmarks/suite/figs/scaling_grad.png" alt="Backward-pass (gradient) scaling" width="45%">
</p>



## API Reference
- **`eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)`**
  Scipy-compatible interface. `type` supports 1: `A@v=B@v@λ`, 2: `A@B@v=v@λ`, 3: `B@A@v=v@λ`.
- **`eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)`**
  Lower-level generalized solver.

## Degenerate Eigenvalues & Gradients
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like `sum`, `var`, `trace`) have stable gradients. The `deg_thresh` parameter (default `1e-9`) masks divisions by near-zero gaps to maintain stability.

## JAX Eigensolvers
A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.

**If you just want a working solver, use `stable_eigh_pyscfad` / `stable_eigh_gen_pyscfad`** from [generalized_eigensolver_pyscfad.py](src/jax/generalized_eigensolver_pyscfad.py). They wrap the fast LAPACK/cuSOLVER kernels with a Lorentzian-broadened custom VJP, so gradients stay stable when eigenvalues are (nearly) degenerate.

On Windows, or if you cannot build the C++ kernels, use `stable_generalized_eigh` from [generalized_eigensolver_stable.py](src/jax/generalized_eigensolver_stable.py) instead — same gradient treatment, pure JAX.

The remaining solvers below are kept for benchmarking and for reproducing prior work; they are not recommended as defaults.

#### Recommended
| Solver | File | Strategy |
|---|---|---|
| **`stable_eigh_pyscfad` / `stable_eigh_gen_pyscfad`** | [generalized_eigensolver_pyscfad.py](src/jax/generalized_eigensolver_pyscfad.py) | LAPACK/cuSOLVER kernels + Lorentzian-broadened VJP [2] |
| **`stable_eigh` / `stable_generalized_eigh`** (pure-JAX) | [generalized_eigensolver_stable.py](src/jax/generalized_eigensolver_stable.py) | Pure-JAX Cholesky + Lorentzian-broadened VJP [2] |

#### Alternative stable solvers
| Solver | File | Strategy | Gradient notes |
|---|---|---|---|
| `subspace_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Custom VJP: Lorentzian broadening `F/(F²+ε²)` [2] | Stable |
| `subspace_generalized_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Symmetry-breaking perturbation + `subspace_eigh` [2,4] | Stable |
| `degen_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Custom VJP: mask degenerate `F_ij` by threshold [1,3] | Stable only for symmetric-subspace losses |
| `safe_generalized_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Cholesky + `degen_eigh` | Inherits `degen_eigh` caveat |

#### Baselines (not gradient-safe at degeneracies)
| Solver | File | Strategy |
|---|---|---|
| `standard_eig` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | `scipy.linalg.eigh` — non-differentiable reference |
| `jax_eig` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Plain Cholesky + `jnp.linalg.eigh`, default VJP |
| `generalized_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Symmetrized Cholesky with SPD shift, default VJP |



## Compatibility

- **Python:** 3.10–3.13.  **JAX:** 0.5 → 0.10+ for the prebuilt wheels; **jax 0.4.29
  via source build** (its FFI ABI differs from 0.5+, so it needs its own build —
  see [Build from source](#build-from-source)).
- **CPU wheels:** Linux x86_64 (`manylinux_2_28`, bundled OpenBLAS) and macOS
  arm64.
- **GPU wheels:** Linux x86_64, CUDA 12 — `eigh-cuda120` (CUDA 12.0+, glibc 2.17)
  and `eigh-cuda128` (CUDA 12.8+, glibc 2.34).
- **Windows:** no compiled wheel — use the pure-JAX solvers in [src/jax/](src/jax/),
  or build from source.

Full detail — the FFI binary-ABI rules, why GPU ships as separate packages, the
abi3 wheel matrix, HPC/cluster notes, and how to build on a pinned old `jaxlib` —
is in **[docs/TECHNICAL_NOTES.md](docs/TECHNICAL_NOTES.md)**.

### References
- [1] Kasim, M. F., & Vinko, S. M. *Learning the exchange–correlation functional from nature with fully differentiable density functional theory.* Phys. Rev. Lett. **127**, 126403 (2021). https://doi.org/10.1103/PhysRevLett.127.126403
- [2] Colburn, S., & Majumdar, A. *Inverse design and flexible parameterization of meta-optics using algorithmic differentiation.* Communications Physics **4**, 54 (2021). https://doi.org/10.1038/s42005-021-00568-6
- [3] JAX Issue #2748 — Differentiable `eigh` with degeneracies. https://github.com/jax-ml/jax/issues/2748
- [4] JAX Issue #5461 — Stable generalized `eigh`. https://github.com/jax-ml/jax/issues/5461


## Development & Testing
- **Requirements**: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
- **Tests**:
  ```bash
  pytest tests/test_eigh.py     # Core functionality
  pytest tests/test_eigh_gen.py # Generalized itypes
  pytest tests/test_eigh_jit.py # JIT & vmap
  ```
- **GPU Setup**:
  ```bash
  source setup_gpu_env_clean.sh
  ./run_gpu.sh python example_simple.py
  ```

## License & Citation
Apache License 2.0. If used in research, please cite:
```bibtex

@software{sokolov2026eigh,
  author={Sokolov, Igor},
  title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
  url={https://github.com/Brogis1/eigh},
  year={2026}
}

@software{pyscfad,
  author = {Zhang, Xing},
  title = {PySCFad: Automatic Differentiation for PySCF},
  url = {https://github.com/fishjojo/pyscfad},
  year = {2021-2025}
}

@article{10.1063/5.0118200,
    author = {Zhang, Xing and Chan, Garnet Kin-Lic},
    title = {Differentiable quantum chemistry with PySCF for molecules and materials at the mean-field level and beyond},
    journal = {The Journal of Chemical Physics},
    volume = {157},
    number = {20},
    pages = {204801},
    year = {2022},
    month = {11},
    issn = {0021-9606},
    doi = {10.1063/5.0118200},
    url = {https://doi.org/10.1063/5.0118200},
}

@article{sokolov2026xc,
  title = {Quantum-enhanced neural exchange-correlation functionals},
  author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
  journal = {Phys. Rev. A},
  volume = {113},
  issue = {1},
  pages = {012427},
  numpages = {24},
  year = {2026},
  month = {Jan},
  publisher = {American Physical Society},
  doi = {10.1103/m51l-fys2},
  url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}
```
