.. _balance_wheel_overview:

===========================
Overview of Balance Wheel
===========================

   *The balance wheel doesn't track every molecule of air or every vibration of
   the case. It feels the aggregate force of the balance spring and responds with
   a precise oscillation. The SFS is the aggregate pressure of evolution on a
   genome: it doesn't track every haplotype, but it captures the net effect of
   demographic history on allele frequencies.*

Mainspring inverts simulations to infer full ARGs.
:ref:`Escapement <escapement_complication>` uses the coalescent likelihood on
raw genotypes to infer genealogies and :math:`N_e(t)`. Both operate on
**sequence-level** data -- they see every site, every haplotype.

Balance Wheel takes a different path. It operates on the **Site Frequency
Spectrum (SFS)**, the same summary statistic that :ref:`dadi <dadi_timepiece>`
and :ref:`moments <moments_timepiece>` use. The SFS compresses the genome into
a histogram of allele frequencies -- a massive dimensionality reduction
(millions of sites :math:`\to` :math:`n - 1` counts for :math:`n` samples, or
:math:`n_1 \times n_2` entries for two populations). This compression discards
spatial information (LD, haplotype structure) but retains everything needed for
demographic inference under the Poisson Random Field model.

The question is: can we replace dadi's PDE solver and moments' ODE integrator
with a neural network, while keeping the same Poisson likelihood and the same
demographic parameters?

**Yes.** And the result is faster, differentiable end-to-end, and handles model
classes that are intractable for dadi/moments -- continuous :math:`N_e(t)`,
high-dimensional joint SFS, complex multi-population topologies.


The SFS as a Sufficient Statistic
====================================

Under the Poisson Random Field (PRF) model, the SFS is a sufficient statistic
for demographic parameters :math:`\Theta`. The PRF model assumes:

1. **Infinitely many sites**: each SNP arises on an independent genealogy.
2. **Low mutation rate**: at most one mutation per site.
3. **Free recombination** between sites (no LD).

Under these assumptions, the number of SNPs in each frequency class is an
independent Poisson random variable:

.. math::

   D_j \sim \text{Poisson}(M_j(\Theta)), \qquad j = 1, \ldots, n-1

where :math:`D_j` is the observed count of SNPs with derived allele frequency
:math:`j/n` and :math:`M_j(\Theta)` is the expected count under demographic
model :math:`\Theta`. The log-likelihood factorizes:

.. math::

   \ell(\Theta) = \sum_{j=1}^{n-1} \left[ D_j \ln M_j(\Theta) - M_j(\Theta)
   - \ln(D_j!) \right]

Since the last term is a constant, maximizing the log-likelihood requires only
the mapping :math:`\Theta \to \mathbf{M}(\Theta)` -- the expected SFS as a
function of demographic parameters. This is precisely what dadi and moments
compute. And it is precisely what Balance Wheel learns to approximate.

.. admonition:: Why the SFS is enough

   The SFS discards haplotype structure, LD, and all spatial information along
   the genome. Yet for demographic inference under the PRF model, it loses
   *nothing*. The intuition: if sites are independent (free recombination, low
   mutation), then the frequency spectrum captures all the information about
   :math:`N_e(t)`, split times, and migration rates. The SFS is a low-dimensional
   projection of a high-dimensional dataset, and under the PRF model, it is the
   *optimal* such projection.

   This is not true for all questions. The SFS cannot distinguish between
   selective sweeps and bottlenecks (both shift the spectrum toward rare
   alleles). It cannot detect gene conversion (which affects LD patterns). It
   cannot resolve fine-scale recombination rate variation. For these questions,
   you need sequence-level methods like :ref:`Mainspring <mainspring_complication>`
   or :ref:`Escapement <escapement_complication>`.


What dadi and moments Actually Compute
=========================================

Both methods solve the same inference problem: given :math:`\Theta`, compute
:math:`\mathbf{M}(\Theta)`, then maximize the Poisson log-likelihood
:math:`\ell(\Theta)`. They differ only in *how* they compute
:math:`\mathbf{M}(\Theta)`.

**dadi** solves the Wright-Fisher diffusion PDE. For one population with
variable size :math:`\nu(t) = N_e(t) / N_{\text{ref}}`:

.. math::

   \frac{\partial \phi}{\partial t} = \frac{1}{2\nu(t)} \cdot
   \frac{\partial^2}{\partial x^2}\!\left[x(1-x)\,\phi\right]

where :math:`\phi(x, t)` is the density of alleles at frequency :math:`x` at
time :math:`t`. The expected SFS entry is obtained by integrating :math:`\phi`
against binomial sampling weights. This requires discretizing the frequency axis
on a grid of :math:`G` points and the time axis into piecewise-constant epochs.
Cost: :math:`O(G^k)` for :math:`k` populations.

**moments** derives and integrates ODEs for the SFS entries directly:

.. math::

   \frac{d\Phi_j}{dt} = \text{drift}(\Phi) + \text{mutation}(\Phi)

where the drift and mutation operators are sparse linear transformations on the
SFS vector. No frequency grid is needed. Cost: :math:`O(n^k)` per ODE step for
:math:`k` populations with sample size :math:`n`.

Both methods produce the same :math:`\mathbf{M}(\Theta)` (up to numerical
precision) and use the same Poisson likelihood. The bottleneck is the forward
computation: for :math:`k \geq 3` populations and large sample sizes, both
become prohibitively expensive. A full treatment of these computations is in
:ref:`What dadi and moments Actually Compute <balance_wheel_what_they_compute>`.


Balance Wheel's Approach
==========================

Balance Wheel replaces the PDE/ODE solver with a neural function approximator.
The idea is simple:

.. math::

   \Theta \;\xrightarrow{\text{neural network}}\; \hat{\mathbf{M}}(\Theta)
   \;\approx\; \mathbf{M}(\Theta)

The network is trained to reproduce the exact SFS that moments or dadi would
compute, using those tools as a **teacher**. Once trained, the network produces
the expected SFS in a single forward pass -- no PDE to solve, no ODE to
integrate, no grid to refine. The Poisson likelihood is then evaluated on the
neural SFS prediction, exactly as dadi and moments would.

The three modules:

1. **Demography Encoder** (the hairspring): encodes population size histories,
   split times, and migration rates into a dense vector :math:`\mathbf{z}_\Theta`.

2. **SFS Predictor** (the balance wheel): a neural network that maps
   :math:`\mathbf{z}_\Theta \to \hat{\mathbf{M}}(\Theta)`. Replaces the PDE/ODE
   solver.

3. **Poisson Likelihood** (the impulse pin): the exact same Poisson
   log-likelihood that dadi and moments optimize. No approximation, no neural
   network. Gradients flow through the SFS Predictor via backpropagation.


Three Reasons Balance Wheel Matters
======================================

Why not just use moments directly? Three reasons.

**1. Speed for complex models.** For :math:`k` populations with large sample
sizes, moments costs :math:`O(n^k)` per ODE step, and each likelihood evaluation
requires integrating the full ODE system. For three or more populations with
:math:`n > 50`, a single likelihood evaluation takes seconds. The neural network
is :math:`O(1)` -- a single forward pass through a small MLP, requiring ~0.1 ms
regardless of model complexity. This enables algorithms that require thousands of
evaluations: Bayesian posterior sampling, bootstrap confidence intervals,
exhaustive model comparison.

.. list-table:: Speed comparison per SFS evaluation
   :header-rows: 1
   :widths: 25 20 20 35

   * - Method
     - 1-pop (n=20)
     - 2-pop (n=20)
     - 3-pop (n=20)
   * - :ref:`dadi <dadi_timepiece>`
     - ~100 ms
     - ~10 s
     - Impractical
   * - :ref:`moments <moments_timepiece>`
     - ~10 ms
     - ~500 ms
     - ~60 s
   * - **Balance Wheel**
     - **~0.1 ms**
     - **~0.1 ms**
     - **~0.1 ms**

**2. Continuous demography.** dadi and moments require piecewise-constant
:math:`N_e(t)`. Real demography is continuous. Balance Wheel can parameterize
:math:`N_e(t)` as a neural spline or Gaussian process and still compute the
SFS, because the SFS Predictor learns a smooth mapping that generalizes beyond
the piecewise-constant training examples. This eliminates the need to choose the
number of epochs -- a model-selection problem that plagues classical approaches.

**3. Gradient quality.** dadi computes gradients via finite differences (perturb
each parameter, re-solve the PDE). moments uses automatic differentiation
through the ODE solver, which is better but can be numerically unstable for
stiff systems. Balance Wheel gives exact gradients via backpropagation through a
stable neural network -- no numerical issues, no step-size sensitivity, no
wasted function evaluations.


Comparison Table
==================

.. list-table:: dadi vs. moments vs. Balance Wheel
   :header-rows: 1
   :widths: 24 24 24 28

   * - Feature
     - :ref:`dadi <dadi_timepiece>`
     - :ref:`moments <moments_timepiece>`
     - **Balance Wheel**
   * - SFS computation
     - PDE on frequency grid
     - ODE for SFS entries
     - Neural forward pass
   * - Speed per SFS eval
     - ~100 ms
     - ~10 ms
     - **~0.1 ms**
   * - Gradient method
     - Finite differences
     - AD through ODE
     - Backprop through MLP
   * - Continuous :math:`N_e(t)`
     - No (piecewise-constant)
     - No (piecewise-constant)
     - **Yes** (neural spline)
   * - Multi-pop scaling
     - :math:`O(G^k)`, impractical for :math:`k > 3`
     - :math:`O(n^k)`
     - :math:`O(1)` forward pass
   * - Uncertainty
     - Profile likelihood
     - Profile likelihood
     - **Full posterior** (HMC/NUTS)
   * - Model comparison
     - AIC from point estimate
     - AIC from point estimate
     - **Marginal likelihood** via importance sampling
   * - Training cost
     - None (classical solver)
     - None (classical solver)
     - One-time (moments evals)
   * - Accuracy guarantee
     - Numerical precision of PDE
     - Numerical precision of ODE
     - Teacher quality ceiling

