Metadata-Version: 2.4
Name: d3-dna
Version: 0.1.3
Summary: D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences
Author: Anirban Sarkar
Author-email: Alejandra Durán <144379180+aduranu@users.noreply.github.com>
License: MIT
Project-URL: Homepage, https://github.com/anirbansarkar-cs/d3-dna
Project-URL: Repository, https://github.com/anirbansarkar-cs/d3-dna
Project-URL: Issues, https://github.com/anirbansarkar-cs/d3-dna/issues
Keywords: dna,diffusion,generative,deep-learning,bioinformatics,genomics
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: pytorch-lightning>=2.0.0
Requires-Dist: omegaconf>=2.3.0
Requires-Dist: numpy>=1.23.0
Requires-Dist: h5py>=3.7.0
Requires-Dist: tqdm>=4.64.0
Requires-Dist: einops>=0.6.0
Requires-Dist: scipy>=1.10.0
Provides-Extra: flash
Requires-Dist: flash-attn>=2.0.0; extra == "flash"
Provides-Extra: logging
Requires-Dist: wandb>=0.16.0; extra == "logging"
Provides-Extra: all
Requires-Dist: flash-attn>=2.0.0; extra == "all"
Requires-Dist: wandb>=0.16.0; extra == "all"
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: build>=1.0; extra == "dev"
Requires-Dist: twine>=4.0; extra == "dev"
Dynamic: license-file

# d3-dna

