Posterior Inference via HMC

The chronometer test does not ask whether a watch keeps perfect time. It asks whether the watchmaker knows how imperfect the time is. A movement rated to ±2 seconds per day is more trustworthy than one claimed to be exact – the first acknowledges its uncertainty, the second conceals it.

dadi and moments find the maximum likelihood estimate (MLE) of demographic parameters. They report a point estimate – the single \(\hat{\Theta}\) that maximizes the Poisson log-likelihood – and, at best, profile likelihood intervals or AIC scores for model comparison. They do not produce a full posterior distribution over \(\Theta\).

The reason is computational: posterior sampling requires thousands of likelihood evaluations, and each evaluation costs 10–100 ms with moments or dadi. For Hamiltonian Monte Carlo (HMC), which needs gradients, the cost per step is even higher: 10 gradient evaluations × 10–100 ms × 2 (for finite differences in dadi) = 200 ms to 2 s per HMC step. Running 10,000 HMC steps would take hours to days.

Balance Wheel changes this arithmetic. At 0.1 ms per evaluation with exact backpropagation gradients, each HMC step costs ~1 ms. Ten thousand steps take 10 seconds. Full Bayesian posterior inference becomes routine.

Why Bayesian Inference Matters

The MLE \(\hat{\Theta}\) tells you the most likely parameter values, but not how confident you should be. Two datasets might yield the same MLE but very different posterior widths – one tightly constraining the bottleneck time, the other barely identifying it. Without the posterior, you cannot distinguish these cases.

Bayesian inference gives you:

  1. Credible intervals on each parameter. The 95% credible interval \([a, b]\) means: given the data and the model, the parameter falls in \([a, b]\) with 95% posterior probability.

  2. Posterior correlations. Demographic parameters are often correlated – e.g., bottleneck depth and duration trade off against each other. The joint posterior reveals these correlations; the MLE hides them.

  3. Posterior predictive checks. Sample parameters from the posterior, compute the expected SFS for each sample, and compare to the observed SFS. Systematic discrepancies indicate model misspecification.

  4. Model comparison via marginal likelihood. The marginal likelihood \(P(\mathbf{D} \mid \mathcal{M})\) integrates the likelihood over the prior – a principled metric for comparing models of different complexity (e.g., two-epoch vs. three-epoch models).

The Log-Posterior

The posterior is proportional to the likelihood times the prior:

\[\log p(\Theta \mid \mathbf{D}) = \underbrace{\ell(\Theta)}_{\text{Poisson log-likelihood}} + \underbrace{\log \pi(\Theta)}_{\text{prior}} - \underbrace{\log P(\mathbf{D})}_{\text{evidence (constant)}}\]

The log-likelihood \(\ell(\Theta)\) is evaluated through Balance Wheel’s neural SFS Predictor. The prior \(\pi(\Theta)\) encodes our beliefs about plausible demographic parameters before seeing the data.

Priors on demographic parameters

We use weakly informative priors that constrain parameters to physically plausible ranges without being overly prescriptive:

\[\begin{split}\log(N_e / N_{\text{ref}}) &\sim \mathcal{N}(0, 2) \\ \log(t_k) &\sim \mathcal{N}(\mu_t, 1) \\ m_{ij} &\sim \text{Exponential}(10)\end{split}\]

where \(N_{\text{ref}}\) is a reference population size and \(\mu_t\) is a prior mean for the log-time (typically set to the expected TMRCA).

import torch
from torch.distributions import Normal, Exponential

def log_prior(log_sizes, log_times):
    """Weakly informative prior on demographic parameters."""
    lp = Normal(0, 2).log_prob(log_sizes).sum()
    lp += Normal(0, 1).log_prob(log_times).sum()
    return lp

