Handling Multiple Populations

The grande complication of watchmaking is not a single mechanism but many – chronograph, perpetual calendar, minute repeater – each interacting with the others through shared wheels and cams. The artistry lies not in making each complication work alone, but in making them work together without interference. The multi-population SFS is the grande complication of demographic inference: each population has its own size history, but the populations interact through splits, merges, and migration.

This is where Balance Wheel truly distinguishes itself from the classical methods it replaces. For two or more populations, the joint SFS is a multi-dimensional tensor, and the classical solvers scale exponentially with the number of populations. Balance Wheel’s neural SFS Predictor scales as \(O(1)\) – a single forward pass regardless of the number of populations.

The Curse of Dimensionality

For \(k\) populations with sample sizes \(n_1, n_2, \ldots, n_k\), the joint SFS is a \(k\)-dimensional tensor with \(\prod_{i=1}^k (n_i - 1)\) entries. Each entry \(M_{j_1, j_2, \ldots, j_k}\) is the expected number of SNPs where population \(i\) has derived allele frequency \(j_i / n_i\).

Joint SFS dimensions

Populations

Sample sizes

SFS entries

dadi cost

moments cost

1

\(n = 20\)

19

\(O(G)\)

\(O(n)\)

2

\(n_1 = n_2 = 20\)

\(19 \times 19 = 361\)

\(O(G^2)\)

\(O(n^2)\)

3

\(n_i = 20\)

\(19^3 = 6{,}859\)

\(O(G^3)\)

\(O(n^3)\)

4

\(n_i = 10\)

\(9^4 = 6{,}561\)

Impractical

\(O(n^4)\) – hours

5

\(n_i = 10\)

\(9^5 = 59{,}049\)

Impractical

Impractical

For dadi, the cost per time step scales as \(O(G^k)\), where \(G\) is the grid resolution (typically 40–60 for accurate results). For \(k = 3\) with \(G = 40\), this is \(64{,}000\) operations per step – feasible but slow. For \(k = 4\), it is \(2{,}560{,}000\) – hours per evaluation. dadi’s authors explicitly recommend \(k \leq 3\).

For moments, the cost per ODE step scales as \(O(\prod n_i)\) times the cost of applying the drift operator along each axis. For \(k = 3\) with \(n = 20\), a single ODE step processes \(\sim 7{,}000\) SFS entries along three axes – about 10 seconds per likelihood evaluation. For \(k = 4\), this becomes minutes.

For Balance Wheel, the cost is \(O(1)\) – the size of the output layer increases with the number of SFS entries, but the forward pass through the MLP is independent of the number of populations. A 5-population joint SFS takes the same ~0.1 ms as a 1-population SFS.

Population Tree Encoding

Multi-population demographic models have a tree structure: populations split from common ancestors, may exchange migrants, and have independent size histories. The Demography Encoder must capture this structure.

Graph representation

We represent the demographic model as a directed graph:

  • Nodes: each population at a given time (e.g., “European, present” or “Ancestral African-European, 70 kya”).

  • Descent edges: from ancestral population to descendant population at a split event. Weight: 1.0 (all lineages move).

  • Migration edges: between contemporary populations. Weight: migration rate \(m_{ij}\).

Node features encode the population’s size and time:

\[\mathbf{x}_v = [\log N_{e,v},\; \log t_v,\; \text{pop\_id}_v]\]

The GNN from Architecture processes this graph to produce a single embedding \(\mathbf{z}_\Theta\) that captures the full topology.

def build_population_graph(demo_model):
    """Convert a demographic model to a graph for the GNN encoder.

    demo_model: dict with keys
        'pop_sizes': {pop_id: [(time, size), ...]}
        'splits': [(time, parent_pop, child1, child2)]
        'migration': [(time_start, time_end, source, target, rate)]
    """
    nodes = []
    edges = []
    edge_attrs = []

    node_id = 0
    pop_node_map = {}

    for pop_id, epochs in demo_model['pop_sizes'].items():
        for time, size in epochs:
            nodes.append([
                float(torch.log(torch.tensor(size))),
                float(torch.log(torch.tensor(max(time, 1.0)))),
                float(pop_id)])
            pop_node_map[(pop_id, time)] = node_id
            node_id += 1

    for time, parent, child1, child2 in demo_model['splits']:
        p_node = pop_node_map.get((parent, time))
        c1_node = pop_node_map.get((child1, time))
        c2_node = pop_node_map.get((child2, time))
        if p_node is not None and c1_node is not None:
            edges.append([p_node, c1_node])
            edge_attrs.append([0.0])
        if p_node is not None and c2_node is not None:
            edges.append([p_node, c2_node])
            edge_attrs.append([0.0])

    for t_start, t_end, src, tgt, rate in demo_model['migration']:
        s_node = pop_node_map.get((src, t_start))
        t_node = pop_node_map.get((tgt, t_start))
        if s_node is not None and t_node is not None:
            edges.append([s_node, t_node])
            edge_attrs.append([rate])

    return {
        'node_features': torch.tensor(nodes),
        'edge_index': torch.tensor(edges).T if edges else torch.zeros(2, 0, dtype=torch.long),
        'edge_attr': torch.tensor(edge_attrs) if edge_attrs else torch.zeros(0, 1),
    }

