Metadata-Version: 2.4
Name: pylate
Version: 1.5.0
Summary: A library for training and retrieval with ColBERT.
Author: LightOn
License-Expression: MIT
Project-URL: Homepage, https://github.com/lightonai/pylate.git
Project-URL: Documentation, https://lightonai.github.io/pylate/
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: sentence-transformers==5.3.0
Requires-Dist: datasets>=2.20.0
Requires-Dist: accelerate>=0.31.0
Requires-Dist: pandas>=2.2.1
Requires-Dist: transformers<=5.3.0,>=4.41.0
Requires-Dist: ujson==5.10.0
Requires-Dist: ninja==1.11.1.4
Requires-Dist: fastkmeans==0.5.0
Requires-Dist: fast-plaid<=1.3.0.290,>=1.2.4.260
Provides-Extra: dev
Requires-Dist: ruff>=0.4.9; extra == "dev"
Requires-Dist: pytest-cov>=5.0.0; extra == "dev"
Requires-Dist: pytest-xdist>=3.6.0; extra == "dev"
Requires-Dist: pytest-rerunfailures>=15.0.0; extra == "dev"
Requires-Dist: pytest>=8.2.1; extra == "dev"
Requires-Dist: pandas>=2.2.1; extra == "dev"
Requires-Dist: einops>=0.8.1; extra == "dev"
Requires-Dist: pre-commit>=4.1.0; extra == "dev"
Requires-Dist: ranx>=0.3.16; extra == "dev"
Requires-Dist: beir>=2.0.0; extra == "dev"
Requires-Dist: fastapi>=0.114.1; extra == "dev"
Requires-Dist: uvicorn>=0.30.6; extra == "dev"
Requires-Dist: batched>=0.1.2; extra == "dev"
Requires-Dist: voyager>=2.0.9; extra == "dev"
Requires-Dist: typos>=0.11.0; extra == "dev"
Provides-Extra: eval
Requires-Dist: ranx>=0.3.16; extra == "eval"
Requires-Dist: beir>=2.0.0; extra == "eval"
Provides-Extra: api
Requires-Dist: fastapi>=0.114.1; extra == "api"
Requires-Dist: uvicorn>=0.30.6; extra == "api"
Requires-Dist: batched>=0.1.2; extra == "api"
Provides-Extra: voyager
Requires-Dist: voyager>=2.0.9; extra == "voyager"
Provides-Extra: scann
Requires-Dist: scann>=1.4.2; extra == "scann"
Provides-Extra: warp
Requires-Dist: xtr-warp-rs==2.0.2.*; extra == "warp"
Dynamic: license-file

<div align="center">
  <h1>PyLate</h1>
  <p>Flexible Training and Retrieval for Late Interaction Models</p>
</div>

<p align="center"><img width=500 src="https://raw.githubusercontent.com/lightonai/pylate/refs/heads/main/docs/img/logo.png"/></p>

<div align="center">
  <!-- Documentation -->
  <a href="https://lightonai.github.io/pylate/"><img src="https://img.shields.io/badge/Documentation-purple.svg?style=flat-square" alt="documentation"></a>
  <!-- License -->
  <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="license"></a>
</div>

&nbsp;

<p align="justify">
PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.
</p>

&nbsp;

## Installation

You can install PyLate using pip:

```bash
pip install pylate
```

For evaluation dependencies, use:

```bash
pip install "pylate[eval]"
```

## Documentation

