.. _balance_wheel_multi_pop:

================================
Handling Multiple Populations
================================

   *The grande complication of watchmaking is not a single mechanism but many --
   chronograph, perpetual calendar, minute repeater -- each interacting with the
   others through shared wheels and cams. The artistry lies not in making each
   complication work alone, but in making them work together without interference.
   The multi-population SFS is the grande complication of demographic inference:
   each population has its own size history, but the populations interact through
   splits, merges, and migration.*

This is where Balance Wheel truly distinguishes itself from the classical
methods it replaces. For two or more populations, the joint SFS is a
multi-dimensional tensor, and the classical solvers scale exponentially with the
number of populations. Balance Wheel's neural SFS Predictor scales as
:math:`O(1)` -- a single forward pass regardless of the number of populations.


The Curse of Dimensionality
==============================

For :math:`k` populations with sample sizes :math:`n_1, n_2, \ldots, n_k`, the
joint SFS is a :math:`k`-dimensional tensor with
:math:`\prod_{i=1}^k (n_i - 1)` entries. Each entry
:math:`M_{j_1, j_2, \ldots, j_k}` is the expected number of SNPs where
population :math:`i` has derived allele frequency :math:`j_i / n_i`.

.. list-table:: Joint SFS dimensions
   :header-rows: 1
   :widths: 15 20 25 20 20

   * - Populations
     - Sample sizes
     - SFS entries
     - dadi cost
     - moments cost
   * - 1
     - :math:`n = 20`
     - 19
     - :math:`O(G)`
     - :math:`O(n)`
   * - 2
     - :math:`n_1 = n_2 = 20`
     - :math:`19 \times 19 = 361`
     - :math:`O(G^2)`
     - :math:`O(n^2)`
   * - 3
     - :math:`n_i = 20`
     - :math:`19^3 = 6{,}859`
     - :math:`O(G^3)`
     - :math:`O(n^3)`
   * - 4
     - :math:`n_i = 10`
     - :math:`9^4 = 6{,}561`
     - Impractical
     - :math:`O(n^4)` -- hours
   * - 5
     - :math:`n_i = 10`
     - :math:`9^5 = 59{,}049`
     - Impractical
     - Impractical

For :ref:`dadi <dadi_timepiece>`, the cost per time step scales as
:math:`O(G^k)`, where :math:`G` is the grid resolution (typically 40--60 for
accurate results). For :math:`k = 3` with :math:`G = 40`, this is :math:`64{,}000`
operations per step -- feasible but slow. For :math:`k = 4`, it is
:math:`2{,}560{,}000` -- hours per evaluation. dadi's authors explicitly
recommend :math:`k \leq 3`.

For :ref:`moments <moments_timepiece>`, the cost per ODE step scales as
:math:`O(\prod n_i)` times the cost of applying the drift operator along each
axis. For :math:`k = 3` with :math:`n = 20`, a single ODE step processes
:math:`\sim 7{,}000` SFS entries along three axes -- about 10 seconds per
likelihood evaluation. For :math:`k = 4`, this becomes minutes.

For Balance Wheel, the cost is :math:`O(1)` -- the size of the output layer
increases with the number of SFS entries, but the forward pass through the MLP
is independent of the number of populations. A 5-population joint SFS takes the
same ~0.1 ms as a 1-population SFS.


Population Tree Encoding
===========================

Multi-population demographic models have a tree structure: populations split
from common ancestors, may exchange migrants, and have independent size
histories. The Demography Encoder must capture this structure.

Graph representation
----------------------

We represent the demographic model as a directed graph:

- **Nodes**: each population at a given time (e.g., "European, present" or
  "Ancestral African-European, 70 kya").
- **Descent edges**: from ancestral population to descendant population at a
  split event. Weight: 1.0 (all lineages move).
- **Migration edges**: between contemporary populations. Weight: migration
  rate :math:`m_{ij}`.

Node features encode the population's size and time:

.. math::

   \mathbf{x}_v = [\log N_{e,v},\; \log t_v,\; \text{pop\_id}_v]

The GNN from :ref:`balance_wheel_architecture` processes this graph to produce
a single embedding :math:`\mathbf{z}_\Theta` that captures the full topology.

