Overview of Balance Wheel
The balance wheel doesn’t track every molecule of air or every vibration of the case. It feels the aggregate force of the balance spring and responds with a precise oscillation. The SFS is the aggregate pressure of evolution on a genome: it doesn’t track every haplotype, but it captures the net effect of demographic history on allele frequencies.
Mainspring inverts simulations to infer full ARGs. Escapement uses the coalescent likelihood on raw genotypes to infer genealogies and \(N_e(t)\). Both operate on sequence-level data – they see every site, every haplotype.
Balance Wheel takes a different path. It operates on the Site Frequency Spectrum (SFS), the same summary statistic that dadi and moments use. The SFS compresses the genome into a histogram of allele frequencies – a massive dimensionality reduction (millions of sites \(\to\) \(n - 1\) counts for \(n\) samples, or \(n_1 \times n_2\) entries for two populations). This compression discards spatial information (LD, haplotype structure) but retains everything needed for demographic inference under the Poisson Random Field model.
The question is: can we replace dadi’s PDE solver and moments’ ODE integrator with a neural network, while keeping the same Poisson likelihood and the same demographic parameters?
Yes. And the result is faster, differentiable end-to-end, and handles model classes that are intractable for dadi/moments – continuous \(N_e(t)\), high-dimensional joint SFS, complex multi-population topologies.
The SFS as a Sufficient Statistic
Under the Poisson Random Field (PRF) model, the SFS is a sufficient statistic for demographic parameters \(\Theta\). The PRF model assumes:
Infinitely many sites: each SNP arises on an independent genealogy.
Low mutation rate: at most one mutation per site.
Free recombination between sites (no LD).
Under these assumptions, the number of SNPs in each frequency class is an independent Poisson random variable:
where \(D_j\) is the observed count of SNPs with derived allele frequency \(j/n\) and \(M_j(\Theta)\) is the expected count under demographic model \(\Theta\). The log-likelihood factorizes:
Since the last term is a constant, maximizing the log-likelihood requires only the mapping \(\Theta \to \mathbf{M}(\Theta)\) – the expected SFS as a function of demographic parameters. This is precisely what dadi and moments compute. And it is precisely what Balance Wheel learns to approximate.
Why the SFS is enough
The SFS discards haplotype structure, LD, and all spatial information along the genome. Yet for demographic inference under the PRF model, it loses nothing. The intuition: if sites are independent (free recombination, low mutation), then the frequency spectrum captures all the information about \(N_e(t)\), split times, and migration rates. The SFS is a low-dimensional projection of a high-dimensional dataset, and under the PRF model, it is the optimal such projection.
This is not true for all questions. The SFS cannot distinguish between selective sweeps and bottlenecks (both shift the spectrum toward rare alleles). It cannot detect gene conversion (which affects LD patterns). It cannot resolve fine-scale recombination rate variation. For these questions, you need sequence-level methods like Mainspring or Escapement.
What dadi and moments Actually Compute
Both methods solve the same inference problem: given \(\Theta\), compute \(\mathbf{M}(\Theta)\), then maximize the Poisson log-likelihood \(\ell(\Theta)\). They differ only in how they compute \(\mathbf{M}(\Theta)\).
dadi solves the Wright-Fisher diffusion PDE. For one population with variable size \(\nu(t) = N_e(t) / N_{\text{ref}}\):
where \(\phi(x, t)\) is the density of alleles at frequency \(x\) at time \(t\). The expected SFS entry is obtained by integrating \(\phi\) against binomial sampling weights. This requires discretizing the frequency axis on a grid of \(G\) points and the time axis into piecewise-constant epochs. Cost: \(O(G^k)\) for \(k\) populations.
moments derives and integrates ODEs for the SFS entries directly:
where the drift and mutation operators are sparse linear transformations on the SFS vector. No frequency grid is needed. Cost: \(O(n^k)\) per ODE step for \(k\) populations with sample size \(n\).
Both methods produce the same \(\mathbf{M}(\Theta)\) (up to numerical precision) and use the same Poisson likelihood. The bottleneck is the forward computation: for \(k \geq 3\) populations and large sample sizes, both become prohibitively expensive. A full treatment of these computations is in What dadi and moments Actually Compute.
Balance Wheel’s Approach
Balance Wheel replaces the PDE/ODE solver with a neural function approximator. The idea is simple:
The network is trained to reproduce the exact SFS that moments or dadi would compute, using those tools as a teacher. Once trained, the network produces the expected SFS in a single forward pass – no PDE to solve, no ODE to integrate, no grid to refine. The Poisson likelihood is then evaluated on the neural SFS prediction, exactly as dadi and moments would.
The three modules:
Demography Encoder (the hairspring): encodes population size histories, split times, and migration rates into a dense vector \(\mathbf{z}_\Theta\).
SFS Predictor (the balance wheel): a neural network that maps \(\mathbf{z}_\Theta \to \hat{\mathbf{M}}(\Theta)\). Replaces the PDE/ODE solver.
Poisson Likelihood (the impulse pin): the exact same Poisson log-likelihood that dadi and moments optimize. No approximation, no neural network. Gradients flow through the SFS Predictor via backpropagation.
Three Reasons Balance Wheel Matters
Why not just use moments directly? Three reasons.
1. Speed for complex models. For \(k\) populations with large sample sizes, moments costs \(O(n^k)\) per ODE step, and each likelihood evaluation requires integrating the full ODE system. For three or more populations with \(n > 50\), a single likelihood evaluation takes seconds. The neural network is \(O(1)\) – a single forward pass through a small MLP, requiring ~0.1 ms regardless of model complexity. This enables algorithms that require thousands of evaluations: Bayesian posterior sampling, bootstrap confidence intervals, exhaustive model comparison.
Method |
1-pop (n=20) |
2-pop (n=20) |
3-pop (n=20) |
|---|---|---|---|
~100 ms |
~10 s |
Impractical |
|
~10 ms |
~500 ms |
~60 s |
|
Balance Wheel |
~0.1 ms |
~0.1 ms |
~0.1 ms |
2. Continuous demography. dadi and moments require piecewise-constant \(N_e(t)\). Real demography is continuous. Balance Wheel can parameterize \(N_e(t)\) as a neural spline or Gaussian process and still compute the SFS, because the SFS Predictor learns a smooth mapping that generalizes beyond the piecewise-constant training examples. This eliminates the need to choose the number of epochs – a model-selection problem that plagues classical approaches.
3. Gradient quality. dadi computes gradients via finite differences (perturb each parameter, re-solve the PDE). moments uses automatic differentiation through the ODE solver, which is better but can be numerically unstable for stiff systems. Balance Wheel gives exact gradients via backpropagation through a stable neural network – no numerical issues, no step-size sensitivity, no wasted function evaluations.
Comparison Table
Feature |
Balance Wheel |
||
|---|---|---|---|
SFS computation |
PDE on frequency grid |
ODE for SFS entries |
Neural forward pass |
Speed per SFS eval |
~100 ms |
~10 ms |
~0.1 ms |
Gradient method |
Finite differences |
AD through ODE |
Backprop through MLP |
Continuous \(N_e(t)\) |
No (piecewise-constant) |
No (piecewise-constant) |
Yes (neural spline) |
Multi-pop scaling |
\(O(G^k)\), impractical for \(k > 3\) |
\(O(n^k)\) |
\(O(1)\) forward pass |
Uncertainty |
Profile likelihood |
Profile likelihood |
Full posterior (HMC/NUTS) |
Model comparison |
AIC from point estimate |
AIC from point estimate |
Marginal likelihood via importance sampling |
Training cost |
None (classical solver) |
None (classical solver) |
One-time (moments evals) |
Accuracy guarantee |
Numerical precision of PDE |
Numerical precision of ODE |
Teacher quality ceiling |
Reading the table
No method dominates all rows. dadi and moments are classical, trusted, and require no training. For a single model with two populations analyzed once, running moments directly is simpler and more trustworthy. Balance Wheel wins when you need many likelihood evaluations – posterior sampling, model comparison grids, bootstrap resampling, or when you need to handle \(k \geq 3\) populations where the classical solvers become impractical.
Honest Limitations
Balance Wheel is not a universal replacement for dadi and moments. It has four fundamental limitations.
1. It inherits the SFS’s limitations. The SFS discards linkage disequilibrium, haplotype structure, and all spatial information along the genome. Balance Wheel cannot detect recombination rate variation, recent selective sweeps (which affect LD more than the SFS), or complex admixture patterns that leave signatures in haplotype sharing but not allele frequencies. For these questions, use Escapement or Mainspring.
2. Teacher quality ceiling. Balance Wheel can only be as accurate as the moments/dadi computation it was trained on. If moments has numerical issues for extreme parameter values (very large populations, very recent events), the neural network will inherit those issues or extrapolate poorly. The student cannot surpass the teacher.
3. Generalization to unseen topologies. The multi-population version must be trained on a distribution of population tree topologies. If the true topology is outside this distribution (e.g., a 6-population model when training only covered up to 4), the network may fail silently. Classical methods handle any topology that can be specified in their framework.
4. Single-dataset analysis may prefer moments directly. For a single dataset analyzed once with a well-specified two-population model, running moments is simpler, more transparent, and gives the exact answer. Balance Wheel’s advantages emerge only when you need thousands of likelihood evaluations or when the model complexity exceeds what the classical solvers can handle.
The Road Ahead
The remaining chapters of this Complication build Balance Wheel from first principles:
What dadi and moments Actually Compute – A deep dive into the PDE solver (dadi), the ODE system (moments), and the coalescent computation (momi2). Understanding what we are replacing.
Architecture – The three modules in detail: Demography Encoder, SFS Predictor, and Poisson Likelihood. PyTorch code for each.
Teacher-Student Training – How to train the SFS Predictor using moments as a teacher. Why this is not simulation-based inference. Validation strategies.
Posterior Inference via HMC – Using the fast differentiable likelihood for Bayesian posterior sampling. Credible intervals, posterior predictive checks, model comparison.
Handling Multiple Populations – The GNN encoder for population trees, the multi-dimensional SFS predictor, and why Balance Wheel scales where dadi/moments cannot.
Comparison and Limitations – Systematic comparison across all three Complications, connections to every Timepiece, and a decision tree for choosing the right tool.
Each chapter follows the book’s rhythm: motivation, math, code, verification. The math here is the Poisson log-likelihood – the same likelihood that dadi and moments optimize, now evaluated 1000× faster through a neural network. The verification is comparing neural SFS predictions against exact computations: does the student match the teacher, and does the posterior make sense?
import torch
from balance_wheel import BalanceWheel
model = BalanceWheel(d_model=128, n_heads=4, n_layers=2,
max_epochs=10, max_n=100)
# Phase 1: Train on moments evaluations (one-time cost)
model.train_on_teacher(n_examples=100_000, teacher="moments")
# Phase 2: Inference on real data
observed_sfs = torch.tensor([3012, 1580, 1102, 845, ...]) # from VCF
result = model.fit(observed_sfs, n=20, theta_L=5000.0, method="HMC")
print(result.posterior_median) # demographic parameters
print(result.credible_intervals) # 95% CI on all parameters
print(result.marginal_likelihood) # for model comparison
No PDE to solve. No ODE to integrate. Just a forward pass, a Poisson likelihood, and gradient descent – or, for the full posterior, Hamiltonian Monte Carlo through a landscape that the neural network makes smooth and fast to traverse.