The MultiPopSFSPredictor

The multi-population SFS Predictor extends Module 2 to output a multi-dimensional tensor. The architecture is similar to the 1D predictor, but the output is reshaped to match the joint SFS dimensions.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiPopSFSPredictor(nn.Module):
    def __init__(self, d_model=128, hidden=512, n_layers=5,
                 max_n_per_pop=50, max_pops=5):
        super().__init__()
        self.max_n = max_n_per_pop
        self.max_pops = max_pops

        max_output = max_n_per_pop ** max_pops
        layers = [nn.Linear(d_model + max_pops, hidden), nn.GELU()]
        for _ in range(n_layers - 2):
            layers.extend([nn.Linear(hidden, hidden), nn.GELU()])
        layers.append(nn.Linear(hidden, max_output))
        self.mlp = nn.Sequential(*layers)

    def forward(self, z, sample_sizes, theta_L):
        """
        z: (batch, d_model) — demographic embedding from GNN
        sample_sizes: list of int — [n1, n2, ...] per population
        theta_L: float — θ · L scaling factor
        """
        k = len(sample_sizes)
        n_input = torch.zeros(z.shape[0], self.max_pops,
                              device=z.device)
        for i, n in enumerate(sample_sizes):
            n_input[:, i] = n / self.max_n

        raw = self.mlp(torch.cat([z, n_input], dim=-1))

        sfs_shape = tuple(n - 1 for n in sample_sizes)
        n_entries = 1
        for s in sfs_shape:
            n_entries *= s

        raw = raw[:, :n_entries]
        sfs = F.softmax(raw, dim=-1) * theta_L
        sfs = sfs.reshape(-1, *sfs_shape)
        return sfs

Output dimension scaling

The maximum output dimension is \(n_{\max}^k\), which can be large for many populations with large sample sizes. In practice, we limit \(n_{\max} = 20\) and \(k \leq 5\), giving at most \(19^5 \approx 2.5 \times 10^6\) output entries. For \(k = 5\), the MLP’s last layer has 2.5M output neurons – large but feasible on a modern GPU. For larger problems, the output can be factored using a low-rank decomposition (see below).

Factored Output for High Dimensions

For \(k \geq 4\) populations, the joint SFS can be factored into a sum of rank-one tensors, reducing the output dimension from \(O(n^k)\) to \(O(R \cdot k \cdot n)\) where \(R\) is the rank:

\[\hat{M}_{j_1, \ldots, j_k} \approx \sum_{r=1}^{R} \prod_{i=1}^{k} f_r^{(i)}(j_i)\]

where \(f_r^{(i)}\) are rank components predicted by the MLP. This CP-decomposition approach trades accuracy for scalability.

class FactoredMultiPopSFSPredictor(nn.Module):
    def __init__(self, d_model=128, hidden=256, rank=32,
                 max_n_per_pop=50, max_pops=5):
        super().__init__()
        self.rank = rank
        self.max_n = max_n_per_pop
        self.max_pops = max_pops
        self.shared_mlp = nn.Sequential(
            nn.Linear(d_model + max_pops, hidden), nn.GELU(),
            nn.Linear(hidden, hidden), nn.GELU())
        self.factor_heads = nn.ModuleList([
            nn.Linear(hidden, rank * max_n_per_pop)
            for _ in range(max_pops)])

    def forward(self, z, sample_sizes, theta_L):
        k = len(sample_sizes)
        n_input = torch.zeros(z.shape[0], self.max_pops,
                              device=z.device)
        for i, n in enumerate(sample_sizes):
            n_input[:, i] = n / self.max_n

        shared = self.shared_mlp(torch.cat([z, n_input], dim=-1))

        factors = []
        for i in range(k):
            raw = self.factor_heads[i](shared)
            raw = raw.reshape(-1, self.rank, self.max_n)
            raw = raw[:, :, :sample_sizes[i] - 1]
            factors.append(F.softmax(raw, dim=-1))

        sfs = torch.zeros(
            z.shape[0], *[n - 1 for n in sample_sizes],
            device=z.device)
        for r in range(self.rank):
            component = factors[0][:, r, :]
            for i in range(1, k):
                component = component.unsqueeze(-1) * \
                    factors[i][:, r, :].unsqueeze(-2)
            sfs = sfs + component.reshape(sfs.shape)

        return sfs * theta_L

Training for Multi-Population Models

Training the multi-population predictor requires computing joint SFS values with the teacher for many random topologies. This is the most expensive part of Balance Wheel’s training pipeline, but it is a one-time cost.

Training data generation

import moments
import numpy as np