.. code-block:: python

   def build_population_graph(demo_model):
       """Convert a demographic model to a graph for the GNN encoder.

       demo_model: dict with keys
           'pop_sizes': {pop_id: [(time, size), ...]}
           'splits': [(time, parent_pop, child1, child2)]
           'migration': [(time_start, time_end, source, target, rate)]
       """
       nodes = []
       edges = []
       edge_attrs = []

       node_id = 0
       pop_node_map = {}

       for pop_id, epochs in demo_model['pop_sizes'].items():
           for time, size in epochs:
               nodes.append([
                   float(torch.log(torch.tensor(size))),
                   float(torch.log(torch.tensor(max(time, 1.0)))),
                   float(pop_id)])
               pop_node_map[(pop_id, time)] = node_id
               node_id += 1

       for time, parent, child1, child2 in demo_model['splits']:
           p_node = pop_node_map.get((parent, time))
           c1_node = pop_node_map.get((child1, time))
           c2_node = pop_node_map.get((child2, time))
           if p_node is not None and c1_node is not None:
               edges.append([p_node, c1_node])
               edge_attrs.append([0.0])
           if p_node is not None and c2_node is not None:
               edges.append([p_node, c2_node])
               edge_attrs.append([0.0])

       for t_start, t_end, src, tgt, rate in demo_model['migration']:
           s_node = pop_node_map.get((src, t_start))
           t_node = pop_node_map.get((tgt, t_start))
           if s_node is not None and t_node is not None:
               edges.append([s_node, t_node])
               edge_attrs.append([rate])

       return {
           'node_features': torch.tensor(nodes),
           'edge_index': torch.tensor(edges).T if edges else torch.zeros(2, 0, dtype=torch.long),
           'edge_attr': torch.tensor(edge_attrs) if edge_attrs else torch.zeros(0, 1),
       }


The MultiPopSFSPredictor
===========================

The multi-population SFS Predictor extends Module 2 to output a
multi-dimensional tensor. The architecture is similar to the 1D predictor, but
the output is reshaped to match the joint SFS dimensions.

.. code-block:: python

   import torch
   import torch.nn as nn
   import torch.nn.functional as F

   class MultiPopSFSPredictor(nn.Module):
       def __init__(self, d_model=128, hidden=512, n_layers=5,
                    max_n_per_pop=50, max_pops=5):
           super().__init__()
           self.max_n = max_n_per_pop
           self.max_pops = max_pops

           max_output = max_n_per_pop ** max_pops
           layers = [nn.Linear(d_model + max_pops, hidden), nn.GELU()]
           for _ in range(n_layers - 2):
               layers.extend([nn.Linear(hidden, hidden), nn.GELU()])
           layers.append(nn.Linear(hidden, max_output))
           self.mlp = nn.Sequential(*layers)

       def forward(self, z, sample_sizes, theta_L):
           """
           z: (batch, d_model) — demographic embedding from GNN
           sample_sizes: list of int — [n1, n2, ...] per population
           theta_L: float — θ · L scaling factor
           """
           k = len(sample_sizes)
           n_input = torch.zeros(z.shape[0], self.max_pops,
                                 device=z.device)
           for i, n in enumerate(sample_sizes):
               n_input[:, i] = n / self.max_n

           raw = self.mlp(torch.cat([z, n_input], dim=-1))

           sfs_shape = tuple(n - 1 for n in sample_sizes)
           n_entries = 1
           for s in sfs_shape:
               n_entries *= s

           raw = raw[:, :n_entries]
           sfs = F.softmax(raw, dim=-1) * theta_L
           sfs = sfs.reshape(-1, *sfs_shape)
           return sfs

.. admonition:: Output dimension scaling

   The maximum output dimension is :math:`n_{\max}^k`, which can be large for
   many populations with large sample sizes. In practice, we limit
   :math:`n_{\max} = 20` and :math:`k \leq 5`, giving at most
   :math:`19^5 \approx 2.5 \times 10^6` output entries. For :math:`k = 5`,
   the MLP's last layer has 2.5M output neurons -- large but feasible on a
   modern GPU. For larger problems, the output can be factored using a
   low-rank decomposition (see below).


