The Differentiable Likelihood

The balance spring is pure physics. No mechanism, no escapement, no gear train – just the restoring force of an elastic strip, precisely calibrated to oscillate at a known frequency. It is the simplest component in the watch, and the most critical.

Module 3 of Escapement contains no neural networks. It is the mathematical core of the system: given a sampled genealogy \(\tau \sim q(\tau \mid \mathbf{D}, \phi)\), it computes three scalar quantities whose sum is the ELBO. These three quantities – the mutation log-likelihood, the coalescent log-prior, and the variational entropy – are the same equations derived across every Timepiece in this book, now assembled into a single differentiable objective.

This chapter derives each term from first principles, connects each to its source Timepiece, shows how gradients flow through the reparameterization trick, and provides the complete PyTorch implementation.

Mutation Log-Likelihood

The mutation log-likelihood answers: given the proposed genealogy \(\tau\), how well does it explain the observed data \(\mathbf{D}\)?

Derivation

Under the infinite-sites mutation model, mutations accumulate on each edge of the genealogy as a Poisson process with rate \(\mu\) per base pair per generation. For an edge connecting sample \(i\) to its coalescent ancestor \(j\) with TMRCA \(t_{i,\ell}\) at position \(\ell\), spanning \(s\) base pairs, the expected number of mutations on both lineages from leaves to MRCA is:

\[\lambda_{i,\ell} = 2 \mu \cdot t_{i,\ell} \cdot s\]

The factor of 2 accounts for mutations on both the \(i \to \text{MRCA}\) and \(j \to \text{MRCA}\) branches. The probability that \(i\) and \(j\) differ at this position (at least one mutation occurred) is:

\[P(d_{i,\ell} \neq d_{j,\ell} \mid t_{i,\ell}) = 1 - e^{-\lambda_{i,\ell}} = 1 - e^{-2\mu \cdot t_{i,\ell} \cdot s}\]

This is the Jukes-Cantor two-allele model, a standard approximation used in tsdate and ARGweaver.

The log-likelihood of a single observation \(d_{i,\ell}\) given the proposed parent \(j\) and coalescence time \(t_{i,\ell}\) is a Bernoulli log-likelihood on the mismatch indicator:

\[\log P(d_{i,\ell} \mid j, t_{i,\ell}) = m_{i,\ell} \log p_{i,\ell} + (1 - m_{i,\ell}) \log(1 - p_{i,\ell})\]

where \(m_{i,\ell} = |d_{i,\ell} - d_{j,\ell}|\) is the observed mismatch (0 or 1) and \(p_{i,\ell} = 1 - e^{-2\mu t_{i,\ell} s}\) is the predicted mismatch probability.

Soft Parent Assignments

Because the topology is represented as soft parent probabilities \(\alpha_{ij}^\ell\) (from the Gumbel-softmax), the “parent genotype” is a probability-weighted average:

\[\hat{d}_{j(i),\ell} = \sum_{j \neq i} \alpha_{ij}^\ell \cdot d_{j,\ell}\]

and the “mismatch” is the soft absolute difference \(\tilde{m}_{i,\ell} = |d_{i,\ell} - \hat{d}_{j(i),\ell}|\). This allows gradients to flow through the topology via the Gumbel-softmax.

def pairwise_mutation_loglik(genotypes, parent_probs, branch_times,
                             mu, span=1.0):
    B, N, L = genotypes.shape
    parent_geno = torch.einsum("blij,bjl->bil", parent_probs, genotypes)
    obs_diff = (genotypes - parent_geno).abs()

    rate = (2.0 * mu * branch_times * span).clamp(max=20.0)
    p_mismatch = (1.0 - torch.exp(-rate)).clamp(1e-7, 1.0 - 1e-7)

    loglik = (obs_diff * torch.log(p_mismatch)
              + (1.0 - obs_diff) * torch.log(1.0 - p_mismatch))
    return loglik.sum(dim=(1, 2))

