Metadata-Version: 2.4
Name: prism-antibody
Version: 0.5.0
Summary: PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling
Author: Anonymous
License: MIT
Project-URL: Homepage, https://github.com/RomeroLab-Duke/prism-antibody
Project-URL: Repository, https://github.com/RomeroLab-Duke/prism-antibody
Project-URL: Documentation, https://github.com/RomeroLab-Duke/prism-antibody#readme
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.13.0
Requires-Dist: pytorch-lightning>=2.0.0
Requires-Dist: transformers>=4.30.0
Requires-Dist: huggingface-hub>=0.14.0
Requires-Dist: numpy>=1.21.0
Requires-Dist: pandas>=1.3.0
Requires-Dist: pyarrow>=10.0.0
Requires-Dist: torchmetrics>=0.11.0
Requires-Dist: PyYAML>=6.0
Requires-Dist: tqdm>=4.60.0
Requires-Dist: tensorboard>=2.11.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: black>=22.0.0; extra == "dev"
Requires-Dist: flake8>=4.0.0; extra == "dev"
Requires-Dist: isort>=5.10.0; extra == "dev"
Requires-Dist: mypy>=0.950; extra == "dev"
Provides-Extra: analysis
Requires-Dist: matplotlib>=3.5.0; extra == "analysis"
Requires-Dist: seaborn>=0.11.0; extra == "analysis"
Requires-Dist: scikit-learn>=1.0.0; extra == "analysis"
Requires-Dist: scipy>=1.7.0; extra == "analysis"
Provides-Extra: all
Requires-Dist: prism-antibody[analysis,dev]; extra == "all"
Dynamic: license-file

# PRISM

**P**artitioning **R**esidue **I**dentity in **S**omatic **M**aturation

This is the official repository for the paper:

> **Explicit representation of germline and non-germline residues improves antibody language modeling**

PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.

---

# Part 1: User Guide

Everything you need to run inference or finetune PRISM on your own antibody data.

## Installation

```bash
pip install prism-antibody
```

Or install from source:

```bash
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .
```

### Verify Installation

```python
import prism
print(prism.__version__)
```

## API Overview

PRISM has **3 core methods** that cover all use cases:

| Method | Cost | Returns |
|--------|------|---------|
| `forward()` | 1 forward pass | logits, embeddings, origin, alpha |
| `pseudo_log_likelihood()` | L forward passes | PLL, perplexity, per-position log-probs (4 modes) |
| `score_mutations()` | 2M forward passes | masked marginal mutation scores (4 modes) |

All methods accept single strings or lists (batch). All support paired (VH+VL) and unpaired input.

## `forward()` --- Logits, Embeddings, Everything

Single forward pass through the model. Returns all outputs as a dict of numpy arrays.

```python
import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")

# Paired heavy + light chain (typical usage)
result = model.forward(
    heavy_chains="EVQLVESGGGLVQPGGSLRL",
    light_chains="DIQMTQSPSSLSASVG",
)
# result["heavy"]["final_logits"]  -> [L_vh, 53]  alpha-gated combined logits
# result["heavy"]["aa_logits"]     -> [L_vh, 33]  AA head logits (pre-gating)
# result["heavy"]["origin_logits"] -> [L_vh]      GL/NGL classification logits
# result["heavy"]["alpha"]         -> [L_vh]      gating values
# result["heavy"]["embedding"]     -> [L_vh, H]   per-residue hidden states
# result["light"] has the same keys for the light chain
```

### Derive Any Signal You Need

```python
import numpy as np

# GL/NGL log-probabilities (slice from 53-vocab)
gl_logits  = result["heavy"]["final_logits"][:, model.GL_INDICES]   # [L_vh, 20]
ngl_logits = result["heavy"]["final_logits"][:, model.NGL_INDICES]  # [L_vh, 20]

# Mean-pooled embedding
vh_emb = result["heavy"]["embedding"].mean(axis=0)  # [H]
vl_emb = result["light"]["embedding"].mean(axis=0)  # [H]

# GL/NGL origin probability per residue
origin_prob = 1 / (1 + np.exp(-result["heavy"]["origin_logits"]))  # sigmoid
# origin_prob > 0.5 = predicted NGL (somatic mutation)
```

### Masked Prediction

Mask specific positions to get context-only predictions:

```python
result = model.forward(
    heavy_chains="EVQLVESGGGLVQPGGSLRL",
    light_chains="DIQMTQSPSSLSASVG",
    mask_positions=[5, 10, 15],  # mask these heavy chain positions
)
```

### Batch Inference

```python
results = model.forward(
    heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# List of 2 dicts, each with "heavy" and "light" sub-dicts
```

### Unpaired Input

For heavy-only, light-only, or any single-chain input:

```python
result = model.forward("EVQLVESGGGLVQPGGSLRL")                     # positional arg
result = model.forward(heavy_chains="EVQLVESGGGLVQPGGSLRL")        # explicit heavy
result = model.forward(light_chains="DIQMTQSPSSLSASVG")            # explicit light
# Returns flat dict (no "heavy"/"light" nesting)
```

## `pseudo_log_likelihood()` --- PLL and Perplexity

Masked marginal scoring: mask each position one at a time, predict, accumulate log P(true token). Returns **all 4 scoring modes** in one pass with progress bars.

```python
# Paired (typical)
result = model.pseudo_log_likelihood(
    heavy_chains="EVQLVESGGGLVQPGGSLRL",
    light_chains="DIQMTQSPSSLSASVG",
)
# {
#   "marginalized": {"pll": -45.3, "perplexity": 2.34, "per_position": [L_vh + L_vl]},
#   "gl":           {"pll": -50.1, "perplexity": 2.71, "per_position": [L_vh + L_vl]},
#   "ngl":          {"pll": -48.2, "perplexity": 2.56, "per_position": [L_vh + L_vl]},
#   "exact":        {"pll": -50.1, "perplexity": 2.71, "per_position": [L_vh + L_vl]},
# }

# Quick access
ppl = result["marginalized"]["perplexity"]
per_pos = result["gl"]["per_position"]  # per-residue GL log-probs

# Batch paired (with progress bar)
results = model.pseudo_log_likelihood(
    heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
```

### Scoring Modes

| Mode | What it scores | Use case |
|------|---------------|----------|
| `marginalized` | `logsumexp(GL, NGL)` | General-purpose |
| `gl` | Uppercase (germline) token log-prob | Germline likeness |
| `ngl` | Lowercase (non-germline) token log-prob | Somatic mutation preference |
| `exact` | Raw token log-prob in 53-vocab | Direct 53-vocab scoring |

### Unpaired

```python
result = model.pseudo_log_likelihood("EVQLVESGGGLVQPGGSLRL")
```

## `score_mutations()` --- Mutation Effect Prediction

Masked marginal scoring at mutation positions. For each mutation, masks that position in both WT and mutant, computes the log-likelihood difference. Returns all 4 modes.

```python
# Paired (typical)
result = model.score_mutations(
    wt="EVQLVESGGGLVQPGGSLRL",
    mutant="EVQLVASGGGLVQPGGSLRL",  # V6A
    wt_light_chains="DIQMTQSPSSLSA",
    mut_light_chains="DIQMTQSPSSLSA",  # same light chain, or different
)
# {
#   "positions": [5],  # 0-indexed mutation positions
#   "marginalized": {"score": 0.42, "per_position": [1]},
#   "gl":           {"score": 0.31, "per_position": [1]},
#   "ngl":          {"score": 0.55, "per_position": [1]},
#   "exact":        {"score": 0.31, "per_position": [1]},
# }
# score > 0 = mutant preferred over WT at that position

# Batch (multiple WT/mutant pairs)
results = model.score_mutations(
    wt=["EVQLVESGG", "DIQMTQSPS"],
    mutant=["EVQLVASGG", "DIQMAQSPS"],
    wt_light_chains=["DIQMT", "EVQLV"],
    mut_light_chains=["DIQMT", "EVQLV"],
)
```

### Unpaired

```python
result = model.score_mutations(wt="ACDEF", mutant="GCKEF")
# result["positions"] == [0, 2]  (A->G, D->K)
```

## Reference

### `forward()` Return Dict

| Key | Shape | Description |
|-----|-------|-------------|
| `final_logits` | `[L, 53]` | Alpha-gated combined logits (53-vocab) |
| `aa_logits` | `[L, 33]` | AA head logits, before gating |
| `origin_logits` | `[L]` | GL/NGL binary classification logits |
| `alpha` | `[L]` | Per-position gating values |
| `embedding` | `[L, H]` | Per-residue hidden states from backbone |

