.. _escapement_variational:

==========================================
Variational Inference Without Simulations
==========================================

   *The watchmaker who understands the physics of the escapement can regulate
   the watch without a reference clock. The equations themselves are the
   standard.*

The core innovation of Escapement is not a new neural architecture -- it is a
new **objective function**. Instead of matching simulated data (supervised
learning), Escapement maximizes a lower bound on the probability of the observed
data under the coalescent model (variational inference). This chapter derives
that bound, explains each term, and addresses the two technical barriers that
prevented this approach until now.


The Evidence Lower Bound for Tree Sequences
=============================================

Let :math:`\mathbf{D} \in \{0,1\}^{n \times L}` be the observed genotype
matrix, :math:`\tau` a tree sequence (topology, coalescence times, and
breakpoints), and :math:`\theta = (N_e(t), \mu, \rho)` the population-genetic
parameters. The **marginal likelihood** (evidence) of the data is:

.. math::

   P(\mathbf{D} \mid \theta) = \int P(\mathbf{D} \mid \tau, \mu) \cdot
   P(\tau \mid N_e, \rho) \; d\tau

This integral is over all possible tree sequences consistent with :math:`n`
samples at :math:`L` sites -- a space of staggering dimension. The topologies
alone are combinatorial (the number of labeled rooted binary trees on :math:`n`
leaves is :math:`(2n-3)!! = 1 \times 3 \times 5 \times \cdots \times (2n-3)`),
and for each topology the coalescence times are continuous. No method can
evaluate this integral exactly.

Variational inference replaces the intractable posterior
:math:`P(\tau \mid \mathbf{D}, \theta)` with a tractable approximation
:math:`q(\tau \mid \mathbf{D}, \phi)`, parameterized by :math:`\phi`. The
quality of the approximation is measured by the KL divergence:

.. math::

   \text{KL}\!\left[q(\tau \mid \mathbf{D}, \phi) \;\|\;
   P(\tau \mid \mathbf{D}, \theta)\right] =
   \mathbb{E}_q\!\left[\log \frac{q(\tau \mid \mathbf{D}, \phi)}
   {P(\tau \mid \mathbf{D}, \theta)}\right] \geq 0

Expanding the posterior using Bayes' rule and rearranging:

.. math::

   \log P(\mathbf{D} \mid \theta) &= \text{KL}\!\left[q \| P(\tau \mid
   \mathbf{D}, \theta)\right] +
   \mathbb{E}_q\!\left[\log P(\mathbf{D} \mid \tau, \mu) +
   \log P(\tau \mid N_e, \rho) - \log q(\tau \mid \mathbf{D}, \phi)\right]

   &\geq \underbrace{\mathbb{E}_q\!\left[\log P(\mathbf{D} \mid \tau, \mu)
   \right]}_{\text{mutation likelihood}} +
   \underbrace{\mathbb{E}_q\!\left[\log P(\tau \mid N_e, \rho)
   \right]}_{\text{coalescent prior}} +
   \underbrace{H[q]}_{\text{entropy}}

The last line is the **Evidence Lower Bound** (ELBO). Since the KL divergence
is non-negative, the ELBO is always less than or equal to the log-evidence.
Maximizing the ELBO with respect to :math:`\phi` simultaneously (1) tightens
the bound and (2) pushes :math:`q` toward the true posterior.

.. admonition:: The ELBO as a loss function

   The ELBO is the negative loss that Escapement minimizes:

   .. math::

      \mathcal{L}(\phi, \theta) = -\text{ELBO} = -\mathbb{E}_q\!\left[
      \log P(\mathbf{D} \mid \tau, \mu) + \log P(\tau \mid N_e, \rho)
      \right] - H[q]

   This loss requires only three things: (1) the ability to sample
   :math:`\tau \sim q`, (2) the ability to evaluate the mutation likelihood
   and coalescent prior analytically, and (3) the ability to compute the
   entropy of :math:`q`. No simulations appear anywhere.


The Three ELBO Terms
======================

Each term in the ELBO has a clear population-genetic interpretation and a
clear source in the Timepieces.

