.. _escapement_likelihood:

===============================
The Differentiable Likelihood
===============================

   *The balance spring is pure physics. No mechanism, no escapement, no gear
   train -- just the restoring force of an elastic strip, precisely calibrated
   to oscillate at a known frequency. It is the simplest component in the watch,
   and the most critical.*

Module 3 of Escapement contains no neural networks. It is the mathematical core
of the system: given a sampled genealogy :math:`\tau \sim q(\tau \mid
\mathbf{D}, \phi)`, it computes three scalar quantities whose sum is the ELBO.
These three quantities -- the mutation log-likelihood, the coalescent log-prior,
and the variational entropy -- are the same equations derived across every
Timepiece in this book, now assembled into a single differentiable objective.

This chapter derives each term from first principles, connects each to its
source Timepiece, shows how gradients flow through the reparameterization trick,
and provides the complete PyTorch implementation.


Mutation Log-Likelihood
=========================

The mutation log-likelihood answers: given the proposed genealogy :math:`\tau`,
how well does it explain the observed data :math:`\mathbf{D}`?

Derivation
-----------

Under the infinite-sites mutation model, mutations accumulate on each edge of
the genealogy as a Poisson process with rate :math:`\mu` per base pair per
generation. For an edge connecting sample :math:`i` to its coalescent ancestor
:math:`j` with TMRCA :math:`t_{i,\ell}` at position :math:`\ell`, spanning
:math:`s` base pairs, the expected number of mutations on both lineages from
leaves to MRCA is:

.. math::

   \lambda_{i,\ell} = 2 \mu \cdot t_{i,\ell} \cdot s

The factor of 2 accounts for mutations on both the :math:`i \to \text{MRCA}`
and :math:`j \to \text{MRCA}` branches. The probability that :math:`i` and
:math:`j` differ at this position (at least one mutation occurred) is:

.. math::

   P(d_{i,\ell} \neq d_{j,\ell} \mid t_{i,\ell}) =
   1 - e^{-\lambda_{i,\ell}} = 1 - e^{-2\mu \cdot t_{i,\ell} \cdot s}

This is the Jukes-Cantor two-allele model, a standard approximation used in
:ref:`tsdate <tsdate_timepiece>` and :ref:`ARGweaver <argweaver_timepiece>`.

The log-likelihood of a single observation :math:`d_{i,\ell}` given the
proposed parent :math:`j` and coalescence time :math:`t_{i,\ell}` is a
Bernoulli log-likelihood on the mismatch indicator:

.. math::

   \log P(d_{i,\ell} \mid j, t_{i,\ell}) = m_{i,\ell} \log p_{i,\ell}
   + (1 - m_{i,\ell}) \log(1 - p_{i,\ell})

where :math:`m_{i,\ell} = |d_{i,\ell} - d_{j,\ell}|` is the observed
mismatch (0 or 1) and :math:`p_{i,\ell} = 1 - e^{-2\mu t_{i,\ell} s}` is the
predicted mismatch probability.

Soft Parent Assignments
-------------------------

Because the topology is represented as soft parent probabilities
:math:`\alpha_{ij}^\ell` (from the Gumbel-softmax), the "parent genotype" is
a probability-weighted average:

.. math::

   \hat{d}_{j(i),\ell} = \sum_{j \neq i} \alpha_{ij}^\ell \cdot d_{j,\ell}

and the "mismatch" is the soft absolute difference
:math:`\tilde{m}_{i,\ell} = |d_{i,\ell} - \hat{d}_{j(i),\ell}|`. This allows
gradients to flow through the topology via the Gumbel-softmax.

.. code-block:: python

   def pairwise_mutation_loglik(genotypes, parent_probs, branch_times,
                                mu, span=1.0):
       B, N, L = genotypes.shape
       parent_geno = torch.einsum("blij,bjl->bil", parent_probs, genotypes)
       obs_diff = (genotypes - parent_geno).abs()

       rate = (2.0 * mu * branch_times * span).clamp(max=20.0)
       p_mismatch = (1.0 - torch.exp(-rate)).clamp(1e-7, 1.0 - 1e-7)

       loglik = (obs_diff * torch.log(p_mismatch)
                 + (1.0 - obs_diff) * torch.log(1.0 - p_mismatch))
       return loglik.sum(dim=(1, 2))