When paired, returns `{"heavy": {dict}, "light": {dict}}` with the above keys in each sub-dict.

### Index Constants

- `model.GL_INDICES` --- 20 uppercase AA token indices in the 53-vocab
- `model.NGL_INDICES` --- 20 lowercase AA token indices in the 53-vocab
- `model.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY"` --- column order for the 20 AA indices

### Input Modes

All 3 methods accept the same input patterns:

```python
# Paired (returns "heavy"/"light" split for forward, combined for PLL/mutations)
model.forward(heavy_chains="VH...", light_chains="VL...")

# Unpaired (positional or keyword)
model.forward("SEQ...")
model.forward(heavy_chains="VH...")
model.forward(light_chains="VL...")

# Batch (list of strings)
model.forward(heavy_chains=["VH1", "VH2"], light_chains=["VL1", "VL2"])
```

## Finetuning on Your Data

```python
import prism

model = prism.pretrained("RomeroLab-Duke/prism-antibody")

best_checkpoint = model.finetune(
    data_path="my_antibodies.parquet",   # parquet, pkl, or csv
    output_dir="outputs/my_finetune",
    max_steps=5000,
    learning_rate=1e-4,
    batch_size=32,
)

# Model is now finetuned --- use immediately
result = model.forward(heavy_chains="EVQLVESGGGLVQ...", light_chains="DIQMT...")
```

### Data Format

Your data file needs at least one of these columns:

| Column | Required | Description |
|--------|----------|-------------|
| `HEAVY_CHAIN_AA_SEQUENCE` | At least one | Heavy chain amino acid sequence |
| `LIGHT_CHAIN_AA_SEQUENCE` | At least one | Light chain amino acid sequence |

Optional columns (auto-detected):

| Column | Description |
|--------|-------------|
| `split` | `"train"` / `"valid"` / `"test"` (auto-generated 90/5/5 if absent) |
| `hc_mut_codes`, `lc_mut_codes` | NGL mutation codes (e.g. `"S30A;T52N"`) |
| `v_gene_heavy`, `j_gene_heavy`, `v_gene_light`, `j_gene_light` | Gene labels |
| `region_mask_heavy`, `region_mask_light` | Region annotations |

### Finetune Parameters

```python
model.finetune(
    data_path="data.parquet",
    output_dir="prism_finetune_output",
    max_steps=5000,             # total training steps
    learning_rate=1e-4,         # peak LR (cosine schedule with warmup)
    batch_size=32,              # per-device batch size
    warmup_steps=500,           # linear warmup steps
    weight_decay=0.01,          # AdamW weight decay
    mask_prob=0.15,             # MLM masking probability
    gradient_accumulation_steps=1,
    devices=1,                  # number of GPUs
    precision="bf16-mixed",     # training precision
    val_check_interval=500,     # validate every N steps
    num_workers=4,
    seed=42,
)
```

---

# Part 2: Developer Guide

For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.

## Development Installation

```bash
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e ".[dev,analysis]"
```

## Project Structure

```
prism/
├── src/prism/                    # Core Python package
│   ├── __init__.py               # Package exports
│   ├── api.py                    # High-level inference & finetune API
│   ├── model.py                  # SFT_ESM2 PyTorch Lightning module
│   ├── io_utils.py               # Dataset & DataModule classes
│   ├── multimodal_io.py          # Gene vocabulary & antibody dataset
│   └── utils.py                  # Utility functions
│
├── configs/                      # Training configuration files
├── script/                       # Training, inference, analysis scripts
├── tests/                        # Test suite
├── pyproject.toml                # Package configuration
└── README.md
```

## Training from Scratch

### Two-Stage Training Protocol

**Stage 1 --- Pretraining** on large unpaired OAS dataset (~60M+ sequences):

```bash
python script/train_esm.py --config configs/v34_pretrain.yaml
```

**Stage 2 --- Finetuning** on paired antibody sequences (~764K):

```bash
python script/train_esm.py --config configs/v34_1b_finetune.yaml
```

### Multi-GPU Training

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml
```

## Testing

```bash
pytest tests/ -v
```

---

## License

MIT License --- see [LICENSE](LICENSE) for details.