.. list-table:: The three ELBO terms
   :header-rows: 1
   :widths: 16 28 28 28

   * - Term
     - Mathematical form
     - Population-genetic meaning
     - Source Timepiece(s)
   * - Mutation likelihood
     - :math:`\mathbb{E}_q[\log P(\mathbf{D} \mid \tau, \mu)]`
     - Does the proposed genealogy explain the observed mutations?
     - :ref:`tsdate <tsdate_timepiece>` (mutation model),
       :ref:`ARGweaver <argweaver_timepiece>` (Poisson on edges)
   * - Coalescent prior
     - :math:`\mathbb{E}_q[\log P(\tau \mid N_e, \rho)]`
     - Is the proposed genealogy plausible under the coalescent?
     - :ref:`msprime <msprime_timepiece>` (Kingman coalescent),
       :ref:`PSMC <psmc_timepiece>` (SMC factorization)
   * - Entropy
     - :math:`H[q] = -\mathbb{E}_q[\log q(\tau)]`
     - How uncertain is the variational posterior? (regularizer)
     - :ref:`tsdate <tsdate_timepiece>` (gamma entropy),
       :ref:`Gamma-SMC <gamma_smc_timepiece>` (continuous posteriors)

**Term 1: Mutation likelihood.** Given a genealogy :math:`\tau`, each edge
:math:`e` with time span :math:`t_e` and genomic span :math:`s_e` accumulates
mutations at rate :math:`\mu`. Under the infinite-sites model, the number of
mutations on edge :math:`e` is Poisson(:math:`\mu \cdot t_e \cdot s_e`). The
mutation likelihood asks: given the proposed tree, are the observed mutations in
the right places?

**Term 2: Coalescent prior.** The Kingman coalescent specifies the distribution
of coalescence times given :math:`N_e(t)`. For :math:`k` lineages at time
:math:`t`, the rate of coalescence is :math:`\binom{k}{2}/N_e(t)`. The
coalescent prior asks: are the proposed coalescence times consistent with the
proposed demography?

**Term 3: Entropy.** The entropy of the variational posterior regularizes the
optimization. Without the entropy term, the optimizer would collapse :math:`q`
to a delta function at the maximum a posteriori (MAP) genealogy, losing all
uncertainty information. The entropy term encourages :math:`q` to remain
appropriately diffuse.


Why No Simulations Appear
===========================

It is worth pausing to understand precisely why simulations are absent from the
ELBO. The training loop is:

1. Receive observed data :math:`\mathbf{D}` (a real genotype matrix).
2. Run :math:`\mathbf{D}` through the encoder to produce latent vectors.
3. Sample a genealogy :math:`\tau \sim q(\tau \mid \mathbf{D}, \phi)` from the
   variational posterior.
4. Evaluate :math:`\log P(\mathbf{D} \mid \tau, \mu)` using the analytical
   Poisson mutation model.
5. Evaluate :math:`\log P(\tau \mid N_e, \rho)` using the analytical Kingman
   coalescent.
6. Compute :math:`H[q]` from the parameters of :math:`q` (analytical for gamma,
   Bernoulli, categorical distributions).
7. Sum the three terms to get the ELBO.
8. Backpropagate and update :math:`\phi` and :math:`\theta`.

Steps 4 and 5 use the **same analytical formulas** derived in the Timepieces.
:ref:`tsdate <tsdate_timepiece>` derives the Poisson mutation likelihood.
:ref:`msprime <msprime_timepiece>` and :ref:`PSMC <psmc_timepiece>` derive the
coalescent prior. These formulas take a proposed genealogy :math:`\tau` and
return a number. They do not simulate anything.

The simulation-based approach (Mainspring) generates
:math:`(\mathbf{D}, \tau^*)` pairs from msprime, where :math:`\tau^*` is the
true genealogy, and trains the network to predict :math:`\tau^*` from
:math:`\mathbf{D}`. The likelihood-based approach (Escapement) takes only
:math:`\mathbf{D}`, proposes :math:`\tau`, and asks whether :math:`\tau`
explains :math:`\mathbf{D}` well under the coalescent model. The analytical
formulas replace the need for ground truth.

.. code-block:: python

   def compute_elbo(model, genotypes, temperature=1.0):
       """Single ELBO evaluation on observed data. No simulations."""
       out = model(genotypes, temperature=temperature)

       elbo = (out['log_lik']          # E_q[log P(D | tau, mu)]
               + out['log_prior_coal']  # E_q[log P(tau | Ne, rho)]
               + out['log_prior_break'] # E_q[log P(breaks | rho)]
               + out['entropy']         # H[q] (branch lengths + breakpoints)
               - out['log_q_topo'])     # -E_q[log q(topology)]

       return elbo


Why This Wasn't Done Before
=============================

If the ELBO requires only analytical formulas that have been known for decades,
why wasn't Escapement built earlier? Two technical barriers stood in the way.

Barrier 1: Discrete Topologies Are Not Differentiable
-------------------------------------------------------

The tree topology is a discrete object: each sample has a parent, chosen from a
finite set of possible ancestors. The ELBO cannot be maximized by gradient
descent if the objective is a function of discrete choices -- the gradient is
zero almost everywhere and undefined at the discontinuities.