The ``clamp(max=20.0)`` on the rate prevents numerical overflow in
``exp(-rate)`` for very long branches. The ``clamp(1e-7, 1.0 - 1e-7)`` on the
mismatch probability prevents ``log(0)``.

.. admonition:: Connection to tsdate's mutation likelihood

   :ref:`tsdate <tsdate_timepiece>` derives the same Poisson mutation model
   but applies it to a fixed tree topology with known parent assignments. The
   edge log-likelihood in tsdate is:

   .. math::

      \log P(\text{mutations on edge } e) = m_e \log(\mu t_e s_e)
      - \mu t_e s_e - \log(m_e!)

   Escapement uses a simpler Bernoulli approximation (mismatch vs. no mismatch)
   rather than counting mutations per edge, because the parent assignment is
   soft (probabilistic) rather than hard (deterministic). When the Gumbel-softmax
   temperature is low and the parent assignment approaches one-hot, the two
   formulations converge.

Auxiliary Pairwise-Difference Likelihood
-------------------------------------------

In addition to the primary parent-based likelihood, Escapement includes an
auxiliary pairwise-difference log-likelihood that provides a global signal
for branch-length calibration:

.. math::

   \log P_{\text{pair}}(\mathbf{D} \mid t) = \sum_{i < j} \sum_{\ell=1}^{L}
   \left[m_{ij,\ell} \log p_{ij,\ell} + (1 - m_{ij,\ell})
   \log(1 - p_{ij,\ell})\right]

where :math:`m_{ij,\ell} = |d_{i,\ell} - d_{j,\ell}|` and the pairwise TMRCA
is estimated as :math:`\hat{t}_{ij} = (t_i + t_j)/2`. This term is weighted
by a small coefficient (typically 0.1) and provides a signal that does not
depend on the topology -- only on the overall scale of branch lengths.

.. code-block:: python

   def pairwise_diff_loglik(genotypes, branch_times, mu, span=1.0):
       B, N, L = genotypes.shape
       diff = (genotypes.unsqueeze(2) - genotypes.unsqueeze(1)).abs()
       t_pair = (branch_times.unsqueeze(2) + branch_times.unsqueeze(1)) / 2.0
       rate_pair = (2.0 * mu * t_pair * span).clamp(max=20.0)
       p_mismatch = (1.0 - torch.exp(-rate_pair)).clamp(1e-7, 1.0 - 1e-7)
       loglik = (diff * torch.log(p_mismatch)
                 + (1.0 - diff) * torch.log(1.0 - p_mismatch))
       mask = torch.triu(torch.ones(N, N, device=genotypes.device), diagonal=1)
       return (loglik * mask.unsqueeze(0).unsqueeze(-1)).sum(dim=(1, 2, 3))


Coalescent Log-Prior
======================

The coalescent log-prior answers: is the proposed genealogy plausible under the
coalescent model with the proposed :math:`N_e(t)`?

Constant :math:`N_e`
-----------------------

For a pair of lineages under the Kingman coalescent with constant effective
population size :math:`N_e`, the TMRCA is exponentially distributed:

.. math::

   T \sim \text{Exp}\!\left(\frac{1}{2N_e}\right), \qquad
   P(T = t) = \frac{1}{2N_e} e^{-t/(2N_e)}

The log-prior for a single coalescence time :math:`t` is:

.. math::

   \log P(t \mid N_e) = \log\frac{1}{2N_e} - \frac{t}{2N_e}

For a complete genealogy with :math:`n` samples, there are :math:`n - 1`
coalescence events. Under Escapement's factored approximation, the prior is
the product over all sample-position pairs:

.. math::

   \log P(\tau \mid N_e) = \sum_{i=1}^{n} \sum_{\ell=1}^{L}
   \left[\log\frac{1}{2N_e} - \frac{t_{i,\ell}}{2N_e}\right]

.. code-block:: python

   def coalescent_log_prior(branch_times, Ne=10000.0):
       rate = 1.0 / (2.0 * Ne)
       log_prior = math.log(rate) - rate * branch_times
       return log_prior.sum(dim=(1, 2))

.. admonition:: Connection to msprime and PSMC

   The exponential distribution of pairwise coalescence times is derived in
   :ref:`msprime <msprime_timepiece>` as the foundation of the Kingman
   coalescent. :ref:`PSMC <psmc_timepiece>` uses the same distribution but
   discretizes it into time intervals for the HMM. Escapement uses the
   continuous distribution directly, avoiding discretization artifacts.

