Metadata-Version: 2.4
Name: d3-dna
Version: 0.1.1
Summary: D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences
Author: Anirban Sarkar
Author-email: Alejandra Durán <144379180+aduranu@users.noreply.github.com>
License: MIT
Project-URL: Homepage, https://github.com/anirbansarkar-cs/d3-dna
Project-URL: Repository, https://github.com/anirbansarkar-cs/d3-dna
Project-URL: Issues, https://github.com/anirbansarkar-cs/d3-dna/issues
Keywords: dna,diffusion,generative,deep-learning,bioinformatics,genomics
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: pytorch-lightning>=2.0.0
Requires-Dist: omegaconf>=2.3.0
Requires-Dist: numpy>=1.23.0
Requires-Dist: h5py>=3.7.0
Requires-Dist: tqdm>=4.64.0
Requires-Dist: einops>=0.6.0
Requires-Dist: scipy>=1.10.0
Provides-Extra: flash
Requires-Dist: flash-attn>=2.0.0; extra == "flash"
Provides-Extra: logging
Requires-Dist: wandb>=0.16.0; extra == "logging"
Provides-Extra: all
Requires-Dist: flash-attn>=2.0.0; extra == "all"
Requires-Dist: wandb>=0.16.0; extra == "all"
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: build>=1.0; extra == "dev"
Requires-Dist: twine>=4.0; extra == "dev"
Dynamic: license-file

# D3-DNA: DNA Discrete Diffusion

`d3-dna` is a standalone, pip-installable library for training and sampling discrete diffusion models on DNA sequences — the core `D3Trainer` / `D3Sampler` / `D3Evaluator` API plus the transformer and convolutional architectures, with all dataset-specific logic factored out.

This repository ships two things:

1. The `d3_dna/` package (published on PyPI as [`d3-dna`](https://pypi.org/project/d3-dna/)) — model architectures, diffusion math, training loop, sampler, and dataset-agnostic evaluation metrics.
2. A small set of minimal reproducibility examples under `examples/` — one self-contained directory per benchmark dataset (K562, HepG2, DeepSTARR, FANTOM5 promoter, plus a `minimal/` scaffold for new datasets), each providing a `Dataset`, oracle, validation callback, and `train.py` / `sample.py` / `evaluate.py` scripts that reproduce the published configuration end-to-end against the data and pretrained checkpoints on Zenodo.

The full research codebase, ablations, and analysis pipelines live separately at [D3-DNA-Discrete-Diffusion](https://github.com/anirbansarkar-cs/D3-DNA-Discrete-Diffusion).

## System requirements

**Operating system.** Linux (POSIX). Verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS.

**Python.** ≥3.9. Verified on 3.11 and 3.12.

**Python dependencies** (auto-installed by `pip`):

| Package | Required | Verified |
|---|---|---|
| `torch` | ≥2.0 | 2.12.0 (CUDA 13.0 build) |
| `pytorch-lightning` | ≥2.0 | 2.6.1 |
| `omegaconf` | ≥2.3 | 2.3 |
| `numpy` | ≥1.23 | 2.4 |
| `scipy` | ≥1.10 | 1.17 |
| `h5py` | ≥3.7 | 3.16 |
| `tqdm` | ≥4.64 | 4.67 |
| `einops` | ≥0.6 | 0.8 |
| `flash-attn` (extra `[flash]`) | ≥2.0 | 2.5.8–2.8.x |
| `wandb` (extra `[logging]`) | ≥0.16 | 0.27 |

**Hardware.** Any CUDA-capable GPU; benchmarks in this README use an NVIDIA H100 NVL (driver 580, CUDA 13.0). The transformer falls back to PyTorch SDPA when `flash-attn` is not installed, so non-Ampere GPUs are supported but slower at long sequence lengths. The `[flash]` extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (`nvcc`, `CUDA_HOME`) at install time. CPU-only operation works for imports and small-scale sampling but is impractical for training.

## Installation

```bash
# Core package
pip install d3-dna

# With flash attention (faster training on long sequences)
pip install d3-dna[flash]

# With Weights & Biases logging
pip install d3-dna[logging]

# Everything
pip install d3-dna[all]
```

**GPU acceleration**: `d3-dna[flash]` installs [flash attention](https://github.com/Dao-AILab/flash-attention) for faster, more memory-efficient training on long sequences. Without it, the package uses PyTorch's built-in scaled dot-product attention (SDPA) — same model quality, just slower for long inputs.

Flash-attention compiles from source and imports `torch` during its build, so install it on a machine with a CUDA toolchain (`CUDA_HOME` set, `nvcc` available) and disable build isolation so the existing torch install is visible:

```bash
pip install d3-dna
pip install flash-attn --no-build-isolation
```

**Typical install time.** Cold install into a fresh Python env on a 1 Gbit link is dominated by the ~3 GB PyTorch + CUDA-runtime wheel download — about 60–90 s end-to-end on a clean conda env, under 10 s if the wheels are already cached locally. The `[flash]` extra additionally compiles flash-attention from source, which takes 5–15 min the first time on a single GPU host.

## Demo

After installing the package, clone this repo to get the example scripts and run a self-contained K562 sampling demo:

```bash
git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1
```

The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in `examples/k562/cache/`). After that, the sampling itself takes about 5 s on an NVIDIA H100 NVL.

Expected output (in `examples/k562/generated/`):

- `sample_0.npz` — one-hot tensor of shape `(100, 230, 4)`
- `sample_0.fasta` — the same 100 sequences in FASTA format

`python sample.py --help` lists every override (config path, predictor, batch size, output dir, etc.).

## Quickstart

### 1. Define your dataset

```python
import torch
from torch.utils.data import Dataset

class MyDNADataset(Dataset):
    def __init__(self, h5_path, split='train'):
        import h5py
        with h5py.File(h5_path, 'r') as f:
            self.X = torch.tensor(f[f'X_{split}'][:]).argmax(dim=1)  # one-hot to indices
            self.y = torch.tensor(f[f'Y_{split}'][:])

    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.y[i]
```

### 2. Write a config

```yaml
dataset:
  name: my_dataset
  sequence_length: 249
  num_classes: 4
  signal_dim: 2

ngpus: 1
tokens: 4

model:
  architecture: transformer
  hidden_size: 256
  cond_dim: 128
  n_blocks: 8
  n_heads: 8
  dropout: 0.1
  class_dropout_prob: 0.1

training:
  batch_size: 128
  accum: 1
  max_epochs: 300
  ema: 0.9999

# ... see examples/minimal/config.yaml for full template
```

### 3. Train

```python
from d3_dna import D3Trainer

trainer = D3Trainer('config.yaml')
trainer.fit(
    train_dataset=MyDNADataset('data.h5', 'train'),
    val_dataset=MyDNADataset('data.h5', 'valid'),
)
```

### 4. Sample

```python
from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf

cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)

sequences = sampler.generate(
    checkpoint='experiments/checkpoints/last.ckpt',
    model=model,
    num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')
```

## API Reference

| Class | Purpose |
|---|---|
| `D3Trainer` | Train a D3 model: `trainer.fit(train_ds, val_ds)` |
| `D3Sampler` | Generate sequences: `sampler.generate(ckpt, model, n)` |
| `D3Evaluator` | Evaluate with oracle: subclass and implement `load_oracle_model()` |
| `TransformerModel` | D3-Tran architecture (config-driven) |
| `ConvolutionalModel` | D3-Conv architecture (config-driven) |

## Sampling performance

Rough single-GPU ballpark from the K562 example (230 bp sequences, 20-step Euler predictor, transformer backbone, bf16-mixed, no flash-attn) on one H100 NVL:

| Sequences | Batch | Steps | Wall time | Throughput |
|---|---|---|---|---|
| 39,340 (full K562 test set, 1 replicate) | 512 | 20 | ~5 min | ~130 seq/s |

Sampling is compute-bound — the per-step rate (~6 step/s at batch 512) stays roughly constant as you change batch size, so doubling the batch ~doubles end-to-end throughput until you hit GPU memory. Wall time scales linearly with `num_samples × steps`, and roughly linearly with sequence length without flash-attn (sub-linearly with it).

## Mixed precision

Both training and sampling default to a per-architecture autocast policy, picked by `d3_dna.modules.precision.precision_for_cfg`:

| Architecture | Lightning precision | Autocast dtype | GradScaler |
|---|---|---|---|
| `transformer` | `bf16-mixed` | `torch.bfloat16` | not used (bf16 has fp32 range) |
| `convolutional` | `16-mixed` | `torch.float16` | installed automatically by Lightning |

The same dtype flows through `get_score_fn` so the loss path and the sampling path share a single autocast policy. `LayerNorm` opts out of autocast and runs in fp32; `score.exp()` in the predictor path is on PyTorch's fp32-cast list, so the score and all post-model sampler arithmetic land in fp32 regardless of architecture.

To override (e.g. when a checkpoint was trained under a different policy), set `cfg.training.precision: '16-mixed'` or `'bf16-mixed'` on the config. **Promoter is a known exception** — `examples/promoter/config_transformer.yaml` overrides the default back to `16-mixed` because the public `D3_Tran_Promoter.ckpt` on Zenodo was trained / validated under fp16 and degrades sharply if sampled in bf16. New transformer training runs in other examples should keep the bf16-mixed default.

## Architecture

```
d3_dna/
├── models/          # TransformerModel, ConvolutionalModel, EMA
├── diffusion.py     # Noise schedules, transition graphs, losses
├── sampling.py      # PC sampler, predictors, D3Sampler
├── trainer.py       # Lightning module, D3Trainer
├── evaluator.py     # SP-MSE callback, D3Evaluator
└── io.py            # Checkpoint loading, data utilities
```