The clamp(max=20.0) on the rate prevents numerical overflow in exp(-rate) for very long branches. The clamp(1e-7, 1.0 - 1e-7) on the mismatch probability prevents log(0).

Connection to tsdate’s mutation likelihood

tsdate derives the same Poisson mutation model but applies it to a fixed tree topology with known parent assignments. The edge log-likelihood in tsdate is:

\[\log P(\text{mutations on edge } e) = m_e \log(\mu t_e s_e) - \mu t_e s_e - \log(m_e!)\]

Escapement uses a simpler Bernoulli approximation (mismatch vs. no mismatch) rather than counting mutations per edge, because the parent assignment is soft (probabilistic) rather than hard (deterministic). When the Gumbel-softmax temperature is low and the parent assignment approaches one-hot, the two formulations converge.

Auxiliary Pairwise-Difference Likelihood

In addition to the primary parent-based likelihood, Escapement includes an auxiliary pairwise-difference log-likelihood that provides a global signal for branch-length calibration:

\[\log P_{\text{pair}}(\mathbf{D} \mid t) = \sum_{i < j} \sum_{\ell=1}^{L} \left[m_{ij,\ell} \log p_{ij,\ell} + (1 - m_{ij,\ell}) \log(1 - p_{ij,\ell})\right]\]

where \(m_{ij,\ell} = |d_{i,\ell} - d_{j,\ell}|\) and the pairwise TMRCA is estimated as \(\hat{t}_{ij} = (t_i + t_j)/2\). This term is weighted by a small coefficient (typically 0.1) and provides a signal that does not depend on the topology – only on the overall scale of branch lengths.

def pairwise_diff_loglik(genotypes, branch_times, mu, span=1.0):
    B, N, L = genotypes.shape
    diff = (genotypes.unsqueeze(2) - genotypes.unsqueeze(1)).abs()
    t_pair = (branch_times.unsqueeze(2) + branch_times.unsqueeze(1)) / 2.0
    rate_pair = (2.0 * mu * t_pair * span).clamp(max=20.0)
    p_mismatch = (1.0 - torch.exp(-rate_pair)).clamp(1e-7, 1.0 - 1e-7)
    loglik = (diff * torch.log(p_mismatch)
              + (1.0 - diff) * torch.log(1.0 - p_mismatch))
    mask = torch.triu(torch.ones(N, N, device=genotypes.device), diagonal=1)
    return (loglik * mask.unsqueeze(0).unsqueeze(-1)).sum(dim=(1, 2, 3))

Coalescent Log-Prior

The coalescent log-prior answers: is the proposed genealogy plausible under the coalescent model with the proposed \(N_e(t)\)?

Constant \(N_e\)

For a pair of lineages under the Kingman coalescent with constant effective population size \(N_e\), the TMRCA is exponentially distributed:

\[T \sim \text{Exp}\!\left(\frac{1}{2N_e}\right), \qquad P(T = t) = \frac{1}{2N_e} e^{-t/(2N_e)}\]

The log-prior for a single coalescence time \(t\) is:

\[\log P(t \mid N_e) = \log\frac{1}{2N_e} - \frac{t}{2N_e}\]

For a complete genealogy with \(n\) samples, there are \(n - 1\) coalescence events. Under Escapement’s factored approximation, the prior is the product over all sample-position pairs:

\[\log P(\tau \mid N_e) = \sum_{i=1}^{n} \sum_{\ell=1}^{L} \left[\log\frac{1}{2N_e} - \frac{t_{i,\ell}}{2N_e}\right]\]
def coalescent_log_prior(branch_times, Ne=10000.0):
    rate = 1.0 / (2.0 * Ne)
    log_prior = math.log(rate) - rate * branch_times
    return log_prior.sum(dim=(1, 2))

Connection to msprime and PSMC

The exponential distribution of pairwise coalescence times is derived in msprime as the foundation of the Kingman coalescent. PSMC uses the same distribution but discretizes it into time intervals for the HMM. Escapement uses the continuous distribution directly, avoiding discretization artifacts.

