Teacher-Student Training

The master watchmaker does not teach the apprentice by describing the finished watch. She hands over a hundred movements, each regulated to chronometer precision, and says: make yours keep the same time. The apprentice learns the mapping from mechanism to accuracy, not by theory alone, but by studying the master’s calibrated examples.

Balance Wheel’s SFS Predictor is trained by a teacher-student protocol. The teacher is moments (or dadi, or momi2): a classical SFS solver that computes the exact expected SFS \(\mathbf{M}^*(\Theta)\) for any demographic model \(\Theta\). The student is the neural network. The training objective is simple: make the student’s predictions match the teacher’s outputs.

This chapter describes both phases of the protocol: Phase 1 (learning the SFS mapping) and Phase 2 (inference on real data, where the teacher is no longer needed).

Why This Is Not Simulation-Based Inference

Before diving into the training procedure, we must address a common confusion. Balance Wheel’s training might look like simulation-based inference (SBI), but it is fundamentally different.

SBI vs. Balance Wheel training

Property

Simulation-based inference

Balance Wheel training

What is generated

Synthetic genomes (coalescent + mutations)

Expected SFS from ODE/PDE solver

Training pairs

(genome, parameters)

(parameters, expected SFS)

Direction

Learns inverse: genome \(\to\) parameters

Learns forward: parameters \(\to\) SFS

Stochasticity

Each simulation is a stochastic realization

Each teacher output is deterministic

Cost per example

\(O(n \cdot L)\) (simulate full genome)

\(O(n^k)\) per ODE step (compute SFS only)

What is learned

A posterior \(q(\Theta \mid \text{data})\)

A function \(\hat{M}(\Theta) \approx M(\Theta)\)

SBI (Mainspring) simulates entire genomes with msprime and trains a network to invert the simulation – mapping from observed data to posterior over parameters. Balance Wheel never simulates a genome. It calls moments’ ODE solver to compute the expected SFS for sampled parameter values. The “training data” are evaluations of a deterministic function, not stochastic simulations. This is classical function approximation (knowledge distillation), not statistical inference.

The practical consequence

SBI requires millions of simulated genomes (each expensive). Balance Wheel requires millions of moments evaluations (each cheap). For a single population with \(n = 20\), one moments evaluation takes ~10 ms; one msprime simulation of 100 kb takes ~100 ms. Balance Wheel’s training data is 10× cheaper to generate per example, and each example is noise-free (no stochastic variation to average over).

Phase 1: Learning the SFS Mapping

The goal of Phase 1 is to train the neural network to approximate the mapping \(\Theta \to \mathbf{M}(\Theta)\) that the classical solver computes.

Training data generation

We sample random demographic parameters from a broad prior, compute the exact expected SFS with the teacher, and store the \((\Theta, \mathbf{M}^*(\Theta))\) pairs:

import numpy as np
import moments

def sample_demography(rng, max_epochs=6):
    """Sample a random piecewise-constant demography."""
    K = rng.integers(1, max_epochs + 1)
    log_sizes = rng.normal(0, 1.5, size=K)
    raw_times = np.sort(rng.exponential(0.5, size=K))
    raw_times[0] = 0.0
    return log_sizes, raw_times, K

def compute_teacher_sfs(log_sizes, raw_times, n=20, theta=1.0):
    """Compute exact expected SFS using moments."""
    sizes = np.exp(log_sizes)
    fs = moments.LinearSystem_1D.steady_state_1D(n, theta=theta)
    for i in range(len(sizes)):
        dt = raw_times[i + 1] - raw_times[i] if i + 1 < len(raw_times) \
            else 0.1
        if dt > 0:
            fs.integrate([sizes[i]], dt)
    return fs.data[1:-1]

def generate_training_set(n_examples=100_000, n=20, seed=42):
    """Generate training pairs (parameters, exact SFS)."""
    rng = np.random.default_rng(seed)
    dataset = []
    for _ in range(n_examples):
        log_sizes, raw_times, K = sample_demography(rng)
        try:
            sfs = compute_teacher_sfs(log_sizes, raw_times, n=n)
            if np.all(np.isfinite(sfs)) and np.all(sfs > 0):
                dataset.append({
                    'log_sizes': log_sizes,
                    'raw_times': raw_times,
                    'n_epochs': K,
                    'sfs': sfs,
                })
        except Exception:
            continue
    return dataset

The prior over demographic parameters should be broad enough to cover the parameter space of interest:

Training prior

Parameter

Distribution

Rationale

Number of epochs \(K\)

\(\text{Uniform}\{1, 2, \ldots, 6\}\)

Covers simple to moderately complex models

\(\log(N_e / N_{\text{ref}})\)

\(\mathcal{N}(0, 1.5)\)

