.. _mainspring_architecture:

==============
Architecture
==============

   *The movement of a grande complication is built in stages. Each stage transforms
   the energy of the mainspring into a more refined form, until the final stage
   moves the hands with perfect precision.*

Mainspring processes a genotype matrix through four stages, each producing an
increasingly refined representation of the evolutionary history encoded in the data.
The stages mirror the factorization of the full posterior:

.. math::

   p(\mathcal{A}, N_e \mid \mathbf{D}) = \underbrace{p(\mathcal{T} \mid \mathbf{D})}_{\text{topology}}
   \;\cdot\; \underbrace{p(\mathbf{t} \mid \mathcal{T}, \mathbf{D})}_{\text{dating}}
   \;\cdot\; \underbrace{p(N_e \mid \mathcal{T}, \mathbf{t})}_{\text{demography}}

where :math:`\mathcal{A} = (\mathcal{T}, \mathbf{t})` is the ARG decomposed into
topology :math:`\mathcal{T}` and node times :math:`\mathbf{t}`, and the genomic
encoder provides the shared representation from which all three factors are decoded.

.. code-block:: text

   ┌─────────────────────────────────────────────────────────┐
   │                                                         │
   │  STAGE 1: GENOMIC ENCODER                               │
   │  ┌─────────────┐   ┌─────────────┐   ┌──────────────┐  │
   │  │  Embedding   │──▶│ Set Transf. │──▶│ Sliding-Win  │  │
   │  │  (per site)  │   │ (samples)   │   │ Attn (pos.)  │  │
   │  └─────────────┘   └─────────────┘   └──────────────┘  │
   │           D ∈ {0,1}^{n×L}  ──▶  Z ∈ R^{n×L×d}         │
   │                                                         │
   ├─────────────────────────────────────────────────────────┤
   │                                                         │
   │  STAGE 2: TOPOLOGY DECODER                              │
   │  ┌──────────────┐   ┌──────────────┐  ┌─────────────┐  │
   │  │ Cross-attn   │──▶│ Breakpoint   │──▶│ Hard attn   │  │
   │  │ (Li&Stephens)│   │ detector     │   │ (Gumbel-SM) │  │
   │  └──────────────┘   └──────────────┘  └─────────────┘  │
   │           Z  ──▶  T = {(parent[], left, right)}         │
   │                                                         │
   ├─────────────────────────────────────────────────────────┤
   │                                                         │
   │  STAGE 3: DATING GNN                                    │
   │  ┌──────────────┐   ┌──────────────┐  ┌─────────────┐  │
   │  │ Node/edge    │──▶│ UP/DOWN msg  │──▶│ Gamma heads │  │
   │  │ features     │   │ passing (×K) │   │ (α_v, β_v)  │  │
   │  └──────────────┘   └──────────────┘  └─────────────┘  │
   │           T + muts  ──▶  t_v ~ Gamma(α_v, β_v)         │
   │                                                         │
   ├─────────────────────────────────────────────────────────┤
   │                                                         │
   │  STAGE 4: DEMOGRAPHIC DECODER                           │
   │  ┌──────────────┐   ┌──────────────┐  ┌─────────────┐  │
   │  │ Coalescence  │──▶│ Normalizing  │──▶│ SFS loss    │  │
   │  │ time hist.   │   │ flow         │   │ (physics)   │  │
   │  └──────────────┘   └──────────────┘  └─────────────┘  │
   │           t_v  ──▶  q(N_e(t))                           │
   │                                                         │
   └─────────────────────────────────────────────────────────┘


Stage 1: Genomic Encoder
==========================

The encoder transforms the raw genotype matrix :math:`\mathbf{D} \in \{0,1\}^{n \times L}`
into a dense representation :math:`\mathbf{Z} \in \mathbb{R}^{n \times L \times d}`
that captures both inter-sample relationships and spatial correlations along the
genome.

Embedding Layer
-----------------

Each site is embedded independently. The input at site :math:`\ell` is the column
vector :math:`\mathbf{d}_\ell = (d_{1,\ell}, \ldots, d_{n,\ell})^\top \in \{0,1\}^n`.
Each sample's binary allele is embedded into :math:`\mathbb{R}^d`:

.. math::

   \mathbf{e}_{i,\ell} = \mathbf{W}_{\text{allele}}[d_{i,\ell}] + \text{RFF}(\ell)
   + \mathbf{W}_{\text{freq}} \cdot \hat{f}_\ell

where :math:`\mathbf{W}_{\text{allele}} \in \mathbb{R}^{2 \times d}` is an allele
embedding table, :math:`\text{RFF}(\ell)` is a random Fourier feature positional
encoding (see :ref:`mainspring_design_principles`, Principle 8), and
:math:`\hat{f}_\ell = \frac{1}{n}\sum_i d_{i,\ell}` is the sample allele frequency
at site :math:`\ell`, projected through :math:`\mathbf{W}_{\text{freq}} \in \mathbb{R}^d`.

.. code-block:: python

   class GenomicEmbedding(nn.Module):
       def __init__(self, d_model, rff_sigma=10.0):
           super().__init__()
           self.allele_embed = nn.Embedding(2, d_model)
           self.freq_proj = nn.Linear(1, d_model, bias=False)
           self.rff = RandomFourierPositionalEncoding(d_model, sigma=rff_sigma)

       def forward(self, D):
           B, n, L = D.shape
           allele = self.allele_embed(D)                          # (B, n, L, d)
           positions = torch.arange(L, device=D.device).float()
           pos_enc = self.rff(positions)                           # (L, d)
           freq = D.float().mean(dim=1, keepdim=True).unsqueeze(-1)  # (B, 1, L, 1)
           freq_enc = self.freq_proj(freq)                         # (B, 1, L, d)
           return allele + pos_enc.unsqueeze(0).unsqueeze(0) + freq_enc

Set Transformer over Samples
-------------------------------

At each site, the :math:`n` sample embeddings are processed by an induced set
attention block (ISAB) that is permutation-equivariant over the sample dimension.
This implements Principle 2 from :ref:`mainspring_design_principles`.

The ISAB uses :math:`m` inducing points to reduce the :math:`O(n^2)` cost of
full self-attention to :math:`O(nm)`:

.. math::

   \mathbf{H} = \text{ISAB}_m(\mathbf{E}_\ell) = \text{MAB}(\mathbf{E}_\ell,\;
   \text{MAB}(\mathbf{I}, \mathbf{E}_\ell))

where :math:`\text{MAB}(\mathbf{X}, \mathbf{Y}) = \text{LayerNorm}(\mathbf{X} +
\text{MultiheadAttention}(\mathbf{X}, \mathbf{Y}, \mathbf{Y}))` is a multihead
attention block, :math:`\mathbf{I} \in \mathbb{R}^{m \times d}` are learned inducing
points, and :math:`\mathbf{E}_\ell \in \mathbb{R}^{n \times d}` are the sample
embeddings at site :math:`\ell`.

.. code-block:: python

   class SampleEncoder(nn.Module):
       def __init__(self, d_model, n_heads, n_inducing, n_layers):
           super().__init__()
           self.layers = nn.ModuleList([
               InducedSetAttention(d_model, n_heads, n_inducing)
               for _ in range(n_layers)
           ])
           self.norm = nn.LayerNorm(d_model)

       def forward(self, x):
           for layer in self.layers:
               x = x + layer(x)
           return self.norm(x)

Sliding-Window Positional Attention
--------------------------------------

After the Set Transformer processes each site independently over samples, we apply
sliding-window self-attention along the genomic axis (Principle 1). Each sample's
sequence of :math:`L` site embeddings is treated as a sequence, and attention is
restricted to a window of :math:`w` positions:

.. code-block:: python

   class GenomicEncoder(nn.Module):
       def __init__(self, d_model, n_heads, n_layers,
                    n_inducing=32, window_size=512):
           super().__init__()
           self.embedding = GenomicEmbedding(d_model)
           self.sample_encoder = SampleEncoder(d_model, n_heads, n_inducing, 2)
           self.positional_layers = nn.ModuleList([
               SlidingWindowAttention(d_model, n_heads, window_size)
               for _ in range(n_layers)
           ])
           self.ffn_layers = nn.ModuleList([
               nn.Sequential(
                   nn.LayerNorm(d_model),
                   nn.Linear(d_model, 4 * d_model),
                   nn.GELU(),
                   nn.Linear(4 * d_model, d_model),
               )
               for _ in range(n_layers)
           ])

       def forward(self, D):
           Z = self.embedding(D)                     # (B, n, L, d)
           B, n, L, d = Z.shape
           Z = Z.permute(0, 2, 1, 3).reshape(B * L, n, d)
           Z = self.sample_encoder(Z)                # Set Transformer over samples
           Z = Z.reshape(B, L, n, d).permute(0, 2, 1, 3)
           Z = Z.reshape(B * n, L, d)
           for attn, ffn in zip(self.positional_layers, self.ffn_layers):
               Z = Z + attn(Z)                       # sliding-window attention
               Z = Z + ffn(Z)
           Z = Z.reshape(B, n, L, d)
           return Z