[![PyPI](https://img.shields.io/pypi/v/d3-dna.svg)](https://pypi.org/project/d3-dna/) [![Python](https://img.shields.io/pypi/pyversions/d3-dna.svg)](https://pypi.org/project/d3-dna/) [![License](https://img.shields.io/pypi/l/d3-dna.svg)](./LICENSE)

[[examples](./examples/)][[Zenodo data + checkpoints](https://zenodo.org/records/19774653)][[PyPI](https://pypi.org/project/d3-dna/)]

Discrete diffusion of DNA sequences. `d3-dna` is a standalone, pip-installable library for training and sampling SEDD-style discrete diffusion models on nucleotide sequences, packaged together with self-contained reproducibility examples for the K562, HepG2, DeepSTARR, and FANTOM5 promoter benchmarks. Models are config-driven — the same `D3Trainer` / `D3Sampler` / `D3Evaluator` API works for any dataset whose sequences fit in a fixed window, with global per-sample labels, per-position labels, or no conditioning at all (the conditioning mode is determined by the shape of `y`, not by a flag). Architectures are swappable: a 12-block diffusion transformer (DDiT) and a 256-channel dilated convolutional model ship out of the box, and either can be replaced via the `cfg.model.architecture` switch. Dataset-specific logic — oracles, masking, strand averaging, real-data layout — lives in the per-example directories, never in the core library, so adopting `d3-dna` on a new dataset is "fill in a Dataset class and an oracle"-sized.

The fully-populated K562 example under [`examples/k562/`](./examples/k562/) is the best place to start: it reproduces the published transformer and convolutional configurations end-to-end (training, sampling, evaluation) against the pretrained checkpoints on Zenodo.

## Installation

```bash
pip install d3-dna
```

Extras: `[flash]` adds [flash-attention](https://github.com/Dao-AILab/flash-attention) for faster training on long sequences (otherwise the transformer falls back to PyTorch SDPA — identical model quality, slower at long sequence lengths); `[logging]` adds Weights & Biases; `[all]` installs both. Flash-attention compiles from source and imports `torch` during its build, so install it on a machine with a CUDA toolchain and disable build isolation so the existing torch install is visible:

```bash
pip install d3-dna
pip install flash-attn --no-build-isolation
```

Cold install on a fresh Python env takes about 60–90 s on a 1 Gbit link (dominated by the ~3 GB PyTorch + CUDA-runtime wheel download), under 10 s when the wheels are already cached. The `[flash]` source build adds another 5–15 min on a single GPU host.

## Demo

After installation, clone this repo and run a self-contained K562 sampling demo:

```bash
git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1
```

The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in `examples/k562/cache/`). After that, sampling itself takes about 5 s on an NVIDIA H100 NVL. Output (in `examples/k562/generated/`): `sample_0.npz` of shape `(100, 230, 4)` and `sample_0.fasta`.

## Usage

The four public classes are `D3Trainer`, `D3Sampler`, `D3Evaluator`, and `BaseSPMSEValidationCallback`, all re-exported from `d3_dna`. Each operates on standard PyTorch `Dataset` objects and an OmegaConf `cfg`, so adopting `d3-dna` in an existing pipeline is mostly a matter of writing the Dataset, picking a config, and instantiating one of these classes.

#### Define a Dataset

Each item is `(X, y)`, where `X` is a `LongTensor` of token indices and `y` is the conditioning label.

```python
import torch
from torch.utils.data import Dataset
import h5py

class MyDNADataset(Dataset):
    def __init__(self, h5_path, split='train'):
        with h5py.File(h5_path, 'r') as f:
            # one-hot (N, L, 4) -> argmax -> (N, L) token indices
            self.X = torch.from_numpy(f[f'onehot_{split}'][:]).argmax(dim=-1)
            self.y = torch.tensor(f[f'y_{split}'][:], dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
```

The shape of `y` decides the conditioning mode automatically (no flag): `(N, signal_dim)` broadcasts a global label across positions; `(N, sequence_length, signal_dim)` adds a per-position label element-wise. K562 / HepG2 / DeepSTARR use the global form, FANTOM5 promoter uses the per-position form. See [`d3_dna/models/transformer.py:EmbeddingLayer.forward`](./d3_dna/models/transformer.py) for the dispatch.

#### Train

`D3Trainer` wraps a PyTorch Lightning trainer around `D3LightningModule`. Pass the config (a path or an `OmegaConf` object), then `.fit(train_dataset, val_dataset)`.

```python
from d3_dna import D3Trainer

trainer = D3Trainer('config.yaml')
trainer.fit(
    train_dataset=MyDNADataset('data.h5', 'train'),
    val_dataset=MyDNADataset('data.h5', 'valid'),
)
```

Resume from a checkpoint with `resume_from='path/to/last.ckpt'`, attach user callbacks with `callbacks=[...]`. The training config (including dataset metadata) is embedded in the saved checkpoint via `save_hyperparameters`, so a checkpoint is self-describing for inference.

#### Sample

`D3Sampler` loads a checkpoint plus a user-instantiated model, generates one-hot sequences via a PC sampler, and writes NPZ + FASTA to disk.

```python
from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf

cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)

sequences = sampler.generate(
    checkpoint='outputs/last.ckpt',
    model=model,
    num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')
```

For long runs, use `.generate_batched(...)` and tune `cfg.sampling.batch_size` against your GPU memory.

#### Evaluate

`D3Evaluator` is a dataset-agnostic dispatcher over four metrics: paired oracle MSE (`mse`), per-feature KS on oracle predictions (`ks`), k-mer Jensen–Shannon distance (`js`), and discriminator AUROC (`auroc`). The caller supplies pre-loaded samples, real data, and an oracle exposing a `.predict(x)` method.

```python
from d3_dna import D3Evaluator

ev = D3Evaluator(tests=['mse', 'ks', 'js', 'auroc'], device='cuda')
results = ev.evaluate(
    samples=generated_one_hot,        # (N, L, 4) ndarray or torch tensor
    real_data=real_test_one_hot,      # (N, L, 4)
    oracle=my_oracle,                 # must implement .predict
    kmer_ks=[6],
)
```

Dataset-specific oracle loading, masking, and strand averaging live in `examples/<name>/`, never in the core library.

#### Train-time SP-MSE

`BaseSPMSEValidationCallback` is the abstract callback for periodic SP-MSE validation against an oracle during training. Subclass it once per dataset and override `get_default_sampling_steps()` and `get_oracle_predictions(samples)`. Each example directory ships a concrete subclass — e.g. `K562MSECallback` in [`examples/k562/callbacks.py`](./examples/k562/callbacks.py).

## Examples

Each example reproduces a published D3 configuration end-to-end. Data, oracle weights, and pretrained transformer + convolutional checkpoints auto-download from Zenodo on first run.

| Example | Sequences | Conditioning | Zenodo |
|---|---|---|---|
| [`k562`](./examples/k562/) | 230 bp MPRA | global activity `(N, 1)` | [19774653](https://zenodo.org/records/19774653) |
| [`hepg2`](./examples/hepg2/) | 230 bp MPRA | global activity `(N, 1)` | [19774653](https://zenodo.org/records/19774653) |
| [`deepstarr`](./examples/deepstarr/) | 249 bp enhancers | dual-head activity `(N, 2)` | [19774653](https://zenodo.org/records/19774653) |
| [`promoter`](./examples/promoter/) | 1024 bp FANTOM5 | per-position CAGE `(N, 1024, 1)` | [19738941](https://zenodo.org/records/19738941) |
| [`minimal`](./examples/minimal/) | scaffold for a new dataset | — | — |

Each `examples/<name>/` is a self-contained `train.py` / `sample.py` / `evaluate.py` flow plus a `Dataset`, oracle, validation callback, and config YAMLs for both architectures. Per-example READMEs carry the reference numbers and reproduction recipe.

## System requirements

Linux (POSIX); verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS. Python ≥3.9, verified on 3.11 and 3.12. Any CUDA-capable GPU works for training and sampling; benchmarks here use one NVIDIA H100 NVL (driver 580, CUDA 13). The `[flash]` extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (`nvcc`, `CUDA_HOME`) at install time. Python dependencies (auto-installed by `pip`): `torch≥2.0` (verified on 2.12), `pytorch-lightning≥2.0` (2.6), `omegaconf≥2.3`, `numpy≥1.23`, `scipy≥1.10`, `h5py≥3.7`, `tqdm≥4.64`, `einops≥0.6`; `flash-attn≥2.0` and `wandb≥0.16` for the optional extras.

## Sampling performance

Single-GPU ballpark from the K562 example (230 bp, 20-step Euler predictor, transformer, bf16-mixed, no flash-attn) on one H100 NVL: ~6 step/s at batch 512, ~130 seq/s end-to-end, full 39,340-sequence test set in ~5 min. Sampling is compute-bound — the per-step rate is roughly batch-size-invariant until GPU memory becomes the limit, so doubling the batch ~doubles end-to-end throughput. Wall time scales linearly with `num_samples × steps`, and roughly linearly with sequence length (sub-linearly when flash-attn is enabled).

## Citation

If you use `d3-dna` in your work, please cite the accompanying paper ([bioRxiv preprint](https://www.biorxiv.org/content/10.1101/2024.05.23.595630v3)). A BibTeX entry is included at [`CITATION.bib`](./CITATION.bib):

```bibtex
@article{sarkar2024d3dna,
    title = {Designing {DNA} With Tunable Regulatory Activity Using Discrete Diffusion},
    author = {Sarkar, Anirban and Duran, Alejandra and Yu, Yiyang and Lin, Da-Wei and Kang, Yijie and Somia, Nirali and Mantilla, Pablo and Zhou, Jessica and Nagai, Masayuki and Tang, Ziqi and Hanington, Kaarina and Chang, Kenneth and Koo, Peter K.},
    journal = {bioRxiv},
    year = {2024},
    doi = {10.1101/2024.05.23.595630},
    url = {https://www.biorxiv.org/content/10.1101/2024.05.23.595630v3},
    publisher = {Cold Spring Harbor Laboratory},
    note = {Preprint, version 3}
}
```

## Related

`d3-dna` extracts the core training, sampling, and evaluation components from the [D3-DNA-Discrete-Diffusion](https://github.com/anirbansarkar-cs/D3-DNA-Discrete-Diffusion) research codebase, which holds the full ablations, analysis pipelines, and experiment scripts.
