.. _balance_wheel_training:

==========================
Teacher-Student Training
==========================

   *The master watchmaker does not teach the apprentice by describing the
   finished watch. She hands over a hundred movements, each regulated to
   chronometer precision, and says: make yours keep the same time. The
   apprentice learns the mapping from mechanism to accuracy, not by theory
   alone, but by studying the master's calibrated examples.*

Balance Wheel's SFS Predictor is trained by a **teacher-student protocol**. The
teacher is :ref:`moments <moments_timepiece>` (or :ref:`dadi <dadi_timepiece>`,
or :ref:`momi2 <momi2_timepiece>`): a classical SFS solver that computes the
exact expected SFS :math:`\mathbf{M}^*(\Theta)` for any demographic model
:math:`\Theta`. The student is the neural network. The training objective is
simple: make the student's predictions match the teacher's outputs.

This chapter describes both phases of the protocol: Phase 1 (learning the SFS
mapping) and Phase 2 (inference on real data, where the teacher is no longer
needed).


Why This Is Not Simulation-Based Inference
============================================

Before diving into the training procedure, we must address a common confusion.
Balance Wheel's training might look like simulation-based inference (SBI), but
it is fundamentally different.

.. list-table:: SBI vs. Balance Wheel training
   :header-rows: 1
   :widths: 22 39 39

   * - Property
     - Simulation-based inference
     - Balance Wheel training
   * - What is generated
     - Synthetic genomes (coalescent + mutations)
     - Expected SFS from ODE/PDE solver
   * - Training pairs
     - (genome, parameters)
     - (parameters, expected SFS)
   * - Direction
     - Learns inverse: genome :math:`\to` parameters
     - Learns forward: parameters :math:`\to` SFS
   * - Stochasticity
     - Each simulation is a stochastic realization
     - Each teacher output is deterministic
   * - Cost per example
     - :math:`O(n \cdot L)` (simulate full genome)
     - :math:`O(n^k)` per ODE step (compute SFS only)
   * - What is learned
     - A posterior :math:`q(\Theta \mid \text{data})`
     - A function :math:`\hat{M}(\Theta) \approx M(\Theta)`

SBI (:ref:`Mainspring <mainspring_complication>`) simulates entire genomes with
msprime and trains a network to invert the simulation -- mapping from observed
data to posterior over parameters. Balance Wheel never simulates a genome. It
calls moments' ODE solver to compute the *expected* SFS for sampled parameter
values. The "training data" are evaluations of a deterministic function, not
stochastic simulations. This is classical function approximation (knowledge
distillation), not statistical inference.

.. admonition:: The practical consequence

   SBI requires millions of simulated genomes (each expensive). Balance Wheel
   requires millions of moments evaluations (each cheap). For a single
   population with :math:`n = 20`, one moments evaluation takes ~10 ms; one
   msprime simulation of 100 kb takes ~100 ms. Balance Wheel's training data
   is 10× cheaper to generate per example, and each example is noise-free
   (no stochastic variation to average over).


Phase 1: Learning the SFS Mapping
====================================

The goal of Phase 1 is to train the neural network to approximate the mapping
:math:`\Theta \to \mathbf{M}(\Theta)` that the classical solver computes.

Training data generation
--------------------------

We sample random demographic parameters from a broad prior, compute the exact
expected SFS with the teacher, and store the :math:`(\Theta, \mathbf{M}^*(\Theta))`
pairs:

