late-interaction-kernels · design walkthrough
how it works

Scoring every query–document token pair without materialising the table

Late-interaction retrieval models compare every query token to every document token. A direct implementation builds a four-dimensional similarity tensor that overflows GPU memory at modern context lengths. The kernels in this library compute the same result without ever storing that tensor, by streaming tiles through on-chip memory and folding the reductions into the same pass.

1The MaxSim score

MaxSim is the scoring function ColBERT-style late-interaction models use. The inputs are two batches of token-level embeddings:

  • $\Qm \in \mathbb{R}^{N_q \times L_q \times d}$: $N_q$ queries, each with $L_q$ token vectors of dimension $d$.
  • $\Dm \in \mathbb{R}^{N_d \times L_d \times d}$: same shape, for documents.

For every (query $i$, document $j$) pair, the score is:

$$\operatorname{score}[i, j] \;=\; \sum_{s=1}^{L_q} \; \max_{t=1}^{L_d} \; \langle \Qm_{i, s}, \Dm_{j, t} \rangle.$$

For each query token, find the document token it matches best (highest inner product), then sum those best matches across all query tokens. The result is a score matrix in $\mathbb{R}^{N_q \times N_d}$. Optional boolean masks drop padding tokens from the sum and from the inner max.

The problem to solve. Computing this naively builds the full $[N_q, N_d, L_q, L_d]$ similarity tensor first, then reduces. At ColPali scale ($L_q = L_d = 1024$, batch of 64 pairs) that tensor is 8 GB in fp16. Each element is written once and read twice, so the GPU spends most of its time on HBM traffic, not compute. The kernels in this library compute the same result without ever materialising that tensor, by streaming $B_q \times B_d$ tiles through on-chip memory and folding the reductions into the same pass.

2Programs and the launch grid

A GPU kernel is a small function the device runs in many parallel instances. Each instance is a program in Triton's vocabulary, or a thread block in CUDA's. Writing a kernel comes down to two design choices: how many programs to launch, and what one program computes.

The MaxSim forward uses a grid of $N_q \cdot N_d$ programs, one per query–document pair. Program $(i, j)$ pulls the embeddings of query $i$ and document $j$ from HBM, accumulates a single fp32 scalar, and writes it to position $(i, j)$ of the score matrix. All $N_q \cdot N_d$ programs run concurrently on the GPU's streaming multiprocessors, with the hardware scheduling them onto whatever SMs are free.

j=01 23 45 67 i=0 1 2 3 0,00,10,20,30,40,50,60,7 1,01,11,21,31,41,51,61,7 2,02,12,22,32,42,52,62,7 3,03,13,23,33,43,53,63,7
$N_q = 4$, $N_d = 8$: 32 programs, one per output cell, all running in parallel. Each program produces a single fp32 number.

3Inside one program

Fix a single program at $(i, j)$. It has the $i$-th query and the $j$-th document available in HBM. The most direct implementation of its scalar output looks like

$$\Sm \;=\; \Qm_{i,\,\cdot}\, \Dm_{j,\,\cdot}^{\top} \;\in\; \mathbb{R}^{L_q \times L_d}, \qquad \operatorname{score}[i,j] \;=\; \sum_{s} \max_{t} \Sm_{s, t}.$$

Building $\Sm$ even for a single pair wastes most of the values: every row of $\Sm$ contributes only one number to the final score, the row maximum. The kernel never builds it. Instead the program walks the output in tiles. Pick block sizes $B_q$ and $B_d$ (usually 32 or 64). The program loops over $\lceil L_q / B_q \rceil$ query tiles and, inside each, over $\lceil L_d / B_d \rceil$ document tiles. The largest thing that ever exists is one $B_q \times B_d$ slice of $\Sm$, held in registers for the duration of a single tile-matmul.

One inner-loop iteration · all of S lives in registers Q tile B_q × d · D tile d × B_d = S tile = Q · D B_q × B_d row max tile_max B_q max(m, ·) m (running) B_q · fp32 m persists across inner-loop iterations · S and tile_max are recomputed each step
Q tile · loaded once per outer iter D tile · loaded once per inner iter S tile · registers only running max $m$ · fp32 registers

Each inner-loop iteration loads a fresh $D$ tile, multiplies it against the resident $Q$ tile on the tensor cores, takes the row-wise max of the resulting $S$ tile, and merges it into $m$. After the inner loop finishes, $m$ holds the true per-row maximum over the full $L_d$ axis, and the program adds $\sum m$ to its scalar accumulator. After the outer loop finishes, that accumulator is written to HBM. One pair, one scalar, one store.