Stage 2: Topology Decoder
===========================

The topology decoder converts the encoder's latent representation into a sequence
of local tree topologies with breakpoints. This is the most structurally novel
component: it implements a **learned Li & Stephens model** (Principle 5).

Cross-Attention as Copying
----------------------------

At each genomic position :math:`\ell`, every sample :math:`i` computes attention
weights over all other samples. The attention weights represent the probability
that sample :math:`i` is "copying from" sample :math:`j` at this position -- the
neural analogue of the Li & Stephens transition probabilities.

.. math::

   \mathbf{q}_i^\ell = \mathbf{W}_Q \mathbf{z}_{i,\ell}, \quad
   \mathbf{k}_j^\ell = \mathbf{W}_K \mathbf{z}_{j,\ell}, \quad
   \alpha_{ij}^\ell = \text{softmax}_j\!\left(\frac{\mathbf{q}_i^{\ell\top}
   \mathbf{k}_j^\ell}{\sqrt{d}}\right)

The attention matrix :math:`\mathbf{A}^\ell \in \mathbb{R}^{n \times n}` at each
position encodes the copying relationships. In a true Li & Stephens model, each
row of this matrix would be a one-hot vector (each sample copies from exactly one
source). We relax this to soft attention during training and gradually harden it.

Breakpoint Detection
-----------------------

Tree topology changes at recombination breakpoints. The breakpoint detector is a
1D convolution along the genomic axis that identifies positions where the
attention pattern changes significantly:

.. math::

   b_\ell = \sigma\!\left(\mathbf{w}_b^\top
   \text{Conv1D}\!\bigl(\|\mathbf{A}^\ell - \mathbf{A}^{\ell-1}\|_F,\;
   \ldots\bigr) + c_b\right)

where :math:`b_\ell \in [0, 1]` is the breakpoint probability at position
:math:`\ell` and :math:`\|\cdot\|_F` is the Frobenius norm of the change in
attention pattern.

.. code-block:: python

   class BreakpointDetector(nn.Module):
       def __init__(self, d_model, kernel_size=5):
           super().__init__()
           self.proj = nn.Linear(d_model, 1)
           self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size // 2)

       def forward(self, Z_diff):
           x = self.proj(Z_diff).squeeze(-1).unsqueeze(1)  # (B, 1, L)
           return torch.sigmoid(self.conv(x)).squeeze(1)     # (B, L)

Hard Attention via Gumbel-Softmax
-----------------------------------

To produce discrete tree topologies, we need hard parent assignments. During
training, we use the Gumbel-softmax trick to maintain differentiability:

.. code-block:: python

   class TopologyDecoder(nn.Module):
       def __init__(self, d_model, n_heads):
           super().__init__()
           self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
           self.breakpoint_det = BreakpointDetector(d_model)
           self.tau = 1.0  # annealed during training

       def forward(self, Z, hard=False):
           B, n, L, d = Z.shape
           parent_logits = []
           for ell in range(L):
               q = Z[:, :, ell, :]          # (B, n, d)
               k = Z[:, :, ell, :]          # (B, n, d)
               scores = torch.bmm(q, k.transpose(1, 2)) / (d ** 0.5)
               scores.diagonal(dim1=1, dim2=2).fill_(float('-inf'))

               if hard:
                   parents = scores.argmax(dim=-1)
               else:
                   parents = F.gumbel_softmax(scores, tau=self.tau, hard=True, dim=-1)
               parent_logits.append(scores)

           parent_logits = torch.stack(parent_logits, dim=2)  # (B, n, L, n)

           Z_diff = Z[:, :, 1:, :] - Z[:, :, :-1, :]
           Z_diff_pooled = Z_diff.mean(dim=1)                  # pool over samples
           breakpoints = self.breakpoint_det(Z_diff_pooled)

           return parent_logits, breakpoints