Piecewise-Constant \(N_e(t)\)

For realistic demography, \(N_e(t)\) varies through time. With piecewise-constant \(N_e(t)\) on a grid \(0 = t_0 < t_1 < \cdots < t_K\), the coalescent rate in interval \(k\) is \(\lambda_k = 1/(2 N_e^{(k)})\). The log-prior of a coalescence time \(t\) involves integrating the hazard:

\[\log P(t \mid N_e(\cdot)) = \log \lambda(t) - \int_0^t \lambda(s)\, ds\]

where \(\lambda(t)\) is the instantaneous rate at the coalescence time and the integral is the cumulative hazard:

\[\int_0^t \lambda(s)\, ds = \sum_{k=0}^{K-1} \lambda_k \cdot \min(\Delta t_k,\; \max(0,\; t - t_k))\]

Here \(\Delta t_k = t_{k+1} - t_k\) is the width of interval \(k\). The coalescence time \(t\) contributes \(\lambda_k \Delta t_k\) for each interval it fully spans, and a partial contribution \(\lambda_k (t - t_k)\) for the interval in which it falls.

def coalescent_log_prior_variable_Ne(branch_times, Ne_fn, time_grid):
    K = Ne_fn.shape[0]
    Ne_clamped = Ne_fn.clamp(min=1.0)
    cumulative_hazard = torch.zeros_like(branch_times)
    t_remaining = branch_times.clone()
    instantaneous_rate = torch.zeros_like(branch_times)

    for k in range(K):
        rate_k = 1.0 / (2.0 * Ne_clamped[k])
        if k < K - 1:
            dt_interval = time_grid[k + 1] - time_grid[k]
        else:
            dt_interval = t_remaining.max().detach().item() + 1.0
        dt_used = torch.clamp(t_remaining, max=dt_interval)
        cumulative_hazard = cumulative_hazard + rate_k * dt_used
        in_this_bin = (t_remaining > 0) & (t_remaining <= dt_interval)
        instantaneous_rate = torch.where(
            in_this_bin, rate_k, instantaneous_rate)
        t_remaining = (t_remaining - dt_used).clamp(min=0.0)

    instantaneous_rate = instantaneous_rate.clamp(min=1e-30)
    log_prior = torch.log(instantaneous_rate) - cumulative_hazard
    return log_prior.sum(dim=(1, 2))

Gradient flow through N_e

The coalescent log-prior is differentiable with respect to \(N_e^{(k)}\) (through the Ne_fn tensor). This is what enables joint optimization of the demography: when the ELBO is maximized, the gradients through the coalescent prior push \(N_e(t)\) toward values consistent with the proposed coalescence times, while the gradients through the mutation likelihood push the coalescence times toward values consistent with the observed data. The two signals converge to a self-consistent estimate of both the genealogy and the demography.

Breakpoint Log-Prior

Under the Sequential Markov Coalescent, recombination breakpoints between adjacent sites occur as a Poisson process with rate \(\rho\) per base pair per generation. For a segment of span \(s\) base pairs, the probability of at least one recombination event (and hence a tree change) is:

\[P(\text{breakpoint}) = 1 - e^{-\rho \cdot s}\]

The breakpoint log-prior for the predicted breakpoint probabilities \(b_\ell\) is:

\[\log P(b \mid \rho) = \sum_{\ell=1}^{L-1} \left[ b_\ell \log(1 - e^{-\rho s}) + (1 - b_\ell) \log(e^{-\rho s})\right]\]
def breakpoint_log_prior(break_probs, rho, span=1.0):
    p_break = max(1.0 - math.exp(-rho * span), 1e-8)
    p_no_break = 1.0 - p_break
    bp = break_probs.clamp(1e-8, 1.0 - 1e-8)
    log_prior = bp * math.log(p_break) + (1.0 - bp) * math.log(p_no_break)
    return log_prior.sum(dim=1)