.. code-block:: python

   import numpy as np
   import moments

   def sample_demography(rng, max_epochs=6):
       """Sample a random piecewise-constant demography."""
       K = rng.integers(1, max_epochs + 1)
       log_sizes = rng.normal(0, 1.5, size=K)
       raw_times = np.sort(rng.exponential(0.5, size=K))
       raw_times[0] = 0.0
       return log_sizes, raw_times, K

   def compute_teacher_sfs(log_sizes, raw_times, n=20, theta=1.0):
       """Compute exact expected SFS using moments."""
       sizes = np.exp(log_sizes)
       fs = moments.LinearSystem_1D.steady_state_1D(n, theta=theta)
       for i in range(len(sizes)):
           dt = raw_times[i + 1] - raw_times[i] if i + 1 < len(raw_times) \
               else 0.1
           if dt > 0:
               fs.integrate([sizes[i]], dt)
       return fs.data[1:-1]

   def generate_training_set(n_examples=100_000, n=20, seed=42):
       """Generate training pairs (parameters, exact SFS)."""
       rng = np.random.default_rng(seed)
       dataset = []
       for _ in range(n_examples):
           log_sizes, raw_times, K = sample_demography(rng)
           try:
               sfs = compute_teacher_sfs(log_sizes, raw_times, n=n)
               if np.all(np.isfinite(sfs)) and np.all(sfs > 0):
                   dataset.append({
                       'log_sizes': log_sizes,
                       'raw_times': raw_times,
                       'n_epochs': K,
                       'sfs': sfs,
                   })
           except Exception:
               continue
       return dataset

The prior over demographic parameters should be broad enough to cover the
parameter space of interest:

.. list-table:: Training prior
   :header-rows: 1
   :widths: 30 30 40

   * - Parameter
     - Distribution
     - Rationale
   * - Number of epochs :math:`K`
     - :math:`\text{Uniform}\{1, 2, \ldots, 6\}`
     - Covers simple to moderately complex models
   * - :math:`\log(N_e / N_{\text{ref}})`
     - :math:`\mathcal{N}(0, 1.5)`
     - Spans bottlenecks (:math:`\sim 0.05 N_{\text{ref}}`) to expansions
       (:math:`\sim 20 N_{\text{ref}}`)
   * - Epoch durations (coalescent units)
     - :math:`\text{Exp}(0.5)`
     - Most epochs are short; occasional long ancient epochs
   * - Sample size :math:`n`
     - :math:`\{10, 20, 30, 50, 100\}`
     - Train on multiple sample sizes simultaneously

Training procedure
--------------------

The training loss is the mean squared error between the neural SFS prediction
and the teacher's exact SFS:

.. math::

   \mathcal{L}_{\text{MSE}} = \frac{1}{n-1} \sum_{j=1}^{n-1}
   \left(\hat{M}_j(\Theta) - M_j^*(\Theta)\right)^2

In practice, we train on the log-SFS to handle the wide dynamic range (SFS
entries span several orders of magnitude, with rare classes much smaller than
common ones):

.. math::

   \mathcal{L}_{\text{log-MSE}} = \frac{1}{n-1} \sum_{j=1}^{n-1}
   \left(\ln \hat{M}_j(\Theta) - \ln M_j^*(\Theta)\right)^2

.. code-block:: python

   import torch
   import torch.nn as nn
   from torch.optim import AdamW
   from torch.optim.lr_scheduler import CosineAnnealingLR

   def train_phase1(model, dataset, n_epochs=100, batch_size=256,
                    lr=3e-4, device='cuda'):
       optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
       scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

       model.to(device)
       model.train()

       for epoch in range(n_epochs):
           total_loss = 0.0
           n_batches = 0

           for batch in make_batches(dataset, batch_size):
               log_sizes = batch['log_sizes'].to(device)
               log_times = batch['log_times'].to(device)
               true_sfs = batch['sfs'].to(device)
               n = true_sfs.shape[-1] + 1

               pred_sfs = model(log_sizes, log_times, n, theta_L=1.0)

               loss = ((torch.log(pred_sfs + 1e-10)
                        - torch.log(true_sfs + 1e-10)) ** 2).mean()

               optimizer.zero_grad()
               loss.backward()
               torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
               optimizer.step()

               total_loss += loss.item()
               n_batches += 1

           scheduler.step()

           if epoch % 10 == 0:
               avg_loss = total_loss / n_batches
               print(f"Epoch {epoch}: log-MSE = {avg_loss:.6f}")

Training data generation strategy
------------------------------------

The quality of the neural approximation depends critically on the coverage of
the training set. Three strategies improve coverage:

**1. Importance sampling of the parameter space.** Regions of parameter space
where the SFS changes rapidly (e.g., near sharp bottlenecks) need more training
examples. After an initial uniform training run, identify high-error regions
and oversample them.

**2. Boundary cases.** The SFS has known analytical forms at the boundaries of
parameter space: the equilibrium SFS (:math:`M_j \propto \theta / j` for constant
:math:`N_e`), the post-bottleneck SFS (excess singletons), and the post-expansion
SFS (deficit of singletons). Include these boundary cases explicitly in the
training set and verify that the network reproduces them.

**3. Multi-scale training.** Train on a mix of simple models (:math:`K = 1, 2`)
and complex models (:math:`K = 4, 5, 6`). Simple models provide clean signal
for the basic SFS shape; complex models teach the network to handle interactions
between epochs.

.. admonition:: How many training examples?

   Empirically, 100,000 examples suffice for 1-population models with
   :math:`K \leq 6` epochs and :math:`n \leq 100`. The log-MSE typically
   converges below :math:`10^{-4}` after 50--100 epochs of training, meaning
   the neural SFS matches the teacher to within 1% relative error per
   frequency bin. For multi-population models, more examples are needed
   (see :ref:`balance_wheel_multi_pop`).


Phase 2: Inference on Real Data
=================================

Once Phase 1 is complete, the neural SFS Predictor is frozen. Phase 2 uses it
to perform demographic inference on observed data -- the same task that dadi and
moments perform, but 100--1000× faster per likelihood evaluation.

The observed SFS is computed from a VCF file or genotype matrix:

.. code-block:: python

   def compute_observed_sfs(genotypes, n):
       """Compute the folded or unfolded SFS from a genotype matrix.

       genotypes: (n, L) binary matrix — rows are haplotypes, columns are sites
       Returns: (n-1,) integer tensor — SFS counts
       """
       freq_counts = genotypes.sum(dim=0)
       sfs = torch.zeros(n + 1, dtype=torch.long)
       for j in range(n + 1):
           sfs[j] = (freq_counts == j).sum()
       return sfs[1:n]

Given the observed SFS, we optimize the demographic parameters :math:`\Theta`
to maximize the Poisson log-likelihood through the neural SFS Predictor:

.. code-block:: python

   def fit_demography(model, observed_sfs, theta_L, n_steps=2000,
                      lr=0.01, n_epochs_model=4):
       """Find MLE demographic parameters using Balance Wheel."""
       n = observed_sfs.shape[0] + 1
       log_sizes = nn.Parameter(torch.zeros(n_epochs_model))
       log_times = nn.Parameter(
           torch.linspace(-2, 2, n_epochs_model))
       optimizer = torch.optim.Adam([log_sizes, log_times], lr=lr)

       model.eval()
       for step in range(n_steps):
           sorted_times = torch.sort(log_times)[0]
           expected_sfs = model(
               log_sizes.unsqueeze(0), sorted_times.unsqueeze(0),
               n, theta_L)
           neg_ll = -poisson_log_likelihood(
               observed_sfs.float(), expected_sfs.squeeze(0))

           optimizer.zero_grad()
           neg_ll.backward()
           optimizer.step()

           if step % 500 == 0:
               print(f"Step {step}: -log L = {neg_ll.item():.2f}")

       return torch.exp(log_sizes).detach(), torch.exp(sorted_times).detach()

This is the same optimization that dadi and moments perform, but each likelihood
evaluation takes ~0.1 ms instead of ~10 ms (moments) or ~100 ms (dadi). For
gradient-based optimization with 2,000 steps, the total inference time is:

.. list-table:: Inference time comparison
   :header-rows: 1
   :widths: 25 25 25 25

   * - Method
     - Per evaluation
     - 2,000 steps
     - Gradient method
   * - dadi
     - ~100 ms
     - ~200 s (finite diff, 10 params)
     - :math:`2 \times 10 \times 100` ms = 2 s/step
   * - moments
     - ~10 ms
     - ~20 s (AD)
     - ~20 ms/step
   * - Balance Wheel
     - ~0.1 ms
     - **~0.2 s**
     - ~0.2 ms/step