Classical variational inference for discrete latent variables uses
**coordinate ascent** (updating one variable at a time, holding others fixed) or
**EM** (computing posterior expectations analytically). These work for simple
models but scale poorly to the exponentially large space of tree topologies.

Escapement overcomes this barrier with the **Gumbel-softmax trick**
(Jang et al. 2017; Maddison et al. 2017). Instead of sampling a hard parent
assignment :math:`\pi_i = j` (one-hot), it samples a soft assignment
:math:`\tilde{\pi}_i \in \Delta^{n-1}` from a continuous relaxation of the
categorical distribution:

.. math::

   \tilde{\pi}_{ij} = \frac{\exp\bigl((\log \alpha_{ij} + g_j) / \tau\bigr)}
   {\sum_{k \neq i} \exp\bigl((\log \alpha_{ik} + g_k) / \tau\bigr)}, \qquad
   g_j \sim \text{Gumbel}(0, 1)

where :math:`\alpha_{ij}` are the attention weights (unnormalized log-
probabilities of :math:`i` choosing :math:`j` as parent), :math:`g_j` are
Gumbel noise samples, and :math:`\tau > 0` is a temperature parameter.

As :math:`\tau \to 0`, :math:`\tilde{\pi}_i` approaches a one-hot vector
(hard assignment). As :math:`\tau \to \infty`, it approaches a uniform
distribution. Crucially, the Gumbel-softmax is differentiable with respect
to :math:`\alpha_{ij}` for all :math:`\tau > 0`, enabling gradient-based
optimization of the topology.

.. admonition:: The straight-through estimator

   During the forward pass, Escapement uses the **straight-through estimator**:
   it takes the argmax of the Gumbel-softmax output (producing a hard one-hot
   vector) but uses the soft Gumbel-softmax for the backward pass. This gives
   the best of both worlds: discrete topology for the likelihood evaluation,
   continuous gradients for optimization.

   .. code-block:: python

      def straight_through_gumbel(logits, tau):
          soft = gumbel_softmax(logits, tau)
          hard = torch.zeros_like(soft).scatter_(-1, soft.argmax(-1, True), 1.0)
          return hard - soft.detach() + soft  # hard forward, soft backward


Barrier 2: The Coalescent Prior Over Full ARGs Is Intractable
---------------------------------------------------------------

The coalescent prior :math:`P(\tau \mid N_e, \rho)` over a full ancestral
recombination graph involves all local trees and their correlations induced by
recombination. The joint prior over the full ARG is:

.. math::

   P(\text{ARG} \mid N_e, \rho) = \prod_{\ell=1}^{T}
   P(\mathcal{T}_\ell \mid \mathcal{T}_{\ell-1}, N_e, \rho)

where :math:`T` is the number of local trees and the transition probability
:math:`P(\mathcal{T}_\ell \mid \mathcal{T}_{\ell-1})` involves the full
Sequentially Markov Coalescent (SMC) transition kernel. Evaluating this kernel
exactly requires marginalizing over all possible recombination events between
adjacent trees -- an intractable computation for trees with more than a few
leaves.

Escapement's solution is the **SMC factorization** from
:ref:`PSMC <psmc_timepiece>`. Under the SMC approximation, the transition
between adjacent trees depends only on the time at which recombination occurs
and the lineage that recombines. For the purpose of the coalescent prior,
Escapement further simplifies: it treats the coalescent prior as approximately
factored across local trees, with the breakpoint prior providing the coupling:

.. math::

   P(\tau \mid N_e, \rho) \approx \prod_{\ell=1}^{T}
   P(\mathcal{T}_\ell \mid N_e) \cdot \prod_{\ell=1}^{L-1}
   P(b_\ell \mid \rho)

where :math:`P(\mathcal{T}_\ell \mid N_e)` is the Kingman coalescent prior for
local tree :math:`\ell` and :math:`P(b_\ell \mid \rho)` is the Bernoulli
breakpoint prior at position :math:`\ell`. This is an approximation -- it
ignores the correlation between adjacent tree topologies -- but it makes the
coalescent prior tractable and differentiable.