At inference time (:math:`\tau \to 0` or ``hard=True``), the Gumbel-softmax collapses
to argmax, producing deterministic parent assignments. The topology is then assembled
into contiguous tree segments separated by breakpoints.

.. admonition:: From attention to tree topology

   The attention matrix :math:`\mathbf{A}^\ell` does not directly encode a valid
   tree. To obtain a tree, we apply a **greedy bottom-up construction**: starting
   from the leaves, we iteratively merge the pair with the highest mutual attention
   weight, creating an internal node. The process continues until all samples are
   connected. This is reminiscent of hierarchical clustering, but the similarity
   metric is learned end-to-end.


Stage 3: Dating GNN
======================

Given the predicted topology, the dating GNN assigns times to internal nodes using
learned message passing. This is the neural analogue of tsdate's inside-outside
algorithm (Principle 4), with gamma output heads (Principle 7) and per-segment
sufficient statistics as input features (Principle 9).

Node and Edge Features
------------------------

Each node :math:`v` in a local tree is initialized with a feature vector:

.. math::

   \mathbf{h}_v^{(0)} = \begin{cases}
   \mathbf{W}_{\text{leaf}} \mathbf{z}_{v,\ell} & \text{if } v \text{ is a leaf} \\
   \mathbf{W}_{\text{init}} [\hat{t}_v;\; \log \hat{t}_v;\; \mathbf{0}] &
   \text{if } v \text{ is internal}
   \end{cases}

where :math:`\mathbf{z}_{v,\ell}` is the encoder output for leaf :math:`v` at the
midpoint of the tree's genomic span, and :math:`\hat{t}_v` is an initial time
estimate from the Threads-style natural estimator.

Each edge :math:`(u, v)` carries features:

.. math::

   \mathbf{f}_{uv} = \mathbf{W}_{\text{edge}} [m_{uv};\; s_{uv};\; \hat{t}_{uv};\;
   \log m_{uv};\; \log s_{uv};\; n_{uv}]

where :math:`m_{uv}` is the mutation count, :math:`s_{uv}` the genomic span,
:math:`\hat{t}_{uv}` the natural time estimate, and :math:`n_{uv}` the number of
descendant leaves.

UP/DOWN Message Passing
--------------------------

The GNN alternates between **UP passes** (children to parent, analogous to tsdate's
inside pass) and **DOWN passes** (parent to children, analogous to the outside pass):

.. math::

   \mathbf{m}_{c \to p}^{(k)} = \text{MLP}_{\text{up}}\!\bigl(
   [\mathbf{h}_c^{(k)};\; \mathbf{f}_{cp}]\bigr) \qquad \text{(UP message)}

.. math::

   \mathbf{m}_{p \to c}^{(k)} = \text{MLP}_{\text{down}}\!\bigl(
   [\mathbf{h}_p^{(k)};\; \mathbf{f}_{pc}]\bigr) \qquad \text{(DOWN message)}

.. math::

   \mathbf{h}_v^{(k+1)} = \text{GRU}\!\left(\mathbf{h}_v^{(k)},\;
   \sum_{u \in \text{children}(v)} \mathbf{m}_{u \to v}^{(k)} +
   \mathbf{m}_{\text{parent}(v) \to v}^{(k)}\right)

The GRU (gated recurrent unit) update prevents the node features from drifting too
far from their initial values while allowing iterative refinement. After :math:`K`
rounds (typically :math:`K = 6`), the node features are decoded into gamma
parameters.

