.. _mainspring_training:

==========
Training
==========

   *A mainspring must be wound carefully. Too little tension and the watch stops.
   Too much and the spring breaks. The art is in the winding -- gradually, evenly,
   with increasing force.*

Training Mainspring requires three ingredients: a simulation engine that generates
(genotype matrix, true ARG, true demography) triples, a composite loss function that
scores each component of the prediction, and a curriculum that gradually increases
the complexity of the training distribution. This chapter builds all three.


The Simulation Engine
=======================

Mainspring learns by inverting simulations. The simulator is the generative model;
the network learns to be its approximate inverse. The quality of the training data
determines the ceiling of inference quality -- no network can learn to infer features
absent from the training simulations.

We use :ref:`msprime <msprime_timepiece>` as the simulation engine. Each training
example is generated by:

1. Sampling a demographic model from a prior distribution.
2. Simulating a tree sequence under that demographic model.
3. Extracting the genotype matrix, the true ARG, and the true demography.

.. code-block:: python

   import msprime
   import numpy as np

   def sample_demography(rng):
       """Sample a random demographic model from the training prior.

       Returns a msprime.Demography object and a callable N_e(t).
       """
       n_epochs = rng.integers(1, 8)
       times = np.sort(rng.exponential(scale=5000, size=n_epochs))
       times = np.concatenate([[0], times])
       sizes = 10 ** rng.uniform(2, 5, size=n_epochs + 1)

       demography = msprime.Demography()
       demography.add_population(name="pop", initial_size=sizes[0])
       for i in range(1, len(times)):
           demography.add_population_parameters_change(
               time=times[i], initial_size=sizes[i], population="pop"
           )

       def ne_func(t):
           idx = np.searchsorted(times, t, side='right') - 1
           return sizes[idx]

       return demography, ne_func, times, sizes

   def simulate_training_example(n_samples, seq_length, mu, rho, rng):
       """Generate one (genotype_matrix, true_ARG, true_demography) triple."""
       demography, ne_func, times, sizes = sample_demography(rng)

       ts = msprime.sim_ancestry(
           samples=n_samples,
           sequence_length=seq_length,
           recombination_rate=rho,
           demography=demography,
           random_seed=rng.integers(1, 2**31),
       )
       ts = msprime.sim_mutations(ts, rate=mu, random_seed=rng.integers(1, 2**31))

       genotype_matrix = ts.genotype_matrix().T  # (n_samples, n_sites)

       return {
           'genotypes': genotype_matrix,
           'tree_sequence': ts,
           'ne_func': ne_func,
           'ne_times': times,
           'ne_sizes': sizes,
       }

.. admonition:: Extracting training targets from the tree sequence

   The tree sequence ``ts`` returned by msprime contains everything we need:

   - **Topology targets**: For each local tree, the parent array
     ``tree.parent_array`` gives the true topology.
   - **Breakpoint targets**: ``ts.breakpoints()`` gives the true positions where
     trees change.
   - **Node time targets**: ``ts.tables.nodes.time`` gives the true time of every
     node.
   - **Demography targets**: The ``ne_func`` callable gives the true :math:`N_e(t)`
     at any time.

   The genotype matrix is the input; everything else is a training target.


Scaling the Simulation Pipeline
---------------------------------

Training requires millions of simulated datasets. Generating them on-the-fly (one
per gradient step) is essential to avoid overfitting to a finite training set. A
typical training configuration:

.. list-table::
   :header-rows: 1
   :widths: 35 30 35

   * - Parameter
     - Value
     - Rationale
   * - :math:`n` (samples)
     - 20--100
     - Covers typical sample sizes for ARG inference
   * - :math:`L` (sequence length)
     - 50 kb -- 1 Mb
     - Covers gene-scale to chromosome-arm-scale
   * - :math:`\mu` (mutation rate)
     - :math:`1.25 \times 10^{-8}` / bp / gen
     - Human mutation rate
   * - :math:`\rho` (recombination rate)
     - :math:`1.0 \times 10^{-8}` / bp / gen
     - Human recombination rate
   * - :math:`N_e` range
     - :math:`10^2` -- :math:`10^5`
     - Covers bottlenecks through large populations
   * - Number of epochs
     - 1--7
     - Covers constant through complex demography
   * - Simulations per GPU-hour
     - ~10,000 (100 kb, 50 samples)
     - msprime is fast; I/O is the bottleneck

