Metadata-Version: 2.4
Name: prism-antibody
Version: 0.3.0
Summary: PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling
Author: Anonymous
License: MIT
Project-URL: Homepage, https://github.com/RomeroLab-Duke/prism-antibody
Project-URL: Repository, https://github.com/RomeroLab-Duke/prism-antibody
Project-URL: Documentation, https://github.com/RomeroLab-Duke/prism-antibody#readme
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.13.0
Requires-Dist: pytorch-lightning>=2.0.0
Requires-Dist: transformers>=4.30.0
Requires-Dist: huggingface-hub>=0.14.0
Requires-Dist: numpy>=1.21.0
Requires-Dist: pandas>=1.3.0
Requires-Dist: pyarrow>=10.0.0
Requires-Dist: torchmetrics>=0.11.0
Requires-Dist: PyYAML>=6.0
Requires-Dist: tqdm>=4.60.0
Requires-Dist: tensorboard>=2.11.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: black>=22.0.0; extra == "dev"
Requires-Dist: flake8>=4.0.0; extra == "dev"
Requires-Dist: isort>=5.10.0; extra == "dev"
Requires-Dist: mypy>=0.950; extra == "dev"
Provides-Extra: analysis
Requires-Dist: matplotlib>=3.5.0; extra == "analysis"
Requires-Dist: seaborn>=0.11.0; extra == "analysis"
Requires-Dist: scikit-learn>=1.0.0; extra == "analysis"
Requires-Dist: scipy>=1.7.0; extra == "analysis"
Provides-Extra: all
Requires-Dist: prism-antibody[analysis,dev]; extra == "all"
Dynamic: license-file

# PRISM

**P**artitioning **R**esidue **I**dentity in **S**omatic **M**aturation

This is the official repository for the paper:

> **Explicit representation of germline and non-germline residues improves antibody language modeling**

PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.

---

# Part 1: User Guide

Everything you need to run inference or finetune PRISM on your own antibody data.

## Installation

```bash
pip install prism-antibody
```

Or install from source:

```bash
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .
```

### Verify Installation

```python
import prism
print(prism.__version__)  # 0.3.0
```

## Quick Start: Inference

```python
import prism

# Load from Hugging Face Hub (auto-downloads and caches)
model = prism.pretrained("RomeroLab-Duke/prism-antibody")

# Or load from a local checkpoint
# model = prism.pretrained("path/to/checkpoint.ckpt")

# Extract germline log-probabilities — [L, 20] numpy array
gl = model.extract_GL_logit("EVQLVESGGGLVQPGGSLRL")

# Extract non-germline log-probabilities — [L, 20] numpy array
ngl = model.extract_NGL_logit("EVQLVESGGGLVQPGGSLRL")

# Marginalized log-probs: logsumexp(GL, NGL) — [L, 20]
marg = model.extract_marginalized_logit("EVQLVESGGGLVQPGGSLRL")

# Full 53-vocab log-probs — [L, 53]
full = model.extract_full_logit("EVQLVESGGGLVQPGGSLRL")

# Alpha gating values (GL/NGL mixture weights) — [L]
alpha = model.extract_alpha("EVQLVESGGGLVQPGGSLRL")

# Mean-pooled embedding — [H]
emb = model.extract_embedding("EVQLVESGGGLVQPGGSLRL")

# Perplexity — scalar
ppl = model.perplexity("EVQLVESGGGLVQPGGSLRL")
```

All methods accept a single string or a list of strings:

```python
# Batch inference
sequences = ["EVQLVESGGGLVQ", "DIQMTQSPSSLSA", "QVQLVQSGAEVKK"]
embeddings = model.extract_embedding(sequences)  # [3, H] numpy array
ppls = model.perplexity(sequences)                # [3] numpy array
```

## Available Methods