.. code-block:: python

   class DatingGNN(nn.Module):
       def __init__(self, d_model, n_rounds=6):
           super().__init__()
           self.n_rounds = n_rounds
           self.node_init = nn.Linear(d_model, d_model)
           self.edge_encoder = nn.Linear(6, d_model)
           self.up_msg = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.ReLU(),
                                       nn.Linear(d_model, d_model))
           self.down_msg = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.ReLU(),
                                         nn.Linear(d_model, d_model))
           self.gru = nn.GRUCell(d_model, d_model)
           self.alpha_head = nn.Linear(d_model, 1)
           self.beta_head = nn.Linear(d_model, 1)

       def forward(self, node_features, edge_features, parent_array):
           h = self.node_init(node_features)
           f = self.edge_encoder(edge_features)

           for k in range(self.n_rounds):
               msg = torch.zeros_like(h)
               for child, parent in enumerate(parent_array):
                   if parent < 0:
                       continue
                   up = self.up_msg(torch.cat([h[child], f[child]], dim=-1))
                   msg[parent] += up
                   down = self.down_msg(torch.cat([h[parent], f[child]], dim=-1))
                   msg[child] += down
               h = self.gru(msg, h)

           alpha = F.softplus(self.alpha_head(h)) + 1.0
           beta = torch.exp(self.beta_head(h))
           return alpha, beta

Gamma Output Heads
--------------------

The final node features :math:`\mathbf{h}_v^{(K)}` are decoded into gamma parameters
:math:`(\alpha_v, \beta_v)`:

.. math::

   \alpha_v = \text{softplus}(\mathbf{w}_\alpha^\top \mathbf{h}_v^{(K)}) + 1, \qquad
   \beta_v = \exp(\mathbf{w}_\beta^\top \mathbf{h}_v^{(K)})

The predicted time distribution for node :math:`v` is then
:math:`t_v \sim \text{Gamma}(\alpha_v, \beta_v)`, with mean
:math:`\mathbb{E}[t_v] = \alpha_v / \beta_v` and variance
:math:`\text{Var}(t_v) = \alpha_v / \beta_v^2`.

Cross-Tree Consistency
------------------------

Adjacent local trees share most of their topology and node times. To enforce
consistency, we add a **cross-tree regularizer** that penalizes large changes in
predicted node times between adjacent trees:

.. math::

   \mathcal{L}_{\text{consistency}} = \sum_{\ell=1}^{T-1} \sum_{v \in \mathcal{V}_\ell
   \cap \mathcal{V}_{\ell+1}} \left(\log \frac{\alpha_v^\ell}{\beta_v^\ell} -
   \log \frac{\alpha_v^{\ell+1}}{\beta_v^{\ell+1}}\right)^2

where :math:`\mathcal{V}_\ell` is the set of nodes in local tree :math:`\ell` and
the intersection identifies nodes shared between adjacent trees.


Stage 4: Demographic Decoder
===============================

The final stage maps the inferred coalescence-time distribution to a posterior over
:math:`N_e(t)` trajectories. This is where the ARG's status as a sufficient statistic
(Principle 3) pays off: the demographic decoder operates entirely on the predicted
coalescence times, not on the raw genotype matrix.

Coalescence-Time Histogram
-----------------------------

From the dated ARG, we extract a histogram of coalescence times. For each internal
node :math:`v` at time :math:`t_v` with :math:`k_v` children, we record
:math:`k_v - 1` coalescence events at time :math:`t_v`. Binning these into
:math:`B` logarithmically-spaced time intervals gives a vector
:math:`\mathbf{c} \in \mathbb{R}^B`:

.. math::

   c_b = \sum_{v \in \text{internal nodes}} (k_v - 1) \cdot \mathbf{1}[t_v \in
   \text{bin } b]

This histogram, together with the predicted SFS from the ARG, forms the input to
the normalizing flow.

Normalizing Flow
------------------

The demographic decoder is a **conditional normalizing flow** that transforms a
simple base distribution (standard normal) into a posterior over :math:`N_e(t)`
functions, conditioned on the coalescence-time histogram and SFS:

.. math::

   \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \quad
   N_e(t) = g_\phi(\mathbf{z} \mid \mathbf{c}, \text{SFS})

where :math:`g_\phi` is an invertible neural network parameterized by :math:`\phi`.
The :math:`N_e(t)` trajectory is represented as a vector of :math:`B` values on
the same log-spaced time grid, with linear interpolation between grid points.

