.. _mainspring_complication:

=======================================
Complication I: Mainspring
=======================================

   *Amortized ARG Inference via Structured Neural Posterior Estimation*

The Mechanism at a Glance
==========================

Every method in this book makes the same fundamental trade: mathematical
tractability vs. biological realism. PSMC discretizes time and assumes
piecewise-constant demography. dadi collapses the genome to a frequency histogram.
ARGweaver is exact under the DSMC model but costs :math:`O(S^2)` per site. tsinfer
scales to millions of samples but surrenders posterior inference entirely. Each
Timepiece chooses a different point on the Pareto frontier between accuracy and
compute.

Deep learning can break this frontier. Not by replacing the gear train -- the
likelihood machinery the book spends hundreds of pages building -- but by **learning
to shortcut it**. Simulate millions of ARGs from the generative model (msprime).
Train a neural network to invert the simulation: map observed sequences back to the
ARG and the demography that produced them. At inference time, a single forward pass
replaces hours of MCMC or EM.

The question is: **what architecture respects the mathematical structure of the
problem well enough to learn efficiently?** This Complication answers that question by
distilling one design principle from each Timepiece.

.. admonition:: The name

   The mainspring is the power source of a mechanical watch -- a coiled spring that
   stores energy and releases it through the gear train to drive the hands. In
   Mainspring, simulated ARGs are the stored energy, and the neural network is the
   gear train that converts them into inference power. Like a physical mainspring,
   the energy is wound in advance (during training) and released in precisely metered
   pulses (at inference time).

The four stages of Mainspring:

1. **The Genomic Encoder** (the escapement) -- A Transformer that processes the
   genotype matrix with sliding-window attention over positions (from PSMC's sequential
   Markov property) and Set Transformer attention over samples (from SMC++'s permutation
   invariance). Output: per-sample, per-position latent vectors.

2. **The Topology Decoder** (the gear train) -- A learned Li & Stephens model.
   Cross-attention between haplotypes identifies who is copying whom at each position
   (from tsinfer/lshmm). Hard attention via Gumbel-softmax yields discrete parent
   assignments. A breakpoint detector identifies where trees change. Output: a tree
   sequence topology.

3. **The Dating GNN** (the mainspring) -- A Graph Neural Network that runs learned
   message-passing on each local tree (from tsdate's inside-outside algorithm).
   Edge features include mutation count and genomic span (from Threads' sufficient
   statistics). Output: gamma-distributed node times (from Gamma-SMC).

4. **The Demographic Decoder** (the case and dial) -- A conditional normalizing flow
   that maps the inferred coalescence-time distribution to a posterior over continuous
   :math:`N_e(t)` functions (from phlash). An SFS auxiliary loss provides
   physics-informed regularization (from dadi/moments/momi2).

.. code-block:: text

   Genotype matrix D ∈ {0,1}^{n × L}
                      |
                      v
            +--------------------------+
            |  GENOMIC ENCODER         |
            |  Sliding-window attn     |
            |  (PSMC: sequential SMC)  |
            |  Set Transformer         |
            |  (SMC++: permutation eq) |
            +--------------------------+
                      |
                      v
            +--------------------------+
            |  TOPOLOGY DECODER        |
            |  Cross-attention         |
            |  (tsinfer: copying model)|
            |  Gumbel-softmax → edges  |
            |  Breakpoint detection    |
            +--------------------------+
                      |
                      v
            +--------------------------+
            |  DATING GNN              |
            |  Message passing on trees|
            |  (tsdate: inside-outside)|
            |  Gamma(α, β) output      |
            |  (Gamma-SMC: posteriors) |
            +--------------------------+
                      |
                      v
            +--------------------------+
            |  DEMOGRAPHIC DECODER     |
            |  Normalizing flow → N_e  |
            |  (phlash: continuous)    |
            |  SFS auxiliary loss      |
            |  (dadi/moments: physics) |
            +--------------------------+
                      |
                      v
            Full dated ARG + N_e(t)

.. admonition:: Prerequisites for this Complication

   Before starting Mainspring, you should have worked through:

   - :ref:`PSMC <psmc_timepiece>` -- the sequential Markov property and HMM inference
   - :ref:`tsinfer <tsinfer_timepiece>` -- the Li & Stephens copying model and tree
     sequence representation
   - :ref:`tsdate <tsdate_timepiece>` -- the inside-outside algorithm on trees and
     variational gamma posteriors
   - :ref:`The SMC <smc>` -- the sequential Markov coalescent approximation
   - :ref:`msprime <msprime_timepiece>` -- the coalescent simulator (for generating
     training data)

   Familiarity with Transformer architectures and PyTorch is assumed.

Chapters
========

.. toctree::
   :maxdepth: 2

   overview
   design_principles
   architecture
   training
   comparison
