Metadata-Version: 2.4
Name: psax
Version: 0.1.1
Summary: JAX implementation of Protein Strain Analysis (per-site weighted finite strain, PSA)
Author-email: Marielle Russo <67157875+maraxen@users.noreply.github.com>
License: MIT License
        
        Copyright (c) 2026 Marielle Russo
        
        Portions derived from prxteinmpnn (https://github.com/maraxen/prxteinmpnn),
        used for the legacy global-PSA regression snapshot under ``src/old_psa/``.
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
        
Project-URL: Homepage, https://github.com/maraxen/psax
Project-URL: Repository, https://github.com/maraxen/psax
Project-URL: Issues, https://github.com/maraxen/psax/issues
Keywords: jax,protein,strain,psa,elasticity,mechanics,structure
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: Topic :: Scientific/Engineering :: Physics
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: flatbuffers>=24.3.25
Requires-Dist: jax>=0.4.35
Requires-Dist: numpy>=1.25
Requires-Dist: typer>=0.12
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: pytest-cov>=5.0; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx>=7.0; extra == "docs"
Requires-Dist: myst-parser>=3.0; extra == "docs"
Requires-Dist: sphinx-rtd-theme>=3.0; extra == "docs"
Dynamic: license-file

# psax

JAX implementation of **Protein Strain Analysis** (PSA): per-site weighted finite strain and deformation gradients, consistent with the reference implementation [**Sartori-Lab/PSA**](https://github.com/Sartori-Lab/PSA) and the original paper:

> Pablo Sartori and Stanislas Leibler, *Evolutionary conservation of mechanical strain distributions in functional transitions of protein structures*, **Phys. Rev. X** 14, 011042 (2024). **DOI:** [10.1103/PhysRevX.14.011042](https://doi.org/10.1103/PhysRevX.14.011042) (APS: [link.aps.org](https://link.aps.org/doi/10.1103/PhysRevX.14.011042)).

BibTeX and preprint links: [`references/sartori_leibler_2024_prx.md`](references/sartori_leibler_2024_prx.md).

## Install

```bash
pip install psax
# editable dev
pip install -e ".[dev]"
# optional: parity tests against upstream PSA (install VCS pin separately — not a PyPI extra)
pip install "psa @ git+https://github.com/Sartori-Lab/PSA.git@a55f44eea3c8165d618cfd607f1c3cebe7535cbb"
```

Requires Python 3.11+.

### Structure I/O (`proxide`)

**proxide** is **not** on PyPI. To run `psax run` or `run_pairwise_psa_from_structures`, install proxide from a local checkout in the same environment:

```bash
pip install /path/to/proxide
```

If this repo lives in a **uv workspace** next to a local `proxide` member, add a workspace source mapping (see [uv workspace sources](https://docs.astral.sh/uv/concepts/projects/workspaces/)):

```toml
[tool.uv.sources]
proxide = { workspace = true }
```

## PyPI and releases

| Workflow | Purpose |
|----------|---------|
| [`.github/workflows/ci.yml`](.github/workflows/ci.yml) | Tests, docs, upstream parity, wheel smoke on `main` |
| [`.github/workflows/publish.yml`](.github/workflows/publish.yml) | Build with `python -m build` and upload to **PyPI** on **GitHub Release** (OIDC trusted publishing) |

Configure **trusted publishing** on PyPI for this repository and workflow, and add a GitHub **Environment** named `pypi` if you use protection rules. Tag releases (e.g. `v0.1.0`) and publish a **GitHub Release** to trigger the workflow (or run it manually via **workflow_dispatch**).

## StableHLO / `jax.export` (AOT)

For deployment and compiler toolchains, psax exposes a thin export layer in [`psax.stablehlo`](src/psax/stablehlo.py): fixed static shapes `(N, 3)` and `(E,)` directed edges, with `method` / `ridge_eps` / `rcond` fixed at export time. Serialized artifacts use JAX’s export format (VHLO / calling-convention version) and are **tied to the JAX version** you built with—see the [JAX export documentation](https://jax.readthedocs.io/en/latest/jax.export.html) and the [OpenXLA StableHLO + JAX tutorial](https://openxla.org/stablehlo/tutorials/jax-export).

CLI:

```bash
psax emit-stablehlo --n 128 --edges 4096 --out psa_strain.bin
```

(`run_pairwise_psa` is **not** the export entrypoint: it builds weights outside of `jit`.)

## Batching ref / design pairs (`vmap`)

- **Aligned batch** — `B` independent pairs `(x_ref[b], x_def[b])` with the **same** graph (shared `edges_i`, `edges_j`, `edge_weights`): use `psax.batch.deformation_gradient_per_site_vmap_pairs` or `run_pairwise_psa_batched_shared_graph` when the dense weight mask `(N, N)` is shared.
- **Ref × design grid** — all combinations `(b_ref, b_design)` with **fixed** shared edges: `psax.batch.deformation_gradient_per_site_grid_ref_design`.
- **Per-sample edges** — same batch size and edge count `E`: `deformation_gradient_per_site_batched_edges` / `edge_batch_mode="per_pair"`.
- Different `E` per sample needs padding/masking or a Python loop with `psax.utils.safe_map` (see `docs` / `bucket` placeholder).

## Quickstart (two coordinate sets → per-site F → strain)

```python
import jax.numpy as jnp
from psax.run.pipeline import run_pairwise_psa

x_ref = jnp.array(...)  # (N, 3)
x_def = jnp.array(...)  # (N, 3)
out = run_pairwise_psa(x_ref, x_def, r_inner=6.0, r_outer=8.0)
F = out.deformation_gradient       # (N, 3, 3)
E = out.green_lagrange_strain      # (N, 3, 3)
lam = out.principal_strain_eigenvalues  # (N, 3)
```

With structures on disk (requires **proxide**):

```python
from psax.run.pipeline import run_pairwise_psa_from_structures

out = run_pairwise_psa_from_structures("ref.pdb", "def.pdb", align_kabsch=False)
```

## CLI

```bash
psax --version   # or -V
psax version
psax run --help
psax emit-stablehlo --help
psax run ref.pdb def.pdb -o out.npz --bfactors-out colored.pdb --template-pdb ref.pdb
psax export template.pdb out.pdb --values "1.0,2.0,3.0"
```

The CLI dispatches to `psax.run.pipeline`, `psax.stablehlo`, and `psax.io.export`.

## Parity & limitations

| What is validated | How |
|-------------------|-----|
| JAX vs in-repo NumPy PSA loop | `tests/test_core_parity.py` (`parity_numpy`) |
| JAX vs **installed** `psa.elastic.deformation_gradient` (dense weights, fp64) | `tests/parity_upstream/` (`parity_upstream`), install `psa` from git (see Install) |
| Symmetrized `D` + solvers vs NumPy reference | Covered indirectly when `ridge_eps=0` matches upstream dense path |

**Not** fully cross-checked here: upstream sparse/Numba dict fast paths, energy/rotation pipelines in PSA, or every PSA strain helper name-for-name—only the dense deformation-gradient path and Green–Lagrange strain built from `F`.

See `tests/parity_upstream/README.md` for dense vs sparse scope.

## Documentation (Sphinx)

```bash
pip install -e ".[docs]"
sphinx-build -b html docs docs/_build/html
```

## Testing & coverage

```bash
pytest
pytest --cov=psax --cov-report=term-missing
```

Markers: `unit`, `parity_numpy`, `parity_upstream`, `structure_io`, `slow` (see `pyproject.toml`).

## API sketch

- Build a **directed edge list** from a dense weight matrix **outside** `jit` (`psax.graph.edges_from_dense_weight_matrix`).
- **`deformation_gradient_per_site`**: \(\mathbf{F}_i\) from weighted bond vectors using `jax.ops.segment_sum`, with **`method=`** ``solve`` / ``lstsq`` / ``svd`` and optional **`return_diagnostics`** (condition numbers).
- **Strain / kinematics** (`psax.core`): Green–Lagrange and small-strain Lagrange, **Euler** strains, **Cauchy–Green invariants**, **principal stretches**, **rotation axis/angle** (polar decomposition).
- **Energy** (`psax.core.energy`): Saint Venant–Kirchhoff density and unit helpers.
- **Alignment** (`psax.alignment`): vendored **soft Smith–Waterman / Needleman–Wunsch** (`sequence`), **Kabsch** rigid superposition (`structure`).
- **Spatial** (`psax.spatial`): cylindrical/spherical coordinates, tensor rotation into cylindrical frame, rough **effective volumes**.
- **I/O** (`psax.io`): **`load_structure`** via **proxide** (optional), **B-factor** export without BioPython.
- **AOT** (`psax.stablehlo`): **`jax.export`** of the PSA strain pipeline at fixed shapes.
- **Synthetic tests** (`psax.testing.forms`): rods, twist/spin/radial deformations.

Legacy **global** least-squares \(\mathbf{F}\) (from the prxteinmpnn snapshot) lives under `old_psa` for regression tests only; it is **not** the published per-site PSA in [Phys. Rev. X **14**, 011042 (2024)](https://doi.org/10.1103/PhysRevX.14.011042).

## JAX notes

- **`safe_map`** (`psax.utils.safe_map`): when batching many independent structures, pass a **Python `int`** `batch_size` into `safe_map` (not a traced value). Inside `jit`, prefer fixed-shape kernels and **bucketing** (see `psax.bucket`) instead of compiling over Python loops.
- **Float precision:** enable `jax_enable_x64` if you need to match double-precision NumPy/SciPy baselines (tests set this in `conftest.py`).
- **Eigenvectors:** `principal_strain_eigensystem` fixes eigenvector signs deterministically (no random hemisphere flips).
- **RNG:** core PSA is deterministic; if you add stochastic workflows, derive subkeys with `jax.random.fold_in(key, index)`.

## Prior art

The `old_psa` package is adapted from **prxteinmpnn**’s PSA helpers (global \(\mathbf{F}\) and dense linear weights). The **per-site** method implemented in `psax.core` follows [**Sartori-Lab/PSA**](https://github.com/Sartori-Lab/PSA) and [Sartori & Leibler, *Phys. Rev. X* **14**, 011042 (2024)](https://doi.org/10.1103/PhysRevX.14.011042).

## Citation

If you use this software in published work, please cite the **PSA paper** (and this codebase if appropriate):

```bibtex
@article{PhysRevX.14.011042,
  title = {Evolutionary Conservation of Mechanical Strain Distributions in Functional Transitions of Protein Structures},
  author = {Sartori, Pablo and Leibler, Stanislas},
  journal = {Phys. Rev. X},
  volume = {14},
  issue = {1},
  pages = {011042},
  year = {2024},
  publisher = {American Physical Society},
  doi = {10.1103/PhysRevX.14.011042},
  url = {https://link.aps.org/doi/10.1103/PhysRevX.14.011042}
}
```

Reference implementation: [https://github.com/Sartori-Lab/PSA](https://github.com/Sartori-Lab/PSA).  
More detail: [`references/sartori_leibler_2024_prx.md`](references/sartori_leibler_2024_prx.md).