Factored Output for High Dimensions
--------------------------------------

For :math:`k \geq 4` populations, the joint SFS can be factored into a sum of
rank-one tensors, reducing the output dimension from :math:`O(n^k)` to
:math:`O(R \cdot k \cdot n)` where :math:`R` is the rank:

.. math::

   \hat{M}_{j_1, \ldots, j_k} \approx \sum_{r=1}^{R}
   \prod_{i=1}^{k} f_r^{(i)}(j_i)

where :math:`f_r^{(i)}` are rank components predicted by the MLP. This
CP-decomposition approach trades accuracy for scalability.

.. code-block:: python

   class FactoredMultiPopSFSPredictor(nn.Module):
       def __init__(self, d_model=128, hidden=256, rank=32,
                    max_n_per_pop=50, max_pops=5):
           super().__init__()
           self.rank = rank
           self.max_n = max_n_per_pop
           self.max_pops = max_pops
           self.shared_mlp = nn.Sequential(
               nn.Linear(d_model + max_pops, hidden), nn.GELU(),
               nn.Linear(hidden, hidden), nn.GELU())
           self.factor_heads = nn.ModuleList([
               nn.Linear(hidden, rank * max_n_per_pop)
               for _ in range(max_pops)])

       def forward(self, z, sample_sizes, theta_L):
           k = len(sample_sizes)
           n_input = torch.zeros(z.shape[0], self.max_pops,
                                 device=z.device)
           for i, n in enumerate(sample_sizes):
               n_input[:, i] = n / self.max_n

           shared = self.shared_mlp(torch.cat([z, n_input], dim=-1))

           factors = []
           for i in range(k):
               raw = self.factor_heads[i](shared)
               raw = raw.reshape(-1, self.rank, self.max_n)
               raw = raw[:, :, :sample_sizes[i] - 1]
               factors.append(F.softmax(raw, dim=-1))

           sfs = torch.zeros(
               z.shape[0], *[n - 1 for n in sample_sizes],
               device=z.device)
           for r in range(self.rank):
               component = factors[0][:, r, :]
               for i in range(1, k):
                   component = component.unsqueeze(-1) * \
                       factors[i][:, r, :].unsqueeze(-2)
               sfs = sfs + component.reshape(sfs.shape)

           return sfs * theta_L


Training for Multi-Population Models
=======================================

Training the multi-population predictor requires computing joint SFS values
with the teacher for many random topologies. This is the most expensive part of
Balance Wheel's training pipeline, but it is a one-time cost.

Training data generation
--------------------------

.. code-block:: python

   import moments
   import numpy as np

   def sample_multi_pop_demography(rng, k=2):
       """Sample a random k-population demographic model."""
       pop_sizes = {}
       for i in range(k):
           n_epochs = rng.integers(1, 4)
           sizes = np.exp(rng.normal(0, 1, size=n_epochs))
           times = np.sort(np.concatenate([[0],
               rng.exponential(0.5, size=n_epochs - 1)]))
           pop_sizes[i] = list(zip(times, sizes))

       n_splits = k - 1
       split_times = np.sort(rng.exponential(1.0, size=n_splits))[::-1]
       splits = []
       available = list(range(k))
       for s in range(n_splits):
           c1, c2 = available[:2]
           parent = k + s
           splits.append((split_times[s], parent, c1, c2))
           available = [parent] + available[2:]

       migration = []
       n_mig = rng.integers(0, 3)
       for _ in range(n_mig):
           src, tgt = rng.choice(k, size=2, replace=False)
           rate = 10 ** rng.uniform(-4, -1)
           migration.append((0, split_times[0], int(src), int(tgt),
                             rate))

       return {
           'pop_sizes': pop_sizes,
           'splits': splits,
           'migration': migration,
       }

   def compute_2pop_teacher_sfs(demo, n1=20, n2=20, theta=1.0):
       """Compute exact 2D joint SFS using moments."""
       ns = [n1, n2]
       fs = moments.Spectrum(np.zeros([n1 + 1, n2 + 1]))
       # ... (build moments demographic model from demo dict)
       # ... (integrate to compute joint SFS)
       return fs.data[1:-1, 1:-1]