This is the entire optimisation. A naive implementation writes the full $\Sm$ tensor to HBM and reads it back twice (once for the max, once for the sum). The fused program reads $\Qm$ and $\Dm$ once, keeps every $\Sm$ slice on-chip, and writes one fp32 scalar per pair. The next section makes the cost difference concrete.

4Why the naive form is bandwidth-bound

The textbook implementation of MaxSim, the one inside PyLate's reference code, separates the two reductions:

S = einsum("nsd, mtd -> nmst", Q, D)   # [Nq, Nd, Lq, Ld]  ← lives in HBM
M = S.max(dim=-1).values               # [Nq, Nd, Lq]      ← max over t
scores = M.sum(dim=-1)                 # [Nq, Nd]          ← sum over s

The full similarity tensor $S$ is written to HBM, then read back to take the row-wise max, then read again for the sum. Its memory footprint is

$$\text{bytes}(S) \;=\; N_q \cdot N_d \cdot L_q \cdot L_d \cdot \text{sizeof(float)}.$$

Scaling is quadratic in the sequence length. At ColPali shapes ($N_q = N_d = 64$, $L_q = L_d = 1024$, fp32) the tensor is $16$ GB; in fp16 it is $8$ GB. Training builds three or four such intermediates per step. The chart shows the same number plotted as a function of the sequence length:

80 GB · H100 limit Naive · materialises S Fused · streams tiles 256 768 1280 1792 2048 sequence length L (L_q = L_d) 0 8 GB 16 GB 24 GB peak HBM for one S
S in HBM, fp16, $N_q = N_d = 64$ fused kernel — only the $[N_q, N_d]$ score matrix

Beyond the memory pressure, the naive version is bandwidth-bound: each element of $S$ is written to HBM once and read twice for a single multiply-add per element. The arithmetic intensity is far below the H100 ridge point, so the SMs spend most of their time waiting on memory. Appendix B works the FLOP-per-byte numbers out and places both implementations on the roofline.

5Fused forward: tile, stream, accumulate

Spelling out the loop nest sketched in section 3: one Triton program per $(q\text{-batch}, d\text{-batch})$ pair, document tiles of $\text{BLOCK\_D}$ rows nested inside query tiles of $\text{BLOCK\_Q}$ rows. The full similarity tensor $S$ never exists in HBM. Only the $\text{BLOCK\_Q} \times \text{BLOCK\_D}$ slice does, and it lives in SRAM for the duration of one tile-matmul.

# one program per (q_batch, d_batch)
score ← 0
for q_start in range(0, Lq, BLOCK_Q):
    Q_tile ← Q[q_batch, q_start : q_start + BLOCK_Q]      # SRAM
    m     ← −∞                                            # [BLOCK_Q] in registers
    for d_start in range(0, Ld, BLOCK_D):
        D_tile ← D[d_batch, d_start : d_start + BLOCK_D]  # SRAM
        S_tile  = Q_tile @ D_tileᵀ                        # tensor cores, fp32 acc
        S_tile ← mask(S_tile, d_active, −∞)               # fused-in masking
        m      ← maximum(m, rowmax(S_tile))               # online max
    score += sum(m)                                       # contribution from this Q tile
scores[q_batch, d_batch] ← score                           # only this scalar is written

The structure is the same outer-product tiling that FlashAttention uses, with one simplification: where FlashAttention has to maintain a running max and a running sum of exponentials for softmax, MaxSim only has a running max. The recurrence is exact and there is nothing to rescale.

Peak SRAM per program is $(\text{BLOCK\_Q} + \text{BLOCK\_D}) \cdot d$ for the operand tiles plus $\text{BLOCK\_Q} \cdot \text{BLOCK\_D}$ for the score tile, in low precision, with a $[\text{BLOCK\_Q}]$ fp32 accumulator on top (roughly 40 KiB at typical block sizes), leaving room for four concurrent programs per H100 SM.

Step through the algorithm

Below is a toy execution with $L_q = 4$, $L_d = 12$, $d = 4$, $\text{BLOCK\_Q} = 4$ (one outer iteration) and $\text{BLOCK\_D} = 4$ (three inner iterations). Use the controls to step through; the counters on the right track HBM bytes moved versus the naive baseline.

State (current tile)

q_start
d_start
running max m
partial sum Σm
score (final)

HBM traffic (one program)

reads
0 B
writes
0 B
total · fused
0 B
naive baseline
0 B
ratio
step 0 / 0
Q (active) D (active) S tile · ephemeral, SRAM only running max m consumed

6The online-max recurrence