Piecewise-Constant :math:`N_e(t)`
------------------------------------

For realistic demography, :math:`N_e(t)` varies through time. With
piecewise-constant :math:`N_e(t)` on a grid
:math:`0 = t_0 < t_1 < \cdots < t_K`, the coalescent rate in interval
:math:`k` is :math:`\lambda_k = 1/(2 N_e^{(k)})`. The log-prior of a
coalescence time :math:`t` involves integrating the hazard:

.. math::

   \log P(t \mid N_e(\cdot)) = \log \lambda(t) - \int_0^t \lambda(s)\, ds

where :math:`\lambda(t)` is the instantaneous rate at the coalescence time
and the integral is the cumulative hazard:

.. math::

   \int_0^t \lambda(s)\, ds = \sum_{k=0}^{K-1} \lambda_k \cdot
   \min(\Delta t_k,\; \max(0,\; t - t_k))

Here :math:`\Delta t_k = t_{k+1} - t_k` is the width of interval :math:`k`.
The coalescence time :math:`t` contributes :math:`\lambda_k \Delta t_k`
for each interval it fully spans, and a partial contribution
:math:`\lambda_k (t - t_k)` for the interval in which it falls.

.. code-block:: python

   def coalescent_log_prior_variable_Ne(branch_times, Ne_fn, time_grid):
       K = Ne_fn.shape[0]
       Ne_clamped = Ne_fn.clamp(min=1.0)
       cumulative_hazard = torch.zeros_like(branch_times)
       t_remaining = branch_times.clone()
       instantaneous_rate = torch.zeros_like(branch_times)

       for k in range(K):
           rate_k = 1.0 / (2.0 * Ne_clamped[k])
           if k < K - 1:
               dt_interval = time_grid[k + 1] - time_grid[k]
           else:
               dt_interval = t_remaining.max().detach().item() + 1.0
           dt_used = torch.clamp(t_remaining, max=dt_interval)
           cumulative_hazard = cumulative_hazard + rate_k * dt_used
           in_this_bin = (t_remaining > 0) & (t_remaining <= dt_interval)
           instantaneous_rate = torch.where(
               in_this_bin, rate_k, instantaneous_rate)
           t_remaining = (t_remaining - dt_used).clamp(min=0.0)

       instantaneous_rate = instantaneous_rate.clamp(min=1e-30)
       log_prior = torch.log(instantaneous_rate) - cumulative_hazard
       return log_prior.sum(dim=(1, 2))

.. admonition:: Gradient flow through N_e

   The coalescent log-prior is differentiable with respect to
   :math:`N_e^{(k)}` (through the ``Ne_fn`` tensor). This is what enables joint
   optimization of the demography: when the ELBO is maximized, the gradients
   through the coalescent prior push :math:`N_e(t)` toward values consistent
   with the proposed coalescence times, while the gradients through the mutation
   likelihood push the coalescence times toward values consistent with the
   observed data. The two signals converge to a self-consistent estimate of
   both the genealogy and the demography.


Breakpoint Log-Prior
======================

Under the Sequential Markov Coalescent, recombination breakpoints between
adjacent sites occur as a Poisson process with rate :math:`\rho` per base pair
per generation. For a segment of span :math:`s` base pairs, the probability of
at least one recombination event (and hence a tree change) is:

.. math::

   P(\text{breakpoint}) = 1 - e^{-\rho \cdot s}

The breakpoint log-prior for the predicted breakpoint probabilities
:math:`b_\ell` is:

.. math::

   \log P(b \mid \rho) = \sum_{\ell=1}^{L-1} \left[
   b_\ell \log(1 - e^{-\rho s}) + (1 - b_\ell) \log(e^{-\rho s})\right]

.. code-block:: python

   def breakpoint_log_prior(break_probs, rho, span=1.0):
       p_break = max(1.0 - math.exp(-rho * span), 1e-8)
       p_no_break = 1.0 - p_break
       bp = break_probs.clamp(1e-8, 1.0 - 1e-8)
       log_prior = bp * math.log(p_break) + (1.0 - bp) * math.log(p_no_break)
       return log_prior.sum(dim=1)

This term acts as a regularizer: it penalizes predicted breakpoint probabilities
that are inconsistent with the recombination rate. For typical human
recombination rates (:math:`\rho \approx 10^{-8}` per bp per generation) and
typical inter-site spacing (:math:`s \approx 100` bp), the prior breakpoint
probability is very low (:math:`\approx 10^{-6}`), encouraging the model to
predict few breakpoints.