| Method | Returns | Description |
|--------|---------|-------------|
| `extract_GL_logit(seq)` | `[L, 20]` | Germline (uppercase) log-probabilities |
| `extract_NGL_logit(seq)` | `[L, 20]` | Non-germline (lowercase) log-probabilities |
| `extract_marginalized_logit(seq)` | `[L, 20]` | logsumexp(GL, NGL) per amino acid |
| `extract_full_logit(seq)` | `[L, 53]` | Full vocabulary log-probabilities |
| `extract_alpha(seq)` | `[L]` | Per-position alpha gating values |
| `extract_embedding(seq)` | `[H]` | Mean-pooled backbone embedding |
| `embed(seq, mode=...)` | varies | Embeddings (`"mean"`, `"per_residue"`, `"cls"`) |
| `perplexity(seq)` | scalar | Pseudo-perplexity |
| `pseudo_log_likelihood(seq)` | scalar | Pseudo-log-likelihood score |
| `predict_origin(seq)` | dict | GL/NGL origin predictions per residue |
| `score_mutations(wt, mut)` | scalar | Log-likelihood ratio at mutated positions |
| `logits(seq, head=...)` | tensor | Raw logits from any head |

Column order for 20-AA outputs follows `PrismModel.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY"`.

## Finetuning on Your Data

Finetune PRISM on your own antibody sequences with a single method call:

```python
import prism

model = prism.pretrained("RomeroLab-Duke/prism-antibody")

best_checkpoint = model.finetune(
    data_path="my_antibodies.parquet",   # parquet, pickle, or csv
    output_dir="outputs/my_finetune",
    max_steps=5000,
    learning_rate=1e-4,
    batch_size=32,
)

# Model is now finetuned — use immediately
gl = model.extract_GL_logit("EVQLVESGGGLVQ...")
```

### Data Format

Your data file needs at least one of these columns:

| Column | Required | Description |
|--------|----------|-------------|
| `HEAVY_CHAIN_AA_SEQUENCE` | At least one | Heavy chain amino acid sequence |
| `LIGHT_CHAIN_AA_SEQUENCE` | At least one | Light chain amino acid sequence |

Optional columns (auto-detected):

| Column | Description |
|--------|-------------|
| `split` | `"train"` / `"valid"` / `"test"` (auto-generated 90/5/5 if absent) |
| `hc_mut_codes`, `lc_mut_codes` | NGL mutation codes (e.g. `"S30A;T52N"`) |
| `v_gene_heavy`, `j_gene_heavy`, `v_gene_light`, `j_gene_light` | Gene labels |
| `region_mask_heavy`, `region_mask_light` | Region annotations |

Accepted file formats: `.parquet`, `.pkl`/`.pickle`, `.csv`.

### Finetune Parameters

```python
model.finetune(
    data_path="data.parquet",
    output_dir="prism_finetune_output",
    # Training
    max_steps=5000,             # total training steps
    learning_rate=1e-4,         # peak LR (cosine schedule with warmup)
    batch_size=32,              # per-device batch size
    warmup_steps=500,           # linear warmup steps
    weight_decay=0.01,          # AdamW weight decay
    mask_prob=0.15,             # MLM masking probability
    gradient_accumulation_steps=1,
    # Trainer
    devices=1,                  # number of GPUs
    precision="bf16-mixed",     # training precision
    val_check_interval=500,     # validate every N steps
    # Data
    num_workers=4,
    seed=42,
)
```

---

# Part 2: Developer Guide

For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.

## Development Installation

```bash
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody

# Install with dev + analysis extras
pip install -e ".[dev,analysis]"
```

## Project Structure

```
prism/
├── src/prism/                    # Core Python package
│   ├── __init__.py               # Package exports
│   ├── api.py                    # High-level inference & finetune API
│   ├── model.py                  # SFT_ESM2 PyTorch Lightning module
│   ├── io_utils.py               # Dataset & DataModule classes
│   ├── multimodal_io.py          # Gene vocabulary & antibody dataset
│   └── utils.py                  # Utility functions
│
├── configs/                      # Training configuration files
│   ├── v34_pretrain.yaml         # Stage 1: Pretraining config
│   └── v34_1b_finetune.yaml      # Stage 2: Finetuning config
│
├── script/                       # Executable scripts
│   ├── train_esm.py              # Main training script
│   ├── inference_esm.py          # Basic inference
│   ├── inference_esm_with_logprobs.py
│   ├── upload_to_hub.py          # Upload checkpoint to HF Hub
│   │
│   ├── data/                     # Data processing pipeline (1-8)
│   │   ├── 1.processing_and_filtering_*.py
│   │   ├── 2.visualize_unpaired_statistics.py
│   │   ├── 3.filter_by_p90_and_save.py
│   │   ├── 4-6.cluster_*.py
│   │   ├── 7.extract_gene_information.py
│   │   └── 8.extract_data_for_probe.py
│   │
│   └── analyze/                  # Analysis & evaluation scripts
│       ├── 1.pppl_calculation/   # Perplexity calculations
│       ├── 2.gl-ngl_calculation/ # GL/NGL embeddings & linear probes
│       ├── 3.zero-shot/          # Zero-shot prediction benchmarks
│       └── 4.thera-sabdab/       # Controllable generation experiments
│
├── tests/                        # Test suite
│   └── test_api.py               # API tests (59 tests)
│
├── pyproject.toml                # Package configuration
├── requirements.txt              # Explicit dependencies
└── README.md
```