Spans bottlenecks (\(\sim 0.05 N_{\text{ref}}\)) to expansions (\(\sim 20 N_{\text{ref}}\))

Epoch durations (coalescent units)

\(\text{Exp}(0.5)\)

Most epochs are short; occasional long ancient epochs

Sample size \(n\)

\(\{10, 20, 30, 50, 100\}\)

Train on multiple sample sizes simultaneously

Training procedure

The training loss is the mean squared error between the neural SFS prediction and the teacher’s exact SFS:

\[\mathcal{L}_{\text{MSE}} = \frac{1}{n-1} \sum_{j=1}^{n-1} \left(\hat{M}_j(\Theta) - M_j^*(\Theta)\right)^2\]

In practice, we train on the log-SFS to handle the wide dynamic range (SFS entries span several orders of magnitude, with rare classes much smaller than common ones):

\[\mathcal{L}_{\text{log-MSE}} = \frac{1}{n-1} \sum_{j=1}^{n-1} \left(\ln \hat{M}_j(\Theta) - \ln M_j^*(\Theta)\right)^2\]
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

def train_phase1(model, dataset, n_epochs=100, batch_size=256,
                 lr=3e-4, device='cuda'):
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

    model.to(device)
    model.train()

    for epoch in range(n_epochs):
        total_loss = 0.0
        n_batches = 0

        for batch in make_batches(dataset, batch_size):
            log_sizes = batch['log_sizes'].to(device)
            log_times = batch['log_times'].to(device)
            true_sfs = batch['sfs'].to(device)
            n = true_sfs.shape[-1] + 1

            pred_sfs = model(log_sizes, log_times, n, theta_L=1.0)

            loss = ((torch.log(pred_sfs + 1e-10)
                     - torch.log(true_sfs + 1e-10)) ** 2).mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        scheduler.step()

        if epoch % 10 == 0:
            avg_loss = total_loss / n_batches
            print(f"Epoch {epoch}: log-MSE = {avg_loss:.6f}")

Training data generation strategy

The quality of the neural approximation depends critically on the coverage of the training set. Three strategies improve coverage:

1. Importance sampling of the parameter space. Regions of parameter space where the SFS changes rapidly (e.g., near sharp bottlenecks) need more training examples. After an initial uniform training run, identify high-error regions and oversample them.

2. Boundary cases. The SFS has known analytical forms at the boundaries of parameter space: the equilibrium SFS (\(M_j \propto \theta / j\) for constant \(N_e\)), the post-bottleneck SFS (excess singletons), and the post-expansion SFS (deficit of singletons). Include these boundary cases explicitly in the training set and verify that the network reproduces them.

3. Multi-scale training. Train on a mix of simple models (\(K = 1, 2\)) and complex models (\(K = 4, 5, 6\)). Simple models provide clean signal for the basic SFS shape; complex models teach the network to handle interactions between epochs.

How many training examples?

Empirically, 100,000 examples suffice for 1-population models with \(K \leq 6\) epochs and \(n \leq 100\). The log-MSE typically converges below \(10^{-4}\) after 50–100 epochs of training, meaning the neural SFS matches the teacher to within 1% relative error per frequency bin. For multi-population models, more examples are needed (see Handling Multiple Populations).

Phase 2: Inference on Real Data

Once Phase 1 is complete, the neural SFS Predictor is frozen. Phase 2 uses it to perform demographic inference on observed data – the same task that dadi and moments perform, but 100–1000× faster per likelihood evaluation.

The observed SFS is computed from a VCF file or genotype matrix:

def compute_observed_sfs(genotypes, n):
    """Compute the folded or unfolded SFS from a genotype matrix.

    genotypes: (n, L) binary matrix — rows are haplotypes, columns are sites
    Returns: (n-1,) integer tensor — SFS counts
    """
    freq_counts = genotypes.sum(dim=0)
    sfs = torch.zeros(n + 1, dtype=torch.long)
    for j in range(n + 1):
        sfs[j] = (freq_counts == j).sum()
    return sfs[1:n]

Given the observed SFS, we optimize the demographic parameters \(\Theta\) to maximize the Poisson log-likelihood through the neural SFS Predictor:

def fit_demography(model, observed_sfs, theta_L, n_steps=2000,
                   lr=0.01, n_epochs_model=4):
    """Find MLE demographic parameters using Balance Wheel."""
    n = observed_sfs.shape[0] + 1
    log_sizes = nn.Parameter(torch.zeros(n_epochs_model))
    log_times = nn.Parameter(
        torch.linspace(-2, 2, n_epochs_model))
    optimizer = torch.optim.Adam([log_sizes, log_times], lr=lr)

    model.eval()
    for step in range(n_steps):
        sorted_times = torch.sort(log_times)[0]
        expected_sfs = model(
            log_sizes.unsqueeze(0), sorted_times.unsqueeze(0),
            n, theta_L)
        neg_ll = -poisson_log_likelihood(
            observed_sfs.float(), expected_sfs.squeeze(0))

        optimizer.zero_grad()
        neg_ll.backward()
        optimizer.step()

        if step % 500 == 0:
            print(f"Step {step}: -log L = {neg_ll.item():.2f}")

    return torch.exp(log_sizes).detach(), torch.exp(sorted_times).detach()

This is the same optimization that dadi and moments perform, but each likelihood evaluation takes ~0.1 ms instead of ~10 ms (moments) or ~100 ms (dadi). For gradient-based optimization with 2,000 steps, the total inference time is:

Inference time comparison

Method

Per evaluation

2,000 steps

Gradient method

dadi

~100 ms

~200 s (finite diff, 10 params)

\(2 \times 10 \times 100\) ms = 2 s/step

moments

~10 ms

~20 s (AD)

~20 ms/step

Balance Wheel

~0.1 ms

~0.2 s

~0.2 ms/step

Validation

How do we know the student has learned well? Three validation strategies.

1. Held-out test set. Generate 10,000 parameter vectors not seen during training, compute both the teacher SFS and the neural SFS, and compare:

def validate(model, test_set, device='cuda'):
    model.eval()
    errors = []
    with torch.no_grad():
        for example in test_set:
            log_sizes = example['log_sizes'].unsqueeze(0).to(device)
            log_times = example['log_times'].unsqueeze(0).to(device)
            true_sfs = example['sfs'].to(device)
            n = true_sfs.shape[0] + 1
            pred_sfs = model(log_sizes, log_times, n, theta_L=1.0)
            rel_error = (
                (pred_sfs.squeeze() - true_sfs).abs()
                / true_sfs.clamp(min=1e-8)
            ).mean()
            errors.append(rel_error.item())
    return np.mean(errors), np.percentile(errors, 95)

Target: mean relative error < 1%, 95th percentile < 5%.

2. Known analytical cases. The equilibrium SFS for constant \(N_e\) has the closed form:

\[M_j = \frac{\theta L}{j}, \qquad j = 1, \ldots, n-1\]

Verify that the neural network reproduces this for constant-size demography. Similarly, check the SFS after a recent exponential expansion or a severe bottleneck, where approximate analytical forms exist.

3. Likelihood surface comparison. For a fixed observed SFS, compute the log-likelihood surface \(\ell(\Theta)\) using both moments and Balance Wheel on a grid of \(\Theta\) values. The surfaces should be nearly identical:

\[\left|\ell_{\text{neural}}(\Theta) - \ell_{\text{moments}}(\Theta)\right| \ll 1 \quad \text{for all } \Theta \text{ in the grid}\]

A log-likelihood difference of < 0.5 is typically acceptable (smaller than typical random noise in the SFS due to sampling). If the difference exceeds 1.0 anywhere in the region of interest, the network needs more training examples in that region.

When the student surpasses the teacher

In rare cases, the neural SFS prediction may actually be smoother than the teacher’s output. moments can have numerical artifacts for extreme parameter values (very large \(N_e\) changes, very short epochs), while the neural network smoothly interpolates. This is not a bug – it is the regularizing effect of the MLP’s smooth activation functions. However, it means the neural SFS may not exactly match moments in these edge cases. Always validate against the teacher in the regime of interest.

Practical Considerations

Training time and compute

Component

Cost

Wall-clock time

Teacher data generation (100k examples)

100k × 10 ms (moments)

~17 minutes

Phase 1 training (100 epochs)

100 × 100k / 256 × 1 ms

~1 hour (single GPU)

Validation (10k test examples)

10k × 0.1 ms

~1 second

The total Phase 1 cost is dominated by training data generation. This is a one-time cost – once the network is trained, it can be used for any observed SFS with compatible sample size and parameter range.

Retraining triggers

The trained network should be retrained if:

  • The parameter space of interest shifts (e.g., moving from human to Drosophila demography with different \(N_e\) ranges).

  • The sample size changes beyond the training range.

  • The number of populations changes (requires retraining the multi-population encoder).

  • The teacher is updated (e.g., moments releases a numerically improved solver).

Balance Wheel as a compiler

Think of Phase 1 as compilation: the slow, careful computation of moments is compiled into the fast, approximate computation of the neural network. Just as a compiled program runs faster than the interpreted source, the neural SFS evaluation runs faster than the ODE integration. And just as you must recompile when the source changes, you must retrain when the parameter space or teacher changes. The compilation cost is paid once; the speedup is enjoyed on every subsequent query.