Architecture
Four wheels, precisely meshed. The escape wheel receives energy from the gear train. The pallet fork converts rotational motion into oscillation. The balance spring stores and returns energy. The regulator adjusts the rate. Remove any one, and the mechanism stops.
Escapement has four modules, each with a distinct role in the variational inference pipeline. Unlike Mainspring, which is trained on simulated (input, target) pairs, every module here is trained end-to-end by maximizing the ELBO on observed data. There are no ground-truth labels. The coalescent likelihood is the supervision signal.
Module 1 Module 2 Module 3 Module 4
GENEALOGY VARIATIONAL TREE DIFFERENTIABLE DEMOGRAPHIC
ENCODER POSTERIOR LIKELIHOOD INFERENCE
(escape wheel) (pallet fork) (balance spring) (regulator)
D ∈ {0,1}^{n×L} h ∈ R^{n×L×d} τ ~ q(τ|D,φ) ELBO terms
| | | |
v v v v
┌──────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────┐
│Transformer│ ──────▶ │ Topology: │──────▶ │ log P(D|τ,μ) │ │ N_e(t): │
│ sample × │ │ Gumbel-SM │ τ │ log P(τ|Ne,ρ)│◀────│ piecewise│
│ position │ │ Branches: │ │ H[q] │ │ or spline│
│ │ │ LogNormal │ │ │ │ │
│ │ │ Breakpoints: │ │ = ELBO │ │ │
│ │ │ Bernoulli │ │ (pure math) │ │ │
└──────────┘ └──────────────┘ └──────────────┘ └──────────┘
Module 1: Genealogy Encoder
The encoder transforms the raw genotype matrix \(\mathbf{D} \in \{0,1\}^{n \times L}\) into latent vectors \(\mathbf{h} \in \mathbb{R}^{n \times L \times d}\). It must capture two kinds of structure:
Inter-sample relationships at each genomic position (which samples share recent common ancestry here?)
Spatial correlations along the genome (how does ancestry change across positions due to recombination?)
The architecture is a Transformer that alternates between attention over samples and attention over positions – the same dual-axis design as Mainspring, but optimized against the ELBO rather than simulation-matching losses.
Sample Attention
At each genomic position, the \(n\) sample embeddings are processed by multi-head self-attention. This is permutation-equivariant over samples: if the sample order in \(\mathbf{D}\) is permuted, the output embeddings are permuted identically. This encodes the exchangeability of coalescent samples.
where \(\mathbf{E}_\ell \in \mathbb{R}^{n \times d}\) is the embedding matrix at position \(\ell\) and MHA is multi-head attention.
The attention weights \(\alpha_{ij}^\ell\) at each position have a direct interpretation: they measure how much sample \(i\) “looks at” sample \(j\). In a well-trained model, high attention corresponds to recent common ancestry – samples that coalesce early in the tree attend strongly to each other. This is the neural analogue of the Li & Stephens copying probabilities from tsinfer.
class SampleAttention(nn.Module):
def __init__(self, d_model, n_heads=4, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
residual = x
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
return self.norm(residual + self.out_proj(out))
Sliding-Window Positional Attention
After processing inter-sample relationships at each position, the encoder processes spatial correlations along the genome using sliding-window self-attention. For each sample, its sequence of \(L\) positional embeddings is treated as a sequence, with attention restricted to a window of \(w\) positions.
The window size \(w\) should be approximately \(1/\rho\) (the expected distance between recombination events), measured in number of segregating sites. Within this window, the local tree is approximately constant, and the attention mechanism can detect patterns of linkage disequilibrium that reveal the local genealogy.
where \(\mathbf{h}_i \in \mathbb{R}^{L \times d}\) is the embedding sequence for sample \(i\), and \(\text{MHA}_w\) restricts attention to positions within distance \(w\).
class PositionalAttention(nn.Module):
def __init__(self, d_model, n_heads=4, window=64, dropout=0.1):
super().__init__()
self.window = window
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
residual = x
B, L, D = x.shape
W = min(self.window, L)
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn_logits = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
positions = torch.arange(L, device=x.device)
mask = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() > W
attn_logits = attn_logits.masked_fill(
mask.unsqueeze(0).unsqueeze(0), float("-inf")
)
attn = F.softmax(attn_logits, dim=-1)
attn = self.dropout(attn)
out = (attn @ v).transpose(1, 2).reshape(B, L, D)
return self.norm(residual + self.out_proj(out))
Full Encoder
The complete encoder stacks \(K\) blocks, each containing one sample attention layer, one positional attention layer, and a feedforward network:
class GenealogyEncoder(nn.Module):
def __init__(self, d_model=64, n_heads=4, n_layers=2,
window=64, dropout=0.1):
super().__init__()
self.allele_embed = nn.Linear(1, d_model)
self.layers = nn.ModuleList()
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"sample_attn": SampleAttention(d_model, n_heads, dropout),
"pos_attn": PositionalAttention(d_model, n_heads, window, dropout),
"ffn": nn.Sequential(
nn.Linear(d_model, 4 * d_model), nn.GELU(),
nn.Linear(4 * d_model, d_model), nn.Dropout(dropout)),
"ffn_norm": nn.LayerNorm(d_model),
}))
def forward(self, genotypes):
B, N, L = genotypes.shape
x = self.allele_embed(genotypes.unsqueeze(-1))
D = x.shape[-1]
for layer in self.layers:
x_pos = x.permute(0, 2, 1, 3).reshape(B * L, N, D)
x_pos = layer["sample_attn"](x_pos)
x = x_pos.reshape(B, L, N, D).permute(0, 2, 1, 3)
x_samp = x.reshape(B * N, L, D)
x_samp = layer["pos_attn"](x_samp)
x = x_samp.reshape(B, N, L, D)
residual = x
x = layer["ffn_norm"](residual + layer["ffn"](x))
return x
Shared encoder, different objective
Escapement’s encoder architecture is nearly identical to Mainspring’s genomic encoder. The crucial difference is the training objective. Mainspring trains the encoder to predict the true ARG from simulations. Escapement trains it to produce latent vectors from which the variational posterior can generate genealogies that maximize the ELBO. The same architecture, trained with different losses, learns different representations.
Module 2: Variational Tree Posterior
The variational posterior \(q(\tau \mid \mathbf{D}, \phi)\) maps the encoder’s latent vectors to a distribution over tree sequences. A tree sequence has three components – topology, branch lengths, and breakpoints – and the posterior factorizes accordingly:
This mean-field factorization is an approximation: in reality, topology and branch lengths are correlated (e.g., star-like trees imply recent coalescence). The approximation enables tractable entropy computation and efficient sampling.
Topology: Gumbel-Softmax Parent Assignments
At each genomic position \(\ell\), each sample \(i\) chooses a parent \(j \neq i\) from the other samples. The parent assignment probabilities are computed by scaled dot-product attention:
where \(\mathbf{q}_i^\ell = \mathbf{W}_Q \mathbf{h}_{i,\ell}\) and \(\mathbf{k}_j^\ell = \mathbf{W}_K \mathbf{h}_{j,\ell}\). Self-assignment is masked out (\(\alpha_{ii}^\ell = 0\)).
During training, the parent assignment is sampled via Gumbel-softmax with temperature \(\tau\):
class TopologyHead(nn.Module):
def __init__(self, d_model):
super().__init__()
self.query_proj = nn.Linear(d_model, d_model)
self.key_proj = nn.Linear(d_model, d_model)
def forward(self, h, temperature=1.0, hard=False):
Q = self.query_proj(h)
K = self.key_proj(h)
N = h.shape[1]
logits = (Q @ K.transpose(-2, -1)) / (h.shape[-1] ** 0.5)
mask = torch.eye(N, device=h.device, dtype=torch.bool).unsqueeze(0)
logits = logits.masked_fill(mask, float("-inf"))
parent_probs = gumbel_softmax_sample(logits, temperature, hard)
log_probs = F.log_softmax(logits, dim=-1)
chosen_log_probs = (parent_probs * log_probs).sum(dim=-1)
return parent_probs, chosen_log_probs
The topology entropy is the categorical entropy of the parent assignment distribution:
Branch Lengths: Log-Normal Reparameterization
Coalescence times are positive and typically span several orders of magnitude (from tens to millions of generations). The log-normal distribution is a natural choice:
where \(\mu_{i,\ell}\) and \(\sigma_{i,\ell}\) are predicted by an MLP from the latent vectors. The reparameterization trick enables gradient flow:
class BranchLengthHead(nn.Module):
def __init__(self, d_model, expected_tmrca=20000.0):
super().__init__()
self.log_expected = math.log(max(expected_tmrca, 1.0))
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(),
nn.Linear(d_model, 2))
with torch.no_grad():
self.mlp[-1].bias[0] = self.log_expected
self.mlp[-1].bias[1] = -1.0
def forward(self, h):
raw = self.mlp(h)
log_mean = raw[..., 0]
log_std = F.softplus(raw[..., 1]) + 1e-4
return log_mean, log_std
The initial bias is set so that the initial branch-length predictions are centered at \(\log(2 N_e)\) – the expected TMRCA for a pair of lineages under the coalescent. This prevents the optimization from starting in a regime where all branches are unreasonably short or long.
The branch-length entropy is the log-normal entropy, which has a closed-form expression:
Breakpoints: Bernoulli Probabilities
At each position \(\ell\), a recombination breakpoint occurs with probability \(b_\ell \in [0, 1]\). The breakpoint detector compares adjacent latent vectors and predicts whether the local tree changes:
class BreakpointHead(nn.Module):
def __init__(self, d_model):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(2 * d_model, d_model), nn.GELU(),
nn.Linear(d_model, 1))
def forward(self, h):
h_left = h[:, :, :-1, :]
h_right = h[:, :, 1:, :]
pair = torch.cat([h_left, h_right], dim=-1)
logits = self.mlp(pair).squeeze(-1)
probs = torch.sigmoid(logits.mean(dim=1))
return probs
The breakpoint entropy is the Bernoulli entropy:
Module 3: Differentiable Likelihood
Module 3 contains no neural networks. It is pure math: given a sampled genealogy \(\tau \sim q\), it computes the three ELBO terms. This module is the balance spring of the mechanism – it provides the restoring force that pulls the variational posterior toward genealogies consistent with the data and the coalescent model.
The full derivation of Module 3 is in The Differentiable Likelihood. Here we summarize the three components:
Mutation log-likelihood. For each sample \(i\) at position \(\ell\), the probability of the observed allele given the proposed parent \(j\) and coalescence time \(t\) is:
where \(s\) is the span in base pairs and the factor of 2 accounts for mutations on both lineages. The total mutation log-likelihood sums over all samples and positions.
Coalescent log-prior. For constant \(N_e\), the TMRCA of a pair of lineages is exponentially distributed with rate \(1/(2N_e)\):
For piecewise-constant \(N_e(t)\), the hazard is integrated across time intervals (see The Differentiable Likelihood for the full derivation).
Entropy. The sum of topology, branch-length, and breakpoint entropies, all computed in closed form from the variational parameters.
Module 4: Demographic Inference
The fourth module parameterizes the effective population size trajectory \(N_e(t)\). This is a learnable parameter, optimized jointly with the variational posterior by maximizing the ELBO.
Escapement supports three parameterizations:
Piecewise-Constant
The simplest option: \(N_e(t)\) is constant within each of \(K\) time intervals. The parameters are \(K\) values in log-space:
n_bins = 20
log_Ne = nn.Parameter(torch.full((n_bins,), math.log(10000.0)))
time_grid = torch.linspace(0, 200000, n_bins + 1)
def get_Ne(t):
return torch.exp(log_Ne)
This is the same parameterization as PSMC. The log-space representation ensures positivity and allows the optimizer to work on a natural scale (multiplicative changes in \(N_e\) correspond to additive changes in \(\log N_e\)).
Neural Spline
For smoother trajectories, \(N_e(t)\) can be parameterized as a monotonic rational-quadratic spline in log-space. The spline knots are at fixed time points, and the knot values and derivatives are predicted by a small MLP conditioned on the latent representation:
where \(\mathbf{w}\) (widths), \(\mathbf{h}\) (heights), and \(\mathbf{s}\) (slopes) are the spline parameters. This allows \(N_e(t)\) to vary smoothly while maintaining the flexibility to capture sharp bottlenecks.
Gaussian Process
For full Bayesian treatment, \(\log N_e(t)\) can be modeled as a Gaussian process with an RBF kernel. The GP posterior is approximated by variational inducing points:
The inducing-point parameters are optimized jointly with the ELBO. This approach is inspired by SINGER, which uses a GP prior on branch lengths, and phlash, which uses SVGD for Bayesian demographic inference.
Joint optimization of φ and N_e
The variational parameters \(\phi\) (neural network weights) and the demographic parameters \(N_e(t)\) are optimized jointly by maximizing the same ELBO objective. This is possible because the coalescent prior \(P(\tau \mid N_e)\) depends on \(N_e(t)\), so the ELBO is a function of both \(\phi\) and \(N_e\). In practice, the two sets of parameters use different learning rates: the neural network parameters use a standard rate (\(3 \times 10^{-4}\)), while the \(N_e\) parameters use a higher rate (\(10^{-2}\)) because they are fewer and more directly constrained by the data.
param_groups = [
{"params": encoder_params, "lr": 3e-4},
{"params": [log_Ne], "lr": 1e-2},
]
optimizer = torch.optim.Adam(param_groups)
Putting It All Together
The complete Escapement model chains all four modules:
class Escapement(nn.Module):
def __init__(self, d_model=64, n_heads=4, n_layers=2, window=64,
Ne=10000.0, mu=1.25e-8, rho=1e-8, span=100.0):
super().__init__()
self.encoder = GenealogyEncoder(d_model, n_heads, n_layers, window)
self.var_posterior = VariationalTreePosterior(d_model, 2.0 * Ne)
self.Ne = Ne
self.mu = mu
self.rho = rho
self.span = span
self.n_Ne_bins = 20
self.log_Ne = nn.Parameter(
torch.full((self.n_Ne_bins,), math.log(max(Ne, 1.0))))
def forward(self, genotypes, temperature=1.0, hard=False):
h = self.encoder(genotypes)
posterior = self.var_posterior(h, temperature, hard)
bt = posterior["branch_times"].clamp(min=1.0)
log_lik = pairwise_mutation_loglik(
genotypes, posterior["parent_probs"], bt, self.mu, self.span)
log_prior = coalescent_log_prior_variable_Ne(
bt, torch.exp(self.log_Ne), self.time_grid)
log_prior_break = breakpoint_log_prior(
posterior["break_probs"], self.rho, self.span)
log_q_topo = posterior["log_q_topology"].sum(dim=(1, 2))
entropy = posterior["entropy_branches"] + posterior["entropy_breaks"]
elbo = log_lik + log_prior + log_prior_break + entropy - log_q_topo
return {"elbo": elbo, **posterior}
def loss(self, genotypes, temperature=1.0):
return -self.forward(genotypes, temperature)["elbo"].mean()
Module |
Complexity |
Bottleneck |
|---|---|---|
Genealogy Encoder |
\(O(n^2 L d + n L w d)\) |
Sample attention (\(n^2\) per position) + sliding window |
Variational Posterior |
\(O(n^2 L d + n L d)\) |
Topology head (\(n^2\) attention per position) |
Differentiable Likelihood |
\(O(n^2 L + n L K)\) |
Pairwise diff (\(n^2\)) + variable-\(N_e\) prior (\(K\) bins) |
Demographic Inference |
\(O(K)\) |
Trivial (just \(K\) exponentiated parameters) |
Total: \(O(n^2 L d)\), dominated by the attention mechanisms. For typical applications (\(n \leq 50\), \(L \sim 500\), \(d = 64\)), each ELBO evaluation takes ~10 ms on a modern GPU.