Metadata-Version: 2.2
Name: eigh
Version: 0.2.2
Summary: Differentiable eigenvalue decomposition with JAX (CPU/GPU)
Keywords: jax,eigenvalue,linear-algebra,gpu,cuda,autodiff
Author: Igor Sokolov
Maintainer: Igor Sokolov
License: Apache-2.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Mathematics
Project-URL: Homepage, https://github.com/Brogis1/eigh
Project-URL: Repository, https://github.com/Brogis1/eigh
Project-URL: Original Project, https://github.com/fishjojo/pyscfad
Requires-Python: >=3.10
Requires-Dist: jax>=0.4.0
Requires-Dist: jaxlib>=0.4.0
Requires-Dist: numpy>=1.20.0
Provides-Extra: cuda
Requires-Dist: jax[cuda12]>=0.4.0; extra == "cuda"
Provides-Extra: cuda-local
Requires-Dist: jax[cuda12-local]>=0.4.0; extra == "cuda-local"
Provides-Extra: test
Requires-Dist: pytest>=7.0.0; extra == "test"
Requires-Dist: scipy>=1.7.0; extra == "test"
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: scipy>=1.7.0; extra == "dev"
Description-Content-Type: text/markdown

# Differentiable Generalized Eigenvalue Decomposition

[![Tests](https://github.com/Brogis1/eigh/actions/workflows/tests.yml/badge.svg)](https://github.com/Brogis1/eigh/actions/workflows/tests.yml)

<img src="https://raw.githubusercontent.com/Brogis1/eigh/main/img/eig.png?v=2" alt="Eigh Logo" width="400">

Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from [pyscfad](https://github.com/fishjojo/pyscfad).

Wheels on PyPI: https://pypi.org/project/eigh/ — Linux (manylinux_2_28, x86_64) and macOS (x86_64, arm64), Python 3.10–3.12. GPU path (cuSOLVER) is tested locally; CI runs CPU tests only.

**Windows:** no prebuilt wheel. The pure-JAX solvers in [src/jax/](src/jax/) (e.g. `safe_generalized_eigh`, `subspace_generalized_eigh`, `stable_generalized_eigh`) work out-of-the-box — `pip install jax numpy scipy` and import directly from that module. The fast LAPACK/cuSOLVER-backed `eigh` / `eigh_gen` kernels (and therefore `stable_eigh_pyscfad` / `stable_eigh_gen_pyscfad`) require building from source against a local BLAS/LAPACK.

## Features
- **Generalized Problems**: `A @ V = B @ V @ diag(W)`, etc.
- **JAX Integrated**: Full support for `jit`, `vmap`, `grad`, and `jvp`.
- **High Performance**: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
- **Precision**: `float32/64` and `complex64/128`.
- **Degeneracy Handling**: Configurable `deg_thresh` for stable gradients.

## Installation & Quick Start

```bash
# Install from source
pip install .

# For GPU support in this environment
pip install .[cuda-local]
```

### Usage Example
```python
import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh

jax.config.update("jax_enable_x64", True)
A = jnp.array([[2., 1.], [1., 2.]])
w, v = eigh(A) # Standard
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable
```

## Benchmarks
Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in `src/jax/`. See [benchmarks/suite/](benchmarks/suite/) for the scripts.

<p align="center">
  <img src="https://raw.githubusercontent.com/Brogis1/eigh/main/benchmarks/suite/figs/scaling_fwd.png" alt="Forward-pass scaling" width="45%">
  <img src="https://raw.githubusercontent.com/Brogis1/eigh/main/benchmarks/suite/figs/scaling_grad.png" alt="Backward-pass (gradient) scaling" width="45%">
</p>



## API Reference
- **`eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)`**
  Scipy-compatible interface. `type` supports 1: `A@v=B@v@λ`, 2: `A@B@v=v@λ`, 3: `B@A@v=v@λ`.
- **`eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)`**
  Lower-level generalized solver.

## Degenerate Eigenvalues & Gradients
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like `sum`, `var`, `trace`) have stable gradients. The `deg_thresh` parameter (default `1e-9`) masks divisions by near-zero gaps to maintain stability.

## JAX Eigensolvers (`src/jax/`)
A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.

**If you just want a working solver, use `stable_eigh_pyscfad` / `stable_eigh_gen_pyscfad`** from [generalized_eigensolver_pyscfad.py](src/jax/generalized_eigensolver_pyscfad.py). They wrap the fast LAPACK/cuSOLVER kernels with a Lorentzian-broadened custom VJP, so gradients stay stable when eigenvalues are (nearly) degenerate.

On Windows, or if you cannot build the C++ kernels, use `stable_generalized_eigh` from [generalized_eigensolver_stable.py](src/jax/generalized_eigensolver_stable.py) instead — same gradient treatment, pure JAX.

The remaining solvers below are kept for benchmarking and for reproducing prior work; they are not recommended as defaults.

#### Recommended
| Solver | File | Strategy |
|---|---|---|
| **`stable_eigh_pyscfad` / `stable_eigh_gen_pyscfad`** | [generalized_eigensolver_pyscfad.py](src/jax/generalized_eigensolver_pyscfad.py) | LAPACK/cuSOLVER kernels + Lorentzian-broadened VJP [2] |
| **`stable_eigh` / `stable_generalized_eigh`** (pure-JAX) | [generalized_eigensolver_stable.py](src/jax/generalized_eigensolver_stable.py) | Pure-JAX Cholesky + Lorentzian-broadened VJP [2] |

#### Alternative stable solvers
| Solver | File | Strategy | Gradient notes |
|---|---|---|---|
| `subspace_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Custom VJP: Lorentzian broadening `F/(F²+ε²)` [2] | Stable |
| `subspace_generalized_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Symmetry-breaking perturbation + `subspace_eigh` [2,4] | Stable |
| `degen_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Custom VJP: mask degenerate `F_ij` by threshold [1,3] | Stable only for symmetric-subspace losses |
| `safe_generalized_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Cholesky + `degen_eigh` | Inherits `degen_eigh` caveat |

#### Baselines (not gradient-safe at degeneracies)
| Solver | File | Strategy |
|---|---|---|
| `standard_eig` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | `scipy.linalg.eigh` — non-differentiable reference |
| `jax_eig` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Plain Cholesky + `jnp.linalg.eigh`, default VJP |
| `generalized_eigh` | [generalized_eigensolver.py](src/jax/generalized_eigensolver.py) | Symmetrized Cholesky with SPD shift, default VJP |

### References
- [1] Kasim, M. F., & Vinko, S. M. *Learning the exchange–correlation functional from nature with fully differentiable density functional theory.* Phys. Rev. Lett. **127**, 126403 (2021). https://doi.org/10.1103/PhysRevLett.127.126403
- [2] Colburn, S., & Majumdar, A. *Inverse design and flexible parameterization of meta-optics using algorithmic differentiation.* Communications Physics **4**, 54 (2021). https://doi.org/10.1038/s42005-021-00568-6
- [3] JAX Issue #2748 — Differentiable `eigh` with degeneracies. https://github.com/jax-ml/jax/issues/2748
- [4] JAX Issue #5461 — Stable generalized `eigh`. https://github.com/jax-ml/jax/issues/5461


## Development & Testing
- **Requirements**: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
- **Tests**:
  ```bash
  pytest tests/test_eigh.py     # Core functionality
  pytest tests/test_eigh_gen.py # Generalized itypes
  pytest tests/test_eigh_jit.py # JIT & vmap
  ```
- **GPU Setup**:
  ```bash
  source setup_gpu_env_clean.sh
  ./run_gpu.sh python example_simple.py
  ```

## License & Citation
Apache License 2.0. If used in research, please cite:
```bibtex
@software{pyscfad,
  author = {Zhang, Xing},
  title = {PySCFad: Automatic Differentiation for PySCF},
  url = {https://github.com/fishjojo/pyscfad},
  year = {2021-2025}
}

@software{sokolov2026eigh,
  title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
  author={Sokolov, Igor},
  url={https://github.com/Brogis1/eigh},
  year={2026}
}

@article{sokolov2026xc,
  title = {Quantum-enhanced neural exchange-correlation functionals},
  author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
  journal = {Phys. Rev. A},
  volume = {113},
  issue = {1},
  pages = {012427},
  numpages = {24},
  year = {2026},
  month = {Jan},
  publisher = {American Physical Society},
  doi = {10.1103/m51l-fys2},
  url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}
```