.. code-block:: python

   from torch.utils.data import IterableDataset

   class MsprimeDataset(IterableDataset):
       def __init__(self, n_samples, seq_length, mu, rho):
           self.n_samples = n_samples
           self.seq_length = seq_length
           self.mu = mu
           self.rho = rho

       def __iter__(self):
           rng = np.random.default_rng()
           while True:
               example = simulate_training_example(
                   self.n_samples, self.seq_length, self.mu, self.rho, rng
               )
               yield self.tensorize(example)

       def tensorize(self, example):
           import torch
           ts = example['tree_sequence']
           return {
               'genotypes': torch.tensor(example['genotypes'], dtype=torch.float32),
               'node_times': torch.tensor(ts.tables.nodes.time, dtype=torch.float32),
               'ne_sizes': torch.tensor(example['ne_sizes'], dtype=torch.float32),
               'ne_times': torch.tensor(example['ne_times'], dtype=torch.float32),
           }


The Loss Function
===================

Mainspring's loss is a weighted sum of four components, each corresponding to a
stage of the architecture:

.. math::

   \mathcal{L} = \mathcal{L}_{\text{topology}} + \lambda_{\text{time}} \cdot
   \mathcal{L}_{\text{time}} + \lambda_{\text{SFS}} \cdot \mathcal{L}_{\text{SFS}}
   + \lambda_{\text{demo}} \cdot \mathcal{L}_{\text{demo}}

.. list-table:: Loss components
   :header-rows: 1
   :widths: 18 20 32 30

   * - Component
     - Symbol
     - What it measures
     - Stage
   * - Topology loss
     - :math:`\mathcal{L}_{\text{topology}}`
     - Cross-entropy between predicted and true parent assignments + binary
       cross-entropy for breakpoints
     - Stage 2 (Topology Decoder)
   * - Time loss
     - :math:`\mathcal{L}_{\text{time}}`
     - Negative log-likelihood of true node times under predicted gamma
       distributions
     - Stage 3 (Dating GNN)
   * - SFS loss
     - :math:`\mathcal{L}_{\text{SFS}}`
     - :math:`\chi^2` distance between predicted and observed SFS
     - Stages 2+3 (physics regularizer)
   * - Demographic loss
     - :math:`\mathcal{L}_{\text{demo}}`
     - Negative log-likelihood of true :math:`N_e(t)` under the normalizing flow
       posterior + KL penalty
     - Stage 4 (Demographic Decoder)

The weights :math:`\lambda_{\text{time}}`, :math:`\lambda_{\text{SFS}}`, and
:math:`\lambda_{\text{demo}}` balance the loss components. They are adjusted during
curriculum training (see below).


Topology Loss
---------------

The topology loss has two terms:

.. math::

   \mathcal{L}_{\text{topology}} = -\frac{1}{nL} \sum_{i=1}^{n} \sum_{\ell=1}^{L}
   \log \alpha_{i, \pi_i^*(\ell)}^\ell \;+\; \text{BCE}(\hat{b}, b^*)

where :math:`\pi_i^*(\ell)` is the true parent of sample :math:`i` in the local tree
at position :math:`\ell`, :math:`\alpha_{ij}^\ell` is the predicted attention weight
(copying probability), :math:`\hat{b}` is the predicted breakpoint vector, and
:math:`b^*` is the true breakpoint indicator.


Time Loss
-----------

The time loss is the negative log-likelihood of the true node times under the
predicted gamma distributions:

.. math::

   \mathcal{L}_{\text{time}} = -\frac{1}{|\mathcal{V}_{\text{int}}|}
   \sum_{v \in \mathcal{V}_{\text{int}}} \bigl[
   (\alpha_v - 1) \log t_v^* - \beta_v t_v^* + \alpha_v \log \beta_v
   - \log \Gamma(\alpha_v) \bigr]

where :math:`\mathcal{V}_{\text{int}}` is the set of internal nodes and
:math:`t_v^*` is the true time of node :math:`v`.


The SFS Loss as a Differentiable Physics Regularizer
-------------------------------------------------------

The SFS loss deserves special attention because it is the key connection between the
Timepieces that operate on summary statistics (:ref:`dadi <dadi_timepiece>`,
:ref:`moments <moments_timepiece>`, :ref:`momi2 <momi2_timepiece>`) and the full
ARG-based inference of Mainspring.

The SFS is a deterministic function of the ARG. For a sample of :math:`n` haplotypes:

.. math::

   \text{SFS}[k] = \mu \sum_{e \in \text{edges}} b(e) \cdot
   \mathbf{1}[\text{desc}(e) = k], \qquad k = 1, \ldots, n-1

where :math:`b(e)` is the branch length (in generations) of edge :math:`e`,
:math:`\text{desc}(e)` is the number of descendant leaves below edge :math:`e`, and
:math:`\mu` is the per-generation mutation rate. This formula says: the expected
number of sites with derived allele frequency :math:`k/n` equals the total branch
length subtending exactly :math:`k` leaves, times the mutation rate.

This relationship is **differentiable** with respect to the predicted branch lengths.
If the network predicts node times :math:`\hat{t}_v` (from the dating GNN), the
branch length of edge :math:`(u, v)` is :math:`\hat{b}_{uv} = \hat{t}_u - \hat{t}_v`,
and we can compute:

.. code-block:: python

   def differentiable_sfs(node_times, parent_array, n_leaves, mu, span):
       """Compute the expected SFS from predicted node times.

       This is differentiable w.r.t. node_times, allowing gradient flow
       from the SFS loss back through the dating GNN.
       """
       n = n_leaves
       sfs = torch.zeros(n + 1, dtype=node_times.dtype, device=node_times.device)

       for child in range(len(parent_array)):
           parent = parent_array[child]
           if parent < 0:
               continue
           branch_length = node_times[parent] - node_times[child]
           n_desc = count_descendants_below(child, parent_array, n_leaves)
           sfs[n_desc] = sfs[n_desc] + branch_length * mu * span

       return sfs[1:n]  # SFS[1] through SFS[n-1]

The SFS loss is then:

.. math::

   \mathcal{L}_{\text{SFS}} = \sum_{k=1}^{n-1}
   \frac{(\widehat{\text{SFS}}[k] - \text{SFS}_{\text{obs}}[k])^2}
   {\max(\text{SFS}_{\text{obs}}[k], \epsilon)}

This is a :math:`\chi^2`-type loss that down-weights rare frequency classes
(where the observed SFS may be zero or very small). The :math:`\epsilon` floor
prevents division by zero.

.. admonition:: Why this works as a regularizer

   The SFS loss does not require knowing the true ARG. It compares the SFS implied
   by the **predicted** ARG to the SFS computed directly from the **observed**
   genotype matrix (which is always available). This provides a supervision signal
   that is independent of the topology and time losses -- it catches global errors
   that per-node losses miss. For example, if the network systematically
   under-estimates deep coalescence times, the predicted SFS will have too few
   singletons (because deep branches subtend many descendants, shifting weight from
   low to high frequency classes). The SFS loss detects and corrects this.


Demographic Loss
-------------------

The demographic loss trains the normalizing flow to produce accurate posterior
distributions over :math:`N_e(t)`. For each training example, we have the true
:math:`N_e(t)` trajectory sampled from our prior. The loss is the negative
log-likelihood of the true trajectory under the flow:

.. math::

   \mathcal{L}_{\text{demo}} = -\log q_\phi(N_e^* \mid \mathbf{c}, \text{SFS})
   = -\log p(\mathbf{z}^*) - \log \left|\det \frac{\partial g_\phi^{-1}}
   {\partial N_e}\right|

where :math:`\mathbf{z}^* = g_\phi^{-1}(N_e^*)` is the true trajectory mapped back
to the base distribution, and the second term is the log-determinant of the Jacobian
of the inverse flow.


Curriculum Training
=====================

Training Mainspring on the full prior from the start is inefficient. Complex
demographic models produce ARGs with deep coalescence events, many breakpoints, and
wide variation in branch lengths -- all of which are difficult for an untrained
network to predict. Instead, we use **curriculum training**: a sequence of phases
that gradually increase the complexity of the training distribution.

.. list-table:: Curriculum phases
   :header-rows: 1
   :widths: 8 18 25 25 24

   * - Phase
     - Demography
     - Focus
     - Loss weights
     - Duration
   * - 1
     - Constant :math:`N_e`
     - Topology and basic dating
     - :math:`\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=0.1, \lambda_{\text{demo}}=0`
     - 100k steps
   * - 2
     - 1--2 size changes
     - Dating under variable :math:`N_e`
     - :math:`\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=0.5, \lambda_{\text{demo}}=0.1`
     - 200k steps
   * - 3
     - Full prior (1--7 epochs)
     - Demographic inference
     - :math:`\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=1, \lambda_{\text{demo}}=1`
     - 500k steps
   * - 4
     - Complex + selection (SLiM)
     - Robustness to model misspecification
     - :math:`\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=1, \lambda_{\text{demo}}=1`
     - 200k steps

**Phase 1: Constant demography.** All simulations use a single constant population
size :math:`N_e \sim 10^{\mathcal{U}(2,5)}`. The Gumbel-softmax temperature starts
high (:math:`\tau = 5`) and is annealed to :math:`\tau = 1` by the end of Phase 1.
The network learns basic topology reconstruction and time estimation without needing
to handle demographic variation.

**Phase 2: Simple demographic changes.** Simulations include one or two step changes
in population size. The demographic decoder is activated
(:math:`\lambda_{\text{demo}} > 0`), and the network begins learning to map
coalescence-time distributions to :math:`N_e(t)`. The SFS loss weight increases to
:math:`0.5`, strengthening the physics regularizer as the ARGs become more complex.

**Phase 3: Full complexity.** The training prior covers the full range of demographic
models (1--7 epochs, arbitrary population sizes). All loss weights are set to 1. The
Gumbel-softmax temperature continues annealing toward :math:`\tau = 0.1`. This is
the longest phase and where most of the learning happens.

**Phase 4: Robustness (optional).** Training examples include simulations from SLiM
(which can model natural selection, population structure, and other complexities not
available in msprime). The network sees data generated under model misspecification
-- the true generative model is more complex than the coalescent assumed by the
architecture. This teaches the network to degrade gracefully rather than produce
confidently wrong answers.


Gumbel-Softmax Annealing
---------------------------

The temperature :math:`\tau` of the Gumbel-softmax in the topology decoder follows
an exponential annealing schedule:

.. math::

   \tau(t) = \tau_{\max} \cdot \left(\frac{\tau_{\min}}{\tau_{\max}}\right)^{t / T}

where :math:`t` is the training step, :math:`T` is the total number of steps, and
typically :math:`\tau_{\max} = 5.0`, :math:`\tau_{\min} = 0.05`.

.. code-block:: python

   def anneal_temperature(step, total_steps, tau_max=5.0, tau_min=0.05):
       return tau_max * (tau_min / tau_max) ** (step / total_steps)


Training Pseudocode
=====================

The complete training loop:

.. code-block:: python

   import torch
   from torch.optim import AdamW
   from torch.optim.lr_scheduler import CosineAnnealingLR

   def train_mainspring(model, n_steps=1_000_000, batch_size=16,
                        lr=3e-4, device='cuda'):
       optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
       scheduler = CosineAnnealingLR(optimizer, T_max=n_steps)
       dataset = MsprimeDataset(n_samples=50, seq_length=100_000,
                                mu=1.25e-8, rho=1.0e-8)
       loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

       phase_boundaries = [100_000, 300_000, 800_000, 1_000_000]
       lambda_configs = [
           {'time': 1.0, 'sfs': 0.1, 'demo': 0.0},
           {'time': 1.0, 'sfs': 0.5, 'demo': 0.1},
           {'time': 1.0, 'sfs': 1.0, 'demo': 1.0},
           {'time': 1.0, 'sfs': 1.0, 'demo': 1.0},
       ]

       model.to(device)
       model.train()

       for step, batch in enumerate(loader):
           if step >= n_steps:
               break

           phase = sum(step >= b for b in phase_boundaries[:-1])
           lambdas = lambda_configs[phase]

           model.topology_decoder.tau = anneal_temperature(step, n_steps)

           genotypes = batch['genotypes'].to(device)
           true_times = batch['node_times'].to(device)
           true_ne = batch['ne_sizes'].to(device)

           outputs = model(genotypes)

           L_topo = topology_loss(outputs['topology'], batch, genotypes)
           L_time = time_loss(outputs['alpha'], outputs['beta'], true_times)
           L_sfs = sfs_loss(outputs['predicted_sfs'], genotypes)
           L_demo = demo_loss(outputs['ne_posterior'], outputs['flow_log_det'],
                              true_ne)

           loss = (L_topo
                   + lambdas['time'] * L_time
                   + lambdas['sfs'] * L_sfs
                   + lambdas['demo'] * L_demo)

           optimizer.zero_grad()
           loss.backward()
           torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
           optimizer.step()
           scheduler.step()

           if step % 1000 == 0:
               print(f"Step {step}: loss={loss.item():.4f} "
                     f"topo={L_topo.item():.4f} time={L_time.item():.4f} "
                     f"sfs={L_sfs.item():.4f} demo={L_demo.item():.4f} "
                     f"tau={model.topology_decoder.tau:.3f}")


Training Diagnostics
======================

Monitoring training requires tracking each loss component independently, plus
several diagnostic metrics:

.. list-table:: Training diagnostics
   :header-rows: 1
   :widths: 30 70

   * - Metric
     - What it tells you
   * - **Topology accuracy** (Robinson-Foulds distance)
     - Whether the predicted tree topologies match ground truth. Should decrease
       rapidly in Phase 1 and plateau in Phase 3.
   * - **Time calibration** (predicted gamma coverage)
     - Whether the predicted gamma distributions are well-calibrated: the fraction
       of true node times falling within the predicted 90% credible interval should
       be close to 0.9.
   * - **SFS residuals**
     - The per-frequency-class difference between predicted and observed SFS.
       Systematic biases (e.g., too few singletons) indicate structural errors in
       the predicted ARG.
   * - **:math:`N_e(t)` RMSE** (log-scale)
     - Root mean squared error between predicted and true :math:`N_e(t)` on a
       log scale. Should decrease steadily throughout Phases 2--3.
   * - **Gumbel-softmax entropy**
     - The entropy of the attention weights in the topology decoder. Should decrease
       as :math:`\tau` anneals, indicating that parent assignments are becoming more
       confident.

.. admonition:: When to stop training

   Training is complete when (1) the validation loss plateaus for 50k steps, (2) the
   time calibration is within 5% of the nominal level across all time scales, and
   (3) the :math:`N_e(t)` RMSE on a held-out validation set of 1,000 simulations
   stops improving. In practice, Phase 3 accounts for most of the training time.
   Phase 4 (SLiM robustness) is optional and mainly relevant for applications where
   selection is expected.