To eliminate the score tile entirely, the inner loop replaces every value of $S$ with its contribution to the per-row maximum, which we update incrementally. After consuming the $k$-th tile, the running max obeys

$$\Mm^{(k)} \;=\; \max\!\bigl(\Mm^{(k-1)},\;\; \operatorname{rowmax}(\Sm^{(k)})\bigr),$$

with $m^{(0)} = -\infty$. After all tiles are consumed, $m$ is the true per-row maximum and $\sum_s m_s$ is the score contribution for the current $\text{BLOCK\_Q}$. Both quantities are computed in fp32 in registers; the GEMM uses bf16/fp16 operands with fp32 accumulation, matching the tensor-core native path.

Why this is bitwise simpler than FlashAttention. Online softmax has to track $m$ and the running sum of exponentials $\ell$, and rescale both whenever the running max changes: the rescalers are the $\exp(m^{(k-1)} - m^{(k)})$ factors in FA-1. Online max has no such rescaler: $\max$ is idempotent, so the update is just an elementwise max and the result is identical to the offline computation.

Worked example

Three doc-tiles, one query row. The argmax index is recorded for backward.

Tile$S$ row (values seen this tile)tile maxrunning max after tileargmax (global $t$)
0[0.42, 0.11, 0.30, 0.18]0.420.420
1[0.20, 0.55, 0.05, 0.31]0.550.555
2[0.49, 0.40, 0.50, 0.22]0.500.555

Masked positions are written as $-\infty$ before the reduction, so they cannot influence the argmax even when scores would otherwise be negative. This is stricter than PyLate's reference, which post-multiplies by a $0/1$ mask, and matches the masking discipline used by flash-maxsim and FlashAttention.

7Backward: argmax-only gradients

$\max$ is sub-differentiable: only the argmax position carries a gradient, so the backward pass is simpler than for softmax. With $g \in \mathbb{R}^{N_q \times N_d}$ the upstream gradient on the score:

$$\nabla_Q \, Q_{i, s} \;=\; q\text{-}\text{mask}_{i,s} \cdot \sum_{j} g_{i,j} \cdot D_{j,\; \operatorname{argmax}_t \langle Q_{i,s}, D_{j,t}\rangle}.$$

$\nabla_D$ has the symmetric form, summed over the $(i, s)$ pairs whose argmax falls on the same $(j, t)$. To avoid recomputing the forward, the forward kernel optionally writes an $[N_q \cdot N_d, L_q]$ int32 buffer of argmax indices (4 MB for a typical training batch).

$\nabla_Q$ is embarrassingly parallel: one program per $(i, s)$ gathers $D_{j, \text{argmax}}$ across $j$ and produces a single output row. $\nabla_D$ has output contention: many $(i, s)$ pairs can land on the same $(j, t)$. The library offers three reduction strategies:

MethodBitwise reproducibleWhen it is picked
unified no (atomic, ≤1e-6 rel) Default. Single-pass fused $\nabla_Q + \nabla_D$, fp32 atomic add.
csr yes Picked automatically at very high contention ($N_q, N_d \geq 256$, $L_q \leq 64$). Radix-sorts $(i,s)$ by argmax into per-$j$ CSR buckets so each $(j, t)$ output is reduced in a single program.
atomic no Legacy two-pass; retained as a fallback on hardware with degraded fp32 atomics.

All three paths use Triton's stable argmax (lowest-index tie-break) on the forward, so only the $\nabla_D$ reduction order distinguishes them numerically.

8Performance summary

Measured on a single H100 80 GB SXM with bf16/fp16 inputs and fp32 accumulators, 50-iteration median; baselines are the same operation in plain PyTorch. Full shapes and reproduction commands are in docs/benchmarks.md.

WorkloadSpeedup vs PyTorch
Reranking / inference (small to medium) 7 – 23×
Long-context ($L_d \geq 8\text{k}$) reranking runs · naive OOMs
PyLate cached-contrastive MaxSim + backward up to 13.8×
PLAID rerank vs fast_plaid.engine.search()19 – 30×
Fused $D$-side head (training) 1.5 – 4.6×
FP8 MaxSim inference (Hopper) up to 1.4×
End-to-end ModernColBERT training (149 M params) 1.0 – 1.06× (free swap)

On full encoder training the MaxSim step is dominated by the transformer forward/backward, so the kernel is effectively a free drop-in. Everywhere MaxSim is not negligible (inference, reranking, long documents, knowledge distillation, compressed indices, small encoders), the fused path dominates.

9Variants in the library