.. admonition:: Connection to PSMC's SMC approximation

   :ref:`PSMC <psmc_timepiece>` introduced the Sequential Markov Coalescent as
   a tractable approximation to the full coalescent with recombination. The key
   insight is that the genealogy at position :math:`\ell+1` depends on the
   genealogy at position :math:`\ell` only through a local modification: one
   lineage detaches and recoalesces. This Markov property makes HMM inference
   possible.

   Escapement uses the same Markov property but in a different way: instead of
   building a transition matrix (PSMC's approach), it factors the prior into
   per-tree and per-breakpoint terms. The per-tree prior is the Kingman
   coalescent for a single tree. The per-breakpoint prior is a Bernoulli
   reflecting the SMC recombination probability. This is a coarser
   approximation than the full SMC, but it is fast, differentiable, and
   sufficient for the variational objective.


Connection to tsdate's Variational Gamma Method
==================================================

Escapement's variational inference is most closely related to the variational
gamma method in :ref:`tsdate <tsdate_timepiece>`. Both methods:

1. Define a variational posterior over coalescence times.
2. Use the Poisson mutation likelihood as one term of the ELBO.
3. Use the coalescent prior as another term.
4. Optimize by maximizing the ELBO.

The key differences:

.. list-table:: tsdate vs. Escapement variational inference
   :header-rows: 1
   :widths: 25 37 38

   * - Aspect
     - :ref:`tsdate <tsdate_timepiece>`
     - Escapement
   * - Variational family
     - Gamma distributions (hand-chosen, analytical updates)
     - Neural network (learned, flexible, but approximate)
   * - Topology
     - Fixed (from :ref:`tsinfer <tsinfer_timepiece>`)
     - Jointly inferred (Gumbel-softmax)
   * - Demography
     - Not inferred (assumed known)
     - Jointly inferred (learnable :math:`N_e(t)`)
   * - Update rule
     - Expectation Propagation (closed-form message passing)
     - Gradient descent (backpropagation through the ELBO)
   * - Scalability
     - Millions of nodes (fast EP)
     - Thousands of nodes (GPU-bound optimization)
   * - Amortization
     - None (per-dataset optimization)
     - Partial (encoder shares parameters across windows)

.. math::

   \underbrace{\text{tsdate}}_{\text{fixed topology, analytical updates}}
   \;\xrightarrow{\text{add neural encoder, Gumbel-softmax topology}}\;
   \underbrace{\text{Escapement}}_{\text{joint topology + times, learned updates}}

tsdate can be viewed as a special case of Escapement where the topology is
fixed, the variational family is restricted to factored gamma distributions, and
the updates are derived analytically rather than learned. Escapement trades
tsdate's elegance and scalability for flexibility and the ability to infer
topology jointly with coalescence times.

.. admonition:: Why not just use tsdate?

   tsdate is excellent for what it does: dating nodes on a fixed tree sequence.
   But it cannot infer topology (it requires tsinfer as a preprocessor), it
   cannot infer demography (it assumes known :math:`N_e(t)`), and its factored
   gamma posterior cannot capture correlations between node times on different
   trees.

   Escapement addresses all three limitations -- at the cost of requiring GPU
   optimization, producing an approximate (not exact) posterior, and scaling to
   fewer samples. The right choice depends on the scientific question and the
   available compute.


The Complete Variational Objective
====================================

Combining all terms, Escapement's ELBO for a single genotype matrix
:math:`\mathbf{D}` is:

.. math::

   \text{ELBO}(\phi, \theta) =
   \underbrace{\sum_{\ell=1}^{L} \sum_{i=1}^{n}
   \log P(d_{i,\ell} \mid \tau_\ell, \mu)}_{\text{mutation log-likelihood}}
   + \underbrace{\sum_{\ell=1}^{T} \log P(\mathcal{T}_\ell \mid N_e)
   + \sum_{\ell=1}^{L-1} \log P(b_\ell \mid \rho)}_{\text{coalescent log-prior}}
   + \underbrace{H_{\text{topo}}[q] + H_{\text{branch}}[q]
   + H_{\text{break}}[q]}_{\text{variational entropy}}

where the expectation over :math:`q` is approximated by a single Monte Carlo
sample :math:`\tau \sim q(\tau \mid \mathbf{D}, \phi)` (the
reparameterization trick makes this a low-variance estimator).

The parameters being optimized are:

- :math:`\phi`: all neural network weights (encoder + variational posterior
  heads)
- :math:`\theta`: the demographic parameters (:math:`N_e(t)` values on a
  piecewise-constant grid, or neural spline parameters)

The mutation rate :math:`\mu` and recombination rate :math:`\rho` are treated
as known constants (from independent estimates), though they could in principle
be optimized jointly.

.. code-block:: python

   def training_step(model, genotypes, optimizer, temperature):
       """One gradient step maximizing the ELBO on observed data."""
       loss = model.loss(genotypes, temperature=temperature)  # -ELBO
       optimizer.zero_grad()
       loss.backward()
       torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
       optimizer.step()
       return loss.item()

The derivation is complete. Every term in the ELBO corresponds to an equation
derived in a Timepiece. The neural network's only role is to parameterize the
variational posterior :math:`q` -- to propose genealogies that explain the
observed data well under the coalescent model. The math is the same. The
optimization is new.