def sample_multi_pop_demography(rng, k=2):
    """Sample a random k-population demographic model."""
    pop_sizes = {}
    for i in range(k):
        n_epochs = rng.integers(1, 4)
        sizes = np.exp(rng.normal(0, 1, size=n_epochs))
        times = np.sort(np.concatenate([[0],
            rng.exponential(0.5, size=n_epochs - 1)]))
        pop_sizes[i] = list(zip(times, sizes))

    n_splits = k - 1
    split_times = np.sort(rng.exponential(1.0, size=n_splits))[::-1]
    splits = []
    available = list(range(k))
    for s in range(n_splits):
        c1, c2 = available[:2]
        parent = k + s
        splits.append((split_times[s], parent, c1, c2))
        available = [parent] + available[2:]

    migration = []
    n_mig = rng.integers(0, 3)
    for _ in range(n_mig):
        src, tgt = rng.choice(k, size=2, replace=False)
        rate = 10 ** rng.uniform(-4, -1)
        migration.append((0, split_times[0], int(src), int(tgt),
                          rate))

    return {
        'pop_sizes': pop_sizes,
        'splits': splits,
        'migration': migration,
    }

def compute_2pop_teacher_sfs(demo, n1=20, n2=20, theta=1.0):
    """Compute exact 2D joint SFS using moments."""
    ns = [n1, n2]
    fs = moments.Spectrum(np.zeros([n1 + 1, n2 + 1]))
    # ... (build moments demographic model from demo dict)
    # ... (integrate to compute joint SFS)
    return fs.data[1:-1, 1:-1]

The training loop is identical to Phase 1 for single populations, but with multi-dimensional SFS targets:

Training cost for multi-population models

\(k\)

Teacher cost per example

Examples needed

Training data time

Network training

2

~500 ms (moments)

200k

~28 hours

~2 hours

3

~60 s (moments)

500k

~1 month (parallelize!)

~4 hours

4

~30 min (moments)

500k

Impractical directly

~8 hours

For \(k \geq 3\), computing the teacher SFS is expensive. The strategy is to generate the training data in parallel across many CPUs (moments is single-threaded but embarrassingly parallelizable across different \(\Theta\) values). For \(k = 3\), distributing across 100 CPUs reduces the data generation time from 1 month to ~8 hours.

The one-time investment

The multi-population training cost is high – but it is paid once. After training, the neural predictor evaluates the joint SFS in 0.1 ms for any topology within the training distribution. A single 3-population demographic inference with moments takes ~60 s per likelihood evaluation; with Balance Wheel, it takes 0.1 ms. The 28-hour training investment is repaid after ~1,700 likelihood evaluations – less than a single HMC run.

Example: 3-Population Demographic Inference

Consider a three-population model (Africa, Europe, East Asia) with:

  • \(n_1 = n_2 = n_3 = 20\) samples per population.

  • Split times \(T_{\text{Eur-EAs}}\) (European-East Asian split) and \(T_{\text{Afr-nonAfr}}\) (African-non-African split).

  • Population sizes: ancestral \(N_a\), African \(N_{\text{Afr}}\), bottleneck \(N_b\), European \(N_{\text{Eur}}\), East Asian \(N_{\text{EAs}}\).

  • Migration: \(m_{\text{Eur-EAs}}\) between Europe and East Asia after their split.

The joint SFS is a \(19 \times 19 \times 19 = 6{,}859\)-entry tensor.

With moments, each SFS evaluation takes ~60 s. HMC with 10,000 steps would require 600,000 s = ~7 days. Profile likelihood on a 7-parameter grid is barely feasible.

With Balance Wheel:

model = BalanceWheelMultiPop(d_model=128, max_pops=3, max_n=20)

observed_joint_sfs = load_joint_sfs("three_pop_data.fs")

result = run_balance_wheel_hmc(
    model, observed_joint_sfs,
    theta_L=5000.0,
    n_epochs_model=3,
    n_samples=10000,
    warmup=2000)

print(f"N_ancestral: {result['median'][0]:.0f} "
      f"({result['ci_95'][0][0]:.0f} - {result['ci_95'][0][1]:.0f})")
print(f"T_Afr-nonAfr: {result['median'][-2]:.0f} generations")
print(f"T_Eur-EAs: {result['median'][-1]:.0f} generations")

Each HMC step takes ~2 ms. Ten thousand steps: 20 seconds. Full posterior over all 7 demographic parameters – including credible intervals, correlations, and posterior predictive checks – in under a minute. This is the task that would take a week with moments.

The scaling advantage in context

The 3-population example illustrates Balance Wheel’s fundamental advantage: the neural forward pass does not care about the number of populations. The MLP has a larger output layer (6,859 vs. 19 entries), but the forward pass through the hidden layers is the same size. The cost increase from 1 to 3 populations is ~2× (larger output layer), not 19²× (as for moments). This sub-exponential scaling is what makes multi-population Bayesian inference practical.

For 4 or 5 populations, where moments and dadi are completely impractical, Balance Wheel may be the only viable path to the joint SFS. But recall the limitation: the network must be trained on topologies representative of the true model. If you train on 3-population models and then analyze a 4-population dataset, the network will fail. Retraining is required when the number of populations changes.