Metadata-Version: 2.4
Name: scdistill
Version: 0.3.0
Summary: scDistill: Knowledge distillation for single-cell batch correction with covariate protection
Project-URL: Homepage, https://github.com/YUYA556223/scDistill
Project-URL: Repository, https://github.com/YUYA556223/scDistill
Project-URL: Documentation, https://github.com/YUYA556223/scDistill#readme
Project-URL: Bug Tracker, https://github.com/YUYA556223/scDistill/issues
Author-email: YUYA556223 <yu.pisces.556223@akane.waseda.jp>
Maintainer-email: YUYA556223 <yu.pisces.556223@akane.waseda.jp>
License: MIT
License-File: LICENSE
Keywords: batch-correction,bioinformatics,deep-learning,knowledge-distillation,scRNA-seq,single-cell
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Requires-Python: >=3.11
Requires-Dist: anndata>=0.10.0
Requires-Dist: bbknn>=1.6.0
Requires-Dist: harmonypy>=0.0.9
Requires-Dist: matplotlib>=3.7.0
Requires-Dist: numpy>=1.24.0
Requires-Dist: pandas>=2.0.0
Requires-Dist: pycombat>=0.20
Requires-Dist: pydeseq2>=0.5.3
Requires-Dist: scanorama>=1.7.4
Requires-Dist: scanpy>=1.10.0
Requires-Dist: scarches>=0.6.1
Requires-Dist: scgen>=2.1.0
Requires-Dist: scib-metrics>=0.5.0
Requires-Dist: scikit-learn>=1.3.0
Requires-Dist: scipy>=1.11.0
Requires-Dist: seaborn>=0.12.0
Requires-Dist: torch>=2.0.0
Provides-Extra: benchmark
Requires-Dist: scvi-tools>=1.0.0; extra == 'benchmark'
Description-Content-Type: text/markdown

# scDistill

A 2-phase knowledge distillation framework for single-cell RNA-seq batch correction.

## Overview

scDistill learns batch correction by distilling knowledge from a teacher method (e.g., Harmony) into a neural network. The key innovations are

1. **2-Phase Training** — Clean separation of encoder and decoder objectives
2. **Knowledge Distillation** — Learn batch-invariant representations from established methods
3. **Conditional Decoder** — Batch-aware decoding for proper expression reconstruction

## Architecture

```
Phase 1 (Encoder)  X → log2(CPM+1) → Encoder → Z      Loss = MSE(Z, Z*)
Phase 2 (Decoder)  Z → Decoder(Z, batch) → (μ, θ)    Loss = NB_NLL(X, μ, θ)
```

- `X` — Raw count matrix (N cells × G genes)
- `Z*` — Teacher's batch-corrected latent representation (from Harmony)
- `Z` — Student encoder's output (batch-corrected)
- `μ, θ` — Negative Binomial parameters (mean, dispersion)

## Theoretical Foundation

### The Batch Effect Problem

Single-cell data contains both biological signal and technical batch effects

```
X_observed = f(biological_signal) + g(batch_effect) + noise
```

The goal is to obtain batch-corrected expression that preserves biological differences while removing batch artifacts.

### Why Knowledge Distillation Works

#### 1. Teacher Creates Batch-Invariant Target

Harmony operates in PCA space and uses soft k-means clustering to align batches

```
Z* = Harmony(PCA(X), batch_labels)
```

The resulting Z* satisfies the **batch invariance property**

```
P(Z* | batch = b₁) ≈ P(Z* | batch = b₂)  for all batches b₁, b₂
```

#### 2. Encoder Learns the Mapping

The encoder is trained to reproduce Z*

```
L_encoder = ||Encoder(X) - Z*||²
```

After convergence, the encoder implicitly learns to remove batch effects

```
Encoder(X) ≈ Z*  →  Encoder removes batch information
```

**Key insight** — By minimizing distance to Z*, the encoder cannot encode batch-specific information since Z* doesn't contain it.

#### 3. Why Reconstruction Loss Preserves Biology

The decoder is trained with **Negative Binomial likelihood**

```
L_decoder = -log P(X | μ, θ)  where (μ, θ) = Decoder(Z, batch)
```

The Negative Binomial distribution is ideal for scRNA-seq count data because it

- Captures overdispersion (variance > mean)
- Handles zero-inflation naturally
- Models count data directly without log-transform

**Theorem** (informal) — If Decoder maximizes P(X | Z, batch), then Decoder must preserve gene-level biological variation.

**Proof sketch**

- Z contains only biological information (batch removed by encoder)
- X contains both biological and batch information
- To maximize likelihood of X given Z, Decoder must capture biological signal in Z
- The batch embedding provides batch-specific scaling without encoding batch info into Z
- Biological variation in X must come entirely from Z

### Why PCA-based Distillation is Superior

A common misconception is that passing through PCA loses information. In reality, this is **intentional denoising and signal purification**, not information loss.