def log_posterior(params, observed_sfs, model, theta_L):
    """Unnormalized log-posterior for HMC sampling."""
    log_sizes, log_times = params[:len(params)//2], params[len(params)//2:]
    log_times_sorted = torch.sort(log_times)[0]
    log_lik = model.log_likelihood(
        log_sizes.unsqueeze(0), log_times_sorted.unsqueeze(0),
        observed_sfs, theta_L)
    lp = log_prior(log_sizes, log_times_sorted)
    return log_lik + lp

Speed Comparison

The advantage of Balance Wheel for posterior sampling is quantitative:

Cost per HMC step (10 leapfrog steps, 8 parameters)

Method

Per eval

Gradient

Per HMC step

10k steps

dadi (finite diff)

100 ms

\(2 \times 8 \times 100\) ms

16 s

44 hours

moments (AD)

10 ms

20 ms

200 ms

33 min

Balance Wheel

0.1 ms

0.2 ms

2 ms

20 s

Balance Wheel is 100× faster than moments and 8,000× faster than dadi for HMC sampling. This is the difference between “theoretically possible but never done” and “run it during a coffee break.”

HMC/NUTS Implementation

Hamiltonian Monte Carlo (see Markov Chain Monte Carlo for the prerequisite) samples from the posterior by simulating Hamiltonian dynamics on the log-posterior surface. The No-U-Turn Sampler (NUTS) automatically tunes the number of leapfrog steps, eliminating the need for manual tuning.

The key requirement is a differentiable log-posterior – which Balance Wheel provides via backpropagation through the neural SFS Predictor.

import torch

class HMCSampler:
    def __init__(self, log_prob_fn, step_size=0.01, n_leapfrog=10):
        self.log_prob_fn = log_prob_fn
        self.step_size = step_size
        self.n_leapfrog = n_leapfrog

    def _leapfrog(self, q, p):
        q = q.detach().requires_grad_(True)
        log_prob = self.log_prob_fn(q)
        grad = torch.autograd.grad(log_prob, q)[0]
        p = p + 0.5 * self.step_size * grad

        for _ in range(self.n_leapfrog - 1):
            q = q + self.step_size * p
            q = q.detach().requires_grad_(True)
            log_prob = self.log_prob_fn(q)
            grad = torch.autograd.grad(log_prob, q)[0]
            p = p + self.step_size * grad

        q = q + self.step_size * p
        q = q.detach().requires_grad_(True)
        log_prob = self.log_prob_fn(q)
        grad = torch.autograd.grad(log_prob, q)[0]
        p = p + 0.5 * self.step_size * grad

        return q, p, log_prob

    def sample(self, q_init, n_samples=5000, warmup=1000):
        q = q_init.clone()
        samples = []
        log_probs = []
        n_accept = 0

        for i in range(n_samples + warmup):
            p = torch.randn_like(q)
            current_log_prob = self.log_prob_fn(
                q.detach().requires_grad_(True))
            current_K = 0.5 * (p ** 2).sum()

            q_new, p_new, new_log_prob = self._leapfrog(q, p)
            new_K = 0.5 * (p_new ** 2).sum()

            log_accept = (new_log_prob - current_log_prob
                          + current_K - new_K)

            if torch.log(torch.rand(1)) < log_accept:
                q = q_new.detach()
                n_accept += 1

            if i >= warmup:
                samples.append(q.clone())
                log_probs.append(new_log_prob.item())

        accept_rate = n_accept / (n_samples + warmup)
        return torch.stack(samples), log_probs, accept_rate

For production use, we recommend using a NUTS implementation (e.g., from NumPyro or PyTorch’s ecosystem) that automatically tunes the step size during warmup:

def run_balance_wheel_hmc(model, observed_sfs, theta_L,
                           n_epochs_model=4, n_samples=5000,
                           warmup=1000, device='cuda'):
    """Full posterior inference pipeline."""
    n = observed_sfs.shape[0] + 1
    observed_sfs = observed_sfs.float().to(device)
    model.eval().to(device)

    init_sizes = torch.zeros(n_epochs_model, device=device)
    init_times = torch.linspace(-2, 2, n_epochs_model, device=device)
    q_init = torch.cat([init_sizes, init_times])

    def log_prob(params):
        return log_posterior(params, observed_sfs, model, theta_L)

    sampler = HMCSampler(log_prob, step_size=0.005, n_leapfrog=10)
    samples, log_probs, accept_rate = sampler.sample(
        q_init, n_samples=n_samples, warmup=warmup)

    half = n_epochs_model
    size_samples = torch.exp(samples[:, :half])
    time_samples = torch.exp(torch.sort(samples[:, half:])[0])

    return {
        'size_samples': size_samples.cpu(),
        'time_samples': time_samples.cpu(),
        'log_probs': log_probs,
        'accept_rate': accept_rate,
    }

What You Get from the Posterior

Credible intervals

The 95% credible interval for each parameter is simply the 2.5th and 97.5th percentiles of the posterior samples:

def credible_intervals(samples, level=0.95):
    alpha = (1 - level) / 2
    lower = torch.quantile(samples, alpha, dim=0)
    upper = torch.quantile(samples, 1 - alpha, dim=0)
    median = torch.quantile(samples, 0.5, dim=0)
    return {'median': median, 'lower': lower, 'upper': upper}

Posterior predictive checks

For each posterior sample \(\Theta^{(s)}\), compute the expected SFS and compare to the observed SFS. Systematic discrepancies indicate model misspecification:

def posterior_predictive_check(model, samples, observed_sfs, theta_L):
    """Generate posterior predictive SFS distribution."""
    n = observed_sfs.shape[0] + 1
    predictive_sfs = []
    n_epochs = samples['size_samples'].shape[1]

    for i in range(min(1000, len(samples['size_samples']))):
        log_s = torch.log(samples['size_samples'][i]).unsqueeze(0)
        log_t = torch.log(samples['time_samples'][i]).unsqueeze(0)
        with torch.no_grad():
            pred = model(log_s, log_t, n, theta_L)
        predictive_sfs.append(pred.squeeze(0))

    predictive_sfs = torch.stack(predictive_sfs)
    return {
        'mean': predictive_sfs.mean(dim=0),
        'std': predictive_sfs.std(dim=0),
        'quantile_025': torch.quantile(predictive_sfs, 0.025, dim=0),
        'quantile_975': torch.quantile(predictive_sfs, 0.975, dim=0),
    }

If the observed SFS falls outside the 95% predictive interval for many frequency classes, the model is likely misspecified – the demographic model cannot produce an SFS that looks like the data.

Model comparison

The marginal likelihood \(P(\mathbf{D} \mid \mathcal{M})\) can be estimated from the posterior samples using the harmonic mean estimator (crude but fast) or bridge sampling (more reliable):

\[P(\mathbf{D} \mid \mathcal{M}) \approx \left[ \frac{1}{S} \sum_{s=1}^{S} \frac{1}{P(\mathbf{D} \mid \Theta^{(s)})} \right]^{-1}\]

where \(\Theta^{(s)} \sim p(\Theta \mid \mathbf{D})\) are posterior samples. This enables Bayes factor comparison between, e.g., a two-epoch and a three-epoch model – a principled alternative to AIC that accounts for posterior uncertainty.

Marginal likelihood vs. AIC

dadi and moments typically use AIC for model comparison: \(\text{AIC} = -2\ell(\hat{\Theta}) + 2k\) where \(k\) is the number of parameters. AIC is a frequentist approximation to the marginal likelihood. It works when the likelihood is approximately Gaussian near the MLE and the sample size is large relative to \(k\). For complex demographic models with correlated parameters and multimodal likelihoods, AIC can be misleading. The marginal likelihood, estimated from posterior samples, is more reliable but requires the posterior samples that only Balance Wheel can provide efficiently.

Comparison with Profile Likelihood

Profile likelihood is what dadi and moments offer as an approximation to posterior uncertainty. For a single parameter \(\Theta_i\), the profile likelihood is:

\[\ell_{\text{profile}}(\Theta_i) = \max_{\Theta_{-i}} \ell(\Theta_i, \Theta_{-i})\]

where \(\Theta_{-i}\) denotes all parameters except \(\Theta_i\). A 95% confidence interval is constructed by finding the values of \(\Theta_i\) where the profile likelihood drops by \(\chi^2_{1,0.95}/2 = 1.92\) from the maximum.

Profile likelihood vs. full posterior

Property

Profile likelihood (dadi/moments)

Full posterior (Balance Wheel)

What it estimates

Confidence interval (frequentist)

Credible interval (Bayesian)

Captures correlations

No (marginalizes by maximization)

Yes (joint posterior)

Model comparison

AIC (point estimate)

Marginal likelihood (integrated)

Computational cost

Moderate (grid search per parameter)

Low with Balance Wheel (~20 s)

Handles multimodality

Poorly (finds local maximum)

Yes (HMC explores modes)

Prior information

Not incorporated

Naturally incorporated

The full posterior is strictly more informative than the profile likelihood. It provides everything the profile likelihood provides (marginal intervals) plus joint distributions, correlations, predictive checks, and marginal likelihoods. The only reason it was not used before is computational cost – and Balance Wheel removes that barrier.

When profile likelihood is sufficient

For well-identified models with approximately Gaussian posteriors (e.g., a two-population split-time model with large sample sizes), the profile likelihood and the posterior credible interval will agree closely. In this regime, the extra cost of HMC sampling may not be justified. Use Balance Wheel’s full posterior when (1) parameters are correlated, (2) the posterior is multimodal or skewed, (3) you need model comparison, or (4) you want posterior predictive checks. Use profile likelihood (via moments directly) when the model is simple and you trust the Gaussian approximation.