The streaming-and-never-storing structure is the entire optimisation. The rest of the library applies the same idea to a few adjacent workloads:

  • Variable-length / packed input (maxsim_varlen). Real corpora have ragged lengths; padding to $L_d^{\max}$ wastes roughly half the FLOPs. The packed kernel consumes $[\text{total\_tokens}, d]$ tensors with FlashAttention-style cu_seqlens offsets and runs the same tiling.
  • Fused $D$-side projection (maxsim_from_hidden). For corpora stored as ModernBERT hidden states, the kernel folds $\texttt{Linear} \rightarrow \texttt{L2-normalize} \rightarrow \texttt{MaxSim}$ into a single pass. The projected embedding tensor is never written to HBM. The training backward gathers only the winning positions and recomputes the projection locally.
  • PLAID / ColBERTv2 (maxsim_residual, maxsim_residual_varlen). The doc embeddings live on disk as $(\text{centroid index}, \text{quantised residual})$. The kernel decompresses on the fly in SRAM, optionally L2-normalises, and runs MaxSim, all in one program with no dense decompressed tensor ever materialised.
  • FP8 MaxSim (maxsim_inference_fp8). On Hopper, the GEMM uses the FP8 tensor cores at twice the bf16 throughput; the reduction stays in fp32 to preserve range.

Different inputs, same recipe: keep large intermediates on-chip, run the reductions in registers, and ship only the final scores to HBM.

AGPU memory hierarchy

Modern GPUs are dramatically faster at arithmetic than at moving data. An H100 sustains roughly $990$ TFLOP/s of bf16 tensor-core throughput against $3.3$ TB/s of HBM bandwidth, a ratio of about $300$ FLOPs per byte. Any kernel that moves data more aggressively than it computes on it becomes bandwidth-bound, and the SMs sit idle waiting for memory. The trick to writing a fast kernel is therefore not faster arithmetic, but fewer round-trips to HBM. The hierarchy in play:

LevelSizeBandwidthWhere it lives
Registers 256 KB / SM ~20 TB/s Inside each thread
SRAM (shared memory) 228 KB / SM (H100) ~19 TB/s On-chip, per SM
L2 cache 50 MB (H100) ~5 TB/s On-chip, chip-wide
HBM 80 GB (H100) ~3.3 TB/s Off-chip, on the package

SRAM is roughly six times faster than HBM, and registers are another four times faster than SRAM. A fused kernel chains several logical operations back-to-back inside SRAM and registers, writing only the final result to HBM. That is the optimisation strategy used throughout this library.

BArithmetic intensity

Whether a kernel is bound by compute or by memory is captured by a single number, the arithmetic intensity, $\mathrm{AI} = \text{FLOPs} / \text{HBM bytes moved}$. Plot achievable throughput against $\mathrm{AI}$ and you get the roofline: a bandwidth slope on the left, a compute plateau on the right, meeting at the ridge point $\mathrm{AI}^{*} = \text{peak compute} / \text{peak bandwidth}$. For an H100 in bf16, $\mathrm{AI}^{*} \approx 295$ FLOPs/byte. Below that you are memory-bound, above it compute-bound.

Scoring one $(Q, D)$ pair is $2 L_q L_d \, d$ FLOPs of matmul plus a few cheap reductions. The two implementations differ only in what they push through HBM. The naive form materialises the score tile $S$, so its traffic is dominated by the $L_q \times L_d$ surface (read + write, fp16):

$$\mathrm{AI}_{\text{naive}} \;\approx\; \frac{2 L_q L_d \, d}{4 L_q L_d} \;=\; \frac{d}{2}.$$

The fused form never writes $S$; the only HBM traffic is the operand reads $Q$ and $D$:

$$\mathrm{AI}_{\text{fused}} \;\approx\; \frac{2 L_q L_d \, d}{2 (L_q + L_d) \, d} \;=\; \frac{L_q L_d}{L_q + L_d}.$$

At ColPali shapes ($L_q = L_d = 1024$, $d = 128$), the naive form sits at about 64 FLOPs/byte, four to five times below the ridge, so the tensor cores idle while HBM works flat out. The fused form sits at about 512 FLOPs/byte, comfortably above the ridge, where the kernel is finally limited by compute. The same workload moves between regimes purely by changing what crosses HBM:

MEMORY-BOUND COMPUTE-BOUND ridge AI ≈ 295 naive ≈ 64 FLOPs/B fused ≈ 512 FLOPs/B 1 10 100 1000 arithmetic intensity (FLOPs / HBM byte, log scale)

Fusing is not faster because the arithmetic is cheaper. It is the same matmul. It is faster because the matmul never has to wait for memory.