Metadata-Version: 2.4
Name: late-interaction-kernels
Version: 0.2.0
Summary: Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.
Project-URL: Homepage, https://github.com/hcompai/late-interaction-kernels
Project-URL: Repository, https://github.com/hcompai/late-interaction-kernels
Project-URL: Issues, https://github.com/hcompai/late-interaction-kernels/issues
Project-URL: Changelog, https://github.com/hcompai/late-interaction-kernels/blob/main/CHANGELOG.md
Project-URL: Documentation, https://github.com/hcompai/late-interaction-kernels#readme
Author-email: Aurélien Lac <aurelien.lac@hcompany.ai>, Tony Wu <tony.wu@hcompany.ai>
Maintainer-email: Aurélien Lac <aurelien.lac@hcompany.ai>, Tony Wu <tony.wu@hcompany.ai>
License: Apache-2.0
License-File: LICENSE
Keywords: colbert,colpali,flash-attention,gpu-kernels,ir,late-interaction,maxsim,moderncolbert,rerank,retrieval,triton
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: torch<3,>=2.5
Requires-Dist: triton<4,>=3.0; platform_system == 'Linux'
Provides-Extra: dev
Requires-Dist: numpy<3,>=1.21; extra == 'dev'
Requires-Dist: pandas<4,>=1.5; extra == 'dev'
Requires-Dist: pytest-xdist<4,>=3; extra == 'dev'
Requires-Dist: pytest<10,>=7; extra == 'dev'
Requires-Dist: ruff<1,>=0.4; extra == 'dev'
Requires-Dist: tabulate<1,>=0.9; extra == 'dev'
Requires-Dist: ty<0.1,>=0.0.20; extra == 'dev'
Provides-Extra: pylate
Requires-Dist: pylate<2,>=1.3.3; extra == 'pylate'
Provides-Extra: torch-cpu
Requires-Dist: torch<3,>=2.5; extra == 'torch-cpu'
Provides-Extra: torch-cuda
Requires-Dist: torch<3,>=2.5; extra == 'torch-cuda'
Description-Content-Type: text/markdown

<div align="center">

# late-interaction-kernels

<img src="assets/banner.webp" alt="late-interaction-kernels banner" />

