Architecture
The movement of a grande complication is built in stages. Each stage transforms the energy of the mainspring into a more refined form, until the final stage moves the hands with perfect precision.
Mainspring processes a genotype matrix through four stages, each producing an increasingly refined representation of the evolutionary history encoded in the data. The stages mirror the factorization of the full posterior:
where \(\mathcal{A} = (\mathcal{T}, \mathbf{t})\) is the ARG decomposed into topology \(\mathcal{T}\) and node times \(\mathbf{t}\), and the genomic encoder provides the shared representation from which all three factors are decoded.
┌─────────────────────────────────────────────────────────┐
│ │
│ STAGE 1: GENOMIC ENCODER │
│ ┌─────────────┐ ┌─────────────┐ ┌──────────────┐ │
│ │ Embedding │──▶│ Set Transf. │──▶│ Sliding-Win │ │
│ │ (per site) │ │ (samples) │ │ Attn (pos.) │ │
│ └─────────────┘ └─────────────┘ └──────────────┘ │
│ D ∈ {0,1}^{n×L} ──▶ Z ∈ R^{n×L×d} │
│ │
├─────────────────────────────────────────────────────────┤
│ │
│ STAGE 2: TOPOLOGY DECODER │
│ ┌──────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │ Cross-attn │──▶│ Breakpoint │──▶│ Hard attn │ │
│ │ (Li&Stephens)│ │ detector │ │ (Gumbel-SM) │ │
│ └──────────────┘ └──────────────┘ └─────────────┘ │
│ Z ──▶ T = {(parent[], left, right)} │
│ │
├─────────────────────────────────────────────────────────┤
│ │
│ STAGE 3: DATING GNN │
│ ┌──────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │ Node/edge │──▶│ UP/DOWN msg │──▶│ Gamma heads │ │
│ │ features │ │ passing (×K) │ │ (α_v, β_v) │ │
│ └──────────────┘ └──────────────┘ └─────────────┘ │
│ T + muts ──▶ t_v ~ Gamma(α_v, β_v) │
│ │
├─────────────────────────────────────────────────────────┤
│ │
│ STAGE 4: DEMOGRAPHIC DECODER │
│ ┌──────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │ Coalescence │──▶│ Normalizing │──▶│ SFS loss │ │
│ │ time hist. │ │ flow │ │ (physics) │ │
│ └──────────────┘ └──────────────┘ └─────────────┘ │
│ t_v ──▶ q(N_e(t)) │
│ │
└─────────────────────────────────────────────────────────┘
Stage 1: Genomic Encoder
The encoder transforms the raw genotype matrix \(\mathbf{D} \in \{0,1\}^{n \times L}\) into a dense representation \(\mathbf{Z} \in \mathbb{R}^{n \times L \times d}\) that captures both inter-sample relationships and spatial correlations along the genome.
Embedding Layer
Each site is embedded independently. The input at site \(\ell\) is the column vector \(\mathbf{d}_\ell = (d_{1,\ell}, \ldots, d_{n,\ell})^\top \in \{0,1\}^n\). Each sample’s binary allele is embedded into \(\mathbb{R}^d\):
where \(\mathbf{W}_{\text{allele}} \in \mathbb{R}^{2 \times d}\) is an allele embedding table, \(\text{RFF}(\ell)\) is a random Fourier feature positional encoding (see Design Principles – One Per Timepiece, Principle 8), and \(\hat{f}_\ell = \frac{1}{n}\sum_i d_{i,\ell}\) is the sample allele frequency at site \(\ell\), projected through \(\mathbf{W}_{\text{freq}} \in \mathbb{R}^d\).
class GenomicEmbedding(nn.Module):
def __init__(self, d_model, rff_sigma=10.0):
super().__init__()
self.allele_embed = nn.Embedding(2, d_model)
self.freq_proj = nn.Linear(1, d_model, bias=False)
self.rff = RandomFourierPositionalEncoding(d_model, sigma=rff_sigma)
def forward(self, D):
B, n, L = D.shape
allele = self.allele_embed(D) # (B, n, L, d)
positions = torch.arange(L, device=D.device).float()
pos_enc = self.rff(positions) # (L, d)
freq = D.float().mean(dim=1, keepdim=True).unsqueeze(-1) # (B, 1, L, 1)
freq_enc = self.freq_proj(freq) # (B, 1, L, d)
return allele + pos_enc.unsqueeze(0).unsqueeze(0) + freq_enc
Set Transformer over Samples
At each site, the \(n\) sample embeddings are processed by an induced set attention block (ISAB) that is permutation-equivariant over the sample dimension. This implements Principle 2 from Design Principles – One Per Timepiece.
The ISAB uses \(m\) inducing points to reduce the \(O(n^2)\) cost of full self-attention to \(O(nm)\):
where \(\text{MAB}(\mathbf{X}, \mathbf{Y}) = \text{LayerNorm}(\mathbf{X} + \text{MultiheadAttention}(\mathbf{X}, \mathbf{Y}, \mathbf{Y}))\) is a multihead attention block, \(\mathbf{I} \in \mathbb{R}^{m \times d}\) are learned inducing points, and \(\mathbf{E}_\ell \in \mathbb{R}^{n \times d}\) are the sample embeddings at site \(\ell\).
class SampleEncoder(nn.Module):
def __init__(self, d_model, n_heads, n_inducing, n_layers):
super().__init__()
self.layers = nn.ModuleList([
InducedSetAttention(d_model, n_heads, n_inducing)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return self.norm(x)
Sliding-Window Positional Attention
After the Set Transformer processes each site independently over samples, we apply sliding-window self-attention along the genomic axis (Principle 1). Each sample’s sequence of \(L\) site embeddings is treated as a sequence, and attention is restricted to a window of \(w\) positions:
class GenomicEncoder(nn.Module):
def __init__(self, d_model, n_heads, n_layers,
n_inducing=32, window_size=512):
super().__init__()
self.embedding = GenomicEmbedding(d_model)
self.sample_encoder = SampleEncoder(d_model, n_heads, n_inducing, 2)
self.positional_layers = nn.ModuleList([
SlidingWindowAttention(d_model, n_heads, window_size)
for _ in range(n_layers)
])
self.ffn_layers = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
for _ in range(n_layers)
])
def forward(self, D):
Z = self.embedding(D) # (B, n, L, d)
B, n, L, d = Z.shape
Z = Z.permute(0, 2, 1, 3).reshape(B * L, n, d)
Z = self.sample_encoder(Z) # Set Transformer over samples
Z = Z.reshape(B, L, n, d).permute(0, 2, 1, 3)
Z = Z.reshape(B * n, L, d)
for attn, ffn in zip(self.positional_layers, self.ffn_layers):
Z = Z + attn(Z) # sliding-window attention
Z = Z + ffn(Z)
Z = Z.reshape(B, n, L, d)
return Z
Stage 2: Topology Decoder
The topology decoder converts the encoder’s latent representation into a sequence of local tree topologies with breakpoints. This is the most structurally novel component: it implements a learned Li & Stephens model (Principle 5).
Cross-Attention as Copying
At each genomic position \(\ell\), every sample \(i\) computes attention weights over all other samples. The attention weights represent the probability that sample \(i\) is “copying from” sample \(j\) at this position – the neural analogue of the Li & Stephens transition probabilities.
The attention matrix \(\mathbf{A}^\ell \in \mathbb{R}^{n \times n}\) at each position encodes the copying relationships. In a true Li & Stephens model, each row of this matrix would be a one-hot vector (each sample copies from exactly one source). We relax this to soft attention during training and gradually harden it.
Breakpoint Detection
Tree topology changes at recombination breakpoints. The breakpoint detector is a 1D convolution along the genomic axis that identifies positions where the attention pattern changes significantly:
where \(b_\ell \in [0, 1]\) is the breakpoint probability at position \(\ell\) and \(\|\cdot\|_F\) is the Frobenius norm of the change in attention pattern.
class BreakpointDetector(nn.Module):
def __init__(self, d_model, kernel_size=5):
super().__init__()
self.proj = nn.Linear(d_model, 1)
self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size // 2)
def forward(self, Z_diff):
x = self.proj(Z_diff).squeeze(-1).unsqueeze(1) # (B, 1, L)
return torch.sigmoid(self.conv(x)).squeeze(1) # (B, L)
Hard Attention via Gumbel-Softmax
To produce discrete tree topologies, we need hard parent assignments. During training, we use the Gumbel-softmax trick to maintain differentiability:
class TopologyDecoder(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.breakpoint_det = BreakpointDetector(d_model)
self.tau = 1.0 # annealed during training
def forward(self, Z, hard=False):
B, n, L, d = Z.shape
parent_logits = []
for ell in range(L):
q = Z[:, :, ell, :] # (B, n, d)
k = Z[:, :, ell, :] # (B, n, d)
scores = torch.bmm(q, k.transpose(1, 2)) / (d ** 0.5)
scores.diagonal(dim1=1, dim2=2).fill_(float('-inf'))
if hard:
parents = scores.argmax(dim=-1)
else:
parents = F.gumbel_softmax(scores, tau=self.tau, hard=True, dim=-1)
parent_logits.append(scores)
parent_logits = torch.stack(parent_logits, dim=2) # (B, n, L, n)
Z_diff = Z[:, :, 1:, :] - Z[:, :, :-1, :]
Z_diff_pooled = Z_diff.mean(dim=1) # pool over samples
breakpoints = self.breakpoint_det(Z_diff_pooled)
return parent_logits, breakpoints
At inference time (\(\tau \to 0\) or hard=True), the Gumbel-softmax collapses
to argmax, producing deterministic parent assignments. The topology is then assembled
into contiguous tree segments separated by breakpoints.
From attention to tree topology
The attention matrix \(\mathbf{A}^\ell\) does not directly encode a valid tree. To obtain a tree, we apply a greedy bottom-up construction: starting from the leaves, we iteratively merge the pair with the highest mutual attention weight, creating an internal node. The process continues until all samples are connected. This is reminiscent of hierarchical clustering, but the similarity metric is learned end-to-end.
Stage 3: Dating GNN
Given the predicted topology, the dating GNN assigns times to internal nodes using learned message passing. This is the neural analogue of tsdate’s inside-outside algorithm (Principle 4), with gamma output heads (Principle 7) and per-segment sufficient statistics as input features (Principle 9).
Node and Edge Features
Each node \(v\) in a local tree is initialized with a feature vector:
where \(\mathbf{z}_{v,\ell}\) is the encoder output for leaf \(v\) at the midpoint of the tree’s genomic span, and \(\hat{t}_v\) is an initial time estimate from the Threads-style natural estimator.
Each edge \((u, v)\) carries features:
where \(m_{uv}\) is the mutation count, \(s_{uv}\) the genomic span, \(\hat{t}_{uv}\) the natural time estimate, and \(n_{uv}\) the number of descendant leaves.
UP/DOWN Message Passing
The GNN alternates between UP passes (children to parent, analogous to tsdate’s inside pass) and DOWN passes (parent to children, analogous to the outside pass):
The GRU (gated recurrent unit) update prevents the node features from drifting too far from their initial values while allowing iterative refinement. After \(K\) rounds (typically \(K = 6\)), the node features are decoded into gamma parameters.
class DatingGNN(nn.Module):
def __init__(self, d_model, n_rounds=6):
super().__init__()
self.n_rounds = n_rounds
self.node_init = nn.Linear(d_model, d_model)
self.edge_encoder = nn.Linear(6, d_model)
self.up_msg = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.ReLU(),
nn.Linear(d_model, d_model))
self.down_msg = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.ReLU(),
nn.Linear(d_model, d_model))
self.gru = nn.GRUCell(d_model, d_model)
self.alpha_head = nn.Linear(d_model, 1)
self.beta_head = nn.Linear(d_model, 1)
def forward(self, node_features, edge_features, parent_array):
h = self.node_init(node_features)
f = self.edge_encoder(edge_features)
for k in range(self.n_rounds):
msg = torch.zeros_like(h)
for child, parent in enumerate(parent_array):
if parent < 0:
continue
up = self.up_msg(torch.cat([h[child], f[child]], dim=-1))
msg[parent] += up
down = self.down_msg(torch.cat([h[parent], f[child]], dim=-1))
msg[child] += down
h = self.gru(msg, h)
alpha = F.softplus(self.alpha_head(h)) + 1.0
beta = torch.exp(self.beta_head(h))
return alpha, beta
Gamma Output Heads
The final node features \(\mathbf{h}_v^{(K)}\) are decoded into gamma parameters \((\alpha_v, \beta_v)\):
The predicted time distribution for node \(v\) is then \(t_v \sim \text{Gamma}(\alpha_v, \beta_v)\), with mean \(\mathbb{E}[t_v] = \alpha_v / \beta_v\) and variance \(\text{Var}(t_v) = \alpha_v / \beta_v^2\).
Cross-Tree Consistency
Adjacent local trees share most of their topology and node times. To enforce consistency, we add a cross-tree regularizer that penalizes large changes in predicted node times between adjacent trees:
where \(\mathcal{V}_\ell\) is the set of nodes in local tree \(\ell\) and the intersection identifies nodes shared between adjacent trees.
Stage 4: Demographic Decoder
The final stage maps the inferred coalescence-time distribution to a posterior over \(N_e(t)\) trajectories. This is where the ARG’s status as a sufficient statistic (Principle 3) pays off: the demographic decoder operates entirely on the predicted coalescence times, not on the raw genotype matrix.
Coalescence-Time Histogram
From the dated ARG, we extract a histogram of coalescence times. For each internal node \(v\) at time \(t_v\) with \(k_v\) children, we record \(k_v - 1\) coalescence events at time \(t_v\). Binning these into \(B\) logarithmically-spaced time intervals gives a vector \(\mathbf{c} \in \mathbb{R}^B\):
This histogram, together with the predicted SFS from the ARG, forms the input to the normalizing flow.
Normalizing Flow
The demographic decoder is a conditional normalizing flow that transforms a simple base distribution (standard normal) into a posterior over \(N_e(t)\) functions, conditioned on the coalescence-time histogram and SFS:
where \(g_\phi\) is an invertible neural network parameterized by \(\phi\). The \(N_e(t)\) trajectory is represented as a vector of \(B\) values on the same log-spaced time grid, with linear interpolation between grid points.
class DemographicDecoder(nn.Module):
def __init__(self, n_time_bins, n_flow_layers, d_cond):
super().__init__()
self.condition_net = nn.Sequential(
nn.Linear(2 * n_time_bins, d_cond),
nn.ReLU(),
nn.Linear(d_cond, d_cond),
)
self.flow_layers = nn.ModuleList([
AffineCouplingLayer(n_time_bins, d_cond)
for _ in range(n_flow_layers)
])
def forward(self, coal_histogram, sfs, n_samples=1):
cond = self.condition_net(torch.cat([coal_histogram, sfs], dim=-1))
z = torch.randn(n_samples, coal_histogram.size(-1))
log_det = 0.0
for layer in self.flow_layers:
z, ld = layer(z, cond)
log_det += ld
ne_trajectory = F.softplus(z)
return ne_trajectory, log_det
SFS Auxiliary Loss
The predicted ARG implies a predicted SFS, which must be consistent with the observed SFS. This consistency check is the physics-informed regularizer from Principle 6:
where the predicted SFS is computed differentiably from the ARG branch lengths and descendant counts, and the observed SFS is computed directly from the genotype matrix.
Why the SFS loss matters
Without the SFS loss, the network can produce ARGs that correctly reconstruct the topology and approximate the node times but systematically miscount the number of lineages at each frequency class. The SFS loss acts as a global consistency check: it catches errors in the predicted ARG that local losses (topology accuracy, node time likelihood) might miss. This is analogous to how a watchmaker, after assembling each gear individually, checks that the overall gear train produces the correct time – a global test that catches assembly errors invisible at the component level.
Putting It All Together
The complete Mainspring model chains all four stages:
class Mainspring(nn.Module):
def __init__(self, d_model=256, n_heads=8, n_encoder_layers=6,
n_gnn_rounds=6, n_time_bins=64, n_flow_layers=8):
super().__init__()
self.encoder = GenomicEncoder(d_model, n_heads, n_encoder_layers)
self.topology_decoder = TopologyDecoder(d_model, n_heads)
self.dating_gnn = DatingGNN(d_model, n_gnn_rounds)
self.demographic_decoder = DemographicDecoder(
n_time_bins, n_flow_layers, d_cond=128
)
def forward(self, D, hard=False):
Z = self.encoder(D)
parent_logits, breakpoints = self.topology_decoder(Z, hard=hard)
topology = self.extract_trees(parent_logits, breakpoints)
node_feats, edge_feats, parent_arrays = self.build_gnn_input(
Z, topology
)
alphas, betas = self.dating_gnn(node_feats, edge_feats, parent_arrays)
times = alphas / betas # point estimate = gamma mean
coal_hist = self.build_coalescence_histogram(times, topology)
pred_sfs = self.compute_sfs(times, topology)
ne_posterior, log_det = self.demographic_decoder(coal_hist, pred_sfs)
return {
'topology': topology,
'breakpoints': breakpoints,
'alpha': alphas,
'beta': betas,
'times': times,
'ne_posterior': ne_posterior,
'flow_log_det': log_det,
'predicted_sfs': pred_sfs,
}
Computational Complexity
Stage |
Complexity |
Bottleneck |
|---|---|---|
Genomic Encoder |
\(O(n^2 L d + n L w d)\) |
Set Transformer (\(n^2\) per site) + sliding-window attention (\(w\) per position) |
Topology Decoder |
\(O(n^2 L d)\) |
Cross-attention at each site |
Dating GNN |
\(O(K n L d)\) |
\(K\) message-passing rounds on trees with \(O(n)\) nodes |
Demographic Decoder |
\(O(B^2)\) |
Normalizing flow on \(B\) time bins (negligible) |
Total: \(O(n^2 L d)\), linear in sequence length and quadratic in sample count. For typical applications (\(n \leq 100\), \(L \sim 10^4\)), this is feasible on a single GPU.