This term acts as a regularizer: it penalizes predicted breakpoint probabilities that are inconsistent with the recombination rate. For typical human recombination rates (\(\rho \approx 10^{-8}\) per bp per generation) and typical inter-site spacing (\(s \approx 100\) bp), the prior breakpoint probability is very low (\(\approx 10^{-6}\)), encouraging the model to predict few breakpoints.

Entropy Decomposition

The entropy \(H[q]\) of the variational posterior decomposes into three independent terms due to the mean-field factorization:

\[H[q] = H_{\text{topo}}[q] + H_{\text{branch}}[q] + H_{\text{break}}[q]\]

Topology Entropy (Categorical)

The topology at each position is a categorical distribution over parent assignments:

\[H_{\text{topo}}[q] = -\sum_{\ell=1}^{L} \sum_{i=1}^{n} \sum_{j \neq i} \alpha_{ij}^\ell \log \alpha_{ij}^\ell\]

This is computed by the topology head as chosen_log_probs. In the ELBO, the topology term appears as \(-\mathbb{E}_q[\log q(\text{topology})]\):

\[-\mathbb{E}_q[\log q(\pi)] = -\sum_{\ell, i} \sum_{j} \alpha_{ij}^\ell \log \alpha_{ij}^\ell\]

Branch-Length Entropy (Log-Normal)

The branch lengths are log-normally distributed with parameters \(\mu_{i,\ell}\) and \(\sigma_{i,\ell}\). The log-normal entropy has a closed-form expression:

\[H[\text{LogNormal}(\mu, \sigma)] = \mu + \frac{1}{2} + \log \sigma + \frac{1}{2}\log(2\pi)\]

The total branch-length entropy sums over all sample-position pairs:

\[H_{\text{branch}}[q] = \sum_{i=1}^{n} \sum_{\ell=1}^{L} \left[\mu_{i,\ell} + \frac{1}{2} + \log \sigma_{i,\ell} + \frac{1}{2}\log(2\pi)\right]\]
def lognormal_entropy(log_mean, log_std):
    return log_mean + 0.5 + log_std + 0.5 * math.log(2 * math.pi)

Why log-normal and not gamma?

tsdate and Gamma-SMC use gamma distributions for coalescence-time posteriors, which is the conjugate choice for the exponential coalescent prior. Escapement uses log-normal instead, for a practical reason: the log-normal reparameterization trick (\(t = \exp(\mu + \sigma\epsilon)\), \(\epsilon \sim \mathcal{N}(0,1)\)) produces lower-variance gradient estimates than the gamma reparameterization.

The gamma distribution can be reparameterized (Figurnov et al. 2018), but the resulting gradient estimates have higher variance, especially for small shape parameters. Since Escapement’s optimization is already challenging (discrete topology, multi-modal landscape), we prioritize gradient quality over distributional faithfulness.

Breakpoint Entropy (Bernoulli)

The breakpoint at each position is Bernoulli-distributed:

\[H[\text{Bernoulli}(b_\ell)] = -b_\ell \log b_\ell - (1 - b_\ell) \log(1 - b_\ell)\]

The total breakpoint entropy is:

\[H_{\text{break}}[q] = \sum_{\ell=1}^{L-1} H[\text{Bernoulli}(b_\ell)]\]
def bernoulli_entropy(p, eps=1e-8):
    p = p.clamp(eps, 1.0 - eps)
    return -(p * p.log() + (1.0 - p) * (1.0 - p).log())

Gradient Flow via the Reparameterization Trick

The ELBO is an expectation over the variational posterior:

\[\text{ELBO} = \mathbb{E}_{q(\tau \mid \mathbf{D}, \phi)}\!\left[ f(\tau, \theta)\right]\]

where \(f(\tau, \theta) = \log P(\mathbf{D} \mid \tau, \mu) + \log P(\tau \mid N_e, \rho) + H[q]\). To optimize \(\phi\) by gradient descent, we need \(\nabla_\phi \text{ELBO}\). The reparameterization trick rewrites the expectation so that the randomness is independent of \(\phi\):

