Training on Real Data
Winding a mainspring requires an external key. Regulating an escapement requires only the mechanism itself – you adjust the regulator screw, observe the rate, and iterate. The data is both the input and the standard.
This chapter describes how Escapement is trained. The distinction from Mainspring is fundamental: there is no simulation engine, no curriculum of increasing complexity, no simulated ground truth. There is only the observed genotype matrix \(\mathbf{D}\) and the coalescent likelihood. The training loop optimizes the ELBO on real data.
The Training Loop
Escapement’s training loop is conceptually simple:
Encode: pass \(\mathbf{D}\) through the genealogy encoder to produce latent vectors \(\mathbf{h}\).
Sample: draw a genealogy \(\tau \sim q(\tau \mid \mathbf{D}, \phi)\) from the variational posterior.
Evaluate: compute the three ELBO terms (mutation likelihood, coalescent prior, entropy) using the differentiable likelihood.
Maximize: backpropagate through the ELBO and update \(\phi\) (network weights) and \(\theta\) (\(N_e(t)\) parameters).
No msprime. No ground truth. No simulated ARGs. The ELBO is the only loss.
import torch
from torch.optim import Adam
def train_escapement(model, genotypes, n_steps=2000, lr_encoder=3e-4,
lr_Ne=1e-2, device='cuda'):
"""Train Escapement on a single observed genotype matrix."""
genotypes = genotypes.to(device)
model = model.to(device)
param_groups = model.get_param_groups(lr_encoder, lr_Ne)
optimizer = Adam(param_groups)
for step in range(n_steps):
temperature = anneal_temperature(step, n_steps)
loss = model.loss(genotypes, temperature=temperature)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
if step % 100 == 0:
out = model(genotypes, temperature=temperature)
print(f"Step {step:4d} | ELBO {out['elbo'].mean():.1f} | "
f"loglik {out['log_lik'].mean():.1f} | "
f"prior {out['log_prior_coal'].mean():.1f} | "
f"H {out['entropy'].mean():.1f} | "
f"tau {temperature:.3f}")
return model
Per-dataset optimization
Unlike Mainspring, which trains once and runs on many datasets, Escapement optimizes separately for each dataset. This is a feature, not a bug: the variational posterior is tailored to the specific data, not averaged across a training distribution. The cost is time (~10–30 minutes per dataset on a GPU). The benefit is that the posterior reflects the actual data, not a simulation prior.
Temperature Annealing for Gumbel-Softmax
The Gumbel-softmax temperature \(\tau\) controls the sharpness of the topology assignments. At high temperature, parent assignments are nearly uniform (maximum entropy). At low temperature, they approach one-hot vectors (deterministic topology).
Escapement uses an exponential annealing schedule:
where \(t\) is the optimization step, \(T\) is the total number of steps, \(\tau_{\max} = 1.0\), and \(\tau_{\min} = 0.1\).
def anneal_temperature(step, total_steps, tau_max=1.0, tau_min=0.1):
return tau_max * (tau_min / tau_max) ** (step / total_steps)
The annealing schedule serves two purposes:
Exploration. At high temperature, the variational posterior explores a wide range of topologies. The ELBO gradients push the posterior toward promising regions without committing to a single topology too early.
Exploitation. At low temperature, the posterior concentrates on the best topology found during exploration. The branch lengths and breakpoints are refined under this near-deterministic topology.
Phase |
Temperature |
Behavior |
|---|---|---|
Early (steps 0–500) |
\(\tau \approx 1.0\) |
Soft parent assignments. Topology is uncertain. Gradients flow easily. Branch-length scale is calibrated. |
Middle (steps 500–1500) |
\(\tau \approx 0.3\text{--}0.5\) |
Parent assignments sharpen. Topology structure emerges. Breakpoints begin to localize. |
Late (steps 1500–2000) |
\(\tau \approx 0.1\) |
Near-deterministic topology. Fine-tuning of branch lengths, breakpoints, and \(N_e(t)\). ELBO converges. |
Warm-Starting
Escapement’s optimization landscape is multi-modal: many different genealogies can explain the same genotype matrix reasonably well. Starting from a random initialization risks getting trapped in a poor local optimum, especially for the discrete topology.
Two warm-starting strategies dramatically improve convergence:
From Mainspring
The recommended approach is to initialize Escapement from Mainspring’s output. Mainspring provides a fast (~1 second), approximate ARG. Escapement then refines this ARG using the coalescent likelihood.
def warm_start_from_mainspring(escapement_model, mainspring_model,
genotypes):
"""Initialize Escapement's encoder from Mainspring's output."""
with torch.no_grad():
mainspring_out = mainspring_model(genotypes, hard=True)
escapement_model.encoder.load_state_dict(
mainspring_model.encoder.state_dict(), strict=False)
with torch.no_grad():
ms_times = mainspring_out['times']
mean_log_t = torch.log(ms_times.clamp(min=1.0))
escapement_model.var_posterior.branch_lengths.mlp[-1].bias[0] = (
mean_log_t.mean().item())
return escapement_model
This is the hybrid pipeline described in Comparison and Limitations. The analogy to horology is precise: the mainspring provides the initial energy (a good starting point), and the escapement regulates it into precise, calibrated motion (a principled posterior).
From tsinfer
An alternative warm-start uses tsinfer to provide an initial topology estimate. Since tsinfer scales to much larger sample sizes than Escapement, this is useful when Mainspring is not available:
def warm_start_from_tsinfer(escapement_model, ts_inferred, genotypes):
"""Initialize branch-length parameters from tsinfer topology."""
import tskit
import numpy as np
pairwise_div = np.zeros((genotypes.shape[1], genotypes.shape[1]))
G = genotypes[0].numpy()
for i in range(G.shape[0]):
for j in range(i + 1, G.shape[0]):
pairwise_div[i, j] = np.mean(G[i] != G[j])
pairwise_div[j, i] = pairwise_div[i, j]
mean_div = pairwise_div[np.triu_indices_from(pairwise_div, k=1)].mean()
estimated_tmrca = mean_div / (2 * escapement_model.mu * escapement_model.span)
with torch.no_grad():
escapement_model.var_posterior.branch_lengths.mlp[-1].bias[0] = (
math.log(max(estimated_tmrca, 1.0)))
return escapement_model
Joint vs. Alternating Optimization
Escapement optimizes two sets of parameters:
\(\phi\): neural network weights (encoder + variational posterior heads)
\(\theta\): demographic parameters (\(N_e(t)\))
These can be optimized jointly (single optimizer, same gradient step) or alternating (update \(\phi\) for \(K_1\) steps, then \(\theta\) for \(K_2\) steps, repeat).
Joint Optimization
The simplest approach: both parameter sets share the same ELBO objective and are updated simultaneously with different learning rates.
param_groups = [
{"params": encoder_params, "lr": 3e-4},
{"params": [log_Ne], "lr": 1e-2},
]
optimizer = Adam(param_groups)
This works well in practice because the \(N_e(t)\) parameters are few (typically 20) and directly constrained by the coalescent prior. The different learning rates accommodate the different scales: the neural network needs small steps to avoid destabilizing the latent representations, while \(N_e(t)\) can take larger steps because the coalescent prior provides a strong signal.
Alternating Optimization
For difficult cases (complex demography, many samples), alternating optimization can be more stable:
def alternating_train(model, genotypes, n_outer=100,
n_phi=10, n_theta=5):
opt_phi = Adam([p for n, p in model.named_parameters()
if n != "log_Ne"], lr=3e-4)
opt_theta = Adam([model.log_Ne], lr=1e-2)
for outer in range(n_outer):
temp = anneal_temperature(outer, n_outer)
for _ in range(n_phi):
loss = model.loss(genotypes, temp)
opt_phi.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
opt_phi.step()
for _ in range(n_theta):
loss = model.loss(genotypes, temp)
opt_theta.zero_grad()
loss.backward()
opt_theta.step()
The rationale: when \(\phi\) is updated with fixed \(N_e\), the network learns to propose genealogies consistent with the current demography. When \(N_e\) is updated with fixed \(\phi\), the demography adjusts to match the genealogies currently being proposed. This alternation can help escape saddle points where joint optimization stalls.
Variance Reduction for Discrete Gradients
The Gumbel-softmax provides biased but low-variance gradient estimates for the discrete topology. For the highest-quality gradients, Escapement can optionally use advanced variance reduction techniques.
NVIL (Neural Variational Inference and Learning)
NVIL (Mnih & Gregor 2014) uses a learned baseline to reduce the variance of the REINFORCE gradient for discrete variables:
where \(c_\psi(\mathbf{h})\) is a neural baseline that predicts the expected ELBO from the latent representation, trained to minimize \((f(\tau) - c_\psi(\mathbf{h}))^2\).
class NVILBaseline(nn.Module):
def __init__(self, d_model):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_model), nn.ReLU(),
nn.Linear(d_model, 1))
def forward(self, h):
return self.net(h.mean(dim=(1, 2))).squeeze(-1)
RELAX
RELAX (Grathwohl et al. 2018) combines the Gumbel-softmax with a control variate that uses both the discrete and relaxed samples:
where \(\text{hard}\) is the discrete sample, \(\text{soft}\) is the Gumbel-softmax relaxation, and \(\tilde{\text{soft}}\) is a conditional relaxation. RELAX provides unbiased, low-variance gradient estimates but is more complex to implement.
In practice, the Gumbel-softmax with straight-through estimation is sufficient for most applications. NVIL and RELAX are recommended only when the optimization fails to converge (typically for large \(n\) or complex demographies).
Practical Considerations
Window Size
Escapement processes the genotype matrix in windows along the genome. The window size \(w\) (in number of segregating sites) controls the trade-off between local accuracy and computational cost:
Window size |
Behavior |
When to use |
|---|---|---|
Small (\(w \leq 32\)) |
Captures only very local LD. May miss long-range haplotype sharing. |
Very high recombination rate; testing and debugging. |
Medium (\(w = 64\text{--}128\)) |
Good balance. Captures LD within ~1 expected tree span. |
Default for most applications. |
Large (\(w \geq 256\)) |
Captures long-range LD. High memory cost (\(O(w^2)\) per sample). |
Low recombination rate; large tree spans. |
A reasonable default is \(w \approx 1/\hat{\rho}\), where \(\hat{\rho}\) is the recombination rate in units of recombinations per segregating site.
Batch Construction
For long genomes, Escapement processes non-overlapping or overlapping windows and averages the ELBO contributions. With overlapping windows (stride < window size), the predictions in the overlap region are averaged, providing smoother breakpoint estimates:
def process_genome_in_windows(model, genotypes, window=256,
stride=192, temperature=1.0):
B, N, L = genotypes.shape
total_elbo = 0.0
n_windows = 0
for start in range(0, L - window + 1, stride):
chunk = genotypes[:, :, start:start + window]
out = model(chunk, temperature=temperature)
total_elbo += out['elbo'].sum()
n_windows += 1
return total_elbo / n_windows
Convergence Monitoring
Since there is no ground truth, convergence must be monitored through the ELBO itself and its components:
Diagnostic |
What it tells you |
|---|---|
ELBO trajectory |
Should increase monotonically (on average). Stalling indicates a local optimum or learning rate issue. |
Mutation log-likelihood |
Should increase as the proposed genealogy better explains the data. If it plateaus early, the topology may be stuck. |
Coalescent log-prior |
Should increase as branch lengths become consistent with \(N_e(t)\). If it decreases while the likelihood increases, the branch lengths are being pulled away from the coalescent prior (tension between data and model). |
Entropy |
Should decrease during annealing (topology sharpens). If it stays high, the model is uncertain about the topology – possibly because the data is not informative enough. |
:math:`N_e(t)` trajectory |
Plot \(N_e(t)\) every 100 steps. It should stabilize. Oscillations indicate learning rate is too high. |
def monitor_convergence(model, genotypes, step, history):
with torch.no_grad():
out = model(genotypes, temperature=0.5)
history['elbo'].append(out['elbo'].mean().item())
history['loglik'].append(out['log_lik'].mean().item())
history['prior'].append(out['log_prior_coal'].mean().item())
history['entropy'].append(out['entropy'].mean().item())
history['Ne'].append(model.get_Ne().cpu().numpy().copy())
return history
When to stop
Stop training when:
The ELBO has not improved by more than 0.1% over the last 200 steps.
The \(N_e(t)\) trajectory has stabilized (relative change < 1% per 100 steps).
The Gumbel-softmax temperature has reached \(\tau_{\min}\).
Typical convergence: 1,000–3,000 steps for simple demography (constant \(N_e\)), 3,000–10,000 steps for complex demography (multiple bottlenecks). With warm-starting from Mainspring, these numbers drop by a factor of 3–5.
A Complete Training Example
Putting it all together: training Escapement on a genotype matrix from a population with a bottleneck.
import torch
import math
from model import Escapement
torch.manual_seed(42)
n_samples, n_sites = 20, 200
genotypes = torch.bernoulli(torch.full((1, n_samples, n_sites), 0.15))
model = Escapement(
d_model=64, n_heads=4, n_layers=2, window=64,
Ne=10_000, mu=1.25e-8, rho=1e-8, span=100.0)
time_grid = torch.tensor([0, 2000, 5000, 10000, 20000, 50000, 1e8])
model.log_Ne = torch.nn.Parameter(torch.full((6,), math.log(10000.0)))
model.n_Ne_bins = 6
model.set_time_grid(time_grid)
optimizer = torch.optim.Adam(model.get_param_groups(lr_encoder=3e-4,
lr_Ne=1e-2))
for step in range(2000):
temp = max(0.1, 1.0 - step / 2000)
loss = model.loss(genotypes, temperature=temp)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
results = model.infer(genotypes)
ne_trajectory = model.get_Ne()
print("Inferred N_e(t):", ne_trajectory.numpy())
print("Final ELBO:", results['elbo'].mean().item())
No simulations were harmed in the making of this inference.