Entropy Decomposition
=======================

The entropy :math:`H[q]` of the variational posterior decomposes into three
independent terms due to the mean-field factorization:

.. math::

   H[q] = H_{\text{topo}}[q] + H_{\text{branch}}[q] + H_{\text{break}}[q]

Topology Entropy (Categorical)
---------------------------------

The topology at each position is a categorical distribution over parent
assignments:

.. math::

   H_{\text{topo}}[q] = -\sum_{\ell=1}^{L} \sum_{i=1}^{n} \sum_{j \neq i}
   \alpha_{ij}^\ell \log \alpha_{ij}^\ell

This is computed by the topology head as ``chosen_log_probs``. In the ELBO, the
topology term appears as :math:`-\mathbb{E}_q[\log q(\text{topology})]`:

.. math::

   -\mathbb{E}_q[\log q(\pi)] = -\sum_{\ell, i}
   \sum_{j} \alpha_{ij}^\ell \log \alpha_{ij}^\ell

Branch-Length Entropy (Log-Normal)
-------------------------------------

The branch lengths are log-normally distributed with parameters
:math:`\mu_{i,\ell}` and :math:`\sigma_{i,\ell}`. The log-normal entropy has a
closed-form expression:

.. math::

   H[\text{LogNormal}(\mu, \sigma)] = \mu + \frac{1}{2}
   + \log \sigma + \frac{1}{2}\log(2\pi)

The total branch-length entropy sums over all sample-position pairs:

.. math::

   H_{\text{branch}}[q] = \sum_{i=1}^{n} \sum_{\ell=1}^{L}
   \left[\mu_{i,\ell} + \frac{1}{2} + \log \sigma_{i,\ell}
   + \frac{1}{2}\log(2\pi)\right]

.. code-block:: python

   def lognormal_entropy(log_mean, log_std):
       return log_mean + 0.5 + log_std + 0.5 * math.log(2 * math.pi)

.. admonition:: Why log-normal and not gamma?

   :ref:`tsdate <tsdate_timepiece>` and :ref:`Gamma-SMC <gamma_smc_timepiece>`
   use gamma distributions for coalescence-time posteriors, which is the
   conjugate choice for the exponential coalescent prior. Escapement uses
   log-normal instead, for a practical reason: the log-normal
   reparameterization trick (:math:`t = \exp(\mu + \sigma\epsilon)`,
   :math:`\epsilon \sim \mathcal{N}(0,1)`) produces lower-variance gradient
   estimates than the gamma reparameterization.

   The gamma distribution *can* be reparameterized (Figurnov et al. 2018), but
   the resulting gradient estimates have higher variance, especially for small
   shape parameters. Since Escapement's optimization is already challenging
   (discrete topology, multi-modal landscape), we prioritize gradient quality
   over distributional faithfulness.

Breakpoint Entropy (Bernoulli)
---------------------------------

The breakpoint at each position is Bernoulli-distributed:

.. math::

   H[\text{Bernoulli}(b_\ell)] = -b_\ell \log b_\ell
   - (1 - b_\ell) \log(1 - b_\ell)

The total breakpoint entropy is:

.. math::

   H_{\text{break}}[q] = \sum_{\ell=1}^{L-1} H[\text{Bernoulli}(b_\ell)]

.. code-block:: python

   def bernoulli_entropy(p, eps=1e-8):
       p = p.clamp(eps, 1.0 - eps)
       return -(p * p.log() + (1.0 - p) * (1.0 - p).log())


Gradient Flow via the Reparameterization Trick
=================================================

The ELBO is an expectation over the variational posterior:

.. math::

   \text{ELBO} = \mathbb{E}_{q(\tau \mid \mathbf{D}, \phi)}\!\left[
   f(\tau, \theta)\right]

where :math:`f(\tau, \theta) = \log P(\mathbf{D} \mid \tau, \mu)
+ \log P(\tau \mid N_e, \rho) + H[q]`. To optimize :math:`\phi` by gradient
descent, we need :math:`\nabla_\phi \text{ELBO}`. The reparameterization trick
rewrites the expectation so that the randomness is independent of :math:`\phi`:

**Branch lengths.** :math:`t_{i,\ell} = \exp(\mu_{i,\ell} + \sigma_{i,\ell}
\cdot \epsilon)`, where :math:`\epsilon \sim \mathcal{N}(0,1)` is fixed noise.
The gradient :math:`\nabla_\phi t_{i,\ell}` flows through :math:`\mu` and
:math:`\sigma` to the encoder.

**Breakpoints.** Bernoulli variables are discrete, so the standard
reparameterization trick does not apply. Instead, Escapement uses the
**Gumbel-sigmoid** relaxation:

.. math::

   \tilde{b}_\ell = \sigma\!\left(\frac{\log \frac{b_\ell}{1 - b_\ell}
   + g_1 - g_2}{\tau}\right), \qquad g_1, g_2 \sim \text{Gumbel}(0, 1)

**Topology.** The Gumbel-softmax trick (described in
:ref:`escapement_variational`) provides gradients through the discrete parent
assignments. The straight-through estimator uses hard assignments in the
forward pass and soft assignments in the backward pass.

.. code-block:: text

   Gradient flow in Escapement:

   ┌──────────────────────────────────────────────────────────────┐
   │                                                              │
   │  ELBO = log P(D|τ,μ) + log P(τ|Ne,ρ) + H[q]               │
   │    │                                                         │
   │    ├── ∂/∂t  (branch lengths)  ← reparameterization trick   │
   │    │    └── ∂/∂μ, ∂/∂σ  ← backprop through encoder         │
   │    │                                                         │
   │    ├── ∂/∂π  (topology)        ← Gumbel-softmax STE         │
   │    │    └── ∂/∂α  (logits)  ← backprop through encoder      │
   │    │                                                         │
   │    ├── ∂/∂b  (breakpoints)     ← Gumbel-sigmoid             │
   │    │    └── ∂/∂logit  ← backprop through encoder            │
   │    │                                                         │
   │    └── ∂/∂Ne  (demography)     ← direct gradient            │
   │         └── through coalescent prior                         │
   │                                                              │
   └──────────────────────────────────────────────────────────────┘

All three sources of randomness (Gaussian for branch lengths, Gumbel for
topology, Gumbel for breakpoints) are sampled independently of :math:`\phi`,
so the gradient estimator :math:`\nabla_\phi f(\tau(\phi, \epsilon), \theta)` is
unbiased and low-variance. This is the same reparameterization trick used in
variational autoencoders (Kingma & Welling 2014), extended to the structured
latent space of tree sequences.


Connection to Source Timepieces
==================================

Every component of the differentiable likelihood can be traced to a specific
Timepiece:

.. list-table:: Likelihood components and their Timepiece origins
   :header-rows: 1
   :widths: 20 30 25 25

   * - Component
     - Formula
     - Source Timepiece
     - What Escapement changes
   * - Poisson mutation model
     - :math:`P(\text{diff}) = 1 - e^{-2\mu t s}`
     - :ref:`tsdate <tsdate_timepiece>` (mutation likelihood)
     - Soft parent assignments instead of fixed topology
   * - Exponential coalescent
     - :math:`T \sim \text{Exp}(1/(2N_e))`
     - :ref:`msprime <msprime_timepiece>` (coalescent theory)
     - Continuous :math:`N_e(t)`, jointly optimized
   * - Piecewise-constant hazard
     - :math:`\int_0^t \lambda(s) ds`
     - :ref:`PSMC <psmc_timepiece>` (discretized coalescent)
     - Direct integration, no discretization of :math:`t`
   * - SMC breakpoint model
     - :math:`P(\text{break}) = 1 - e^{-\rho s}`
     - :ref:`PSMC <psmc_timepiece>` (SMC approximation)
     - Learned breakpoint detection, not HMM transitions
   * - Gamma/log-normal times
     - :math:`t \sim \text{LogNormal}(\mu, \sigma)`
     - :ref:`tsdate <tsdate_timepiece>` / :ref:`Gamma-SMC <gamma_smc_timepiece>`
     - Neural parameterization instead of hand-derived EP
   * - Factored prior
     - :math:`P(\tau) \approx \prod_\ell P(\mathcal{T}_\ell)`
     - :ref:`PSMC <psmc_timepiece>` (SMC factorization)
     - Same approximation, used in variational objective

The differentiable likelihood is not a new population-genetic model. It is a
**compilation** of existing models into a form that supports automatic
differentiation. The Timepieces provide the equations; Escapement provides the
gradient infrastructure.