Branch lengths. \(t_{i,\ell} = \exp(\mu_{i,\ell} + \sigma_{i,\ell} \cdot \epsilon)\), where \(\epsilon \sim \mathcal{N}(0,1)\) is fixed noise. The gradient \(\nabla_\phi t_{i,\ell}\) flows through \(\mu\) and \(\sigma\) to the encoder.

Breakpoints. Bernoulli variables are discrete, so the standard reparameterization trick does not apply. Instead, Escapement uses the Gumbel-sigmoid relaxation:

\[\tilde{b}_\ell = \sigma\!\left(\frac{\log \frac{b_\ell}{1 - b_\ell} + g_1 - g_2}{\tau}\right), \qquad g_1, g_2 \sim \text{Gumbel}(0, 1)\]

Topology. The Gumbel-softmax trick (described in Variational Inference Without Simulations) provides gradients through the discrete parent assignments. The straight-through estimator uses hard assignments in the forward pass and soft assignments in the backward pass.

Gradient flow in Escapement:

┌──────────────────────────────────────────────────────────────┐
│                                                              │
│  ELBO = log P(D|τ,μ) + log P(τ|Ne,ρ) + H[q]               │
│    │                                                         │
│    ├── ∂/∂t  (branch lengths)  ← reparameterization trick   │
│    │    └── ∂/∂μ, ∂/∂σ  ← backprop through encoder         │
│    │                                                         │
│    ├── ∂/∂π  (topology)        ← Gumbel-softmax STE         │
│    │    └── ∂/∂α  (logits)  ← backprop through encoder      │
│    │                                                         │
│    ├── ∂/∂b  (breakpoints)     ← Gumbel-sigmoid             │
│    │    └── ∂/∂logit  ← backprop through encoder            │
│    │                                                         │
│    └── ∂/∂Ne  (demography)     ← direct gradient            │
│         └── through coalescent prior                         │
│                                                              │
└──────────────────────────────────────────────────────────────┘

All three sources of randomness (Gaussian for branch lengths, Gumbel for topology, Gumbel for breakpoints) are sampled independently of \(\phi\), so the gradient estimator \(\nabla_\phi f(\tau(\phi, \epsilon), \theta)\) is unbiased and low-variance. This is the same reparameterization trick used in variational autoencoders (Kingma & Welling 2014), extended to the structured latent space of tree sequences.

Connection to Source Timepieces

Every component of the differentiable likelihood can be traced to a specific Timepiece:

Likelihood components and their Timepiece origins

Component

Formula

Source Timepiece

What Escapement changes

Poisson mutation model

\(P(\text{diff}) = 1 - e^{-2\mu t s}\)

tsdate (mutation likelihood)

Soft parent assignments instead of fixed topology

Exponential coalescent

\(T \sim \text{Exp}(1/(2N_e))\)

msprime (coalescent theory)

Continuous \(N_e(t)\), jointly optimized

Piecewise-constant hazard

\(\int_0^t \lambda(s) ds\)

PSMC (discretized coalescent)

Direct integration, no discretization of \(t\)

SMC breakpoint model

\(P(\text{break}) = 1 - e^{-\rho s}\)

PSMC (SMC approximation)

Learned breakpoint detection, not HMM transitions

Gamma/log-normal times

\(t \sim \text{LogNormal}(\mu, \sigma)\)

tsdate / Gamma-SMC

Neural parameterization instead of hand-derived EP

Factored prior

\(P(\tau) \approx \prod_\ell P(\mathcal{T}_\ell)\)

PSMC (SMC factorization)

Same approximation, used in variational objective

The differentiable likelihood is not a new population-genetic model. It is a compilation of existing models into a form that supports automatic differentiation. The Timepieces provide the equations; Escapement provides the gradient infrastructure.