## Core Modules

### `model.py` — SFT_ESM2

The central PyTorch Lightning module. Key components:

- **Base**: ESM2 transformer (HuggingFace) with optional SwiGLU activation
- **Multi-head architecture**: AA head + Origin head + Alpha gating + Final head
- **Gene conditioning**: V/J gene embeddings + region embeddings
- **Loss**: Focal loss with region-balanced and CDR-boosted variants

### `io_utils.py` — Data Loading

- `SeqSeqDataset`: Handles paired/unpaired antibody sequences with germline reconstruction
- `SFTDataModule`: Standard PyTorch Lightning DataModule
- `LazyShardedDataModule`: Memory-efficient sharded loading for large datasets
- `make_collate_fn_multihead`: Collate with 80/10/10 masking, gene encoding, region IDs

### `multimodal_io.py` — Gene & Region Handling

- `GeneVocabulary`: Maps V/J gene strings to integer IDs
- `AntibodyDataset` / `AntibodyDataCollator` / `AntibodyMLMCollator`: Multi-modal data handling

## Training from Scratch

### Two-Stage Training Protocol

**Stage 1 — Pretraining** on large unpaired OAS dataset (~60M+ sequences):

```bash
python script/train_esm.py --config configs/v34_pretrain.yaml
```

**Stage 2 — Finetuning** on paired antibody sequences (~764K):

```bash
# Set pretrained_checkpoint_path in the config to point to Stage 1 best checkpoint
python script/train_esm.py --config configs/v34_1b_finetune.yaml
```

### Multi-GPU Training

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml
```

### Key Config Options

```yaml
data:
  data_path: "path/to/data"
  batch_size: 256
  mask_prob: 0.15

model:
  model_identifier: "esm2_t12_35M_UR50D"
  use_multihead_architecture: true
  use_alpha_gating: true
  ngl_loss_alpha: 3.0

training:
  max_steps: 100000
  peak_learning_rate: 4e-4
  warmup_steps: 2000
  gradient_accumulation_steps: 8

trainer:
  devices: 4
  precision: "bf16-mixed"
```

## Analysis Pipeline

Located in `script/analyze/`, numbered by experiment stage:

| Stage | Directory | Description |
|-------|-----------|-------------|
| 1 | `pppl_calculation/` | Pseudo-perplexity comparison across models |
| 2 | `gl-ngl_calculation/` | Embedding extraction, linear probes, UMAP |
| 3 | `zero-shot/` | Binding affinity and developability benchmarks |
| 4 | `thera-sabdab/` | Controllable generation and mutation recovery |

## Data Processing Pipeline

Located in `script/data/`, processes OAS (Observed Antibody Space) data:

| Step | Purpose |
|------|---------|
| 1 | Filter OAS data, identify NGL positions |
| 2 | Visualize data distribution |
| 3 | Quality filtering by sequence coverage |
| 4-6 | MMseqs2 sequence clustering & deduplication |
| 7 | Parse V/J gene annotations |
| 8 | Prepare data for linear probe training |

## Testing

```bash
# Run all tests
pytest tests/ -v

# Run specific test class
pytest tests/test_api.py::TestFinetune -v
```

## Linting

```bash
black --line-length 100 src/
isort --profile black --line-length 100 src/
flake8 src/
mypy src/
```

---

## License

MIT License — see [LICENSE](LICENSE) for details.
