Complication III: Balance Wheel

Neural SFS Inference via Differentiable Diffusion

The Mechanism at a Glance

Mainspring inverts simulations to infer full ARGs. Escapement uses the coalescent likelihood on raw genotypes to infer genealogies and \(N_e(t)\). Both operate on sequence-level data – they see every site, every haplotype.

Balance Wheel takes a different path: it operates on the Site Frequency Spectrum (SFS), the same summary statistic that dadi and moments use. The SFS compresses the genome into a histogram of allele frequencies – a massive dimensionality reduction (millions of sites \(\to\) \(n - 1\) counts for \(n\) samples, or \(n_1 \times n_2\) entries for two populations). This compression discards spatial information (LD, haplotype structure) but retains everything needed for demographic inference under the Poisson Random Field model.

The question is: can we replace dadi’s PDE solver and moments’ ODE integrator with a neural network, while keeping the same Poisson likelihood and the same demographic parameters?

Yes. And the result is faster, differentiable end-to-end, and handles model classes that are intractable for dadi/moments (continuous \(N_e(t)\), high-dimensional joint SFS, complex multi-population topologies).

The name

The balance wheel is the oscillating component of a mechanical watch that regulates the timekeeping. It doesn’t track every molecule of air or every vibration of the case – it feels the aggregate force of the balance spring and responds with a precise oscillation. The SFS is the aggregate pressure of evolution on a genome: it doesn’t track every haplotype or every recombination event, but it captures the net effect of demographic history on allele frequencies. Balance Wheel, like its namesake, works with this aggregate signal.

The three modules of Balance Wheel:

  1. The Demography Encoder (the hairspring) – Encodes population size histories, split times, and migration rates into a dense vector representation. For piecewise-constant demography: a Transformer over (time, size) pairs. For continuous demography: a neural ODE. For multi-population models: a GNN over the population tree.

  2. The SFS Predictor (the balance wheel) – A neural network that maps demographic embeddings directly to the expected SFS. Replaces dadi’s PDE solver and moments’ ODE integrator. Trained with moments/dadi as a teacher – not on coalescent simulations. Output: \(M(\Theta) \in \mathbb{R}^{n-1}\) (1D SFS) or \(\mathbb{R}^{(n_1-1) \times (n_2-1)}\) (2D joint SFS).

  3. The Poisson Likelihood (the impulse pin) – The exact same Poisson log-likelihood that dadi and moments optimize. No approximation, no neural network. Given the predicted SFS and the observed SFS, compute the likelihood. Gradients flow through the SFS Predictor via backpropagation.

Observed SFS D ∈ Z^{n-1}
Demographic parameters Θ
                   |
                   v
         +--------------------------+
         |  DEMOGRAPHY ENCODER      |
         |  Transformer / NeuralODE |
         |  over (time, size) pairs |
         |                          |
         |  Multi-pop: GNN over     |
         |  population tree         |
         +--------------------------+
                   |
                   v
         +--------------------------+
         |  SFS PREDICTOR           |
         |  (replaces PDE/ODE)      |
         |                          |
         |  Θ → M(Θ)               |
         |  Trained with moments    |
         |  as teacher              |
         +--------------------------+
                   |
                   v
         +--------------------------+
         |  POISSON LIKELIHOOD      |
         |  (same as dadi/moments)  |
         |                          |
         |  ℓ(Θ) = Σ [D_j ln M_j   |
         |        - M_j - ln(D_j!)] |
         +--------------------------+
                   |
                   v (optimize Θ, or sample via HMC)
         Demographic parameters
         with full posterior

Prerequisites for this Complication

Balance Wheel directly extends two Timepieces. Before starting, you should have worked through:

  • dadi – the Wright-Fisher diffusion PDE and numerical SFS computation. Balance Wheel learns to approximate this solver.

  • moments – the moment-equation ODE system. Used as the teacher during training.

  • momi2 – coalescent SFS computation for multi-population models. Alternative teacher for complex topologies.

Familiarity with function approximation, knowledge distillation, and Hamiltonian Monte Carlo is helpful but not strictly required.

Chapters