Validation
============

How do we know the student has learned well? Three validation strategies.

**1. Held-out test set.** Generate 10,000 parameter vectors not seen during
training, compute both the teacher SFS and the neural SFS, and compare:

.. code-block:: python

   def validate(model, test_set, device='cuda'):
       model.eval()
       errors = []
       with torch.no_grad():
           for example in test_set:
               log_sizes = example['log_sizes'].unsqueeze(0).to(device)
               log_times = example['log_times'].unsqueeze(0).to(device)
               true_sfs = example['sfs'].to(device)
               n = true_sfs.shape[0] + 1
               pred_sfs = model(log_sizes, log_times, n, theta_L=1.0)
               rel_error = (
                   (pred_sfs.squeeze() - true_sfs).abs()
                   / true_sfs.clamp(min=1e-8)
               ).mean()
               errors.append(rel_error.item())
       return np.mean(errors), np.percentile(errors, 95)

Target: mean relative error < 1%, 95th percentile < 5%.

**2. Known analytical cases.** The equilibrium SFS for constant :math:`N_e` has
the closed form:

.. math::

   M_j = \frac{\theta L}{j}, \qquad j = 1, \ldots, n-1

Verify that the neural network reproduces this for constant-size demography.
Similarly, check the SFS after a recent exponential expansion or a severe
bottleneck, where approximate analytical forms exist.

**3. Likelihood surface comparison.** For a fixed observed SFS, compute the
log-likelihood surface :math:`\ell(\Theta)` using both moments and Balance
Wheel on a grid of :math:`\Theta` values. The surfaces should be nearly
identical:

.. math::

   \left|\ell_{\text{neural}}(\Theta) - \ell_{\text{moments}}(\Theta)\right|
   \ll 1 \quad \text{for all } \Theta \text{ in the grid}

A log-likelihood difference of < 0.5 is typically acceptable (smaller than
typical random noise in the SFS due to sampling). If the difference exceeds 1.0
anywhere in the region of interest, the network needs more training examples in
that region.

.. admonition:: When the student surpasses the teacher

   In rare cases, the neural SFS prediction may actually be *smoother* than the
   teacher's output. moments can have numerical artifacts for extreme parameter
   values (very large :math:`N_e` changes, very short epochs), while the neural
   network smoothly interpolates. This is not a bug -- it is the regularizing
   effect of the MLP's smooth activation functions. However, it means the neural
   SFS may not exactly match moments in these edge cases. Always validate against
   the teacher in the regime of interest.


Practical Considerations
==========================

Training time and compute
---------------------------

.. list-table::
   :header-rows: 1
   :widths: 35 30 35

   * - Component
     - Cost
     - Wall-clock time
   * - Teacher data generation (100k examples)
     - 100k × 10 ms (moments)
     - ~17 minutes
   * - Phase 1 training (100 epochs)
     - 100 × 100k / 256 × 1 ms
     - ~1 hour (single GPU)
   * - Validation (10k test examples)
     - 10k × 0.1 ms
     - ~1 second

The total Phase 1 cost is dominated by training data generation. This is a
one-time cost -- once the network is trained, it can be used for any observed
SFS with compatible sample size and parameter range.

Retraining triggers
---------------------

The trained network should be retrained if:

- The parameter space of interest shifts (e.g., moving from human to
  *Drosophila* demography with different :math:`N_e` ranges).
- The sample size changes beyond the training range.
- The number of populations changes (requires retraining the multi-population
  encoder).
- The teacher is updated (e.g., moments releases a numerically improved solver).

.. admonition:: Balance Wheel as a compiler

   Think of Phase 1 as *compilation*: the slow, careful computation of moments
   is compiled into the fast, approximate computation of the neural network.
   Just as a compiled program runs faster than the interpreted source, the
   neural SFS evaluation runs faster than the ODE integration. And just as you
   must recompile when the source changes, you must retrain when the parameter
   space or teacher changes. The compilation cost is paid once; the speedup is
   enjoyed on every subsequent query.