[![ColBERT](https://img.shields.io/badge/ColBERT-2004.12832-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2004.12832)
[![PyLate](https://img.shields.io/badge/PyLate-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/lightonai/pylate)
[![colpali-engine](https://img.shields.io/badge/colpali--engine-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/illuin-tech/colpali)
[![Hugging Face](https://img.shields.io/badge/Hcompany-FFD21E?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/Hcompany)

[![CI](https://github.com/hcompai/late-interaction-kernels/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/hcompai/late-interaction-kernels/actions/workflows/ci.yml)
[![Version](https://img.shields.io/pypi/v/late-interaction-kernels?color=%2334D058&label=pypi%20package)](https://pypi.org/project/late-interaction-kernels/)
[![Downloads](https://static.pepy.tech/badge/late-interaction-kernels)](https://pepy.tech/project/late-interaction-kernels)

---

[[Benchmarks]](docs/benchmarks.md)
[[Design]](docs/design.md)
[[Supported models]](docs/supported_models.md)
[[Changelog]](CHANGELOG.md)

</div>

> [!NOTE]
> Full algorithmic walkthrough, animations and benchmark plots live on the docs site: **[hcompai.github.io/late-interaction-kernels](https://hcompai.github.io/late-interaction-kernels/how-it-works.html)**.

## Introduction

`late-interaction-kernels` provides fused Triton kernels for **MaxSim**, the late-interaction scoring used by ColBERT, ColPali, ModernColBERT, LateOn and ColBERTv2. The kernels are numerically identical to plain PyTorch and come with three APIs:

- a one-line PyLate drop-in (`patch_pylate()`),
- a stateless `nn.Module` (`MaxSimScorer`) for custom training loops,
- function-level entry points (`maxsim`, `maxsim_varlen`, `maxsim_padded`, ...) for everything else.

This is **not** a search engine. For end-to-end training or retrieval use [PyLate](https://github.com/lightonai/pylate), [FastPlaid](https://github.com/lightonai/fast-plaid) or [NextPlaid](https://github.com/lightonai/next-plaid). This library is the MaxSim math they compile down to.

## Install

```bash
pip install late-interaction-kernels
```

| Platform                       | Backend                                                                       |
| ------------------------------ | ----------------------------------------------------------------------------- |
| Linux + CUDA (sm_75+)          | Fused Triton kernels (autotuned, FP8 on Hopper).                              |
| macOS (Apple Silicon, MPS)     | Fused Metal `simdgroup_matrix` for inference, `torch.compile` for training.   |
| CPU / Windows                  | Autograd-aware pure-PyTorch reference.                                        |

> [!NOTE]
> The PyLate drop-in targets PyLate >= 1.3. The pure-PyTorch reference imports on every platform, so training and retrieval code is unit-testable on a laptop before you rent a GPU.

## Quickstart

### Patch PyLate (one line)

```python
from late_interaction_kernels import patch_pylate

patch_pylate()
# PyLate training / rerank code is unchanged
```

Set `LIK_DISABLE=1` in the environment to fall back to vanilla PyLate at runtime.

### Custom training loop

```python
from late_interaction_kernels import MaxSimScorer

scorer = MaxSimScorer(normalize=True)                # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask)  # [Nq, Nd] fp32
scores.mean().backward()
```

### Top-k retrieval

```python
from late_interaction_kernels import retrieve

scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100]; chunk= bounds peak HBM at Nq * (chunk + top_k)
```

### PLAID / ColBERTv2 on compressed, ragged docs

```python
from late_interaction_kernels.plaid import maxsim_residual_varlen

scores = maxsim_residual_varlen(
    Q, codes_flat, residuals_flat, cu_seqlens_d,
    centroids=centroids, bucket_weights=bucket_weights,
    nbits=2, normalize=True,
)  # [Nd] fp32; one kernel does decompress + L2-normalize + MaxSim
```

## Benchmarks

1xH100, bf16/fp16, 50-iter median, vs the same op in plain PyTorch:

| Workload                                           | Speedup           |
| -------------------------------------------------- | ----------------- |
| Reranking / inference vs naive einsum              | 7-23x             |
| Long-context (`Ld >= 8k`) reranking                | runs; naive OOMs  |
| PyLate cached-contrastive MaxSim + backward        | up to 13.8x       |
| PLAID rerank vs `fast_plaid.engine.search()`       | 19-30x            |
| Fused D-side head (training)                       | 1.5-4.6x          |
| FP8 MaxSim inference (Hopper)                      | up to 1.4x        |
| End-to-end training of a 149M encoder              | 1.00-1.06x (free) |

Full tables and reproduction commands: [`docs/benchmarks.md`](docs/benchmarks.md).

## Choose a kernel

<table>
<tr>
<td width="55%" valign="middle">

Not sure which entry point fits your stack? The docs site ships an interactive decision tree that narrows the public API down to the right function in four questions (stack · phase · layout · workload):

**👉 [hcompai.github.io/late-interaction-kernels/choose-a-kernel.html](https://hcompai.github.io/late-interaction-kernels/choose-a-kernel.html#choose-a-kernel)**

</td>
<td width="45%" valign="middle" align="center">

<a href="https://hcompai.github.io/late-interaction-kernels/choose-a-kernel.html#choose-a-kernel">
  <img src="assets/kernel_picker_widget_preview.webp" alt="Pick a kernel · interactive decision tree" width="420">
</a>

</td>
</tr>
</table>

## API

| Symbol                                | What it does                                                          |
| ------------------------------------- | --------------------------------------------------------------------- |
| `patch_pylate()` / `unpatch_pylate()` | One-line PyLate drop-in. `LIK_DISABLE=1` kill switch.                 |
| `MaxSimScorer(normalize=, backward=)` | Stateless `nn.Module`, autograd-aware.                                |
| `retrieve(Q, D, top_k, chunk=)`       | Top-k retrieval, chunked for huge corpora.                            |
| `maxsim`                              | Core MaxSim, dense layout. Autograd-aware; auto-skips argmax save when no input requires grad. |
| `maxsim_varlen`                       | Packed (`cu_seqlens`) layout. Autograd-aware.                         |
| `maxsim_padded`                       | Padded reranking wrapper: packs internally, returns `[B, C]` fp32.    |

Other kernels are in submodules: `padded`, `score_pairs`, `fused_head`, `plaid`, `fp8`, `experimental`, `reference`. See [`docs/design.md`](docs/design.md) for details on every kernel, the autograd graph and the backward variants.

<details>
<summary><strong>🔽 Configuration knobs (env vars + kwargs)</strong></summary>

| Knob                                                              | Effect                                                            |
| ----------------------------------------------------------------- | ----------------------------------------------------------------- |
| `maxsim(..., backward="auto" \| "unified" \| "atomic" \| "csr")`  | Per-call `grad_D` strategy. `"auto"` picks per shape.             |
| `patch_colpali_engine()` / `unpatch_colpali_engine()`             | colpali_engine drop-in: loss + scoring routes through the kernel. |
| `LIK_DISABLE=1`                                                   | Patched entry points delegate to vanilla PyLate.                  |
| `LIK_SUPPRESS_NORM_WARN=1`                                        | Silence the "looks unnormalized" one-shot warning.                |
| `LIK_DISABLE_COMPILE=1`                                           | Skip `torch.compile` on the MPS path (eager fallback).            |
| `LIK_FORCE_MPS_BACKEND={metal,compile,reference}`                 | Pin the MPS dispatch.                                             |

</details>

## Development

```bash
git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
uv sync --extra dev --extra pylate --extra torch-cuda   # GPU dev; use --extra torch-cpu on CPU-only boxes
uv run pytest -q                                        # CUDA tests auto-skip without a GPU
uv run ruff check . && uv run ruff format --check .
```

> [!NOTE]
> Pick exactly one of `--extra torch-cuda` (pulls torch from the CUDA index — `cu124`) or `--extra torch-cpu` (CPU-only wheel, what CI uses). The two are declared as conflicting in `pyproject.toml` so the lockfile resolves cleanly for both. On macOS, `--extra torch-cpu` falls back to PyPI's default (MPS-capable) wheel automatically.

GPU tests run automatically on every push to `main`. To run them on a PR, apply the `run-gpu-tests` label.

See [`CONTRIBUTING.md`](CONTRIBUTING.md) for the contribution workflow.

## Citation

```bibtex
@software{late_interaction_kernels_2026,
  author  = {Lac, Aurélien and Wu, Tony},
  title   = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
  year    = {2026},
  url     = {https://github.com/hcompai/late-interaction-kernels},
}
```