The training loop is identical to Phase 1 for single populations, but with
multi-dimensional SFS targets:

.. list-table:: Training cost for multi-population models
   :header-rows: 1
   :widths: 15 25 25 20 15

   * - :math:`k`
     - Teacher cost per example
     - Examples needed
     - Training data time
     - Network training
   * - 2
     - ~500 ms (moments)
     - 200k
     - ~28 hours
     - ~2 hours
   * - 3
     - ~60 s (moments)
     - 500k
     - ~1 month (parallelize!)
     - ~4 hours
   * - 4
     - ~30 min (moments)
     - 500k
     - Impractical directly
     - ~8 hours

For :math:`k \geq 3`, computing the teacher SFS is expensive. The strategy is
to generate the training data in parallel across many CPUs (moments is
single-threaded but embarrassingly parallelizable across different
:math:`\Theta` values). For :math:`k = 3`, distributing across 100 CPUs
reduces the data generation time from 1 month to ~8 hours.

.. admonition:: The one-time investment

   The multi-population training cost is high -- but it is paid once. After
   training, the neural predictor evaluates the joint SFS in 0.1 ms for any
   topology within the training distribution. A single 3-population
   demographic inference with moments takes ~60 s per likelihood evaluation;
   with Balance Wheel, it takes 0.1 ms. The 28-hour training investment is
   repaid after ~1,700 likelihood evaluations -- less than a single HMC run.


Example: 3-Population Demographic Inference
==============================================

Consider a three-population model (Africa, Europe, East Asia) with:

- :math:`n_1 = n_2 = n_3 = 20` samples per population.
- Split times :math:`T_{\text{Eur-EAs}}` (European-East Asian split) and
  :math:`T_{\text{Afr-nonAfr}}` (African-non-African split).
- Population sizes: ancestral :math:`N_a`, African :math:`N_{\text{Afr}}`,
  bottleneck :math:`N_b`, European :math:`N_{\text{Eur}}`, East Asian
  :math:`N_{\text{EAs}}`.
- Migration: :math:`m_{\text{Eur-EAs}}` between Europe and East Asia after
  their split.

The joint SFS is a :math:`19 \times 19 \times 19 = 6{,}859`-entry tensor.

With moments, each SFS evaluation takes ~60 s. HMC with 10,000 steps would
require 600,000 s = ~7 days. Profile likelihood on a 7-parameter grid is
barely feasible.

With Balance Wheel:

.. code-block:: python

   model = BalanceWheelMultiPop(d_model=128, max_pops=3, max_n=20)

   observed_joint_sfs = load_joint_sfs("three_pop_data.fs")

   result = run_balance_wheel_hmc(
       model, observed_joint_sfs,
       theta_L=5000.0,
       n_epochs_model=3,
       n_samples=10000,
       warmup=2000)

   print(f"N_ancestral: {result['median'][0]:.0f} "
         f"({result['ci_95'][0][0]:.0f} - {result['ci_95'][0][1]:.0f})")
   print(f"T_Afr-nonAfr: {result['median'][-2]:.0f} generations")
   print(f"T_Eur-EAs: {result['median'][-1]:.0f} generations")

Each HMC step takes ~2 ms. Ten thousand steps: 20 seconds. Full posterior over
all 7 demographic parameters -- including credible intervals, correlations, and
posterior predictive checks -- in under a minute. This is the task that would
take a week with moments.

.. admonition:: The scaling advantage in context

   The 3-population example illustrates Balance Wheel's fundamental advantage:
   the neural forward pass does not care about the number of populations. The
   MLP has a larger output layer (6,859 vs. 19 entries), but the forward pass
   through the hidden layers is the same size. The cost increase from 1 to 3
   populations is ~2× (larger output layer), not 19²× (as for moments).
   This sub-exponential scaling is what makes multi-population Bayesian
   inference practical.

   For 4 or 5 populations, where moments and dadi are completely impractical,
   Balance Wheel may be the only viable path to the joint SFS. But recall the
   limitation: the network must be trained on topologies representative of the
   true model. If you train on 3-population models and then analyze a
   4-population dataset, the network will fail. Retraining is required when
   the number of populations changes.