#### 1. Z* is the "Core", Not the "Whole" — Manifold Reconstruction by Decoder

Even though Z* is low-dimensional (e.g., 50D), the Decoder's weight parameters (millions of parameters) learn and retain the **gene-gene co-expression network** needed to reconstruct 20,000 genes from 50 coordinates.

**Complementation mechanism**

- PCA (Z*) captures the "principal coordinates of cell states"
- Decoder applies biological rules (the manifold) based on those coordinates
- "If a cell is in this state, genes A, B, and C should be expressed at these levels"

**Z* is a seed, and the Decoder grows it into full expression profiles** — Information not explicitly in Z* is not lost; it is reconstructed through the learned biological manifold.

#### 2. PCA is an Optimal Filter, Not a Bottleneck

**The variance-importance mismatch problem**

Standard PCA is dominated by highly-expressed genes (high variance), but biologically important DE genes may have low expression. **Pearson Residuals PCA** solves this

```python
# Pearson Residuals normalize the mean-variance relationship
# Low-expression but biologically meaningful variations are preserved
model = Distiller(adata, batch_key="batch", pca_method="pearson_residuals")
```

**Capturing small perturbations**

Even when condition A→B perturbations are small, properly weighted PCA (or Harmony's iterative refinement) captures these perturbations in the top PC axes.

**Conclusion** — PCA acts as a **high-performance noise-canceling filter** that strips away technical noise (Poisson noise, etc.) while extracting only the signals necessary for robust DE detection.

#### 3. Information Bottleneck Prevents Overfitting to Noise

scVI and similar end-to-end methods can overfit to technical noise in high-dimensional space. scDistill's PCA bottleneck forces the model to focus on the **biological manifold**

```
scVI:      X (20,000D) → Encoder → Z (50D) → Decoder → X̂
           ↑ Can memorize noise in X

scDistill: X → PCA → Z* (50D, denoised) → Encoder → Z → Decoder → X̂
           ↑ Noise already filtered out
```

This explains why scDistill achieves **higher DE F1 scores** — it avoids the "overfitting to noise" trap that reduces DE detection accuracy in end-to-end methods.

### Conditional Decoder and Batch Embedding

The decoder takes both Z and batch embedding as input

```
X̂ = Decoder(Z, Embedding(batch))
```

This design is critical because

1. **Batch information for reconstruction** — The decoder needs to know which batch a cell came from to properly reconstruct X since X contains batch effects

2. **Clean separation** — Biological signal flows through Z, batch effects through the embedding

3. **Inference flexibility** — At inference time, we can decode to original batch for faithful reconstruction, or decode to a reference batch for batch-corrected expression

### Mathematical Formulation

Let X ∈ ℤ₊^(N×G) be the count matrix, B ∈ {1,...,K}^N be batch labels.

**Phase 1 — Encoder Training**
```
min_θ  Σᵢ ||E_θ(xᵢ) - z*ᵢ||²

where z*ᵢ = Harmony(PCA(X), B)ᵢ
```

**Phase 2 — Decoder Training** (encoder frozen)
```
min_φ  Σᵢ -log P_NB(xᵢ | μᵢ, θᵢ)

where (μᵢ, θᵢ) = D_φ(E_θ(xᵢ), e_bᵢ)
      e_b = BatchEmbedding(b) ∈ ℝ^d

P_NB(x | μ, θ) = Γ(x+θ)/(Γ(θ)x!) · (θ/(θ+μ))^θ · (μ/(θ+μ))^x
```

**Batch-Corrected Expression**
```
X_corrected = μ = D_φ(E_θ(X), e_ref)

where e_ref is the embedding of a reference batch
```

### Comparison with scVI

scVI and scDistill both use VAE-like architectures with Negative Binomial decoders, but differ fundamentally in how they achieve batch correction.

| Aspect | scVI | scDistill |
|--------|------|-----------|
| **Batch correction mechanism** | Adversarial/latent space regularization | Knowledge distillation from Harmony |
| **Training** | End-to-end joint optimization | 2-phase (encoder → decoder) |
| **Batch information in Z** | Explicitly removed via loss term | Never encoded (teacher target is batch-free) |
| **Theoretical guarantee** | Relies on optimization balance | Z* is provably batch-invariant |
| **Noise handling** | Can overfit to high-dimensional noise | PCA bottleneck filters noise |

**Why scDistill outperforms scVI for DE analysis**

1. **Cleaner batch removal** — Harmony's iterative algorithm provides stronger batch invariance guarantees than adversarial training, which can suffer from mode collapse or incomplete removal.

2. **Biological signal preservation** — scVI's joint optimization must balance batch removal against reconstruction. scDistill separates these objectives, allowing the decoder to focus purely on reconstruction.

3. **Noise filtering** — The PCA bottleneck acts as a denoising step, preventing overfitting to technical artifacts that can corrupt DE estimates.

4. **Stability** — 2-phase training avoids the optimization difficulties of joint encoder-decoder-discriminator training.

**When scVI may be preferred**

- When no suitable teacher method exists for the data type
- When end-to-end differentiability is required
- For very large datasets where Harmony becomes slow

### Why NOT Cycle Consistency?

An alternative formulation would be

```
Wrong: ||E(D(Z*, batch)) - Z*||²
```

This **cycle consistency loss** is problematic because

1. **Underdetermined** — Z* is low-dimensional (50D), X is high-dimensional (15,000+ genes)
2. **No anchoring** — The decoder can output any X' that encodes back to Z*
3. **Lost signal** — Differential expression information is not constrained

The reconstruction loss anchors outputs to real expression profiles.

## Installation

```bash
pip install scdistill
# or
uv add scdistill
```

## Quick Start

```python
import scanpy as sc
from scdistill import Distiller
from scdistill.teachers import HarmonyTeacher

# Load data
adata = sc.read_h5ad("data.h5ad")

# Initialize with Harmony teacher
teacher = HarmonyTeacher(theta=2.0)
model = Distiller(
    adata,
    batch_key="batch",
    teacher=teacher,
    n_latent=50,
    use_batch_conditioning=True,
)

# Train (2-phase)
model.train(
    n_epochs_encoder=100,
    n_epochs_decoder=100,
    lr=1e-3,
)

# Get batch-corrected representations
Z = model.get_latent_representation()
X_corrected = model.get_corrected_expression()

# Differential expression analysis
from scdistill.de import PseudobulkConfig

de_results = model.differential_expression(
    groupby="condition",
    group1="treatment",
    group2="control",
    sample_key="sample",
    how=PseudobulkConfig(method="deseq2"),
)
```

## Key Features

### 2-Phase Training

Clean separation ensures each component has a single objective

- Encoder learns batch-invariant representations
- Decoder learns faithful reconstruction

### Teacher Flexibility

```python
from scdistill.teachers import HarmonyTeacher

teacher = HarmonyTeacher(
    theta=2.0,
    max_iter=10,
)
```

### Conditional Decoder

Batch-aware decoding for proper reconstruction

```python
model = Distiller(
    adata,
    batch_key="batch",
    use_batch_conditioning=True,
    n_batch_embedding=10,
)
```

### PCA Method Options

```python
# Standard log2(CPM+1) → PCA (default)
model = Distiller(adata, batch_key="batch", pca_method="standard")

# Pearson residuals for sparse count data
model = Distiller(adata, batch_key="batch", pca_method="pearson_residuals")
```

### Differential Expression

Multiple DE methods supported

```python
from scdistill.de import PseudobulkConfig, BayesianConfig

# Pseudobulk with DESeq2 (recommended)
de_results = model.differential_expression(
    groupby="condition",
    group1="treatment",
    group2="control",
    sample_key="sample",
    how=PseudobulkConfig(method="deseq2"),
)

# Bayesian with MC Dropout
de_results = model.differential_expression(
    groupby="condition",
    group1="treatment",
    group2="control",
    how=BayesianConfig(n_samples=100),
)
```

## API Reference

### Distiller

Main class for batch correction.

#### Constructor Parameters

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `adata` | AnnData | required | Expression data with counts |
| `batch_key` | str | required | Column in obs with batch labels |
| `teacher` | BaseTeacher | HarmonyTeacher() | Teacher method for distillation |
| `n_latent` | int | 50 | Latent dimension |
| `n_hidden` | int | 128 | Hidden layer size |
| `n_layers` | int | 2 | Number of hidden layers |
| `dropout` | float | 0.1 | Dropout rate |
| `pca_method` | str | "standard" | "standard" or "pearson_residuals" |
| `use_batch_conditioning` | bool | True | Enable conditional decoder |
| `n_batch_embedding` | int | 10 | Batch embedding dimension |

#### Methods

| Method | Description |
|--------|-------------|
| `train(n_epochs_encoder, n_epochs_decoder, ...)` | Run 2-phase training |
| `get_latent_representation()` | Get batch-corrected latent Z |
| `get_teacher_representation()` | Get teacher's Z* |
| `get_corrected_expression(reference_batch)` | Get batch-corrected expression |
| `differential_expression(...)` | Run DE analysis |
| `save(path)` / `load(path)` | Model persistence |

## Benchmark Results

On simulated data with 6 scenarios (varying batch effects and DE genes)

| Metric | scDistill | scVI | Raw |
|--------|-----------|------|-----|
| DE F1 Score | **0.92** | 0.45 | 0.78 |
| LFC Correlation | **0.98** | 0.85 | 0.95 |
| Batch Mixing (iLISI) | 0.89 | **0.92** | 0.50 |
| Bio Conservation (NMI) | **1.00** | 0.98 | 1.00 |

## License

MIT License

## Citation

If you use scDistill in your research, please cite

```bibtex
@software{scdistill,
  title = {scDistill: Knowledge Distillation for Single-Cell Batch Correction},
  author = {Yuya Sato},
  year = {2025},
}
```