.. code-block:: python

   class DemographicDecoder(nn.Module):
       def __init__(self, n_time_bins, n_flow_layers, d_cond):
           super().__init__()
           self.condition_net = nn.Sequential(
               nn.Linear(2 * n_time_bins, d_cond),
               nn.ReLU(),
               nn.Linear(d_cond, d_cond),
           )
           self.flow_layers = nn.ModuleList([
               AffineCouplingLayer(n_time_bins, d_cond)
               for _ in range(n_flow_layers)
           ])

       def forward(self, coal_histogram, sfs, n_samples=1):
           cond = self.condition_net(torch.cat([coal_histogram, sfs], dim=-1))
           z = torch.randn(n_samples, coal_histogram.size(-1))
           log_det = 0.0
           for layer in self.flow_layers:
               z, ld = layer(z, cond)
               log_det += ld
           ne_trajectory = F.softplus(z)
           return ne_trajectory, log_det

SFS Auxiliary Loss
---------------------

The predicted ARG implies a predicted SFS, which must be consistent with the
observed SFS. This consistency check is the physics-informed regularizer from
Principle 6:

.. math::

   \mathcal{L}_{\text{SFS}} = \sum_{k=1}^{n-1} \left(
   \text{SFS}_{\text{pred}}[k] \cdot \mu - \text{SFS}_{\text{obs}}[k]
   \right)^2 / \text{SFS}_{\text{obs}}[k]

where the predicted SFS is computed differentiably from the ARG branch lengths and
descendant counts, and the observed SFS is computed directly from the genotype matrix.

.. admonition:: Why the SFS loss matters

   Without the SFS loss, the network can produce ARGs that correctly reconstruct the
   topology and approximate the node times but systematically miscount the number of
   lineages at each frequency class. The SFS loss acts as a global consistency check:
   it catches errors in the predicted ARG that local losses (topology accuracy, node
   time likelihood) might miss. This is analogous to how a watchmaker, after
   assembling each gear individually, checks that the overall gear train produces the
   correct time -- a global test that catches assembly errors invisible at the
   component level.


Putting It All Together
=========================

The complete Mainspring model chains all four stages:

.. code-block:: python

   class Mainspring(nn.Module):
       def __init__(self, d_model=256, n_heads=8, n_encoder_layers=6,
                    n_gnn_rounds=6, n_time_bins=64, n_flow_layers=8):
           super().__init__()
           self.encoder = GenomicEncoder(d_model, n_heads, n_encoder_layers)
           self.topology_decoder = TopologyDecoder(d_model, n_heads)
           self.dating_gnn = DatingGNN(d_model, n_gnn_rounds)
           self.demographic_decoder = DemographicDecoder(
               n_time_bins, n_flow_layers, d_cond=128
           )

       def forward(self, D, hard=False):
           Z = self.encoder(D)
           parent_logits, breakpoints = self.topology_decoder(Z, hard=hard)
           topology = self.extract_trees(parent_logits, breakpoints)
           node_feats, edge_feats, parent_arrays = self.build_gnn_input(
               Z, topology
           )
           alphas, betas = self.dating_gnn(node_feats, edge_feats, parent_arrays)
           times = alphas / betas  # point estimate = gamma mean
           coal_hist = self.build_coalescence_histogram(times, topology)
           pred_sfs = self.compute_sfs(times, topology)
           ne_posterior, log_det = self.demographic_decoder(coal_hist, pred_sfs)
           return {
               'topology': topology,
               'breakpoints': breakpoints,
               'alpha': alphas,
               'beta': betas,
               'times': times,
               'ne_posterior': ne_posterior,
               'flow_log_det': log_det,
               'predicted_sfs': pred_sfs,
           }

Computational Complexity
--------------------------

.. list-table:: Per-stage computational complexity
   :header-rows: 1
   :widths: 25 30 45

   * - Stage
     - Complexity
     - Bottleneck
   * - Genomic Encoder
     - :math:`O(n^2 L d + n L w d)`
     - Set Transformer (:math:`n^2` per site) + sliding-window attention
       (:math:`w` per position)
   * - Topology Decoder
     - :math:`O(n^2 L d)`
     - Cross-attention at each site
   * - Dating GNN
     - :math:`O(K n L d)`
     - :math:`K` message-passing rounds on trees with :math:`O(n)` nodes
   * - Demographic Decoder
     - :math:`O(B^2)`
     - Normalizing flow on :math:`B` time bins (negligible)

Total: :math:`O(n^2 L d)`, linear in sequence length and quadratic in sample
count. For typical applications (:math:`n \leq 100`, :math:`L \sim 10^4`), this
is feasible on a single GPU.