.. admonition:: Reading the table

   No method dominates all rows. dadi and moments are classical, trusted, and
   require no training. For a single model with two populations analyzed once,
   running moments directly is simpler and more trustworthy. Balance Wheel wins
   when you need *many* likelihood evaluations -- posterior sampling, model
   comparison grids, bootstrap resampling, or when you need to handle
   :math:`k \geq 3` populations where the classical solvers become impractical.


Honest Limitations
====================

Balance Wheel is not a universal replacement for dadi and moments. It has four
fundamental limitations.

**1. It inherits the SFS's limitations.** The SFS discards linkage
disequilibrium, haplotype structure, and all spatial information along the
genome. Balance Wheel cannot detect recombination rate variation, recent
selective sweeps (which affect LD more than the SFS), or complex admixture
patterns that leave signatures in haplotype sharing but not allele frequencies.
For these questions, use :ref:`Escapement <escapement_complication>` or
:ref:`Mainspring <mainspring_complication>`.

**2. Teacher quality ceiling.** Balance Wheel can only be as accurate as the
moments/dadi computation it was trained on. If moments has numerical issues for
extreme parameter values (very large populations, very recent events), the neural
network will inherit those issues or extrapolate poorly. The student cannot
surpass the teacher.

**3. Generalization to unseen topologies.** The multi-population version must be
trained on a distribution of population tree topologies. If the true topology is
outside this distribution (e.g., a 6-population model when training only covered
up to 4), the network may fail silently. Classical methods handle any topology
that can be specified in their framework.

**4. Single-dataset analysis may prefer moments directly.** For a single dataset
analyzed once with a well-specified two-population model, running moments is
simpler, more transparent, and gives the exact answer. Balance Wheel's advantages
emerge only when you need thousands of likelihood evaluations or when the model
complexity exceeds what the classical solvers can handle.


The Road Ahead
================

The remaining chapters of this Complication build Balance Wheel from first
principles:

1. :ref:`What dadi and moments Actually Compute <balance_wheel_what_they_compute>`
   -- A deep dive into the PDE solver (dadi), the ODE system (moments), and the
   coalescent computation (momi2). Understanding what we are replacing.

2. :ref:`Architecture <balance_wheel_architecture>` -- The three modules in
   detail: Demography Encoder, SFS Predictor, and Poisson Likelihood. PyTorch
   code for each.

3. :ref:`Teacher-Student Training <balance_wheel_training>` -- How to train the
   SFS Predictor using moments as a teacher. Why this is not simulation-based
   inference. Validation strategies.

4. :ref:`Posterior Inference via HMC <balance_wheel_posterior>` -- Using the fast
   differentiable likelihood for Bayesian posterior sampling. Credible intervals,
   posterior predictive checks, model comparison.

5. :ref:`Handling Multiple Populations <balance_wheel_multi_pop>` -- The GNN
   encoder for population trees, the multi-dimensional SFS predictor, and why
   Balance Wheel scales where dadi/moments cannot.

6. :ref:`Comparison and Limitations <balance_wheel_comparison>` -- Systematic
   comparison across all three Complications, connections to every Timepiece, and
   a decision tree for choosing the right tool.

Each chapter follows the book's rhythm: motivation, math, code, verification.
The math here is the Poisson log-likelihood -- the same likelihood that dadi
and moments optimize, now evaluated 1000× faster through a neural network. The
verification is comparing neural SFS predictions against exact computations: does
the student match the teacher, and does the posterior make sense?

.. code-block:: python

   import torch
   from balance_wheel import BalanceWheel

   model = BalanceWheel(d_model=128, n_heads=4, n_layers=2,
                        max_epochs=10, max_n=100)

   # Phase 1: Train on moments evaluations (one-time cost)
   model.train_on_teacher(n_examples=100_000, teacher="moments")

   # Phase 2: Inference on real data
   observed_sfs = torch.tensor([3012, 1580, 1102, 845, ...])  # from VCF
   result = model.fit(observed_sfs, n=20, theta_L=5000.0, method="HMC")

   print(result.posterior_median)   # demographic parameters
   print(result.credible_intervals) # 95% CI on all parameters
   print(result.marginal_likelihood)  # for model comparison

No PDE to solve. No ODE to integrate. Just a forward pass, a Poisson
likelihood, and gradient descent -- or, for the full posterior, Hamiltonian
Monte Carlo through a landscape that the neural network makes smooth and fast
to traverse.