The complete documentation is available [here](https://lightonai.github.io/pylate/), which includes in-depth guides, examples, and API references.

&nbsp;

## Training

### Contrastive Training

Here’s a simple example of training a ColBERT model on the MS MARCO dataset triplet dataset using PyLate. This script demonstrates training with contrastive loss and evaluating the model on a held-out eval set:

```python
import torch
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)

from pylate import evaluation, losses, models, utils

# Define model parameters for contrastive training
model_name = "bert-base-uncased"  # Choose the pre-trained model you want to use as base
batch_size = 32  # Larger batch size often improves results, but requires more memory

num_train_epochs = 1  # Adjust based on your requirements
# Set the run name for logging and output directory
run_name = "contrastive-bert-base-uncased"
output_dir = f"output/{run_name}"

# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models.ColBERT(model_name_or_path=model_name)

# Compiling the model makes the training faster
model = torch.compile(model)

# Load dataset
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
# Split the dataset (this dataset does not have a validation set, so we split the training set)
splits = dataset.train_test_split(test_size=0.01)
train_dataset = splits["train"]
eval_dataset = splits["test"]

# Define the loss function
train_loss = losses.Contrastive(model=model)

# Initialize the evaluator
dev_evaluator = evaluation.ColBERTTripletEvaluator(
    anchors=eval_dataset["query"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
)

# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    learning_rate=3e-6,
)

# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
    data_collator=utils.ColBERTCollator(model.tokenize),
)
# Start the training process
trainer.train()
```

After training, the model can be loaded using the output directory path:

```python
from pylate import models

model = models.ColBERT(model_name_or_path="contrastive-bert-base-uncased")
```

Please note that temperature parameter has a [very high importance in contrastive learning](https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Understanding_the_Behaviour_of_Contrastive_Loss_CVPR_2021_paper.pdf), and a temperature around 0.02 is often used in the literature:

```python
train_loss = losses.Contrastive(model=model, temperature=0.02)
```

As contrastive learning is not compatible with gradient accumulation, you can leverage [GradCache](https://arxiv.org/abs/2101.06983) to emulate bigger batch sizes without requiring more memory by using the `CachedContrastiveLoss` to define a mini_batch_size while increasing the `per_device_train_batch_size`:

```python
train_loss = losses.CachedContrastive(
        model=model, mini_batch_size=mini_batch_size
)
```

Finally, if you are in a multi-GPU setting, you can gather all the elements from the different GPUs to create even bigger batch sizes by setting `gather_across_devices` to `True` (for both `Contrastive` and `CachedContrastive` losses):

```python
train_loss = losses.Contrastive(model=model, gather_across_devices=True)
```

&nbsp;

### Knowledge Distillation

To get the best performance when training a ColBERT model, you should use knowledge distillation to train the model using the scores of a strong teacher model.
Here's a simple example of how to train a model using knowledge distillation in PyLate on MS MARCO:

```python
import torch
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)

from pylate import losses, models, utils

# Load the datasets required for knowledge distillation (train, queries, documents)
train = load_dataset(
    path="lightonai/ms-marco-en-bge",
    name="train",
)

queries = load_dataset(
    path="lightonai/ms-marco-en-bge",
    name="queries",
)

documents = load_dataset(
    path="lightonai/ms-marco-en-bge",
    name="documents",
)

# Set the transformation to load the documents/queries texts using the corresponding ids on the fly
train.set_transform(
    utils.KDProcessing(queries=queries, documents=documents).transform,
)

# Define the base model, training parameters, and output directory
model_name = "bert-base-uncased"  # Choose the pre-trained model you want to use as base
batch_size = 16
num_train_epochs = 1
# Set the run name for logging and output directory
run_name = "knowledge-distillation-bert-base"
output_dir = f"output/{run_name}"

# Initialize the ColBERT model from the base model
model = models.ColBERT(model_name_or_path=model_name)

# Compiling the model to make the training faster
model = torch.compile(model)

# Configure the training arguments (e.g., epochs, batch size, learning rate)
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    run_name=run_name,
    learning_rate=1e-5,
)

# Use the Distillation loss function for training
train_loss = losses.Distillation(model=model)

# Initialize the trainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train,
    loss=train_loss,
    data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)

# Start the training process
trainer.train()
```

#### NanoBEIR evaluator

If you are training an English retrieval model, you can use [NanoBEIR evaluator](https://huggingface.co/collections/zeta-alpha-ai/nanobeir-66e1a0af21dfd93e620cd9f6), which allows to run small version of BEIR to get quick validation results.

```python
evaluator=evaluation.NanoBEIREvaluator(),
```

&nbsp;

## Datasets

PyLate supports Hugging Face [Datasets](https://huggingface.co/docs/datasets/en/index), enabling seamless triplet / knowledge distillation based training. For contrastive training, you can use any of the existing sentence transformers triplet datasets. Below is an example of creating a custom triplet dataset for training:

```python
from datasets import Dataset

dataset = [
    {
        "query": "example query 1",
        "positive": "example positive document 1",
        "negative": "example negative document 1",
    },
    {
        "query": "example query 2",
        "positive": "example positive document 2",
        "negative": "example negative document 2",
    },
    {
        "query": "example query 3",
        "positive": "example positive document 3",
        "negative": "example negative document 3",
    },
]

dataset = Dataset.from_list(mapping=dataset)

train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
```

Note that PyLate supports more than one negative per query, simply add the additional negatives after the first one in the row.

```python
{
        "query": "example query 1",
        "positive": "example positive document 1",
        "negative_1": "example negative document 1",
        "negative_2": "example negative document 2",
}
```

To create a knowledge distillation dataset, you can use the following snippet:

```python
from datasets import Dataset

dataset = [
    {
        "query_id": 54528,
        "document_ids": [
            6862419,
            335116,
            339186,
        ],
        "scores": [
            0.4546215673141326,
            0.6575686537173476,
            0.26825184192900203,
        ],
    },
    {
        "query_id": 749480,
        "document_ids": [
            6862419,
            335116,
            339186,
        ],
        "scores": [
            0.2546215673141326,
            0.7575686537173476,
            0.96825184192900203,
        ],
    },
]


dataset = Dataset.from_list(mapping=dataset)

documents = [
    {"document_id": 6862419, "text": "example doc 1"},
    {"document_id": 335116, "text": "example doc 2"},
    {"document_id": 339186, "text": "example doc 3"},
]

queries = [
    {"query_id": 749480, "text": "example query"},
]

documents = Dataset.from_list(mapping=documents)

queries = Dataset.from_list(mapping=queries)
```

&nbsp;

## Retrieval

PyLate provides an efficient index with [FastPLAID](https://github.com/lightonai/fast-plaid). Simply load a ColBERT model and initialize the index to perform retrieval.

```python
from pylate import indexes, models, retrieve

model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
)

index = indexes.PLAID(
    index_folder="pylate-index",
    index_name="index",
    override=True,
)

retriever = retrieve.ColBERT(index=index)
```

Once the model and index are set up, we can add documents to the index using their embeddings and corresponding ids:

```python
documents_ids = ["1", "2", "3"]

documents = [
    "ColBERT’s late-interaction keeps token-level embeddings to deliver cross-encoder-quality ranking at near-bi-encoder speed, enabling fine-grained relevance, robustness across domains, and hardware-friendly scalable search.",

    "PLAID compresses ColBERT token vectors via product quantization to shrink storage by 10×, uses two-stage centroid scoring for sub-200 ms latency, and plugs directly into existing ColBERT pipelines.",

    "PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.",
]

# Encode the documents
documents_embeddings = model.encode(
    documents,
    batch_size=32,
    is_query=False, # Encoding documents
    show_progress_bar=True,
)

# Add the documents ids and embeddings to the PLAID index
index.add_documents(
    documents_ids=documents_ids,
    documents_embeddings=documents_embeddings,
)
```

Then we can retrieve the top-k documents for a given set of queries:

```python
queries_embeddings = model.encode(
    ["query for document 3", "query for document 1"],
    batch_size=32,
    is_query=True, # Encoding queries
    show_progress_bar=True,
)

scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=10,
)

print(scores)
```

Sample Output:

```python
[
    [
        {"id": "3", "score": 11.266985893249512},
        {"id": "1", "score": 10.303335189819336},
        {"id": "2", "score": 9.502392768859863},
    ],
    [
        {"id": "1", "score": 10.88800048828125},
        {"id": "3", "score": 9.950843811035156},
        {"id": "2", "score": 9.602447509765625},
    ],
]
```

Also note that a [WARP](https://github.com/pau-mensa/xtr-warp-rs) backend is also available via `indexes.WARP` (install with `pip install "pylate[warp]"`). WARP is an index to be used with [XTR-trained models](https://lightonai.github.io/pylate/documentation/training/#xtr-training) and might not work well with every model. It uses approximations that make retrieval faster and cheaper but can degrade performance on models that were not trained for them; see the [retrieval documentation](https://lightonai.github.io/pylate/documentation/retrieval/#xtr-retrieval) for the model compatibility notes and a concrete example.

&nbsp;

## Reranking

If you want to use the ColBERT model to perform reranking on top of your first-stage retrieval pipeline without building an index, you can simply use `rank.rerank` function which takes the queries and documents embeddings along with the documents ids to rerank them:

```python
from pylate import rank

queries = [
    "query A",
    "query B",
]

documents = [
    ["document A", "document B"],
    ["document 1", "document C", "document B"],
]

documents_ids = [
    [1, 2],
    [1, 3, 2],
]

queries_embeddings = model.encode(
    queries,
    is_query=True,
)

documents_embeddings = model.encode(
    documents,
    is_query=False,
)

reranked_documents = rank.rerank(
    documents_ids=documents_ids,
    queries_embeddings=queries_embeddings,
    documents_embeddings=documents_embeddings,
)
```

&nbsp;

## Contributing

We welcome contributions! To get started:

1. Install the development dependencies:

```bash
pip install "pylate[dev]"
```

2. Run tests:

```bash
make test
```

3. Format code with Ruff:

```bash
make lint
```

## Citation

You can refer to the library with this BibTeX:

```bibtex
@inproceedings{DBLP:conf/cikm/ChaffinS25,
  author       = {Antoine Chaffin and
                  Rapha{\"{e}}l Sourty},
  editor       = {Meeyoung Cha and
                  Chanyoung Park and
                  Noseong Park and
                  Carl Yang and
                  Senjuti Basu Roy and
                  Jessie Li and
                  Jaap Kamps and
                  Kijung Shin and
                  Bryan Hooi and
                  Lifang He},
  title        = {PyLate: Flexible Training and Retrieval for Late Interaction Models},
  booktitle    = {Proceedings of the 34th {ACM} International Conference on Information
                  and Knowledge Management, {CIKM} 2025, Seoul, Republic of Korea, November
                  10-14, 2025},
  pages        = {6334--6339},
  publisher    = {{ACM}},
  year         = {2025},
  url          = {https://github.com/lightonai/pylate},
  doi          = {10.1145/3746252.3761608},
}
```

## DeepWiki

PyLate is indexed on [DeepWiki](https://deepwiki.com/lightonai/pylate) so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
