Skip to content

API Reference

This page provides the detailed API documentation for EmbedRAG, automatically generated from the docstrings in the source code.

CLI

The command-line interface for managing and running EmbedRAG nodes.

CLI entry point for embedrag writer/query node.

This module provides the primary command-line interface (CLI) for managing EmbedRAG nodes. It uses a sub-command pattern to provide different functionalities such as starting servers, downloading remote snapshots, and performing data migrations between versions.

The embedrag command is the single entry point for both operators (running production nodes) and developers (migrating data or testing snapshots).

main()

The main entry point for the embedrag command-line tool.

This function parses command-line arguments using argparse and dispatches to the appropriate sub-command handler.

Available Sub-commands

writer: Starts the Writer Node server for ingestion and indexing. query: Starts the Query Node server for serving search traffic. migrate: Upgrades local data/manifests to the current version. pull: Downloads and extracts snapshots from a remote URL.

Source code in src/embedrag/cli.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def main() -> None:
    """The main entry point for the `embedrag` command-line tool.

    This function parses command-line arguments using `argparse` and dispatches
    to the appropriate sub-command handler.

    Available Sub-commands:
        writer: Starts the Writer Node server for ingestion and indexing.
        query: Starts the Query Node server for serving search traffic.
        migrate: Upgrades local data/manifests to the current version.
        pull: Downloads and extracts snapshots from a remote URL.
    """
    parser = argparse.ArgumentParser(prog="embedrag", description="EmbedRAG server")
    sub = parser.add_subparsers(dest="command")

    # Writer
    wp = sub.add_parser("writer", help="Run the writer node")
    wp.add_argument("--config", "-c", default=None, help="Path to writer config YAML")
    wp.add_argument("--host", default="0.0.0.0")
    wp.add_argument("--port", type=int, default=8001)

    # Query
    qp = sub.add_parser("query", help="Run the query node")
    qp.add_argument("--config", "-c", default=None, help="Path to query config YAML")
    qp.add_argument("--host", default="0.0.0.0")
    qp.add_argument("--port", type=int, default=8000)

    # Migrate
    mp = sub.add_parser("migrate", help="Upgrade a snapshot to latest schema + manifest v3")
    mp.add_argument("path", help="Path to snapshot directory or embedrag.db file")
    mp.add_argument("--dry-run", action="store_true", help="Show current version without modifying")

    # Cluster
    cp = sub.add_parser("cluster", help="Cluster a set of inputs (standalone file or a writer DB)")
    src = cp.add_argument_group("input source (choose one)")
    src.add_argument("--input", help="Input file: .jsonl / .json / .csv of {id, text, [embedding]}")
    src.add_argument("--embeddings", help="Optional .npy of precomputed vectors aligned with --input rows")
    src.add_argument("--db", help="Writer SQLite DB path (reads exact vectors from chunk_embeddings)")
    cp.add_argument("--text-field", default="text", help="Field name for text (default: text)")
    cp.add_argument("--id-field", default="id", help="Field name for id (default: id)")
    cp.add_argument("--embedding-field", default="embedding", help="Field name for inline embedding")
    cp.add_argument("--space", default="text", help="Embedding space (DB source, default: text)")
    cp.add_argument("--filter", action="append", default=[], help="DB filter key=value (e.g. doc_type=complaint)")
    cp.add_argument("--algorithm", default="auto", help="auto|hdbscan|kmeans|agglomerative|dbscan|leiden")
    cp.add_argument("--reduce", default="auto", help="auto|none|pca|umap")
    cp.add_argument("--no-auto", action="store_true", help="Disable automatic parameter sweep")
    cp.add_argument("--min-cluster-size", type=int, help="HDBSCAN min_cluster_size")
    cp.add_argument("--k", type=int, help="KMeans/Agglomerative cluster count")
    cp.add_argument("--eps", type=float, help="DBSCAN eps")
    cp.add_argument("--embed-url", help="OpenAI-compatible embeddings URL (vectorize text)")
    cp.add_argument("--embed-model", default="", help="Embedding model name")
    cp.add_argument("--embed-key", default="", help="Embedding API key")
    cp.add_argument("--llm-url", help="OpenAI-compatible chat URL for cluster labeling")
    cp.add_argument("--llm-model", default="", help="LLM model name")
    cp.add_argument("--llm-key", default="", help="LLM API key")
    cp.add_argument("--llm-language", default="auto", help="Label language (default: auto)")
    cp.add_argument("-o", "--output", help="Write result JSON to this path")
    cp.add_argument("--viz", help="Write a self-contained interactive HTML report to this path")

    # Pull
    pp = sub.add_parser("pull", help="Download a snapshot from a URL (GitHub Release, CDN, etc.)")
    pp.add_argument("url", help="Snapshot URL: archive (.tar.zst) or base URL with latest.json")
    pp.add_argument(
        "--output",
        "-o",
        default="./snapshot/active",
        help="Output directory (default: ./snapshot/active)",
    )
    pp.add_argument(
        "--timeout",
        type=int,
        default=600,
        help="Download timeout in seconds (default: 600)",
    )

    args = parser.parse_args()
    if not args.command:
        parser.print_help()
        sys.exit(1)

    if args.command == "cluster":
        _run_cluster(args)
        return

    if args.command == "pull":
        _run_pull(args.url, output=args.output, timeout=args.timeout)
        return

    if args.command == "migrate":
        _run_migrate(args.path, dry_run=args.dry_run)
        return

    import uvicorn

    if args.command == "writer":
        from embedrag.config import load_writer_config
        from embedrag.writer.app import create_writer_app

        writer_config = load_writer_config(args.config)
        app = create_writer_app(config_path=args.config)
        port = args.port if args.port != 8001 else (writer_config.server.port or 8001)
        uvicorn.run(app, host=args.host, port=port)
    elif args.command == "query":
        from embedrag.config import load_query_config
        from embedrag.query.app import create_query_app

        query_config = load_query_config(args.config)
        app = create_query_app(config_path=args.config)
        port = args.port if args.port != 8000 else (query_config.server.port or 8000)
        uvicorn.run(app, host=args.host, port=port)

Configuration

EmbedRAG's configuration system, including writer and query node settings.

Configuration for writer and query nodes, loaded from YAML with env var support.

EmbedRAG uses Pydantic models for configuration, providing automatic validation, type safety, and the ability to override settings via environment variables. The configuration is hierarchically structured into logical groups such as server settings, object store credentials, and search parameters.

Any field ending in _env in the YAML configuration is treated as a pointer to an environment variable, allowing secrets (like AWS keys) to be managed outside of the version-controlled configuration files.

DBConfig

Bases: BaseModel

Configuration for the writer's SQLite metadata database.

Attributes:

Name Type Description
path str

File path to the database. If empty, it's auto-resolved relative to NodeConfig.data_dir.

wal_autocheckpoint int

SQLite WAL checkpoint interval in pages.

cache_size_mb int

SQLite page cache size in megabytes.

Source code in src/embedrag/config.py
250
251
252
253
254
255
256
257
258
259
260
261
262
class DBConfig(BaseModel):
    """Configuration for the writer's SQLite metadata database.

    Attributes:
        path (str): File path to the database. If empty, it's auto-resolved
            relative to NodeConfig.data_dir.
        wal_autocheckpoint (int): SQLite WAL checkpoint interval in pages.
        cache_size_mb (int): SQLite page cache size in megabytes.
    """

    path: str = ""
    wal_autocheckpoint: int = 1000
    cache_size_mb: int = 64

EmbeddingConfig

Bases: BaseModel

Root configuration for all embedding services.

EmbedRAG supports multiple "spaces" (e.g., one for text, one for images). If the spaces dictionary is empty, the top-level fields are used to define a single default "text" space.

Attributes:

Name Type Description
service_url str

Default service URL.

api_format Literal['embedrag', 'openai']

Default API format.

api_key str

Default API key.

model str

Default model name.

batch_size int

Default batch size.

timeout_seconds int

Default timeout.

retry_count int

Default retry count.

spaces dict[str, EmbeddingSpaceConfig]

Dictionary mapping space names to their specific configurations.

Source code in src/embedrag/config.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class EmbeddingConfig(BaseModel):
    """Root configuration for all embedding services.

    EmbedRAG supports multiple "spaces" (e.g., one for text, one for images).
    If the `spaces` dictionary is empty, the top-level fields are used
    to define a single default "text" space.

    Attributes:
        service_url (str): Default service URL.
        api_format (Literal["embedrag", "openai"]): Default API format.
        api_key (str): Default API key.
        model (str): Default model name.
        batch_size (int): Default batch size.
        timeout_seconds (int): Default timeout.
        retry_count (int): Default retry count.
        spaces (dict[str, EmbeddingSpaceConfig]): Dictionary mapping space
            names to their specific configurations.
    """

    service_url: str = "http://localhost:8080/embed"
    api_format: Literal["embedrag", "openai"] = "embedrag"
    api_key: str = ""
    model: str = ""
    batch_size: int = 64
    timeout_seconds: int = 30
    retry_count: int = 3
    spaces: dict[str, EmbeddingSpaceConfig] = Field(default_factory=dict)

    def get_space_config(self, space: str = "text") -> EmbeddingSpaceConfig:
        """Get the configuration for a specific embedding space.

        If the space is not found in the `spaces` dictionary, returns a
        configuration built from the default top-level fields.

        Args:
            space (str, optional): The name of the space to retrieve.
                Defaults to "text".

        Returns:
            EmbeddingSpaceConfig: The configuration for the requested space.
        """
        if self.spaces and space in self.spaces:
            return self.spaces[space]
        return EmbeddingSpaceConfig(
            service_url=self.service_url,
            api_format=self.api_format,
            api_key=self.api_key,
            model=self.model,
            batch_size=self.batch_size,
            timeout_seconds=self.timeout_seconds,
            retry_count=self.retry_count,
        )

    def get_all_spaces(self) -> list[str]:
        """Get a list of all configured space names.

        Returns:
            list[str]: A list of space identifiers.
        """
        if self.spaces:
            return list(self.spaces.keys())
        return ["text"]

get_all_spaces()

Get a list of all configured space names.

Returns:

Type Description
list[str]

list[str]: A list of space identifiers.

Source code in src/embedrag/config.py
341
342
343
344
345
346
347
348
349
def get_all_spaces(self) -> list[str]:
    """Get a list of all configured space names.

    Returns:
        list[str]: A list of space identifiers.
    """
    if self.spaces:
        return list(self.spaces.keys())
    return ["text"]

get_space_config(space='text')

Get the configuration for a specific embedding space.

If the space is not found in the spaces dictionary, returns a configuration built from the default top-level fields.

Parameters:

Name Type Description Default
space str

The name of the space to retrieve. Defaults to "text".

'text'

Returns:

Name Type Description
EmbeddingSpaceConfig EmbeddingSpaceConfig

The configuration for the requested space.

Source code in src/embedrag/config.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def get_space_config(self, space: str = "text") -> EmbeddingSpaceConfig:
    """Get the configuration for a specific embedding space.

    If the space is not found in the `spaces` dictionary, returns a
    configuration built from the default top-level fields.

    Args:
        space (str, optional): The name of the space to retrieve.
            Defaults to "text".

    Returns:
        EmbeddingSpaceConfig: The configuration for the requested space.
    """
    if self.spaces and space in self.spaces:
        return self.spaces[space]
    return EmbeddingSpaceConfig(
        service_url=self.service_url,
        api_format=self.api_format,
        api_key=self.api_key,
        model=self.model,
        batch_size=self.batch_size,
        timeout_seconds=self.timeout_seconds,
        retry_count=self.retry_count,
    )

EmbeddingSpaceConfig

Bases: BaseModel

Configuration for a specific embedding space/model.

Attributes:

Name Type Description
service_url str

The endpoint of the external embedding service.

api_format Literal['embedrag', 'openai']

The API protocol used by the service.

api_key str

Optional API key for the service.

model str

The model identifier to send in the request.

batch_size int

Max number of texts to send in a single batch.

timeout_seconds int

Request timeout in seconds.

retry_count int

Number of retry attempts on network failure.

Source code in src/embedrag/config.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class EmbeddingSpaceConfig(BaseModel):
    """Configuration for a specific embedding space/model.

    Attributes:
        service_url (str): The endpoint of the external embedding service.
        api_format (Literal["embedrag", "openai"]): The API protocol used
            by the service.
        api_key (str): Optional API key for the service.
        model (str): The model identifier to send in the request.
        batch_size (int): Max number of texts to send in a single batch.
        timeout_seconds (int): Request timeout in seconds.
        retry_count (int): Number of retry attempts on network failure.
    """

    service_url: str = "http://localhost:8080/embed"
    api_format: Literal["embedrag", "openai"] = "embedrag"
    api_key: str = ""
    model: str = ""
    batch_size: int = 64
    timeout_seconds: int = 30
    retry_count: int = 3

HotfixConfig

Bases: BaseModel

Configuration for real-time incremental updates (hotfixes).

Hotfixes allow inserting or deleting documents in the query node's memory between snapshot updates.

Attributes:

Name Type Description
enabled bool

Whether to enable the hotfix buffer.

max_vectors int

The maximum number of vectors to keep in the in-memory hotfix index. Defaults to 10,000.

Source code in src/embedrag/config.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class HotfixConfig(BaseModel):
    """Configuration for real-time incremental updates (hotfixes).

    Hotfixes allow inserting or deleting documents in the query node's
    memory between snapshot updates.

    Attributes:
        enabled (bool): Whether to enable the hotfix buffer.
        max_vectors (int): The maximum number of vectors to keep in the
            in-memory hotfix index. Defaults to 10,000.
    """

    enabled: bool = True
    max_vectors: int = 10_000

IndexBuildConfig

Bases: BaseModel

Configuration for building FAISS indexes on the writer node.

Attributes:

Name Type Description
num_shards int

Number of shards to split the index into.

ivf_nlist int

Number of IVF centroids to train.

pq_m int

Number of PQ sub-vectors.

train_sample_size int

Maximum number of vectors to use for training.

compression Literal['zstd', 'none']

Compression algorithm for shards.

compression_level int

Zstd compression level (1-22).

Source code in src/embedrag/config.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class IndexBuildConfig(BaseModel):
    """Configuration for building FAISS indexes on the writer node.

    Attributes:
        num_shards (int): Number of shards to split the index into.
        ivf_nlist (int): Number of IVF centroids to train.
        pq_m (int): Number of PQ sub-vectors.
        train_sample_size (int): Maximum number of vectors to use for training.
        compression (Literal["zstd", "none"]): Compression algorithm for shards.
        compression_level (int): Zstd compression level (1-22).
    """

    num_shards: int = 4
    ivf_nlist: int = 4096
    pq_m: int = 64
    train_sample_size: int = 500_000
    compression: Literal["zstd", "none"] = "zstd"
    compression_level: int = 3

IndexConfig

Bases: BaseModel

Configuration for FAISS index loading on the query node.

Attributes:

Name Type Description
num_shards int

The expected number of shards in the index.

nprobe int

Number of IVF cells to visit during search. Higher values increase recall but also increase latency. Defaults to 32.

mmap bool

Whether to use memory-mapped loading for FAISS indexes. Required for handling indexes larger than available RAM. Defaults to True.

Source code in src/embedrag/config.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class IndexConfig(BaseModel):
    """Configuration for FAISS index loading on the query node.

    Attributes:
        num_shards (int): The expected number of shards in the index.
        nprobe (int): Number of IVF cells to visit during search. Higher
            values increase recall but also increase latency. Defaults to 32.
        mmap (bool): Whether to use memory-mapped loading for FAISS indexes.
            Required for handling indexes larger than available RAM. Defaults to True.
    """

    num_shards: int = 4
    nprobe: int = 32
    mmap: bool = True

LoggingConfig

Bases: BaseModel

Structured logging configuration.

Attributes:

Name Type Description
level str

Log level (DEBUG, INFO, WARNING, ERROR).

format Literal['json', 'console']

The log output format. Use "json" for production/ELK stacks and "console" for local development.

access_log bool

Whether to enable HTTP access logs.

Source code in src/embedrag/config.py
198
199
200
201
202
203
204
205
206
207
208
209
210
class LoggingConfig(BaseModel):
    """Structured logging configuration.

    Attributes:
        level (str): Log level (DEBUG, INFO, WARNING, ERROR).
        format (Literal["json", "console"]): The log output format. Use "json"
             for production/ELK stacks and "console" for local development.
        access_log (bool): Whether to enable HTTP access logs.
    """

    level: str = "INFO"
    format: Literal["json", "console"] = "json"
    access_log: bool = True

MetricsConfig

Bases: BaseModel

Prometheus metrics configuration.

Attributes:

Name Type Description
enabled bool

Whether to start the Prometheus metrics exporter.

port int

The port to expose the /metrics endpoint on. Defaults to 9090.

Source code in src/embedrag/config.py
213
214
215
216
217
218
219
220
221
222
class MetricsConfig(BaseModel):
    """Prometheus metrics configuration.

    Attributes:
        enabled (bool): Whether to start the Prometheus metrics exporter.
        port (int): The port to expose the /metrics endpoint on. Defaults to 9090.
    """

    enabled: bool = True
    port: int = 9090

NodeConfig

Bases: BaseModel

Basic node identification and role configuration.

Attributes:

Name Type Description
role Literal['query', 'writer']

The operating mode of the node.

node_id str

A unique identifier for this node instance. Use "auto" to use the machine's hostname.

data_dir str

The local directory used for storing databases, shards, and temporary build files.

port int

Port override. If 0, uses ServerConfig.port.

Source code in src/embedrag/config.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
class NodeConfig(BaseModel):
    """Basic node identification and role configuration.

    Attributes:
        role (Literal["query", "writer"]): The operating mode of the node.
        node_id (str): A unique identifier for this node instance. Use "auto"
            to use the machine's hostname.
        data_dir (str): The local directory used for storing databases,
            shards, and temporary build files.
        port (int): Port override. If 0, uses ServerConfig.port.
    """

    role: Literal["query", "writer"] = "query"
    node_id: str = "auto"
    data_dir: str = "/data/embedrag"
    port: int = 0

    @model_validator(mode="after")
    def _auto_node_id(self) -> NodeConfig:
        """Automatically set node_id to hostname if "auto"."""
        if self.node_id == "auto":
            self.node_id = socket.gethostname()
        return self

ObjectStoreConfig

Bases: BaseModel

Configuration for S3-compatible object storage providers.

This class manages connection details for services like AWS S3, Google Cloud Storage (in S3-compatibility mode), MinIO, and ByteDance TOS. It is used by the writer to upload snapshots and by the query node to download them.

Attributes:

Name Type Description
provider Literal['tos', 's3', 'minio']

The storage provider type. Defaults to "s3".

endpoint str

Custom endpoint URL (e.g., "http://localhost:9000" for MinIO).

bucket str

The name of the bucket where snapshots are stored.

prefix str

A key prefix (folder) within the bucket for all EmbedRAG data.

access_key_env str

Name of the environment variable holding the access key.

secret_key_env str

Name of the environment variable holding the secret key.

region str

The AWS region identifier (e.g., "us-east-1").

Source code in src/embedrag/config.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ObjectStoreConfig(BaseModel):
    """Configuration for S3-compatible object storage providers.

    This class manages connection details for services like AWS S3, Google Cloud
    Storage (in S3-compatibility mode), MinIO, and ByteDance TOS. It is used
    by the writer to upload snapshots and by the query node to download them.

    Attributes:
        provider (Literal["tos", "s3", "minio"]): The storage provider type.
            Defaults to "s3".
        endpoint (str): Custom endpoint URL (e.g., "http://localhost:9000" for MinIO).
        bucket (str): The name of the bucket where snapshots are stored.
        prefix (str): A key prefix (folder) within the bucket for all EmbedRAG data.
        access_key_env (str): Name of the environment variable holding the access key.
        secret_key_env (str): Name of the environment variable holding the secret key.
        region (str): The AWS region identifier (e.g., "us-east-1").
    """

    provider: Literal["tos", "s3", "minio"] = "s3"
    endpoint: str = ""
    bucket: str = "embedrag-data"
    prefix: str = "snapshots/"
    access_key_env: str = "AWS_ACCESS_KEY_ID"
    secret_key_env: str = "AWS_SECRET_ACCESS_KEY"
    region: str = "us-east-1"

    @property
    def access_key(self) -> str:
        """str: The resolved access key from the environment."""
        return _resolve_env(self.access_key_env)

    @property
    def secret_key(self) -> str:
        """str: The resolved secret key from the environment."""
        return _resolve_env(self.secret_key_env)

access_key property

str: The resolved access key from the environment.

secret_key property

str: The resolved secret key from the environment.

QueryNodeConfig

Bases: BaseModel

The root configuration for an EmbedRAG Query Node.

Attributes:

Name Type Description
node NodeConfig

Basic node settings.

object_store ObjectStoreConfig

Snapshot source settings.

snapshot SnapshotConfig

Local snapshot management.

sync SyncConfig

Background update settings.

index IndexConfig

FAISS loading parameters.

search SearchConfig

Retrieval and fusion settings.

hotfix HotfixConfig

Real-time update buffer.

embedding EmbeddingConfig

Embedding service settings.

server ServerConfig

FastAPI server settings.

logging LoggingConfig

Logging settings.

metrics MetricsConfig

Monitoring settings.

Source code in src/embedrag/config.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
class QueryNodeConfig(BaseModel):
    """The root configuration for an EmbedRAG Query Node.

    Attributes:
        node (NodeConfig): Basic node settings.
        object_store (ObjectStoreConfig): Snapshot source settings.
        snapshot (SnapshotConfig): Local snapshot management.
        sync (SyncConfig): Background update settings.
        index (IndexConfig): FAISS loading parameters.
        search (SearchConfig): Retrieval and fusion settings.
        hotfix (HotfixConfig): Real-time update buffer.
        embedding (EmbeddingConfig): Embedding service settings.
        server (ServerConfig): FastAPI server settings.
        logging (LoggingConfig): Logging settings.
        metrics (MetricsConfig): Monitoring settings.
    """

    node: NodeConfig = Field(default_factory=lambda: NodeConfig(role="query"))
    object_store: ObjectStoreConfig = Field(default_factory=ObjectStoreConfig)
    snapshot: SnapshotConfig = Field(default_factory=SnapshotConfig)
    sync: SyncConfig = Field(default_factory=SyncConfig)
    index: IndexConfig = Field(default_factory=IndexConfig)
    search: SearchConfig = Field(default_factory=SearchConfig)
    hotfix: HotfixConfig = Field(default_factory=HotfixConfig)
    embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
    server: ServerConfig = Field(default_factory=ServerConfig)
    logging: LoggingConfig = Field(default_factory=LoggingConfig)
    metrics: MetricsConfig = Field(default_factory=MetricsConfig)

SearchConfig

Bases: BaseModel

High-level retrieval and ranking configuration.

Attributes:

Name Type Description
default_top_k int

Default number of results to return if not specified.

max_top_k int

Absolute maximum results allowed per request.

enable_sparse bool

Whether to include the keyword (BM25) search path.

enable_hierarchy_expand bool

Whether to automatically fetch parent context for retrieved chunks.

context_depth int

How many levels of parent context to traverse.

dense_weight float

Multiplier for dense scores in RRF fusion.

sparse_weight float

Multiplier for sparse scores in RRF fusion.

Source code in src/embedrag/config.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class SearchConfig(BaseModel):
    """High-level retrieval and ranking configuration.

    Attributes:
        default_top_k (int): Default number of results to return if not specified.
        max_top_k (int): Absolute maximum results allowed per request.
        enable_sparse (bool): Whether to include the keyword (BM25) search path.
        enable_hierarchy_expand (bool): Whether to automatically fetch parent
            context for retrieved chunks.
        context_depth (int): How many levels of parent context to traverse.
        dense_weight (float): Multiplier for dense scores in RRF fusion.
        sparse_weight (float): Multiplier for sparse scores in RRF fusion.
    """

    default_top_k: int = 10
    max_top_k: int = 100
    enable_sparse: bool = True
    enable_hierarchy_expand: bool = True
    context_depth: int = 1
    dense_weight: float = 1.0
    sparse_weight: float = 1.0

ServerConfig

Bases: BaseModel

Web server (FastAPI/Uvicorn) configuration.

Attributes:

Name Type Description
host str

The interface to bind the server to.

port int

The port to listen on.

workers int

Number of worker processes. Note: For FAISS mmap stability, 1 worker is recommended.

readiness_delay_seconds float

Time to wait before reporting the node as ready during startup.

shutdown_drain_seconds int

Maximum time to wait for active requests to complete during shutdown.

Source code in src/embedrag/config.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class ServerConfig(BaseModel):
    """Web server (FastAPI/Uvicorn) configuration.

    Attributes:
        host (str): The interface to bind the server to.
        port (int): The port to listen on.
        workers (int): Number of worker processes. Note: For FAISS mmap
            stability, 1 worker is recommended.
        readiness_delay_seconds (float): Time to wait before reporting the
            node as ready during startup.
        shutdown_drain_seconds (int): Maximum time to wait for active
            requests to complete during shutdown.
    """

    host: str = "0.0.0.0"
    port: int = 8000
    workers: int = 1
    readiness_delay_seconds: float = 0
    shutdown_drain_seconds: int = 30

SnapshotConfig

Bases: BaseModel

Configuration for local snapshot management on the query node.

Attributes:

Name Type Description
bootstrap_version str

The version ID to load on startup. Use "latest" to automatically pull the most recent version from the source.

poll_interval_seconds int

How often to check for new snapshots when using basic polling. Defaults to 300.

download_concurrency int

Max number of concurrent downloads for shards and data files. Defaults to 4.

download_timeout_seconds int

Timeout for individual file downloads. Defaults to 600.

disk_reserve_bytes int

Minimum free disk space (in bytes) to maintain on the data partition. Defaults to 5GB.

Source code in src/embedrag/config.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class SnapshotConfig(BaseModel):
    """Configuration for local snapshot management on the query node.

    Attributes:
        bootstrap_version (str): The version ID to load on startup. Use "latest"
            to automatically pull the most recent version from the source.
        poll_interval_seconds (int): How often to check for new snapshots when
            using basic polling. Defaults to 300.
        download_concurrency (int): Max number of concurrent downloads for
            shards and data files. Defaults to 4.
        download_timeout_seconds (int): Timeout for individual file downloads.
            Defaults to 600.
        disk_reserve_bytes (int): Minimum free disk space (in bytes) to maintain
            on the data partition. Defaults to 5GB.
    """

    bootstrap_version: str = "latest"
    poll_interval_seconds: int = 300
    download_concurrency: int = 4
    download_timeout_seconds: int = 600
    disk_reserve_bytes: int = 5_368_709_120  # 5GB

SyncConfig

Bases: BaseModel

Background snapshot synchronization configuration.

When enabled, the query node will periodically poll the source (either object storage or an HTTP server) for newer snapshot versions and automatically perform zero-downtime hot-swaps.

Attributes:

Name Type Description
enabled bool

Whether to activate background synchronization.

source Literal['object_store', 'http']

The metadata source type.

http_url str

The base URL for fetching latest.json if source is "http".

cron str

An optional 5-field cron expression for scheduling checks.

poll_interval_seconds int

Interval between checks if cron is not set.

download_concurrency int

Max concurrent downloads for new snapshots.

download_timeout_seconds int

Timeout for sync downloads.

Source code in src/embedrag/config.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class SyncConfig(BaseModel):
    """Background snapshot synchronization configuration.

    When enabled, the query node will periodically poll the source (either
    object storage or an HTTP server) for newer snapshot versions and
    automatically perform zero-downtime hot-swaps.

    Attributes:
        enabled (bool): Whether to activate background synchronization.
        source (Literal["object_store", "http"]): The metadata source type.
        http_url (str): The base URL for fetching `latest.json` if source is "http".
        cron (str): An optional 5-field cron expression for scheduling checks.
        poll_interval_seconds (int): Interval between checks if cron is not set.
        download_concurrency (int): Max concurrent downloads for new snapshots.
        download_timeout_seconds (int): Timeout for sync downloads.
    """

    enabled: bool = False
    source: Literal["object_store", "http"] = "object_store"
    http_url: str = ""
    cron: str = ""
    poll_interval_seconds: int = 300
    download_concurrency: int = 4
    download_timeout_seconds: int = 600

WriterNodeConfig

Bases: BaseModel

The root configuration for an EmbedRAG Writer Node.

Attributes:

Name Type Description
node NodeConfig

Basic node settings.

object_store ObjectStoreConfig

Snapshot upload settings.

db DBConfig

Metadata database settings.

embedding EmbeddingConfig

Embedding service settings.

index_build IndexBuildConfig

FAISS build parameters.

server ServerConfig

FastAPI server settings (defaults to port 8001).

logging LoggingConfig

Logging settings.

metrics MetricsConfig

Monitoring settings.

Source code in src/embedrag/config.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class WriterNodeConfig(BaseModel):
    """The root configuration for an EmbedRAG Writer Node.

    Attributes:
        node (NodeConfig): Basic node settings.
        object_store (ObjectStoreConfig): Snapshot upload settings.
        db (DBConfig): Metadata database settings.
        embedding (EmbeddingConfig): Embedding service settings.
        index_build (IndexBuildConfig): FAISS build parameters.
        server (ServerConfig): FastAPI server settings (defaults to port 8001).
        logging (LoggingConfig): Logging settings.
        metrics (MetricsConfig): Monitoring settings.
    """

    node: NodeConfig = Field(default_factory=lambda: NodeConfig(role="writer"))
    object_store: ObjectStoreConfig = Field(default_factory=ObjectStoreConfig)
    db: DBConfig = Field(default_factory=DBConfig)
    embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
    index_build: IndexBuildConfig = Field(default_factory=IndexBuildConfig)
    server: ServerConfig = Field(default_factory=lambda: ServerConfig(port=8001))
    logging: LoggingConfig = Field(default_factory=LoggingConfig)
    metrics: MetricsConfig = Field(default_factory=MetricsConfig)

    @model_validator(mode="after")
    def _resolve_db_path(self) -> WriterNodeConfig:
        """Resolve the database path relative to data_dir if not specified."""
        if not self.db.path:
            self.db.path = str(Path(self.node.data_dir) / "db" / "writer.db")
        return self

load_config(path)

Load a configuration file and return the appropriate node config.

The function reads the node.role field to decide whether to return a QueryNodeConfig or a WriterNodeConfig.

Parameters:

Name Type Description Default
path str | Path

Path to the YAML configuration file.

required

Returns:

Type Description
QueryNodeConfig | WriterNodeConfig

QueryNodeConfig | WriterNodeConfig: The loaded and validated config.

Source code in src/embedrag/config.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def load_config(path: str | Path) -> QueryNodeConfig | WriterNodeConfig:
    """Load a configuration file and return the appropriate node config.

    The function reads the `node.role` field to decide whether to return
    a `QueryNodeConfig` or a `WriterNodeConfig`.

    Args:
        path (str | Path): Path to the YAML configuration file.

    Returns:
        QueryNodeConfig | WriterNodeConfig: The loaded and validated config.
    """
    with open(path) as f:
        raw = yaml.safe_load(f)

    role = raw.get("node", {}).get("role", "query")
    if role == "writer":
        return WriterNodeConfig(**raw)
    return QueryNodeConfig(**raw)

load_query_config(path=None)

Load configuration for a query node.

Parameters:

Name Type Description Default
path str | Path

Path to the YAML file. If None, returns the default configuration.

None

Returns:

Name Type Description
QueryNodeConfig QueryNodeConfig

The validated query node configuration.

Source code in src/embedrag/config.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def load_query_config(path: str | Path | None = None) -> QueryNodeConfig:
    """Load configuration for a query node.

    Args:
        path (str | Path, optional): Path to the YAML file. If None,
            returns the default configuration.

    Returns:
        QueryNodeConfig: The validated query node configuration.
    """
    if path is None:
        return QueryNodeConfig()
    with open(path) as f:
        raw = yaml.safe_load(f)
    return QueryNodeConfig(**raw)

load_writer_config(path=None)

Load configuration for a writer node.

Parameters:

Name Type Description Default
path str | Path

Path to the YAML file. If None, returns the default configuration.

None

Returns:

Name Type Description
WriterNodeConfig WriterNodeConfig

The validated writer node configuration.

Source code in src/embedrag/config.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
def load_writer_config(path: str | Path | None = None) -> WriterNodeConfig:
    """Load configuration for a writer node.

    Args:
        path (str | Path, optional): Path to the YAML file. If None,
            returns the default configuration.

    Returns:
        WriterNodeConfig: The validated writer node configuration.
    """
    if path is None:
        return WriterNodeConfig()
    with open(path) as f:
        raw = yaml.safe_load(f)
    return WriterNodeConfig(**raw)

Models

Data structures used for communication and internal storage.

API Models

Request and response models for the FastAPI endpoints.

Request/Response Pydantic models for both writer and query APIs.

This module defines the data transfer objects (DTOs) that form the API contract for EmbedRAG. These models ensure type safety, provide automatic validation, and generate high-quality OpenAPI/Swagger documentation. They are used for everything from document ingestion and index building to complex hybrid searches and administrative synchronization tasks.

ArchiveRequest

Bases: BaseModel

Parameters for creating a portable snapshot archive.

Attributes:

Name Type Description
format str

The archive format (e.g., "tar.zst").

compression_level int

Zstd compression level (1-22). Defaults to 3.

Source code in src/embedrag/models/api.py
129
130
131
132
133
134
135
136
137
138
139
class ArchiveRequest(BaseModel):
    """Parameters for creating a portable snapshot archive.

    Attributes:
        format (str, optional): The archive format (e.g., "tar.zst").
        compression_level (int, optional): Zstd compression level (1-22).
            Defaults to 3.
    """

    format: str = "tar.zst"
    compression_level: int = 3

ArchiveResponse

Bases: BaseModel

Location and metadata for a created snapshot archive.

Attributes:

Name Type Description
version str

The snapshot version that was archived.

format str

The archive format used.

path str

The relative path to the archive file on the local disk.

size_bytes int

The size of the resulting archive file.

build_time_seconds float

Time taken to compress and archive.

Source code in src/embedrag/models/api.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class ArchiveResponse(BaseModel):
    """Location and metadata for a created snapshot archive.

    Attributes:
        version (str): The snapshot version that was archived.
        format (str): The archive format used.
        path (str): The relative path to the archive file on the local disk.
        size_bytes (int): The size of the resulting archive file.
        build_time_seconds (float): Time taken to compress and archive.
    """

    version: str
    format: str
    path: str
    size_bytes: int
    build_time_seconds: float

BuildRequest

Bases: BaseModel

Parameters for triggering a new FAISS index build.

Attributes:

Name Type Description
force_full_rebuild bool

If True, ignores incremental state and rebuilds the entire index from all documents in the database. Defaults to False.

Source code in src/embedrag/models/api.py
83
84
85
86
87
88
89
90
91
92
class BuildRequest(BaseModel):
    """Parameters for triggering a new FAISS index build.

    Attributes:
        force_full_rebuild (bool, optional): If True, ignores incremental
            state and rebuilds the entire index from all documents in the
            database. Defaults to False.
    """

    force_full_rebuild: bool = False

BuildResponse

Bases: BaseModel

Detailed statistics for a successfully completed index build.

Attributes:

Name Type Description
version str

The new unique version ID for the built generation.

doc_count int

The total number of documents included in the index.

chunk_count int

The total number of chunks across all documents.

vector_count int

The total number of vectors embedded in FAISS.

num_shards int

The number of physical index shards created.

build_time_seconds float

The wall-clock time taken for the build.

Source code in src/embedrag/models/api.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class BuildResponse(BaseModel):
    """Detailed statistics for a successfully completed index build.

    Attributes:
        version (str): The new unique version ID for the built generation.
        doc_count (int): The total number of documents included in the index.
        chunk_count (int): The total number of chunks across all documents.
        vector_count (int): The total number of vectors embedded in FAISS.
        num_shards (int): The number of physical index shards created.
        build_time_seconds (float): The wall-clock time taken for the build.
    """

    version: str
    doc_count: int
    chunk_count: int
    vector_count: int
    num_shards: int
    build_time_seconds: float

BulkDeleteRequest

Bases: BaseModel

Parameters for deleting documents in bulk.

Attributes:

Name Type Description
doc_ids list[str]

A specific list of document IDs to delete.

doc_type str

If provided, deletes all documents of this type.

Source code in src/embedrag/models/api.py
246
247
248
249
250
251
252
253
254
255
class BulkDeleteRequest(BaseModel):
    """Parameters for deleting documents in bulk.

    Attributes:
        doc_ids (list[str], optional): A specific list of document IDs to delete.
        doc_type (str, optional): If provided, deletes all documents of this type.
    """

    doc_ids: list[str] = Field(default_factory=list)
    doc_type: str = ""

BulkDeleteResponse

Bases: BaseModel

Summary of a bulk deletion operation.

Attributes:

Name Type Description
deleted_docs int

The number of documents successfully removed.

deleted_chunks int

The total number of chunks removed.

Source code in src/embedrag/models/api.py
258
259
260
261
262
263
264
265
266
267
class BulkDeleteResponse(BaseModel):
    """Summary of a bulk deletion operation.

    Attributes:
        deleted_docs (int): The number of documents successfully removed.
        deleted_chunks (int): The total number of chunks removed.
    """

    deleted_docs: int
    deleted_chunks: int

ChunkResult

Bases: BaseModel

A single retrieved search result (a chunk of a document).

Attributes:

Name Type Description
chunk_id str

Unique identifier for the chunk.

doc_id str

Identifier of the parent document.

text str

The text content of this chunk.

score float

The final relevance score (e.g., RRF or Cosine).

level int

Hierarchical level (0 for base chunks).

level_type str

Name of the level (e.g., "chunk", "section").

metadata dict

Metadata associated with the document/chunk.

parent_text str

Content of the parent node if context expansion was performed.

Source code in src/embedrag/models/api.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class ChunkResult(BaseModel):
    """A single retrieved search result (a chunk of a document).

    Attributes:
        chunk_id (str): Unique identifier for the chunk.
        doc_id (str): Identifier of the parent document.
        text (str): The text content of this chunk.
        score (float): The final relevance score (e.g., RRF or Cosine).
        level (int, optional): Hierarchical level (0 for base chunks).
        level_type (str, optional): Name of the level (e.g., "chunk", "section").
        metadata (dict, optional): Metadata associated with the document/chunk.
        parent_text (str, optional): Content of the parent node if context
            expansion was performed.
    """

    chunk_id: str
    doc_id: str
    text: str
    score: float
    level: int = 0
    level_type: str = "chunk"
    metadata: dict = Field(default_factory=dict)
    parent_text: str | None = None

ClusterRequest

Bases: BaseModel

Run clustering over the loaded corpus (or a filtered subset).

Source code in src/embedrag/models/api.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
class ClusterRequest(BaseModel):
    """Run clustering over the loaded corpus (or a filtered subset)."""

    algorithm: str = "auto"  # auto|hdbscan|kmeans|agglomerative|dbscan|leiden
    reduce: str = "auto"  # auto|none|pca|umap
    auto: bool = True
    space: str = "text"
    filters: dict | None = None  # doc_type / doc_id
    max_items: int = 20000
    params: dict = Field(default_factory=dict)  # min_cluster_size / k / eps / ...
    top_keywords: int = 10
    top_representatives: int = 5
    label_with_llm: bool = False
    llm_url: str = ""
    llm_model: str = ""
    llm_key: str = ""
    llm_language: str = "auto"
    persist: bool = True

DebugDenseHit

Bases: BaseModel

An intermediate hit from the dense retrieval path.

Attributes:

Name Type Description
chunk_id str

Chunk identifier.

score float

Original dense similarity score.

Source code in src/embedrag/models/api.py
696
697
698
699
700
701
702
703
704
705
class DebugDenseHit(BaseModel):
    """An intermediate hit from the dense retrieval path.

    Attributes:
        chunk_id (str): Chunk identifier.
        score (float): Original dense similarity score.
    """

    chunk_id: str
    score: float

DebugFusedHit

Bases: BaseModel

Detailed ranking state for a hit after RRF fusion.

Attributes:

Name Type Description
chunk_id str

Chunk identifier.

rrf_score float

The final calculated RRF score.

dense_score float

Score contribution from dense path.

sparse_score float

Score contribution from sparse path.

dense_rank int

Rank of this chunk in the dense list (-1 if not found).

sparse_rank int

Rank of this chunk in the sparse list (-1 if not found).

Source code in src/embedrag/models/api.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
class DebugFusedHit(BaseModel):
    """Detailed ranking state for a hit after RRF fusion.

    Attributes:
        chunk_id (str): Chunk identifier.
        rrf_score (float): The final calculated RRF score.
        dense_score (float): Score contribution from dense path.
        sparse_score (float): Score contribution from sparse path.
        dense_rank (int): Rank of this chunk in the dense list (-1 if not found).
        sparse_rank (int): Rank of this chunk in the sparse list (-1 if not found).
    """

    chunk_id: str
    rrf_score: float
    dense_score: float
    sparse_score: float
    dense_rank: int = -1
    sparse_rank: int = -1

DebugSearchRequest

Bases: BaseModel

A detailed search request that returns internal ranking state.

Attributes:

Name Type Description
query_text str

The search query in plain text.

top_k int

Number of results.

filters dict

Metadata filters.

expand_context bool

Whether to fetch adjacent chunks.

context_depth int

Context window size.

mode str

Search algorithm.

space str

Target embedding space.

Source code in src/embedrag/models/api.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
class DebugSearchRequest(BaseModel):
    """A detailed search request that returns internal ranking state.

    Attributes:
        query_text (str): The search query in plain text.
        top_k (int, optional): Number of results.
        filters (dict, optional): Metadata filters.
        expand_context (bool, optional): Whether to fetch adjacent chunks.
        context_depth (int, optional): Context window size.
        mode (str, optional): Search algorithm.
        space (str, optional): Target embedding space.
    """

    query_text: str
    top_k: int = 10
    filters: dict | None = None
    expand_context: bool = True
    context_depth: int = 1
    mode: str = "hybrid"
    space: str = "text"

DebugSearchResponse

Bases: BaseModel

Detailed diagnostic information for a search query.

Attributes:

Name Type Description
query_text str

The original search query.

mode str

The search mode used.

fts_query str

The actual FTS5 query string generated.

embedding_time_ms float

Embedding latency.

score_type str

The fusion algorithm used.

dense_results list[DebugDenseHit]

The raw top hits from dense path.

sparse_results list[DebugSparseHit]

The raw top hits from sparse path.

fused_results list[DebugFusedHit]

Rank details after fusion.

final_chunks list[ChunkResult]

The final hydrated results returned.

timing DebugTiming

Detailed latency breakdown.

config_snapshot dict

A snapshot of search-relevant config settings.

Source code in src/embedrag/models/api.py
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
class DebugSearchResponse(BaseModel):
    """Detailed diagnostic information for a search query.

    Attributes:
        query_text (str): The original search query.
        mode (str): The search mode used.
        fts_query (str): The actual FTS5 query string generated.
        embedding_time_ms (float): Embedding latency.
        score_type (str): The fusion algorithm used.
        dense_results (list[DebugDenseHit]): The raw top hits from dense path.
        sparse_results (list[DebugSparseHit]): The raw top hits from sparse path.
        fused_results (list[DebugFusedHit]): Rank details after fusion.
        final_chunks (list[ChunkResult]): The final hydrated results returned.
        timing (DebugTiming): Detailed latency breakdown.
        config_snapshot (dict): A snapshot of search-relevant config settings.
    """

    query_text: str
    mode: str
    fts_query: str = ""
    embedding_time_ms: float = 0
    score_type: str = ""
    dense_results: list[DebugDenseHit] = Field(default_factory=list)
    sparse_results: list[DebugSparseHit] = Field(default_factory=list)
    fused_results: list[DebugFusedHit] = Field(default_factory=list)
    final_chunks: list[ChunkResult] = Field(default_factory=list)
    timing: DebugTiming = Field(default_factory=DebugTiming)
    config_snapshot: dict = Field(default_factory=dict)
    """Snapshot of relevant configuration settings at query time."""

config_snapshot = Field(default_factory=dict) class-attribute instance-attribute

Snapshot of relevant configuration settings at query time.

DebugSparseHit

Bases: BaseModel

An intermediate hit from the sparse retrieval path.

Attributes:

Name Type Description
chunk_id str

Chunk identifier.

score float

Original sparse relevance score.

Source code in src/embedrag/models/api.py
708
709
710
711
712
713
714
715
716
717
class DebugSparseHit(BaseModel):
    """An intermediate hit from the sparse retrieval path.

    Attributes:
        chunk_id (str): Chunk identifier.
        score (float): Original sparse relevance score.
    """

    chunk_id: str
    score: float

DebugTiming

Bases: BaseModel

Fine-grained breakdown of search latency components.

Attributes:

Name Type Description
embedding_ms float

Latency of external embedding call.

dense_ms float

Latency of FAISS search.

sparse_ms float

Latency of SQLite FTS5 search.

fusion_ms float

Latency of the fusion algorithm.

fetch_ms float

Latency of fetching text from the database.

expand_ms float

Latency of hierarchical context expansion.

total_ms float

Total end-to-end request time.

Source code in src/embedrag/models/api.py
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
class DebugTiming(BaseModel):
    """Fine-grained breakdown of search latency components.

    Attributes:
        embedding_ms (float): Latency of external embedding call.
        dense_ms (float): Latency of FAISS search.
        sparse_ms (float): Latency of SQLite FTS5 search.
        fusion_ms (float): Latency of the fusion algorithm.
        fetch_ms (float): Latency of fetching text from the database.
        expand_ms (float): Latency of hierarchical context expansion.
        total_ms (float): Total end-to-end request time.
    """

    embedding_ms: float = 0
    dense_ms: float = 0
    sparse_ms: float = 0
    fusion_ms: float = 0
    fetch_ms: float = 0
    expand_ms: float = 0
    total_ms: float = 0

DeleteDocumentResponse

Bases: BaseModel

Response confirming document deletion.

Attributes:

Name Type Description
doc_id str

The identifier of the deleted document.

chunks_deleted int

The number of related chunks removed from storage.

Source code in src/embedrag/models/api.py
160
161
162
163
164
165
166
167
168
169
class DeleteDocumentResponse(BaseModel):
    """Response confirming document deletion.

    Attributes:
        doc_id (str): The identifier of the deleted document.
        chunks_deleted (int): The number of related chunks removed from storage.
    """

    doc_id: str
    chunks_deleted: int

DocumentDetailResponse

Bases: BaseModel

Full detail view of a single document and its structure.

Attributes:

Name Type Description
doc_id str

Unique document identifier.

title str

Document title.

source str

Origin identifier.

doc_type str

Category/type.

metadata dict

Arbitrary key-value metadata.

chunk_ids list[str]

List of all chunk IDs belonging to this document.

Source code in src/embedrag/models/api.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class DocumentDetailResponse(BaseModel):
    """Full detail view of a single document and its structure.

    Attributes:
        doc_id (str): Unique document identifier.
        title (str): Document title.
        source (str): Origin identifier.
        doc_type (str): Category/type.
        metadata (dict): Arbitrary key-value metadata.
        chunk_ids (list[str]): List of all chunk IDs belonging to this document.
    """

    doc_id: str
    title: str
    source: str
    doc_type: str
    metadata: dict
    chunk_ids: list[str]

DocumentInput

Bases: BaseModel

Input structure for a single document during the ingestion process.

This model represents the raw data that will be chunked, embedded, and stored in the writer's database.

Attributes:

Name Type Description
doc_id str

A unique global identifier for the document.

title str

The title of the document. Defaults to "".

text str

The raw text content of the document to be indexed.

source str

An identifier for the document's origin (e.g., a URL or file path). Defaults to "".

doc_type str

A category for the document (e.g., "wiki", "manual"). Useful for filtering. Defaults to "".

chunking str

The chunking strategy to use (e.g., "auto", "character", "none"). Defaults to "auto".

chunk_size int

Override for the default number of characters per chunk.

chunk_overlap int

Override for the default overlap between adjacent chunks.

metadata dict

Arbitrary key-value pairs to store with the document for filtering and display.

modality str

The data modality (e.g., "text", "image"). Defaults to "text".

content_ref str

A reference to external raw content if the index only stores metadata/vectors. Defaults to "".

Source code in src/embedrag/models/api.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class DocumentInput(BaseModel):
    """Input structure for a single document during the ingestion process.

    This model represents the raw data that will be chunked, embedded, and
    stored in the writer's database.

    Attributes:
        doc_id (str): A unique global identifier for the document.
        title (str, optional): The title of the document. Defaults to "".
        text (str): The raw text content of the document to be indexed.
        source (str, optional): An identifier for the document's origin
            (e.g., a URL or file path). Defaults to "".
        doc_type (str, optional): A category for the document (e.g., "wiki",
            "manual"). Useful for filtering. Defaults to "".
        chunking (str, optional): The chunking strategy to use (e.g., "auto",
            "character", "none"). Defaults to "auto".
        chunk_size (int, optional): Override for the default number of
            characters per chunk.
        chunk_overlap (int, optional): Override for the default overlap
            between adjacent chunks.
        metadata (dict, optional): Arbitrary key-value pairs to store with
            the document for filtering and display.
        modality (str, optional): The data modality (e.g., "text", "image").
            Defaults to "text".
        content_ref (str, optional): A reference to external raw content if
            the index only stores metadata/vectors. Defaults to "".
    """

    doc_id: str
    title: str = ""
    text: str
    source: str = ""
    doc_type: str = ""
    chunking: str = "auto"
    chunk_size: int | None = None
    chunk_overlap: int | None = None
    metadata: dict = Field(default_factory=dict)
    modality: str = "text"
    content_ref: str = ""

DocumentListResponse

Bases: BaseModel

Paginated list of documents stored on the writer node.

Attributes:

Name Type Description
documents list[DocumentSummary]

The list of document summaries.

total int

The total number of documents in the system.

limit int

The pagination limit used.

offset int

The pagination offset used.

Source code in src/embedrag/models/api.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class DocumentListResponse(BaseModel):
    """Paginated list of documents stored on the writer node.

    Attributes:
        documents (list[DocumentSummary]): The list of document summaries.
        total (int): The total number of documents in the system.
        limit (int): The pagination limit used.
        offset (int): The pagination offset used.
    """

    documents: list[DocumentSummary]
    total: int
    limit: int
    offset: int

DocumentSummary

Bases: BaseModel

A minimal representation of a document for list views.

Attributes:

Name Type Description
doc_id str

Unique document identifier.

title str

Document title.

source str

Origin identifier.

doc_type str

Category/type.

created_at str

ISO timestamp of ingestion.

Source code in src/embedrag/models/api.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class DocumentSummary(BaseModel):
    """A minimal representation of a document for list views.

    Attributes:
        doc_id (str): Unique document identifier.
        title (str): Document title.
        source (str): Origin identifier.
        doc_type (str): Category/type.
        created_at (str): ISO timestamp of ingestion.
    """

    doc_id: str
    title: str
    source: str
    doc_type: str
    created_at: str

HealthResponse

Bases: BaseModel

Basic health check status.

Attributes:

Name Type Description
status str

The node status ("ok", "starting", etc.).

node_type str

Either "query" or "writer".

version str

The code version of the running node.

Source code in src/embedrag/models/api.py
465
466
467
468
469
470
471
472
473
474
475
476
class HealthResponse(BaseModel):
    """Basic health check status.

    Attributes:
        status (str): The node status ("ok", "starting", etc.).
        node_type (str): Either "query" or "writer".
        version (str): The code version of the running node.
    """

    status: str = "ok"
    node_type: str = ""
    version: str = ""

HotfixAddRequest

Bases: BaseModel

Request to add a new chunk to the query node's real-time buffer.

Attributes:

Name Type Description
chunk_id str

Unique identifier for the new chunk.

doc_id str

Identifier for the parent document.

text str

The chunk's text content.

embedding list[float]

The pre-calculated embedding vector.

metadata dict

Key-value metadata for the chunk.

space str

The target embedding space. Defaults to "text".

Source code in src/embedrag/models/api.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
class HotfixAddRequest(BaseModel):
    """Request to add a new chunk to the query node's real-time buffer.

    Attributes:
        chunk_id (str): Unique identifier for the new chunk.
        doc_id (str): Identifier for the parent document.
        text (str): The chunk's text content.
        embedding (list[float]): The pre-calculated embedding vector.
        metadata (dict, optional): Key-value metadata for the chunk.
        space (str, optional): The target embedding space. Defaults to "text".
    """

    chunk_id: str
    doc_id: str
    text: str
    embedding: list[float]
    metadata: dict = Field(default_factory=dict)
    space: str = "text"

HotfixDeleteRequest

Bases: BaseModel

Request to logically delete chunks from the query node's search results.

Attributes:

Name Type Description
chunk_ids list[str]

List of identifiers to exclude from search results.

space str

The target embedding space. Defaults to "text".

Source code in src/embedrag/models/api.py
439
440
441
442
443
444
445
446
447
448
class HotfixDeleteRequest(BaseModel):
    """Request to logically delete chunks from the query node's search results.

    Attributes:
        chunk_ids (list[str]): List of identifiers to exclude from search results.
        space (str, optional): The target embedding space. Defaults to "text".
    """

    chunk_ids: list[str]
    space: str = "text"

HotfixResponse

Bases: BaseModel

Confirmation of a real-time hotfix operation.

Attributes:

Name Type Description
operation str

The operation type ("add" or "delete").

affected int

Number of chunks modified in the buffer.

buffer_size int

The new total size of the hotfix buffer for the space.

Source code in src/embedrag/models/api.py
451
452
453
454
455
456
457
458
459
460
461
462
class HotfixResponse(BaseModel):
    """Confirmation of a real-time hotfix operation.

    Attributes:
        operation (str): The operation type ("add" or "delete").
        affected (int): Number of chunks modified in the buffer.
        buffer_size (int): The new total size of the hotfix buffer for the space.
    """

    operation: str
    affected: int
    buffer_size: int

IngestRequest

Bases: BaseModel

Request payload for the bulk ingestion endpoint.

Attributes:

Name Type Description
documents list[DocumentInput]

A list of one or more documents to be processed and stored.

Source code in src/embedrag/models/api.py
58
59
60
61
62
63
64
65
66
class IngestRequest(BaseModel):
    """Request payload for the bulk ingestion endpoint.

    Attributes:
        documents (list[DocumentInput]): A list of one or more documents to
            be processed and stored.
    """

    documents: list[DocumentInput]

IngestResponse

Bases: BaseModel

Success response from the ingestion endpoint.

Attributes:

Name Type Description
ingested int

The number of documents successfully processed.

chunk_count int

The total number of chunks generated and stored.

doc_ids list[str]

The list of document IDs that were processed.

Source code in src/embedrag/models/api.py
69
70
71
72
73
74
75
76
77
78
79
80
class IngestResponse(BaseModel):
    """Success response from the ingestion endpoint.

    Attributes:
        ingested (int): The number of documents successfully processed.
        chunk_count (int): The total number of chunks generated and stored.
        doc_ids (list[str]): The list of document IDs that were processed.
    """

    ingested: int
    chunk_count: int
    doc_ids: list[str]

MultiSpaceSearchRequest

Bases: BaseModel

Advanced search across multiple model generations or modalities.

Attributes:

Name Type Description
queries list[SpaceQuery]

List of individual space queries to execute.

top_k int

Number of final results. Defaults to 10.

fusion str

The fusion algorithm (e.g., "rrf").

filters dict

Metadata filters applied across all spaces.

expand_context bool

Whether to fetch adjacent chunks.

context_depth int

Adjacent context window size.

Source code in src/embedrag/models/api.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
class MultiSpaceSearchRequest(BaseModel):
    """Advanced search across multiple model generations or modalities.

    Attributes:
        queries (list[SpaceQuery]): List of individual space queries to execute.
        top_k (int, optional): Number of final results. Defaults to 10.
        fusion (str, optional): The fusion algorithm (e.g., "rrf").
        filters (dict, optional): Metadata filters applied across all spaces.
        expand_context (bool, optional): Whether to fetch adjacent chunks.
        context_depth (int, optional): Adjacent context window size.
    """

    queries: list[SpaceQuery]
    top_k: int = 10
    fusion: str = "rrf"
    filters: dict | None = None
    expand_context: bool = True
    context_depth: int = 1

MultiSpaceSearchResponse

Bases: BaseModel

The unified results from a multi-space search operation.

Attributes:

Name Type Description
chunks list[ChunkResult]

The final ranked and fused hits.

total int

The total number of hits found.

per_space dict[str, int]

Count of matching documents found per space.

total_time_ms float

End-to-end request latency.

Source code in src/embedrag/models/api.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
class MultiSpaceSearchResponse(BaseModel):
    """The unified results from a multi-space search operation.

    Attributes:
        chunks (list[ChunkResult]): The final ranked and fused hits.
        total (int): The total number of hits found.
        per_space (dict[str, int]): Count of matching documents found per space.
        total_time_ms (float): End-to-end request latency.
    """

    chunks: list[ChunkResult]
    total: int
    per_space: dict[str, int] = Field(default_factory=dict)
    total_time_ms: float = 0

PublishResponse

Bases: BaseModel

Metadata for a newly published snapshot generation.

Attributes:

Name Type Description
version str

The generation version that was published.

upload_time_seconds float

Time taken to transfer files to storage.

snapshot_size_bytes int

Total size of the published files.

Source code in src/embedrag/models/api.py
115
116
117
118
119
120
121
122
123
124
125
126
class PublishResponse(BaseModel):
    """Metadata for a newly published snapshot generation.

    Attributes:
        version (str): The generation version that was published.
        upload_time_seconds (float): Time taken to transfer files to storage.
        snapshot_size_bytes (int): Total size of the published files.
    """

    version: str
    upload_time_seconds: float
    snapshot_size_bytes: int

ReadinessResponse

Bases: BaseModel

Detailed probe to determine if the node can serve traffic.

Attributes:

Name Type Description
ready bool

True if the node is fully initialized (e.g., index loaded).

active_version str

The snapshot version currently being served.

vector_count int

Total vectors available for search.

doc_count int

Total documents available for search.

Source code in src/embedrag/models/api.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
class ReadinessResponse(BaseModel):
    """Detailed probe to determine if the node can serve traffic.

    Attributes:
        ready (bool): True if the node is fully initialized (e.g., index loaded).
        active_version (str, optional): The snapshot version currently being served.
        vector_count (int, optional): Total vectors available for search.
        doc_count (int, optional): Total documents available for search.
    """

    ready: bool
    active_version: str = ""
    vector_count: int = 0
    doc_count: int = 0

RerankRequest

Bases: BaseModel

Input for an external reranking service.

Attributes:

Name Type Description
query str

The search query text.

texts list[str]

The candidate texts to be reranked.

url str

Override for the reranker service URL.

model str

The model name for reranking.

Source code in src/embedrag/models/api.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class RerankRequest(BaseModel):
    """Input for an external reranking service.

    Attributes:
        query (str): The search query text.
        texts (list[str]): The candidate texts to be reranked.
        url (str, optional): Override for the reranker service URL.
        model (str, optional): The model name for reranking.
    """

    query: str
    texts: list[str]
    url: str = ""
    model: str = ""

RerankResponse

Bases: BaseModel

Result of a cross-encoder reranking operation.

Attributes:

Name Type Description
results list[RerankResult]

Results sorted by descending score.

elapsed_ms float

Time taken for the reranking operation.

Source code in src/embedrag/models/api.py
301
302
303
304
305
306
307
308
309
310
class RerankResponse(BaseModel):
    """Result of a cross-encoder reranking operation.

    Attributes:
        results (list[RerankResult]): Results sorted by descending score.
        elapsed_ms (float): Time taken for the reranking operation.
    """

    results: list[RerankResult]
    elapsed_ms: float

RerankResult

Bases: BaseModel

A single item from a reranking operation.

Attributes:

Name Type Description
index int

The index of the text in the original input list.

score float

The new relevance score assigned by the reranker.

Source code in src/embedrag/models/api.py
289
290
291
292
293
294
295
296
297
298
class RerankResult(BaseModel):
    """A single item from a reranking operation.

    Attributes:
        index (int): The index of the text in the original input list.
        score (float): The new relevance score assigned by the reranker.
    """

    index: int
    score: float

SearchRequest

Bases: BaseModel

Standard vector search request parameters.

Attributes:

Name Type Description
query_embedding list[float]

The query vector, pre-embedded in the appropriate space.

query_text str

The raw query text. Required for hybrid and sparse search modes.

top_k int

The number of results to return. Defaults to 10.

filters dict

Metadata filters to apply (e.g., {"doc_type": "manual"}).

expand_context bool

If True, retrieves adjacent chunks to provide broader context for each hit. Defaults to True.

context_depth int

Number of surrounding chunks to fetch per result. Defaults to 1.

space str

The name of the embedding space to search. Defaults to "text".

Source code in src/embedrag/models/api.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
class SearchRequest(BaseModel):
    """Standard vector search request parameters.

    Attributes:
        query_embedding (list[float]): The query vector, pre-embedded in
            the appropriate space.
        query_text (str, optional): The raw query text. Required for hybrid
            and sparse search modes.
        top_k (int, optional): The number of results to return. Defaults to 10.
        filters (dict, optional): Metadata filters to apply (e.g., `{"doc_type": "manual"}`).
        expand_context (bool, optional): If True, retrieves adjacent chunks
            to provide broader context for each hit. Defaults to True.
        context_depth (int, optional): Number of surrounding chunks to
            fetch per result. Defaults to 1.
        space (str, optional): The name of the embedding space to search.
            Defaults to "text".
    """

    query_embedding: list[float]
    query_text: str | None = None
    top_k: int = 10
    filters: dict | None = None
    expand_context: bool = True
    context_depth: int = 1
    space: str = "text"
    cluster_run_id: str | None = None
    cluster_id: int | None = None

SearchResponse

Bases: BaseModel

The standard response containing search hits and performance timing.

Attributes:

Name Type Description
chunks list[ChunkResult]

The ranked list of matching chunks.

total int

The total number of hits matching the query.

score_type str

The type of scoring used (e.g., "rrf").

embedding_time_ms float

Time taken to embed the query.

dense_time_ms float

Time taken for the dense FAISS search.

sparse_time_ms float

Time taken for the FTS5 sparse search.

fusion_time_ms float

Time taken to fuse ranked lists.

total_time_ms float

Total end-to-end request latency.

Source code in src/embedrag/models/api.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class SearchResponse(BaseModel):
    """The standard response containing search hits and performance timing.

    Attributes:
        chunks (list[ChunkResult]): The ranked list of matching chunks.
        total (int): The total number of hits matching the query.
        score_type (str, optional): The type of scoring used (e.g., "rrf").
        embedding_time_ms (float, optional): Time taken to embed the query.
        dense_time_ms (float, optional): Time taken for the dense FAISS search.
        sparse_time_ms (float, optional): Time taken for the FTS5 sparse search.
        fusion_time_ms (float, optional): Time taken to fuse ranked lists.
        total_time_ms (float, optional): Total end-to-end request latency.
    """

    chunks: list[ChunkResult]
    total: int
    score_type: str = "rrf"
    embedding_time_ms: float = 0
    dense_time_ms: float = 0
    sparse_time_ms: float = 0
    fusion_time_ms: float = 0
    total_time_ms: float = 0

SpaceQuery

Bases: BaseModel

A sub-query targeting a specific embedding space.

Attributes:

Name Type Description
space str

The identifier of the space (e.g., "v1", "v2", "images").

query_embedding list[float]

The pre-calculated vector for this space.

query_text str

The raw text for sparse path in this space.

weight float

Contribution weight during fusion. Defaults to 1.0.

Source code in src/embedrag/models/api.py
572
573
574
575
576
577
578
579
580
581
582
583
584
585
class SpaceQuery(BaseModel):
    """A sub-query targeting a specific embedding space.

    Attributes:
        space (str): The identifier of the space (e.g., "v1", "v2", "images").
        query_embedding (list[float]): The pre-calculated vector for this space.
        query_text (str, optional): The raw text for sparse path in this space.
        weight (float, optional): Contribution weight during fusion. Defaults to 1.0.
    """

    space: str
    query_embedding: list[float]
    query_text: str | None = None
    weight: float = 1.0

StatsResponse

Bases: BaseModel

Comprehensive health and size statistics for the writer node.

Attributes:

Name Type Description
doc_count int

Total documents in the SQLite database.

chunk_count int

Total chunks in the SQLite database.

embedding_spaces list[str]

Names of configured embedding spaces.

vectors_per_space dict[str, int]

Count of vectors stored for each space.

current_version str

The version ID of the last successful build.

db_size_bytes int

Size of the SQLite database file on disk.

Source code in src/embedrag/models/api.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
class StatsResponse(BaseModel):
    """Comprehensive health and size statistics for the writer node.

    Attributes:
        doc_count (int): Total documents in the SQLite database.
        chunk_count (int): Total chunks in the SQLite database.
        embedding_spaces (list[str]): Names of configured embedding spaces.
        vectors_per_space (dict[str, int]): Count of vectors stored for each space.
        current_version (str): The version ID of the last successful build.
        db_size_bytes (int): Size of the SQLite database file on disk.
    """

    doc_count: int
    chunk_count: int
    embedding_spaces: list[str]
    vectors_per_space: dict[str, int]
    current_version: str
    db_size_bytes: int

SyncStatusResponse

Bases: BaseModel

Real-time monitoring information for the snapshot sync process.

Attributes:

Name Type Description
enabled bool

Whether background sync is active.

source str

The source type ("object_store" or "http").

cron str

The cron expression being used for scheduling.

poll_interval_seconds int

Polling interval in use.

last_check_at float

Unix timestamp of the last check for updates.

last_sync_at float

Unix timestamp of the last successful index swap.

last_result str

Outcome of the last sync check.

last_version str

The snapshot version found during the last check.

next_check_at float

Unix timestamp of the next scheduled sync check.

consecutive_errors int

Number of failed sync attempts in a row.

current_version str

The snapshot version currently being served.

Source code in src/embedrag/models/api.py
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
class SyncStatusResponse(BaseModel):
    """Real-time monitoring information for the snapshot sync process.

    Attributes:
        enabled (bool): Whether background sync is active.
        source (str): The source type ("object_store" or "http").
        cron (str): The cron expression being used for scheduling.
        poll_interval_seconds (int): Polling interval in use.
        last_check_at (float): Unix timestamp of the last check for updates.
        last_sync_at (float): Unix timestamp of the last successful index swap.
        last_result (str): Outcome of the last sync check.
        last_version (str): The snapshot version found during the last check.
        next_check_at (float): Unix timestamp of the next scheduled sync check.
        consecutive_errors (int): Number of failed sync attempts in a row.
        current_version (str): The snapshot version currently being served.
    """

    enabled: bool = False
    source: str = ""
    cron: str = ""
    poll_interval_seconds: int = 0
    last_check_at: float = 0
    last_sync_at: float = 0
    last_result: str = "none"
    last_version: str = ""
    next_check_at: float = 0
    consecutive_errors: int = 0
    current_version: str = ""

SyncTriggerRequest

Bases: BaseModel

Manual request to trigger a snapshot pull or swap.

Attributes:

Name Type Description
source_url str

A specific URL to pull a snapshot from, bypassing the configured global source.

snapshot_dir str

A local directory path to swap to directly from the filesystem.

Source code in src/embedrag/models/api.py
657
658
659
660
661
662
663
664
665
666
667
668
class SyncTriggerRequest(BaseModel):
    """Manual request to trigger a snapshot pull or swap.

    Attributes:
        source_url (str, optional): A specific URL to pull a snapshot from,
            bypassing the configured global source.
        snapshot_dir (str, optional): A local directory path to swap to
            directly from the filesystem.
    """

    source_url: str = ""
    snapshot_dir: str = ""

TextSearchRequest

Bases: BaseModel

Natural language search request where the node handles embedding.

Attributes:

Name Type Description
query_text str

The search query in plain text.

top_k int

Number of results. Defaults to 10.

filters dict

Metadata filters.

expand_context bool

Whether to fetch adjacent chunks.

context_depth int

Surrounding context window size.

mode str

Search algorithm to use ("dense", "sparse", "hybrid"). Defaults to "hybrid".

space str

The embedding space to target. Defaults to "text".

Source code in src/embedrag/models/api.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
class TextSearchRequest(BaseModel):
    """Natural language search request where the node handles embedding.

    Attributes:
        query_text (str): The search query in plain text.
        top_k (int, optional): Number of results. Defaults to 10.
        filters (dict, optional): Metadata filters.
        expand_context (bool, optional): Whether to fetch adjacent chunks.
        context_depth (int, optional): Surrounding context window size.
        mode (str, optional): Search algorithm to use ("dense", "sparse",
            "hybrid"). Defaults to "hybrid".
        space (str, optional): The embedding space to target. Defaults to "text".
    """

    query_text: str
    top_k: int = 10
    filters: dict | None = None
    expand_context: bool = True
    context_depth: int = 1
    mode: str = "hybrid"
    space: str = "text"
    cluster_run_id: str | None = None
    cluster_id: int | None = None

Chunking Models

Internal models for representing document chunks.

Core document and chunk data models.

This module defines the internal data structures used to represent documents and their hierarchical decompositions within the EmbedRAG system. These models are foundational for the chunking, embedding, and retrieval processes, allowing for a rich, tree-like representation of content.

ChunkNode dataclass

A single node in the document's hierarchical chunk tree.

EmbedRAG represents documents as trees of chunks. A ChunkNode can represent anything from the entire document at the root level down to a single sentence or fixed-size window at the leaf level. This hierarchical structure enables features like "hierarchical expansion," where small chunks are retrieved but their parent context (e.g., the containing paragraph or section) is also returned to the user.

Attributes:

Name Type Description
chunk_id str

A globally unique identifier for this specific chunk.

doc_id str

The identifier of the parent Document.

text str

The text content of this chunk.

parent_chunk_id str

The ID of the parent node in the hierarchy. If None, this is a root node.

level int

The depth in the tree. Conventionally: 0=document, 1=section, 2=paragraph, 3=leaf chunk. Defaults to 0.

level_type str

A descriptive name for the level (e.g., 'chunk', 'section', 'document'). Defaults to 'chunk'.

seq_in_parent int

The 0-indexed position of this chunk among its siblings under the same parent. Defaults to 0.

metadata dict

Key-value pairs specific to this chunk.

embedding list[float]

The pre-calculated vector embedding for this chunk's text.

children list[ChunkNode]

A list of child ChunkNode instances.

Source code in src/embedrag/models/chunk.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@dataclass
class ChunkNode:
    """A single node in the document's hierarchical chunk tree.

    EmbedRAG represents documents as trees of chunks. A `ChunkNode` can represent
    anything from the entire document at the root level down to a single
    sentence or fixed-size window at the leaf level. This hierarchical structure
    enables features like "hierarchical expansion," where small chunks are
    retrieved but their parent context (e.g., the containing paragraph or section)
    is also returned to the user.

    Attributes:
        chunk_id (str): A globally unique identifier for this specific chunk.
        doc_id (str): The identifier of the parent `Document`.
        text (str): The text content of this chunk.
        parent_chunk_id (str, optional): The ID of the parent node in the
            hierarchy. If None, this is a root node.
        level (int): The depth in the tree. Conventionally: 0=document,
            1=section, 2=paragraph, 3=leaf chunk. Defaults to 0.
        level_type (str): A descriptive name for the level (e.g., 'chunk',
            'section', 'document'). Defaults to 'chunk'.
        seq_in_parent (int): The 0-indexed position of this chunk among its
            siblings under the same parent. Defaults to 0.
        metadata (dict): Key-value pairs specific to this chunk.
        embedding (list[float], optional): The pre-calculated vector
            embedding for this chunk's text.
        children (list[ChunkNode]): A list of child `ChunkNode` instances.
    """

    chunk_id: str
    doc_id: str
    text: str
    parent_chunk_id: str | None = None
    level: int = 0
    level_type: str = "chunk"
    seq_in_parent: int = 0
    metadata: dict = field(default_factory=dict)
    embedding: list[float] | None = None
    children: list[ChunkNode] = field(default_factory=list)

Document dataclass

Represents a complete, logical document ingested into the system.

A Document is the top-level unit of information. It contains the raw content and global metadata that applies to all of its child chunks.

Attributes:

Name Type Description
doc_id str

A globally unique identifier for the document.

title str

The human-readable title of the document.

source str

The origin of the document (e.g., a URL, filepath, or database key).

doc_type str

A category or classification for the document (e.g., 'technical_manual', 'news_article').

metadata dict

A dictionary of arbitrary key-value pairs stored with the document.

Source code in src/embedrag/models/chunk.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class Document:
    """Represents a complete, logical document ingested into the system.

    A `Document` is the top-level unit of information. It contains the raw
    content and global metadata that applies to all of its child chunks.

    Attributes:
        doc_id (str): A globally unique identifier for the document.
        title (str, optional): The human-readable title of the document.
        source (str, optional): The origin of the document (e.g., a URL,
            filepath, or database key).
        doc_type (str, optional): A category or classification for the
            document (e.g., 'technical_manual', 'news_article').
        metadata (dict): A dictionary of arbitrary key-value pairs stored
            with the document.
    """

    doc_id: str
    title: str = ""
    source: str = ""
    doc_type: str = ""
    metadata: dict = field(default_factory=dict)

Manifest

Models for the snapshot manifest that coordinates index loading.

Manifest v3: self-describing snapshot metadata with per-space indexes and checksums.

This module defines the schema for the EmbedRAG snapshot manifest. The manifest is the central piece of metadata that coordinates how the query node loads an index, verifying checksums and mapping shards to the correct embedding spaces.

DeltaInfo dataclass

Information about the difference between this and a previous version.

Attributes:

Name Type Description
from_version str

The base version this delta is calculated from.

unchanged_files list[str]

List of files that haven't changed.

changed_files list[str]

List of files that were added or modified.

delta_compressed_size int

Total size of the changed files when compressed.

Source code in src/embedrag/models/manifest.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass
class DeltaInfo:
    """Information about the difference between this and a previous version.

    Attributes:
        from_version (str): The base version this delta is calculated from.
        unchanged_files (list[str]): List of files that haven't changed.
        changed_files (list[str]): List of files that were added or modified.
        delta_compressed_size (int): Total size of the changed files when compressed.
    """

    from_version: str = ""
    unchanged_files: list[str] = field(default_factory=list)
    changed_files: list[str] = field(default_factory=list)
    delta_compressed_size: int = 0

FileEntry dataclass

Metadata for a single non-index file in the snapshot (e.g., SQLite DB).

Attributes:

Name Type Description
file str

The relative path to the uncompressed file.

compressed_file str

The relative path to the compressed version.

sha256_raw str

The SHA256 checksum of the uncompressed file.

sha256_compressed str

The SHA256 checksum of the compressed file.

raw_size int

The size of the uncompressed file in bytes.

compressed_size int

The size of the compressed file in bytes.

doc_count int

For databases, the total number of documents.

chunk_count int

For databases, the total number of chunks.

Source code in src/embedrag/models/manifest.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass
class FileEntry:
    """Metadata for a single non-index file in the snapshot (e.g., SQLite DB).

    Attributes:
        file (str): The relative path to the uncompressed file.
        compressed_file (str, optional): The relative path to the compressed version.
        sha256_raw (str, optional): The SHA256 checksum of the uncompressed file.
        sha256_compressed (str, optional): The SHA256 checksum of the compressed file.
        raw_size (int, optional): The size of the uncompressed file in bytes.
        compressed_size (int, optional): The size of the compressed file in bytes.
        doc_count (int, optional): For databases, the total number of documents.
        chunk_count (int, optional): For databases, the total number of chunks.
    """

    file: str
    compressed_file: str = ""
    sha256_raw: str = ""
    sha256_compressed: str = ""
    raw_size: int = 0
    compressed_size: int = 0
    doc_count: int = 0
    chunk_count: int = 0

IndexInfo dataclass

Metadata describing the FAISS index for a specific embedding space.

Attributes:

Name Type Description
type str

The FAISS index factory string (e.g., 'IVF4096,PQ64').

dim int

The dimensionality of the vectors in this index.

metric str

The distance metric used (e.g., 'IP' for Inner Product).

nprobe_default int

The default nprobe value for search.

num_shards int

The number of shards the index is split into.

total_vectors int

The total number of vectors across all shards.

shards list[ShardEntry]

A list of individual shard metadata.

Source code in src/embedrag/models/manifest.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass
class IndexInfo:
    """Metadata describing the FAISS index for a specific embedding space.

    Attributes:
        type (str): The FAISS index factory string (e.g., 'IVF4096,PQ64').
        dim (int): The dimensionality of the vectors in this index.
        metric (str): The distance metric used (e.g., 'IP' for Inner Product).
        nprobe_default (int): The default nprobe value for search.
        num_shards (int): The number of shards the index is split into.
        total_vectors (int): The total number of vectors across all shards.
        shards (list[ShardEntry]): A list of individual shard metadata.
    """

    type: str = "IVF4096,PQ64"
    dim: int = 1024
    metric: str = "IP"
    nprobe_default: int = 32
    num_shards: int = 4
    total_vectors: int = 0
    shards: list[ShardEntry] = field(default_factory=list)

Manifest dataclass

The root metadata object for an EmbedRAG snapshot.

Attributes:

Name Type Description
manifest_version int

The version of the manifest schema itself (currently 3).

snapshot_version str

The unique version identifier for this snapshot.

created_at str

ISO timestamp of when the snapshot was created.

previous_version str

The version ID of the snapshot this one was built upon.

schema_version int

The version of the underlying database schema.

indexes dict[str, IndexInfo]

Mapping of space names to index metadata.

db FileEntry

Metadata for the primary SQLite database.

id_maps dict[str, FileEntry]

Mapping of space names to ID mapping files.

total_raw_size int

Sum of all uncompressed file sizes.

total_compressed_size int

Sum of all compressed file sizes.

delta DeltaInfo

Differential information relative to a previous version.

Source code in src/embedrag/models/manifest.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@dataclass
class Manifest:
    """The root metadata object for an EmbedRAG snapshot.

    Attributes:
        manifest_version (int): The version of the manifest schema itself (currently 3).
        snapshot_version (str): The unique version identifier for this snapshot.
        created_at (str): ISO timestamp of when the snapshot was created.
        previous_version (str): The version ID of the snapshot this one was built upon.
        schema_version (int): The version of the underlying database schema.
        indexes (dict[str, IndexInfo]): Mapping of space names to index metadata.
        db (FileEntry): Metadata for the primary SQLite database.
        id_maps (dict[str, FileEntry]): Mapping of space names to ID mapping files.
        total_raw_size (int): Sum of all uncompressed file sizes.
        total_compressed_size (int): Sum of all compressed file sizes.
        delta (DeltaInfo, optional): Differential information relative to a previous version.
    """

    manifest_version: int = 3
    snapshot_version: str = ""
    created_at: str = ""
    previous_version: str = ""
    schema_version: int = 3

    indexes: dict[str, IndexInfo] = field(default_factory=lambda: {"text": IndexInfo()})
    db: FileEntry = field(default_factory=lambda: FileEntry(file="db/embedrag.db"))
    id_maps: dict[str, FileEntry] = field(default_factory=lambda: {"text": FileEntry(file="index/text/id_map.msgpack")})

    total_raw_size: int = 0
    total_compressed_size: int = 0

    delta: DeltaInfo | None = None

    @property
    def spaces(self) -> list[str]:
        """list[str]: A list of all embedding space names defined in this manifest."""
        return list(self.indexes.keys())

    def all_compressed_files(self) -> list[str]:
        """Return a list of all compressed file paths required for this snapshot.

        This method is used by the syncer to determine which files need to be
        downloaded from the object store.

        Returns:
            list[str]: A list of relative file paths.
        """
        files: list[str] = []
        for idx_info in self.indexes.values():
            for shard in idx_info.shards:
                files.append(shard.compressed_file or shard.file)
        files.append(self.db.compressed_file or self.db.file)
        for idm in self.id_maps.values():
            files.append(idm.compressed_file or idm.file)
        return files

    def get_file_entry_by_raw(self, raw_path: str) -> FileEntry | ShardEntry | None:
        """Look up a file or shard entry using its uncompressed (raw) path.

        Args:
            raw_path (str): The relative path to the uncompressed file.

        Returns:
            FileEntry | ShardEntry | None: The matching metadata entry,
                or None if not found.
        """
        for idx_info in self.indexes.values():
            for shard in idx_info.shards:
                if shard.file == raw_path:
                    return shard
        if self.db.file == raw_path:
            return self.db
        for idm in self.id_maps.values():
            if idm.file == raw_path:
                return idm
        return None

    def get_file_entry_by_compressed(self, compressed_path: str) -> FileEntry | ShardEntry | None:
        for idx_info in self.indexes.values():
            for shard in idx_info.shards:
                if (shard.compressed_file or shard.file) == compressed_path:
                    return shard
        if (self.db.compressed_file or self.db.file) == compressed_path:
            return self.db
        for idm in self.id_maps.values():
            if (idm.compressed_file or idm.file) == compressed_path:
                return idm
        return None

    def save(self, path: str | Path) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w") as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def load(cls, path: str | Path) -> Manifest:
        with open(path) as f:
            data = json.load(f)
        return cls.from_dict(data)

    @classmethod
    def from_json(cls, raw: str) -> Manifest:
        return cls.from_dict(json.loads(raw))

    def to_json(self) -> str:
        return json.dumps(self.to_dict(), indent=2)

    def to_dict(self) -> dict:
        def _idx_to_dict(idx: IndexInfo) -> dict:
            return {
                "type": idx.type,
                "dim": idx.dim,
                "metric": idx.metric,
                "nprobe_default": idx.nprobe_default,
                "num_shards": idx.num_shards,
                "total_vectors": idx.total_vectors,
                "shards": [
                    {
                        "file": s.file,
                        "compressed_file": s.compressed_file,
                        "sha256_raw": s.sha256_raw,
                        "sha256_compressed": s.sha256_compressed,
                        "raw_size": s.raw_size,
                        "compressed_size": s.compressed_size,
                        "num_vectors": s.num_vectors,
                    }
                    for s in idx.shards
                ],
            }

        def _fe_to_dict(fe: FileEntry) -> dict:
            return {
                "file": fe.file,
                "compressed_file": fe.compressed_file,
                "sha256_raw": fe.sha256_raw,
                "sha256_compressed": fe.sha256_compressed,
                "raw_size": fe.raw_size,
                "compressed_size": fe.compressed_size,
            }

        d: dict = {
            "manifest_version": self.manifest_version,
            "snapshot_version": self.snapshot_version,
            "created_at": self.created_at,
            "previous_version": self.previous_version,
            "schema_version": self.schema_version,
            "indexes": {space: _idx_to_dict(idx) for space, idx in self.indexes.items()},
            "db": {
                "file": self.db.file,
                "compressed_file": self.db.compressed_file,
                "sha256_raw": self.db.sha256_raw,
                "sha256_compressed": self.db.sha256_compressed,
                "raw_size": self.db.raw_size,
                "compressed_size": self.db.compressed_size,
                "doc_count": self.db.doc_count,
                "chunk_count": self.db.chunk_count,
            },
            "id_maps": {space: _fe_to_dict(fe) for space, fe in self.id_maps.items()},
            "total_raw_size": self.total_raw_size,
            "total_compressed_size": self.total_compressed_size,
        }
        if self.delta:
            d["delta"] = {
                "from_version": self.delta.from_version,
                "unchanged_files": self.delta.unchanged_files,
                "changed_files": self.delta.changed_files,
                "delta_compressed_size": self.delta.delta_compressed_size,
            }
        return d

    @classmethod
    def from_dict(cls, data: dict) -> Manifest:
        def _parse_index(d: dict) -> IndexInfo:
            shards = [ShardEntry(**s) for s in d.get("shards", [])]
            return IndexInfo(
                type=d.get("type", "IVF4096,PQ64"),
                dim=d.get("dim", 1024),
                metric=d.get("metric", "IP"),
                nprobe_default=d.get("nprobe_default", 32),
                num_shards=d.get("num_shards", 4),
                total_vectors=d.get("total_vectors", 0),
                shards=shards,
            )

        def _parse_file_entry(d: dict, default_file: str = "") -> FileEntry:
            return FileEntry(
                file=d.get("file", default_file),
                compressed_file=d.get("compressed_file", ""),
                sha256_raw=d.get("sha256_raw", ""),
                sha256_compressed=d.get("sha256_compressed", ""),
                raw_size=d.get("raw_size", 0),
                compressed_size=d.get("compressed_size", 0),
                doc_count=d.get("doc_count", 0),
                chunk_count=d.get("chunk_count", 0),
            )

        indexes = {space: _parse_index(d) for space, d in data.get("indexes", {}).items()} or {"text": IndexInfo()}

        id_maps = {
            space: _parse_file_entry(d, f"index/{space}/id_map.msgpack") for space, d in data.get("id_maps", {}).items()
        } or {"text": FileEntry(file="index/text/id_map.msgpack")}

        db_data = data.get("db", {})
        db = _parse_file_entry(db_data, "db/embedrag.db")

        delta_data = data.get("delta")
        delta = None
        if delta_data:
            delta = DeltaInfo(
                from_version=delta_data.get("from_version", ""),
                unchanged_files=delta_data.get("unchanged_files", []),
                changed_files=delta_data.get("changed_files", []),
                delta_compressed_size=delta_data.get("delta_compressed_size", 0),
            )

        return cls(
            manifest_version=data.get("manifest_version", 3),
            snapshot_version=data.get("snapshot_version", ""),
            created_at=data.get("created_at", ""),
            previous_version=data.get("previous_version", ""),
            schema_version=data.get("schema_version", 3),
            indexes=indexes,
            db=db,
            id_maps=id_maps,
            total_raw_size=data.get("total_raw_size", 0),
            total_compressed_size=data.get("total_compressed_size", 0),
            delta=delta,
        )

spaces property

list[str]: A list of all embedding space names defined in this manifest.

all_compressed_files()

Return a list of all compressed file paths required for this snapshot.

This method is used by the syncer to determine which files need to be downloaded from the object store.

Returns:

Type Description
list[str]

list[str]: A list of relative file paths.

Source code in src/embedrag/models/manifest.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def all_compressed_files(self) -> list[str]:
    """Return a list of all compressed file paths required for this snapshot.

    This method is used by the syncer to determine which files need to be
    downloaded from the object store.

    Returns:
        list[str]: A list of relative file paths.
    """
    files: list[str] = []
    for idx_info in self.indexes.values():
        for shard in idx_info.shards:
            files.append(shard.compressed_file or shard.file)
    files.append(self.db.compressed_file or self.db.file)
    for idm in self.id_maps.values():
        files.append(idm.compressed_file or idm.file)
    return files

get_file_entry_by_raw(raw_path)

Look up a file or shard entry using its uncompressed (raw) path.

Parameters:

Name Type Description Default
raw_path str

The relative path to the uncompressed file.

required

Returns:

Type Description
FileEntry | ShardEntry | None

FileEntry | ShardEntry | None: The matching metadata entry, or None if not found.

Source code in src/embedrag/models/manifest.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def get_file_entry_by_raw(self, raw_path: str) -> FileEntry | ShardEntry | None:
    """Look up a file or shard entry using its uncompressed (raw) path.

    Args:
        raw_path (str): The relative path to the uncompressed file.

    Returns:
        FileEntry | ShardEntry | None: The matching metadata entry,
            or None if not found.
    """
    for idx_info in self.indexes.values():
        for shard in idx_info.shards:
            if shard.file == raw_path:
                return shard
    if self.db.file == raw_path:
        return self.db
    for idm in self.id_maps.values():
        if idm.file == raw_path:
            return idm
    return None

ShardEntry dataclass

Represents a single FAISS index shard file within a snapshot.

Attributes:

Name Type Description
file str

The relative path to the uncompressed shard file.

compressed_file str

The relative path to the compressed version of the shard.

sha256_raw str

The SHA256 checksum of the uncompressed file.

sha256_compressed str

The SHA256 checksum of the compressed file.

raw_size int

The size of the uncompressed file in bytes.

compressed_size int

The size of the compressed file in bytes.

num_vectors int

The number of vectors contained in this shard.

Source code in src/embedrag/models/manifest.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class ShardEntry:
    """Represents a single FAISS index shard file within a snapshot.

    Attributes:
        file (str): The relative path to the uncompressed shard file.
        compressed_file (str, optional): The relative path to the compressed
            version of the shard.
        sha256_raw (str, optional): The SHA256 checksum of the uncompressed file.
        sha256_compressed (str, optional): The SHA256 checksum of the compressed file.
        raw_size (int, optional): The size of the uncompressed file in bytes.
        compressed_size (int, optional): The size of the compressed file in bytes.
        num_vectors (int, optional): The number of vectors contained in this shard.
    """

    file: str
    compressed_file: str = ""
    sha256_raw: str = ""
    sha256_compressed: str = ""
    raw_size: int = 0
    compressed_size: int = 0
    num_vectors: int = 0

Shared Utilities

Core utilities shared across the project.

Object Store

Abstraction layer for S3, TOS, and MinIO storage.

Abstraction over S3-compatible object storage (S3, TOS, MinIO).

This module provides a unified interface for interacting with various object storage providers. It is primarily used for uploading snapshots from the writer node and downloading them on the query node for synchronization.

ObjectStoreClient

Thin wrapper around boto3 S3 client for snapshot upload/download.

This client abstracts away provider-specific configurations (like custom endpoints for MinIO or ByteDance TOS) while providing a simplified API for common file and JSON operations.

Source code in src/embedrag/shared/object_store.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class ObjectStoreClient:
    """Thin wrapper around boto3 S3 client for snapshot upload/download.

    This client abstracts away provider-specific configurations (like custom
    endpoints for MinIO or ByteDance TOS) while providing a simplified API
    for common file and JSON operations.
    """

    def __init__(self, config: ObjectStoreConfig):
        """Initialize the ObjectStoreClient.

        Args:
            config (ObjectStoreConfig): The configuration object containing
                credentials, bucket name, and provider settings.
        """
        self._bucket = config.bucket
        self._prefix = config.prefix.rstrip("/")
        kwargs: dict = {
            "aws_access_key_id": config.access_key,
            "aws_secret_access_key": config.secret_key,
            "config": BotoConfig(
                retries={"max_attempts": 3, "mode": "adaptive"},
                max_pool_connections=10,
            ),
        }
        if config.endpoint:
            kwargs["endpoint_url"] = config.endpoint
        if config.region:
            kwargs["region_name"] = config.region

        self._client = boto3.client("s3", **kwargs)

    def _key(self, path: str) -> str:
        """Helper to prepend the configured prefix to a path."""
        return f"{self._prefix}/{path}" if self._prefix else path

    def upload_file(self, local_path: str | Path, remote_path: str) -> None:
        """Upload a local file to the object store.

        Args:
            local_path (str | Path): The path to the file on the local filesystem.
            remote_path (str): The destination path (key) in the bucket, relative
                to the configured prefix.
        """
        key = self._key(remote_path)
        logger.info("object_store_upload", key=key, local=str(local_path))
        self._client.upload_file(str(local_path), self._bucket, key)

    def download_file(self, remote_path: str, local_path: str | Path) -> None:
        """Download a file from the object store to the local filesystem.

        Args:
            remote_path (str): The source path (key) in the bucket, relative
                to the configured prefix.
            local_path (str | Path): The local destination path. Parent directories
                will be created if they don't exist.
        """
        key = self._key(remote_path)
        Path(local_path).parent.mkdir(parents=True, exist_ok=True)
        logger.info("object_store_download", key=key, local=str(local_path))
        self._client.download_file(self._bucket, key, str(local_path))

    def put_json(self, remote_path: str, data: dict) -> None:
        """Serialize a dictionary to JSON and upload it to the object store.

        Args:
            remote_path (str): The destination path (key) in the bucket.
            data (dict): The dictionary to serialize and upload.
        """
        key = self._key(remote_path)
        body = json.dumps(data, indent=2).encode()
        self._client.put_object(Bucket=self._bucket, Key=key, Body=body)

    def get_json(self, remote_path: str) -> dict | None:
        """Download a JSON file and deserialize it into a dictionary.

        Args:
            remote_path (str): The source path (key) in the bucket.

        Returns:
            dict | None: The deserialized dictionary, or None if the key
                does not exist.
        """
        key = self._key(remote_path)
        try:
            resp = self._client.get_object(Bucket=self._bucket, Key=key)
            return json.loads(resp["Body"].read())
        except self._client.exceptions.NoSuchKey:
            return None

    def head_object(self, remote_path: str) -> dict | None:
        """Retrieve metadata for an object without downloading its content.

        Args:
            remote_path (str): The path (key) in the bucket.

        Returns:
            dict | None: The object metadata, or None if an error occurs
                (e.g., object not found).
        """
        key = self._key(remote_path)
        try:
            return self._client.head_object(Bucket=self._bucket, Key=key)
        except Exception:
            return None

    def list_prefix(self, prefix: str) -> list[str]:
        """List all object keys under a specific prefix.

        Args:
            prefix (str): The prefix to list objects from.

        Returns:
            list[str]: A list of full object keys (including the bucket prefix).
        """
        key = self._key(prefix)
        paginator = self._client.get_paginator("list_objects_v2")
        keys: list[str] = []
        for page in paginator.paginate(Bucket=self._bucket, Prefix=key):
            for obj in page.get("Contents", []):
                keys.append(obj["Key"])
        return keys

__init__(config)

Initialize the ObjectStoreClient.

Parameters:

Name Type Description Default
config ObjectStoreConfig

The configuration object containing credentials, bucket name, and provider settings.

required
Source code in src/embedrag/shared/object_store.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(self, config: ObjectStoreConfig):
    """Initialize the ObjectStoreClient.

    Args:
        config (ObjectStoreConfig): The configuration object containing
            credentials, bucket name, and provider settings.
    """
    self._bucket = config.bucket
    self._prefix = config.prefix.rstrip("/")
    kwargs: dict = {
        "aws_access_key_id": config.access_key,
        "aws_secret_access_key": config.secret_key,
        "config": BotoConfig(
            retries={"max_attempts": 3, "mode": "adaptive"},
            max_pool_connections=10,
        ),
    }
    if config.endpoint:
        kwargs["endpoint_url"] = config.endpoint
    if config.region:
        kwargs["region_name"] = config.region

    self._client = boto3.client("s3", **kwargs)

download_file(remote_path, local_path)

Download a file from the object store to the local filesystem.

Parameters:

Name Type Description Default
remote_path str

The source path (key) in the bucket, relative to the configured prefix.

required
local_path str | Path

The local destination path. Parent directories will be created if they don't exist.

required
Source code in src/embedrag/shared/object_store.py
70
71
72
73
74
75
76
77
78
79
80
81
82
def download_file(self, remote_path: str, local_path: str | Path) -> None:
    """Download a file from the object store to the local filesystem.

    Args:
        remote_path (str): The source path (key) in the bucket, relative
            to the configured prefix.
        local_path (str | Path): The local destination path. Parent directories
            will be created if they don't exist.
    """
    key = self._key(remote_path)
    Path(local_path).parent.mkdir(parents=True, exist_ok=True)
    logger.info("object_store_download", key=key, local=str(local_path))
    self._client.download_file(self._bucket, key, str(local_path))

get_json(remote_path)

Download a JSON file and deserialize it into a dictionary.

Parameters:

Name Type Description Default
remote_path str

The source path (key) in the bucket.

required

Returns:

Type Description
dict | None

dict | None: The deserialized dictionary, or None if the key does not exist.

Source code in src/embedrag/shared/object_store.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def get_json(self, remote_path: str) -> dict | None:
    """Download a JSON file and deserialize it into a dictionary.

    Args:
        remote_path (str): The source path (key) in the bucket.

    Returns:
        dict | None: The deserialized dictionary, or None if the key
            does not exist.
    """
    key = self._key(remote_path)
    try:
        resp = self._client.get_object(Bucket=self._bucket, Key=key)
        return json.loads(resp["Body"].read())
    except self._client.exceptions.NoSuchKey:
        return None

head_object(remote_path)

Retrieve metadata for an object without downloading its content.

Parameters:

Name Type Description Default
remote_path str

The path (key) in the bucket.

required

Returns:

Type Description
dict | None

dict | None: The object metadata, or None if an error occurs (e.g., object not found).

Source code in src/embedrag/shared/object_store.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def head_object(self, remote_path: str) -> dict | None:
    """Retrieve metadata for an object without downloading its content.

    Args:
        remote_path (str): The path (key) in the bucket.

    Returns:
        dict | None: The object metadata, or None if an error occurs
            (e.g., object not found).
    """
    key = self._key(remote_path)
    try:
        return self._client.head_object(Bucket=self._bucket, Key=key)
    except Exception:
        return None

list_prefix(prefix)

List all object keys under a specific prefix.

Parameters:

Name Type Description Default
prefix str

The prefix to list objects from.

required

Returns:

Type Description
list[str]

list[str]: A list of full object keys (including the bucket prefix).

Source code in src/embedrag/shared/object_store.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def list_prefix(self, prefix: str) -> list[str]:
    """List all object keys under a specific prefix.

    Args:
        prefix (str): The prefix to list objects from.

    Returns:
        list[str]: A list of full object keys (including the bucket prefix).
    """
    key = self._key(prefix)
    paginator = self._client.get_paginator("list_objects_v2")
    keys: list[str] = []
    for page in paginator.paginate(Bucket=self._bucket, Prefix=key):
        for obj in page.get("Contents", []):
            keys.append(obj["Key"])
    return keys

put_json(remote_path, data)

Serialize a dictionary to JSON and upload it to the object store.

Parameters:

Name Type Description Default
remote_path str

The destination path (key) in the bucket.

required
data dict

The dictionary to serialize and upload.

required
Source code in src/embedrag/shared/object_store.py
84
85
86
87
88
89
90
91
92
93
def put_json(self, remote_path: str, data: dict) -> None:
    """Serialize a dictionary to JSON and upload it to the object store.

    Args:
        remote_path (str): The destination path (key) in the bucket.
        data (dict): The dictionary to serialize and upload.
    """
    key = self._key(remote_path)
    body = json.dumps(data, indent=2).encode()
    self._client.put_object(Bucket=self._bucket, Key=key, Body=body)

upload_file(local_path, remote_path)

Upload a local file to the object store.

Parameters:

Name Type Description Default
local_path str | Path

The path to the file on the local filesystem.

required
remote_path str

The destination path (key) in the bucket, relative to the configured prefix.

required
Source code in src/embedrag/shared/object_store.py
58
59
60
61
62
63
64
65
66
67
68
def upload_file(self, local_path: str | Path, remote_path: str) -> None:
    """Upload a local file to the object store.

    Args:
        local_path (str | Path): The path to the file on the local filesystem.
        remote_path (str): The destination path (key) in the bucket, relative
            to the configured prefix.
    """
    key = self._key(remote_path)
    logger.info("object_store_upload", key=key, local=str(local_path))
    self._client.upload_file(str(local_path), self._bucket, key)

Metrics

Prometheus metrics collection and reporting.

Prometheus metrics for both writer and query nodes.

This module defines the Prometheus metrics used to monitor the performance, health, and internal state of EmbedRAG nodes. These metrics can be scraped by a Prometheus server and used for building Grafana dashboards or setting up operational alerts.

Writer Node

Components specific to the writer node, which handles ingestion and indexing.

Ingestion & Build

The writer's FastAPI application and lifespan management.

Writer node FastAPI application.

This module defines the web application and runtime state management for the EmbedRAG Writer Node. The writer node is responsible for the "write" side of the system: ingesting documents, managing the persistent SQLite database, communicating with external embedding services, and building the final FAISS indexes that are published as snapshots.

WriterState

Holds all runtime state for the writer node.

This class serves as a central registry for shared resources such as database connection pools and embedding clients. It is initialized once during the application startup and made available to all API routes via the app.state object.

Attributes:

Name Type Description
config WriterNodeConfig

The validated configuration for this node.

db WriterSQLitePool

The connection pool to the primary SQLite database.

embedding_clients dict[str, EmbeddingClient]

A mapping of embedding space names to their respective API clients.

build_dir Path

The local directory where new index versions are built before being published.

current_version str

The version ID of the most recently built index.

Source code in src/embedrag/writer/app.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class WriterState:
    """Holds all runtime state for the writer node.

    This class serves as a central registry for shared resources such as
    database connection pools and embedding clients. It is initialized once
    during the application startup and made available to all API routes
    via the `app.state` object.

    Attributes:
        config (WriterNodeConfig): The validated configuration for this node.
        db (WriterSQLitePool): The connection pool to the primary SQLite database.
        embedding_clients (dict[str, EmbeddingClient]): A mapping of embedding
            space names to their respective API clients.
        build_dir (Path): The local directory where new index versions are built
            before being published.
        current_version (str): The version ID of the most recently built index.
    """

    def __init__(self, config: WriterNodeConfig):
        """Initialize the writer state.

        Args:
            config (WriterNodeConfig): The writer node configuration.
        """
        self.config = config
        self.db = WriterSQLitePool(
            db_path=config.db.path,
            wal_autocheckpoint=config.db.wal_autocheckpoint,
            cache_size_mb=config.db.cache_size_mb,
        )
        self.embedding_clients: dict[str, EmbeddingClient] = {}
        for space in config.embedding.get_all_spaces():
            space_cfg = config.embedding.get_space_config(space)
            self.embedding_clients[space] = EmbeddingClient(space_cfg)
        self.build_dir = Path(config.node.data_dir) / "builds"
        self.build_dir.mkdir(parents=True, exist_ok=True)
        self.current_version: str = ""
        self._last_manifest = None

    def get_embedding_client(self, space: str = "text") -> EmbeddingClient:
        """Retrieve the embedding client for a specific space.

        Args:
            space (str, optional): The identifier of the embedding space.
                Defaults to "text".

        Returns:
            EmbeddingClient: The client configured for the requested space.

        Raises:
            KeyError: If no client is configured for the given space name.
        """
        if space in self.embedding_clients:
            return self.embedding_clients[space]
        available = list(self.embedding_clients.keys())
        raise KeyError(f"No embedding client for space '{space}'. Available: {available}")

    @property
    def last_manifest(self):
        """The manifest from the most recent successful build."""
        return self._last_manifest

    @last_manifest.setter
    def last_manifest(self, val):
        self._last_manifest = val

    async def close(self) -> None:
        """Gracefully shut down all database connections and network clients.

        This method should be called during the application's shutdown sequence.
        """
        for client in self.embedding_clients.values():
            await client.close()
        self.db.close()

last_manifest property writable

The manifest from the most recent successful build.

__init__(config)

Initialize the writer state.

Parameters:

Name Type Description Default
config WriterNodeConfig

The writer node configuration.

required
Source code in src/embedrag/writer/app.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(self, config: WriterNodeConfig):
    """Initialize the writer state.

    Args:
        config (WriterNodeConfig): The writer node configuration.
    """
    self.config = config
    self.db = WriterSQLitePool(
        db_path=config.db.path,
        wal_autocheckpoint=config.db.wal_autocheckpoint,
        cache_size_mb=config.db.cache_size_mb,
    )
    self.embedding_clients: dict[str, EmbeddingClient] = {}
    for space in config.embedding.get_all_spaces():
        space_cfg = config.embedding.get_space_config(space)
        self.embedding_clients[space] = EmbeddingClient(space_cfg)
    self.build_dir = Path(config.node.data_dir) / "builds"
    self.build_dir.mkdir(parents=True, exist_ok=True)
    self.current_version: str = ""
    self._last_manifest = None

close() async

Gracefully shut down all database connections and network clients.

This method should be called during the application's shutdown sequence.

Source code in src/embedrag/writer/app.py
 93
 94
 95
 96
 97
 98
 99
100
async def close(self) -> None:
    """Gracefully shut down all database connections and network clients.

    This method should be called during the application's shutdown sequence.
    """
    for client in self.embedding_clients.values():
        await client.close()
    self.db.close()

get_embedding_client(space='text')

Retrieve the embedding client for a specific space.

Parameters:

Name Type Description Default
space str

The identifier of the embedding space. Defaults to "text".

'text'

Returns:

Name Type Description
EmbeddingClient EmbeddingClient

The client configured for the requested space.

Raises:

Type Description
KeyError

If no client is configured for the given space name.

Source code in src/embedrag/writer/app.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def get_embedding_client(self, space: str = "text") -> EmbeddingClient:
    """Retrieve the embedding client for a specific space.

    Args:
        space (str, optional): The identifier of the embedding space.
            Defaults to "text".

    Returns:
        EmbeddingClient: The client configured for the requested space.

    Raises:
        KeyError: If no client is configured for the given space name.
    """
    if space in self.embedding_clients:
        return self.embedding_clients[space]
    available = list(self.embedding_clients.keys())
    raise KeyError(f"No embedding client for space '{space}'. Available: {available}")

create_writer_app(config_path=None)

Factory function to create and configure the Writer FastAPI application.

This function sets up the basic FastAPI app, attaches the lifespan manager, and registers all functional routes.

Parameters:

Name Type Description Default
config_path str

An optional file path to a YAML configuration file.

None

Returns:

Name Type Description
FastAPI FastAPI

The fully configured web application instance.

Source code in src/embedrag/writer/app.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def create_writer_app(config_path: str | None = None) -> FastAPI:
    """Factory function to create and configure the Writer FastAPI application.

    This function sets up the basic FastAPI app, attaches the lifespan manager,
    and registers all functional routes.

    Args:
        config_path (str, optional): An optional file path to a YAML
            configuration file.

    Returns:
        FastAPI: The fully configured web application instance.
    """
    app = FastAPI(title="EmbedRAG Writer", version="0.5.1", lifespan=writer_lifespan)
    app.state.config_path = config_path

    @app.get("/metrics", include_in_schema=False)
    async def metrics() -> PlainTextResponse:
        """Prometheus metrics endpoint."""
        from prometheus_client import CONTENT_TYPE_LATEST, generate_latest

        return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)

    from embedrag.writer.routes import router

    app.include_router(router)
    return app

writer_lifespan(app) async

Manages the lifecycle of the Writer FastAPI application.

This context manager handles the startup and shutdown phases. On startup, it loads the configuration, initializes the WriterState, and sets up structured logging. On shutdown, it ensures that all resources (DB connections, clients) are released properly.

Parameters:

Name Type Description Default
app FastAPI

The FastAPI application instance.

required

Yields:

Name Type Description
None AsyncIterator[None]

Control is returned to the FastAPI framework to start serving.

Source code in src/embedrag/writer/app.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@asynccontextmanager
async def writer_lifespan(app: FastAPI) -> AsyncIterator[None]:
    """Manages the lifecycle of the Writer FastAPI application.

    This context manager handles the startup and shutdown phases. On startup,
    it loads the configuration, initializes the `WriterState`, and sets up
    structured logging. On shutdown, it ensures that all resources (DB
    connections, clients) are released properly.

    Args:
        app (FastAPI): The FastAPI application instance.

    Yields:
        None: Control is returned to the FastAPI framework to start serving.
    """
    config_path = app.state.config_path
    config = load_writer_config(config_path)
    setup_logging(level=config.logging.level, fmt=config.logging.format, node_type="writer")
    state = WriterState(config)
    app.state.writer = state
    logger.info("writer_started", data_dir=config.node.data_dir)
    yield
    await state.close()
    logger.info("writer_stopped")

Index Builder

The logic for constructing FAISS indexes from document vectors.

FAISS IVF_PQ index builder: training, sharding, and serialization.

This module provides the logic for constructing sharded FAISS indexes from document embeddings. To achieve high performance and support datasets that exceed single-machine memory, EmbedRAG splits the vector space into multiple independent shards. This builder handles the automatic selection of the most appropriate index type (e.g., Flat, IVF, or PQ) based on the dataset size and performs deterministic sharding to ensure consistency.

IndexBuilder

Builds sharded FAISS IVF_PQ indexes from chunk embeddings.

The builder orchestrates the entire index creation pipeline: 1. Distributing chunk IDs and embeddings into shards using consistent hashing. 2. Analyzing the dataset size to select an optimal FAISS index factory string. 3. Training IVF centroids and PQ sub-quantizers if necessary. 4. Building and serializing individual shards to disk. 5. Generating the corresponding ID mapping files.

Source code in src/embedrag/writer/index_builder.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class IndexBuilder:
    """Builds sharded FAISS IVF_PQ indexes from chunk embeddings.

    The builder orchestrates the entire index creation pipeline:
    1. Distributing chunk IDs and embeddings into shards using consistent hashing.
    2. Analyzing the dataset size to select an optimal FAISS index factory string.
    3. Training IVF centroids and PQ sub-quantizers if necessary.
    4. Building and serializing individual shards to disk.
    5. Generating the corresponding ID mapping files.
    """

    def __init__(self, config: IndexBuildConfig, dim: int = 1024):
        """Initialize the IndexBuilder.

        Args:
            config (IndexBuildConfig): Configuration parameters for the build
                process, including shard count and quantization settings.
            dim (int, optional): The dimensionality of the input vectors.
                Defaults to 1024.
        """
        self._config = config
        self._dim = dim

    def build(
        self,
        chunk_ids: list[str],
        embeddings: np.ndarray,
        output_dir: str,
        space: str = "text",
    ) -> tuple[IndexInfo, str]:
        """Build a complete sharded FAISS index for a specific embedding space.

        This is the primary entry point for index construction. It takes a
        flat list of embeddings and their IDs, partitions them, builds the
        physical shard files, and returns the metadata required for the manifest.

        Args:
            chunk_ids (list[str]): A list of globally unique chunk identifiers.
            embeddings (np.ndarray): A 2D float32 numpy array of shape (N, dim)
                containing the vectors to be indexed.
            output_dir (str): The root directory where the built files will
                be stored.
            space (str, optional): The name of the embedding space (e.g., 'text').
                Files will be stored in a sub-folder named after the space.
                Defaults to "text".

        Returns:
            tuple[IndexInfo, str]: A tuple containing:
                - IndexInfo: A dataclass describing the built index (type, shards, etc.).
                - str: The filesystem path to the generated ID mapping file.
        """
        n_vectors = len(chunk_ids)
        assert embeddings.shape == (
            n_vectors,
            self._dim,
        ), f"Shape mismatch: {embeddings.shape} vs ({n_vectors}, {self._dim})"

        output = Path(output_dir) / "index" / space
        output.mkdir(parents=True, exist_ok=True)

        logger.info("index_build_start", space=space, n_vectors=n_vectors, dim=self._dim)
        t0 = time.monotonic()

        num_shards = self._config.num_shards
        index_type = self._determine_index_type(n_vectors)

        shard_assignments = self._assign_shards(chunk_ids, num_shards)
        shard_entries: list[ShardEntry] = []
        global_id_map: dict[int, str] = {}
        global_offset = 0

        for shard_idx in range(num_shards):
            mask = shard_assignments == shard_idx
            shard_ids = [cid for cid, m in zip(chunk_ids, mask) if m]
            shard_vecs = embeddings[mask]

            if len(shard_ids) == 0:
                continue

            index = self._build_single_shard(shard_vecs, index_type, n_vectors)

            shard_file = f"shard_{shard_idx}.faiss"
            shard_path = output / shard_file
            faiss.write_index(index, str(shard_path))

            for local_idx, cid in enumerate(shard_ids):
                global_id_map[global_offset + local_idx] = cid
            global_offset += len(shard_ids)

            shard_entries.append(
                ShardEntry(
                    file=f"index/{space}/{shard_file}",
                    raw_size=shard_path.stat().st_size,
                    num_vectors=len(shard_ids),
                )
            )
            logger.info(
                "shard_built",
                space=space,
                shard=shard_idx,
                vectors=len(shard_ids),
                size_mb=round(shard_path.stat().st_size / 1024 / 1024, 1),
            )

        str_id_map = {str(k): v for k, v in global_id_map.items()}
        id_map_path = str(output / "id_map.msgpack")
        with open(id_map_path, "wb") as f:
            msgpack.pack(str_id_map, f)

        elapsed = time.monotonic() - t0
        logger.info("index_build_done", space=space, elapsed_s=round(elapsed, 1), shards=len(shard_entries))

        index_info = IndexInfo(
            type=index_type,
            dim=self._dim,
            metric="IP",
            nprobe_default=32,
            num_shards=len(shard_entries),
            total_vectors=n_vectors,
            shards=shard_entries,
        )
        return index_info, id_map_path

    def _determine_index_type(self, n_vectors: int) -> str:
        """Heuristically choose the optimal FAISS index factory string.

        The selection logic is:
            - < 1,000 vectors: 'Flat' (Exact search, no training).
            - < 50,000 vectors: 'IVF{n},Flat' (Inverted file with exact residuals).
            - >= 50,000 vectors: 'IVF{n},PQ{m}' (Product Quantization for high compression).

        Args:
            n_vectors (int): The total number of vectors in the dataset.

        Returns:
            str: A FAISS factory string compatible with `faiss.index_factory`.
        """
        if n_vectors < 1_000:
            return "Flat"
        nlist = min(self._config.ivf_nlist, max(4, int(n_vectors**0.5)))
        pq_m = self._config.pq_m
        if n_vectors < 50_000 or self._dim < pq_m:
            return f"IVF{nlist},Flat"
        return f"IVF{nlist},PQ{pq_m}"

    def _build_single_shard(self, vectors: np.ndarray, index_type: str, total_vectors: int) -> faiss.Index:
        """Construct, train, and populate a single FAISS index shard.

        Args:
            vectors (np.ndarray): The subset of vectors to be added to this shard.
            index_type (str): The FAISS factory string to use.
            total_vectors (int): The total number of vectors across all shards
                (used to determine training samples).

        Returns:
            faiss.Index: A fully built and trained FAISS index object.
        """
        n = vectors.shape[0]
        dim = vectors.shape[1]

        if index_type == "Flat":
            index = faiss.IndexFlatIP(dim)
            index.add(vectors)
            return index

        index = faiss.index_factory(dim, index_type, faiss.METRIC_INNER_PRODUCT)

        train_size = min(n, self._config.train_sample_size)
        if train_size < n:
            rng = np.random.RandomState(42)
            indices = rng.choice(n, train_size, replace=False)
            train_data = vectors[indices]
        else:
            train_data = vectors

        nlist = int(index_type.split("IVF")[1].split(",")[0])
        if train_size < nlist:
            index = faiss.IndexFlatIP(dim)
            index.add(vectors)
            return index

        index.train(train_data)
        index.add(vectors)
        return index

    def _assign_shards(self, chunk_ids: list[str], num_shards: int) -> np.ndarray:
        """Deterministically assign each chunk ID to a shard index.

        Uses MD5 hashing to ensure that a given chunk ID always maps to the
        same shard, which is critical for consistent ID mapping and
        incremental updates.

        Args:
            chunk_ids (list[str]): The list of IDs to assign.
            num_shards (int): The total number of shards.

        Returns:
            np.ndarray: An int32 array of shard assignments (0 to num_shards-1).
        """
        assignments = np.zeros(len(chunk_ids), dtype=np.int32)
        for i, cid in enumerate(chunk_ids):
            h = int(hashlib.md5(cid.encode()).hexdigest(), 16)
            assignments[i] = h % num_shards
        return assignments

__init__(config, dim=1024)

Initialize the IndexBuilder.

Parameters:

Name Type Description Default
config IndexBuildConfig

Configuration parameters for the build process, including shard count and quantization settings.

required
dim int

The dimensionality of the input vectors. Defaults to 1024.

1024
Source code in src/embedrag/writer/index_builder.py
39
40
41
42
43
44
45
46
47
48
49
def __init__(self, config: IndexBuildConfig, dim: int = 1024):
    """Initialize the IndexBuilder.

    Args:
        config (IndexBuildConfig): Configuration parameters for the build
            process, including shard count and quantization settings.
        dim (int, optional): The dimensionality of the input vectors.
            Defaults to 1024.
    """
    self._config = config
    self._dim = dim

build(chunk_ids, embeddings, output_dir, space='text')

Build a complete sharded FAISS index for a specific embedding space.

This is the primary entry point for index construction. It takes a flat list of embeddings and their IDs, partitions them, builds the physical shard files, and returns the metadata required for the manifest.

Parameters:

Name Type Description Default
chunk_ids list[str]

A list of globally unique chunk identifiers.

required
embeddings ndarray

A 2D float32 numpy array of shape (N, dim) containing the vectors to be indexed.

required
output_dir str

The root directory where the built files will be stored.

required
space str

The name of the embedding space (e.g., 'text'). Files will be stored in a sub-folder named after the space. Defaults to "text".

'text'

Returns:

Type Description
tuple[IndexInfo, str]

tuple[IndexInfo, str]: A tuple containing: - IndexInfo: A dataclass describing the built index (type, shards, etc.). - str: The filesystem path to the generated ID mapping file.

Source code in src/embedrag/writer/index_builder.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def build(
    self,
    chunk_ids: list[str],
    embeddings: np.ndarray,
    output_dir: str,
    space: str = "text",
) -> tuple[IndexInfo, str]:
    """Build a complete sharded FAISS index for a specific embedding space.

    This is the primary entry point for index construction. It takes a
    flat list of embeddings and their IDs, partitions them, builds the
    physical shard files, and returns the metadata required for the manifest.

    Args:
        chunk_ids (list[str]): A list of globally unique chunk identifiers.
        embeddings (np.ndarray): A 2D float32 numpy array of shape (N, dim)
            containing the vectors to be indexed.
        output_dir (str): The root directory where the built files will
            be stored.
        space (str, optional): The name of the embedding space (e.g., 'text').
            Files will be stored in a sub-folder named after the space.
            Defaults to "text".

    Returns:
        tuple[IndexInfo, str]: A tuple containing:
            - IndexInfo: A dataclass describing the built index (type, shards, etc.).
            - str: The filesystem path to the generated ID mapping file.
    """
    n_vectors = len(chunk_ids)
    assert embeddings.shape == (
        n_vectors,
        self._dim,
    ), f"Shape mismatch: {embeddings.shape} vs ({n_vectors}, {self._dim})"

    output = Path(output_dir) / "index" / space
    output.mkdir(parents=True, exist_ok=True)

    logger.info("index_build_start", space=space, n_vectors=n_vectors, dim=self._dim)
    t0 = time.monotonic()

    num_shards = self._config.num_shards
    index_type = self._determine_index_type(n_vectors)

    shard_assignments = self._assign_shards(chunk_ids, num_shards)
    shard_entries: list[ShardEntry] = []
    global_id_map: dict[int, str] = {}
    global_offset = 0

    for shard_idx in range(num_shards):
        mask = shard_assignments == shard_idx
        shard_ids = [cid for cid, m in zip(chunk_ids, mask) if m]
        shard_vecs = embeddings[mask]

        if len(shard_ids) == 0:
            continue

        index = self._build_single_shard(shard_vecs, index_type, n_vectors)

        shard_file = f"shard_{shard_idx}.faiss"
        shard_path = output / shard_file
        faiss.write_index(index, str(shard_path))

        for local_idx, cid in enumerate(shard_ids):
            global_id_map[global_offset + local_idx] = cid
        global_offset += len(shard_ids)

        shard_entries.append(
            ShardEntry(
                file=f"index/{space}/{shard_file}",
                raw_size=shard_path.stat().st_size,
                num_vectors=len(shard_ids),
            )
        )
        logger.info(
            "shard_built",
            space=space,
            shard=shard_idx,
            vectors=len(shard_ids),
            size_mb=round(shard_path.stat().st_size / 1024 / 1024, 1),
        )

    str_id_map = {str(k): v for k, v in global_id_map.items()}
    id_map_path = str(output / "id_map.msgpack")
    with open(id_map_path, "wb") as f:
        msgpack.pack(str_id_map, f)

    elapsed = time.monotonic() - t0
    logger.info("index_build_done", space=space, elapsed_s=round(elapsed, 1), shards=len(shard_entries))

    index_info = IndexInfo(
        type=index_type,
        dim=self._dim,
        metric="IP",
        nprobe_default=32,
        num_shards=len(shard_entries),
        total_vectors=n_vectors,
        shards=shard_entries,
    )
    return index_info, id_map_path

Storage

SQLite-based storage for document text and metadata.

SQLite WAL-mode read/write connection pool for the writer node.

This module provides a robust connection pool for the EmbedRAG writer node, supporting concurrent readers and a single serialized writer using SQLite's Write-Ahead Logging (WAL) mode. It handles schema initialization and efficient storage of documents and chunks.

WriterSQLitePool

Read/write split connection pool with WAL mode for the writer node.

EmbedRAG uses a single-writer, multiple-reader pattern. This pool manages a single persistent writer connection protected by an asyncio.Lock, and a queue of read-only connections. WAL mode is used to allow readers to proceed while a write is in progress.

Source code in src/embedrag/writer/storage.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
class WriterSQLitePool:
    """Read/write split connection pool with WAL mode for the writer node.

    EmbedRAG uses a single-writer, multiple-reader pattern. This pool manages
    a single persistent writer connection protected by an `asyncio.Lock`,
    and a queue of read-only connections. WAL mode is used to allow readers
    to proceed while a write is in progress.
    """

    def __init__(
        self,
        db_path: str,
        max_readers: int = 4,
        wal_autocheckpoint: int = 1000,
        cache_size_mb: int = 64,
    ):
        """Initialize the connection pool.

        Args:
            db_path (str): The file path to the SQLite database.
            max_readers (int, optional): The number of read-only connections to maintain
                in the pool. Defaults to 4.
            wal_autocheckpoint (int, optional): The WAL autocheckpoint interval in pages.
                Defaults to 1000.
            cache_size_mb (int, optional): The SQLite page cache size in megabytes.
                Defaults to 64.
        """
        self._db_path = db_path
        self._cache_size_mb = cache_size_mb
        Path(db_path).parent.mkdir(parents=True, exist_ok=True)

        self._writer = self._create_conn(readonly=False)
        self._writer.execute(f"PRAGMA wal_autocheckpoint={wal_autocheckpoint}")
        self._write_lock = asyncio.Lock()

        self._readers: asyncio.Queue[sqlite3.Connection] = asyncio.Queue(maxsize=max_readers)
        for _ in range(max_readers):
            self._readers.put_nowait(self._create_conn(readonly=True))
        initialize_schema(self._writer)
        logger.info("writer_pool_init", db=db_path, readers=max_readers)

    def _create_conn(self, readonly: bool) -> sqlite3.Connection:
        """Create a new SQLite connection with the standard EmbedRAG pragmas."""
        mode = "ro" if readonly else "rwc"
        uri = f"file:{self._db_path}?mode={mode}"
        conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
        conn.execute("PRAGMA journal_mode=WAL")
        conn.execute("PRAGMA synchronous=NORMAL")
        conn.execute(f"PRAGMA cache_size=-{self._cache_size_mb * 1024}")
        conn.execute("PRAGMA busy_timeout=5000")
        conn.execute("PRAGMA foreign_keys=ON")
        conn.row_factory = sqlite3.Row
        return conn

    @asynccontextmanager
    async def read_conn(self) -> AsyncIterator[sqlite3.Connection]:
        """Acquire a read-only connection from the pool.

        Yields:
            sqlite3.Connection: A read-only SQLite connection.
        """
        conn = await self._readers.get()
        try:
            yield conn
        finally:
            self._readers.put_nowait(conn)

    @asynccontextmanager
    async def write_conn(self) -> AsyncIterator[sqlite3.Connection]:
        """Acquire the exclusive writer connection.

        Yields:
            sqlite3.Connection: The read-write SQLite connection.
        """
        async with self._write_lock:
            yield self._writer

    @contextmanager
    def write_conn_sync(self) -> Iterator[sqlite3.Connection]:
        """Acquire the exclusive writer connection in a synchronous context.

        Yields:
            sqlite3.Connection: The read-write SQLite connection.
        """
        yield self._writer

    def checkpoint(self) -> None:
        """Manually trigger a SQLite WAL checkpoint (TRUNCATE)."""
        self._writer.execute("PRAGMA wal_checkpoint(TRUNCATE)")

    def close(self) -> None:
        """Close all connections in the pool and trigger a final checkpoint."""
        self.checkpoint()
        self._writer.close()
        while not self._readers.empty():
            try:
                conn = self._readers.get_nowait()
                conn.close()
            except asyncio.QueueEmpty:
                break

    # ── Document operations ──

    async def insert_document(self, doc: Document) -> None:
        async with self.write_conn() as conn:
            conn.execute(
                """INSERT OR REPLACE INTO documents
                   (doc_id, title, source, doc_type, metadata_json)
                   VALUES (?, ?, ?, ?, ?)""",
                (doc.doc_id, doc.title, doc.source, doc.doc_type, json.dumps(doc.metadata)),
            )
            conn.commit()

    async def insert_documents_batch(self, docs: list[Document]) -> None:
        async with self.write_conn() as conn:
            conn.executemany(
                """INSERT OR REPLACE INTO documents
                   (doc_id, title, source, doc_type, metadata_json)
                   VALUES (?, ?, ?, ?, ?)""",
                [(d.doc_id, d.title, d.source, d.doc_type, json.dumps(d.metadata)) for d in docs],
            )
            conn.commit()

    async def get_document(self, doc_id: str) -> Document | None:
        async with self.read_conn() as conn:
            row = conn.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,)).fetchone()
            if not row:
                return None
            return Document(
                doc_id=row["doc_id"],
                title=row["title"],
                source=row["source"],
                doc_type=row["doc_type"],
                metadata=json.loads(row["metadata_json"]),
            )

    # ── Chunk operations ──

    async def insert_chunks_batch(self, chunks: list[ChunkNode], space: str = "text") -> None:
        async with self.write_conn() as conn:
            chunk_rows = []
            emb_rows = []
            for c in chunks:
                chunk_rows.append(
                    (
                        c.chunk_id,
                        c.doc_id,
                        c.parent_chunk_id,
                        c.level,
                        c.level_type,
                        c.seq_in_parent,
                        c.text,
                        json.dumps(c.metadata),
                    )
                )
                if c.embedding:
                    emb_rows.append((c.chunk_id, space, _embed_to_blob(c.embedding)))
            conn.executemany(
                """INSERT OR REPLACE INTO chunks
                   (chunk_id, doc_id, parent_chunk_id, level, level_type,
                    seq_in_parent, text, metadata_json)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
                chunk_rows,
            )
            if emb_rows:
                conn.executemany(
                    "INSERT OR REPLACE INTO chunk_embeddings (chunk_id, space, embedding) VALUES (?, ?, ?)",
                    emb_rows,
                )
            conn.commit()

    async def cleanup_before_upsert(self, doc_ids: list[str]) -> int:
        """Remove stale FTS and closure rows for docs about to be re-ingested.

        Called before insert_chunks_batch on re-ingest so that FTS5
        (which doesn't participate in CASCADE) stays consistent.
        Returns the number of stale chunk rows cleaned.
        """
        if not doc_ids:
            return 0
        async with self.write_conn() as conn:
            placeholders = ",".join("?" * len(doc_ids))
            chunk_ids = [
                r[0]
                for r in conn.execute(
                    f"SELECT chunk_id FROM chunks WHERE doc_id IN ({placeholders})",
                    doc_ids,
                ).fetchall()
            ]
            if not chunk_ids:
                return 0
            cp = ",".join("?" * len(chunk_ids))
            conn.execute(
                f"DELETE FROM chunks_fts WHERE chunk_id IN ({cp})",
                chunk_ids,
            )
            conn.execute(
                f"DELETE FROM chunk_closure WHERE descendant_id IN ({cp})",
                chunk_ids,
            )
            conn.execute(
                f"DELETE FROM chunk_embeddings WHERE chunk_id IN ({cp})",
                chunk_ids,
            )
            conn.commit()
        return len(chunk_ids)

    async def insert_fts_batch(self, chunks: list[ChunkNode], doc_titles: dict[str, str]) -> None:
        from embedrag.text.normalize import normalize_for_fts

        async with self.write_conn() as conn:
            rows = []
            for c in chunks:
                title = doc_titles.get(c.doc_id, "")
                tags = c.metadata.get("tags", "")
                if isinstance(tags, list):
                    tags = " ".join(tags)
                rows.append(
                    (
                        c.chunk_id,
                        c.text,
                        normalize_for_fts(c.text),
                        title,
                        normalize_for_fts(title),
                        tags,
                    )
                )
            conn.executemany(
                "INSERT OR REPLACE INTO chunks_fts "
                "(chunk_id, text, text_norm, title, title_norm, tags) "
                "VALUES (?, ?, ?, ?, ?, ?)",
                rows,
            )
            conn.commit()

    async def insert_closure_batch(self, relations: list[tuple[str, str, int]]) -> None:
        """Insert closure table entries: (ancestor_id, descendant_id, depth).

        Args:
            relations: A list of ``(ancestor_id, descendant_id, depth)`` tuples
                as produced by ``build_closure_entries()``.
        """
        async with self.write_conn() as conn:
            sql = "INSERT OR IGNORE INTO chunk_closure (ancestor_id, descendant_id, depth) VALUES (?, ?, ?)"
            conn.executemany(sql, relations)
            conn.commit()

    async def get_all_chunks_with_embeddings(self, space: str = "text") -> list[tuple[str, np.ndarray]]:
        """Read all ``(chunk_id, embedding)`` pairs for a given embedding space.

        Args:
            space: The embedding space name (default ``"text"``).

        Returns:
            A list of ``(chunk_id, float32_array)`` tuples.
        """
        async with self.read_conn() as conn:
            rows = conn.execute(
                "SELECT chunk_id, embedding FROM chunk_embeddings WHERE space = ?",
                (space,),
            ).fetchall()
            return [(r["chunk_id"], _blob_to_embed(r["embedding"])) for r in rows]

    async def get_embedding_spaces(self) -> list[str]:
        """Return all distinct embedding space names in the database.

        Returns:
            An alphabetically sorted list of space names.
        """
        async with self.read_conn() as conn:
            rows = conn.execute("SELECT DISTINCT space FROM chunk_embeddings ORDER BY space").fetchall()
            return [r[0] for r in rows]

    async def get_chunk_count(self) -> int:
        """Return the total number of chunk rows in the database."""
        async with self.read_conn() as conn:
            row = conn.execute("SELECT count(*) FROM chunks").fetchone()
            return row[0]

    async def get_doc_count(self) -> int:
        """Return the total number of document rows in the database."""
        async with self.read_conn() as conn:
            row = conn.execute("SELECT count(*) FROM documents").fetchone()
            return row[0]

    async def get_per_space_vector_counts(self) -> dict[str, int]:
        """Return a mapping of embedding space to vector count.

        Returns:
            A dict like ``{"text": 1234, "image": 567}``.
        """
        async with self.read_conn() as conn:
            rows = conn.execute(
                "SELECT space, count(*) AS cnt FROM chunk_embeddings GROUP BY space ORDER BY space"
            ).fetchall()
            return {r["space"]: r["cnt"] for r in rows}

    async def list_documents(
        self,
        limit: int = 50,
        offset: int = 0,
        doc_type: str | None = None,
        source: str | None = None,
    ) -> tuple[list[dict], int]:
        """Return a paginated document list and total count matching optional filters.

        Args:
            limit: Maximum number of documents per page.
            offset: Number of documents to skip.
            doc_type: Optional document type filter.
            source: Optional document source filter.

        Returns:
            A tuple of ``(document_dicts, total_count)``.
        """
        async with self.read_conn() as conn:
            where_parts: list[str] = []
            params: list = []
            if doc_type:
                where_parts.append("doc_type = ?")
                params.append(doc_type)
            if source:
                where_parts.append("source = ?")
                params.append(source)
            where_sql = (" WHERE " + " AND ".join(where_parts)) if where_parts else ""

            total = conn.execute(f"SELECT count(*) FROM documents{where_sql}", params).fetchone()[0]

            rows = conn.execute(
                f"SELECT doc_id, title, source, doc_type, created_at "
                f"FROM documents{where_sql} ORDER BY doc_id LIMIT ? OFFSET ?",
                params + [limit, offset],
            ).fetchall()

            docs = [
                {
                    "doc_id": r["doc_id"],
                    "title": r["title"],
                    "source": r["source"],
                    "doc_type": r["doc_type"],
                    "created_at": r["created_at"],
                }
                for r in rows
            ]
            return docs, total

    async def get_chunk_ids_for_doc(self, doc_id: str) -> list[str]:
        """Return all chunk IDs belonging to a document, ordered by seq_in_parent.

        Args:
            doc_id: The document identifier.

        Returns:
            An ordered list of chunk IDs.
        """
        async with self.read_conn() as conn:
            rows = conn.execute(
                "SELECT chunk_id FROM chunks WHERE doc_id = ? ORDER BY level, seq_in_parent",
                (doc_id,),
            ).fetchall()
            return [r["chunk_id"] for r in rows]

    async def delete_documents_bulk(self, doc_ids: list[str]) -> tuple[int, int]:
        """Delete multiple documents and their associated chunks.

        Args:
            doc_ids: The document identifiers to delete.

        Returns:
            A tuple of ``(docs_deleted, chunks_deleted)``.
        """
        if not doc_ids:
            return 0, 0
        total_chunks = 0
        for doc_id in doc_ids:
            total_chunks += await self.delete_document(doc_id)
        return len(doc_ids), total_chunks

    async def get_doc_ids_by_type(self, doc_type: str) -> list[str]:
        """Return all document IDs matching a given document type.

        Args:
            doc_type: The document type to filter by.

        Returns:
            A list of matching document IDs.
        """
        async with self.read_conn() as conn:
            rows = conn.execute("SELECT doc_id FROM documents WHERE doc_type = ?", (doc_type,)).fetchall()
            return [r["doc_id"] for r in rows]

    def get_db_size_bytes(self) -> int:
        """Return the on-disk size of the database file in bytes."""
        return Path(self._db_path).stat().st_size

    async def delete_document(self, doc_id: str) -> int:
        """Delete a single document and all its associated chunks.

        Cascades through closure, FTS, and embedding tables.

        Args:
            doc_id: The document identifier to delete.

        Returns:
            The number of chunk rows deleted.
        """
        async with self.write_conn() as conn:
            chunk_ids = [
                r[0] for r in conn.execute("SELECT chunk_id FROM chunks WHERE doc_id = ?", (doc_id,)).fetchall()
            ]
            if chunk_ids:
                placeholders = ",".join("?" * len(chunk_ids))
                conn.execute(
                    f"DELETE FROM chunk_closure WHERE descendant_id IN ({placeholders})",
                    chunk_ids,
                )
                conn.execute(
                    f"DELETE FROM chunks_fts WHERE chunk_id IN ({placeholders})",
                    chunk_ids,
                )
                conn.execute(
                    f"DELETE FROM chunk_embeddings WHERE chunk_id IN ({placeholders})",
                    chunk_ids,
                )
            conn.execute("DELETE FROM chunks WHERE doc_id = ?", (doc_id,))
            conn.execute("DELETE FROM documents WHERE doc_id = ?", (doc_id,))
            conn.commit()
            return len(chunk_ids)

    def export_query_db(self, output_path: str) -> tuple[int, int]:
        """Export a lean read-only SQLite database for query nodes.

        The exported database excludes the embedding column and includes only
        the tables required for serving queries (documents, chunks, closure,
        FTS, schema version).

        Args:
            output_path: Filesystem path for the exported database.

        Returns:
            A tuple of ``(doc_count, chunk_count)`` in the exported DB.
        """
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        dst = sqlite3.connect(output_path)
        dst.execute("PRAGMA journal_mode=DELETE")
        dst.execute("PRAGMA synchronous=FULL")

        dst.executescript("""
            CREATE TABLE IF NOT EXISTS documents (
                doc_id TEXT PRIMARY KEY,
                title TEXT NOT NULL DEFAULT '',
                source TEXT NOT NULL DEFAULT '',
                doc_type TEXT NOT NULL DEFAULT '',
                metadata_json TEXT NOT NULL DEFAULT '{}',
                created_at TEXT NOT NULL DEFAULT '',
                updated_at TEXT NOT NULL DEFAULT ''
            );
            CREATE TABLE IF NOT EXISTS chunks (
                chunk_id TEXT PRIMARY KEY,
                doc_id TEXT NOT NULL,
                parent_chunk_id TEXT,
                level INTEGER NOT NULL DEFAULT 0,
                level_type TEXT NOT NULL DEFAULT 'chunk',
                seq_in_parent INTEGER NOT NULL DEFAULT 0,
                text TEXT NOT NULL,
                metadata_json TEXT NOT NULL DEFAULT '{}',
                created_at TEXT NOT NULL DEFAULT ''
            );
            CREATE INDEX IF NOT EXISTS idx_chunks_doc ON chunks(doc_id);
            CREATE INDEX IF NOT EXISTS idx_chunks_parent ON chunks(parent_chunk_id);
            CREATE INDEX IF NOT EXISTS idx_chunks_level ON chunks(level);
            CREATE TABLE IF NOT EXISTS chunk_closure (
                ancestor_id TEXT NOT NULL,
                descendant_id TEXT NOT NULL,
                depth INTEGER NOT NULL,
                PRIMARY KEY (ancestor_id, descendant_id)
            );
            CREATE INDEX IF NOT EXISTS idx_closure_desc ON chunk_closure(descendant_id, depth);
            CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
                chunk_id UNINDEXED, text, text_norm, title, title_norm, tags,
                tokenize='trigram case_sensitive 0'
            );
            CREATE TABLE IF NOT EXISTS schema_version (
                version INTEGER PRIMARY KEY,
                applied_at TEXT NOT NULL DEFAULT (datetime('now')),
                description TEXT NOT NULL DEFAULT ''
            );
        """)
        dst.execute(
            "INSERT INTO schema_version (version, description) VALUES (?, ?)",
            (3, "exported by writer"),
        )

        src = self._writer
        for row in src.execute("SELECT * FROM documents"):
            dst.execute(
                "INSERT INTO documents VALUES (?,?,?,?,?,?,?)",
                tuple(row),
            )

        for row in src.execute(
            "SELECT chunk_id, doc_id, parent_chunk_id, level, level_type, "
            "seq_in_parent, text, metadata_json, created_at FROM chunks"
        ):
            dst.execute("INSERT INTO chunks VALUES (?,?,?,?,?,?,?,?,?)", tuple(row))

        for row in src.execute("SELECT * FROM chunk_closure"):
            dst.execute("INSERT INTO chunk_closure VALUES (?,?,?)", tuple(row))

        for row in src.execute("SELECT chunk_id, text, text_norm, title, title_norm, tags FROM chunks_fts"):
            dst.execute("INSERT INTO chunks_fts VALUES (?,?,?,?,?,?)", tuple(row))

        doc_count = dst.execute("SELECT count(*) FROM documents").fetchone()[0]
        chunk_count = dst.execute("SELECT count(*) FROM chunks").fetchone()[0]

        dst.commit()
        dst.execute("VACUUM")
        dst.close()
        return doc_count, chunk_count

__init__(db_path, max_readers=4, wal_autocheckpoint=1000, cache_size_mb=64)

Initialize the connection pool.

Parameters:

Name Type Description Default
db_path str

The file path to the SQLite database.

required
max_readers int

The number of read-only connections to maintain in the pool. Defaults to 4.

4
wal_autocheckpoint int

The WAL autocheckpoint interval in pages. Defaults to 1000.

1000
cache_size_mb int

The SQLite page cache size in megabytes. Defaults to 64.

64
Source code in src/embedrag/writer/storage.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    db_path: str,
    max_readers: int = 4,
    wal_autocheckpoint: int = 1000,
    cache_size_mb: int = 64,
):
    """Initialize the connection pool.

    Args:
        db_path (str): The file path to the SQLite database.
        max_readers (int, optional): The number of read-only connections to maintain
            in the pool. Defaults to 4.
        wal_autocheckpoint (int, optional): The WAL autocheckpoint interval in pages.
            Defaults to 1000.
        cache_size_mb (int, optional): The SQLite page cache size in megabytes.
            Defaults to 64.
    """
    self._db_path = db_path
    self._cache_size_mb = cache_size_mb
    Path(db_path).parent.mkdir(parents=True, exist_ok=True)

    self._writer = self._create_conn(readonly=False)
    self._writer.execute(f"PRAGMA wal_autocheckpoint={wal_autocheckpoint}")
    self._write_lock = asyncio.Lock()

    self._readers: asyncio.Queue[sqlite3.Connection] = asyncio.Queue(maxsize=max_readers)
    for _ in range(max_readers):
        self._readers.put_nowait(self._create_conn(readonly=True))
    initialize_schema(self._writer)
    logger.info("writer_pool_init", db=db_path, readers=max_readers)

checkpoint()

Manually trigger a SQLite WAL checkpoint (TRUNCATE).

Source code in src/embedrag/writer/storage.py
138
139
140
def checkpoint(self) -> None:
    """Manually trigger a SQLite WAL checkpoint (TRUNCATE)."""
    self._writer.execute("PRAGMA wal_checkpoint(TRUNCATE)")

cleanup_before_upsert(doc_ids) async

Remove stale FTS and closure rows for docs about to be re-ingested.

Called before insert_chunks_batch on re-ingest so that FTS5 (which doesn't participate in CASCADE) stays consistent. Returns the number of stale chunk rows cleaned.

Source code in src/embedrag/writer/storage.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
async def cleanup_before_upsert(self, doc_ids: list[str]) -> int:
    """Remove stale FTS and closure rows for docs about to be re-ingested.

    Called before insert_chunks_batch on re-ingest so that FTS5
    (which doesn't participate in CASCADE) stays consistent.
    Returns the number of stale chunk rows cleaned.
    """
    if not doc_ids:
        return 0
    async with self.write_conn() as conn:
        placeholders = ",".join("?" * len(doc_ids))
        chunk_ids = [
            r[0]
            for r in conn.execute(
                f"SELECT chunk_id FROM chunks WHERE doc_id IN ({placeholders})",
                doc_ids,
            ).fetchall()
        ]
        if not chunk_ids:
            return 0
        cp = ",".join("?" * len(chunk_ids))
        conn.execute(
            f"DELETE FROM chunks_fts WHERE chunk_id IN ({cp})",
            chunk_ids,
        )
        conn.execute(
            f"DELETE FROM chunk_closure WHERE descendant_id IN ({cp})",
            chunk_ids,
        )
        conn.execute(
            f"DELETE FROM chunk_embeddings WHERE chunk_id IN ({cp})",
            chunk_ids,
        )
        conn.commit()
    return len(chunk_ids)

close()

Close all connections in the pool and trigger a final checkpoint.

Source code in src/embedrag/writer/storage.py
142
143
144
145
146
147
148
149
150
151
def close(self) -> None:
    """Close all connections in the pool and trigger a final checkpoint."""
    self.checkpoint()
    self._writer.close()
    while not self._readers.empty():
        try:
            conn = self._readers.get_nowait()
            conn.close()
        except asyncio.QueueEmpty:
            break

delete_document(doc_id) async

Delete a single document and all its associated chunks.

Cascades through closure, FTS, and embedding tables.

Parameters:

Name Type Description Default
doc_id str

The document identifier to delete.

required

Returns:

Type Description
int

The number of chunk rows deleted.

Source code in src/embedrag/writer/storage.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
async def delete_document(self, doc_id: str) -> int:
    """Delete a single document and all its associated chunks.

    Cascades through closure, FTS, and embedding tables.

    Args:
        doc_id: The document identifier to delete.

    Returns:
        The number of chunk rows deleted.
    """
    async with self.write_conn() as conn:
        chunk_ids = [
            r[0] for r in conn.execute("SELECT chunk_id FROM chunks WHERE doc_id = ?", (doc_id,)).fetchall()
        ]
        if chunk_ids:
            placeholders = ",".join("?" * len(chunk_ids))
            conn.execute(
                f"DELETE FROM chunk_closure WHERE descendant_id IN ({placeholders})",
                chunk_ids,
            )
            conn.execute(
                f"DELETE FROM chunks_fts WHERE chunk_id IN ({placeholders})",
                chunk_ids,
            )
            conn.execute(
                f"DELETE FROM chunk_embeddings WHERE chunk_id IN ({placeholders})",
                chunk_ids,
            )
        conn.execute("DELETE FROM chunks WHERE doc_id = ?", (doc_id,))
        conn.execute("DELETE FROM documents WHERE doc_id = ?", (doc_id,))
        conn.commit()
        return len(chunk_ids)

delete_documents_bulk(doc_ids) async

Delete multiple documents and their associated chunks.

Parameters:

Name Type Description Default
doc_ids list[str]

The document identifiers to delete.

required

Returns:

Type Description
tuple[int, int]

A tuple of (docs_deleted, chunks_deleted).

Source code in src/embedrag/writer/storage.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
async def delete_documents_bulk(self, doc_ids: list[str]) -> tuple[int, int]:
    """Delete multiple documents and their associated chunks.

    Args:
        doc_ids: The document identifiers to delete.

    Returns:
        A tuple of ``(docs_deleted, chunks_deleted)``.
    """
    if not doc_ids:
        return 0, 0
    total_chunks = 0
    for doc_id in doc_ids:
        total_chunks += await self.delete_document(doc_id)
    return len(doc_ids), total_chunks

export_query_db(output_path)

Export a lean read-only SQLite database for query nodes.

The exported database excludes the embedding column and includes only the tables required for serving queries (documents, chunks, closure, FTS, schema version).

Parameters:

Name Type Description Default
output_path str

Filesystem path for the exported database.

required

Returns:

Type Description
tuple[int, int]

A tuple of (doc_count, chunk_count) in the exported DB.

Source code in src/embedrag/writer/storage.py
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
def export_query_db(self, output_path: str) -> tuple[int, int]:
    """Export a lean read-only SQLite database for query nodes.

    The exported database excludes the embedding column and includes only
    the tables required for serving queries (documents, chunks, closure,
    FTS, schema version).

    Args:
        output_path: Filesystem path for the exported database.

    Returns:
        A tuple of ``(doc_count, chunk_count)`` in the exported DB.
    """
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    dst = sqlite3.connect(output_path)
    dst.execute("PRAGMA journal_mode=DELETE")
    dst.execute("PRAGMA synchronous=FULL")

    dst.executescript("""
        CREATE TABLE IF NOT EXISTS documents (
            doc_id TEXT PRIMARY KEY,
            title TEXT NOT NULL DEFAULT '',
            source TEXT NOT NULL DEFAULT '',
            doc_type TEXT NOT NULL DEFAULT '',
            metadata_json TEXT NOT NULL DEFAULT '{}',
            created_at TEXT NOT NULL DEFAULT '',
            updated_at TEXT NOT NULL DEFAULT ''
        );
        CREATE TABLE IF NOT EXISTS chunks (
            chunk_id TEXT PRIMARY KEY,
            doc_id TEXT NOT NULL,
            parent_chunk_id TEXT,
            level INTEGER NOT NULL DEFAULT 0,
            level_type TEXT NOT NULL DEFAULT 'chunk',
            seq_in_parent INTEGER NOT NULL DEFAULT 0,
            text TEXT NOT NULL,
            metadata_json TEXT NOT NULL DEFAULT '{}',
            created_at TEXT NOT NULL DEFAULT ''
        );
        CREATE INDEX IF NOT EXISTS idx_chunks_doc ON chunks(doc_id);
        CREATE INDEX IF NOT EXISTS idx_chunks_parent ON chunks(parent_chunk_id);
        CREATE INDEX IF NOT EXISTS idx_chunks_level ON chunks(level);
        CREATE TABLE IF NOT EXISTS chunk_closure (
            ancestor_id TEXT NOT NULL,
            descendant_id TEXT NOT NULL,
            depth INTEGER NOT NULL,
            PRIMARY KEY (ancestor_id, descendant_id)
        );
        CREATE INDEX IF NOT EXISTS idx_closure_desc ON chunk_closure(descendant_id, depth);
        CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
            chunk_id UNINDEXED, text, text_norm, title, title_norm, tags,
            tokenize='trigram case_sensitive 0'
        );
        CREATE TABLE IF NOT EXISTS schema_version (
            version INTEGER PRIMARY KEY,
            applied_at TEXT NOT NULL DEFAULT (datetime('now')),
            description TEXT NOT NULL DEFAULT ''
        );
    """)
    dst.execute(
        "INSERT INTO schema_version (version, description) VALUES (?, ?)",
        (3, "exported by writer"),
    )

    src = self._writer
    for row in src.execute("SELECT * FROM documents"):
        dst.execute(
            "INSERT INTO documents VALUES (?,?,?,?,?,?,?)",
            tuple(row),
        )

    for row in src.execute(
        "SELECT chunk_id, doc_id, parent_chunk_id, level, level_type, "
        "seq_in_parent, text, metadata_json, created_at FROM chunks"
    ):
        dst.execute("INSERT INTO chunks VALUES (?,?,?,?,?,?,?,?,?)", tuple(row))

    for row in src.execute("SELECT * FROM chunk_closure"):
        dst.execute("INSERT INTO chunk_closure VALUES (?,?,?)", tuple(row))

    for row in src.execute("SELECT chunk_id, text, text_norm, title, title_norm, tags FROM chunks_fts"):
        dst.execute("INSERT INTO chunks_fts VALUES (?,?,?,?,?,?)", tuple(row))

    doc_count = dst.execute("SELECT count(*) FROM documents").fetchone()[0]
    chunk_count = dst.execute("SELECT count(*) FROM chunks").fetchone()[0]

    dst.commit()
    dst.execute("VACUUM")
    dst.close()
    return doc_count, chunk_count

get_all_chunks_with_embeddings(space='text') async

Read all (chunk_id, embedding) pairs for a given embedding space.

Parameters:

Name Type Description Default
space str

The embedding space name (default "text").

'text'

Returns:

Type Description
list[tuple[str, ndarray]]

A list of (chunk_id, float32_array) tuples.

Source code in src/embedrag/writer/storage.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
async def get_all_chunks_with_embeddings(self, space: str = "text") -> list[tuple[str, np.ndarray]]:
    """Read all ``(chunk_id, embedding)`` pairs for a given embedding space.

    Args:
        space: The embedding space name (default ``"text"``).

    Returns:
        A list of ``(chunk_id, float32_array)`` tuples.
    """
    async with self.read_conn() as conn:
        rows = conn.execute(
            "SELECT chunk_id, embedding FROM chunk_embeddings WHERE space = ?",
            (space,),
        ).fetchall()
        return [(r["chunk_id"], _blob_to_embed(r["embedding"])) for r in rows]

get_chunk_count() async

Return the total number of chunk rows in the database.

Source code in src/embedrag/writer/storage.py
325
326
327
328
329
async def get_chunk_count(self) -> int:
    """Return the total number of chunk rows in the database."""
    async with self.read_conn() as conn:
        row = conn.execute("SELECT count(*) FROM chunks").fetchone()
        return row[0]

get_chunk_ids_for_doc(doc_id) async

Return all chunk IDs belonging to a document, ordered by seq_in_parent.

Parameters:

Name Type Description Default
doc_id str

The document identifier.

required

Returns:

Type Description
list[str]

An ordered list of chunk IDs.

Source code in src/embedrag/writer/storage.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
async def get_chunk_ids_for_doc(self, doc_id: str) -> list[str]:
    """Return all chunk IDs belonging to a document, ordered by seq_in_parent.

    Args:
        doc_id: The document identifier.

    Returns:
        An ordered list of chunk IDs.
    """
    async with self.read_conn() as conn:
        rows = conn.execute(
            "SELECT chunk_id FROM chunks WHERE doc_id = ? ORDER BY level, seq_in_parent",
            (doc_id,),
        ).fetchall()
        return [r["chunk_id"] for r in rows]

get_db_size_bytes()

Return the on-disk size of the database file in bytes.

Source code in src/embedrag/writer/storage.py
443
444
445
def get_db_size_bytes(self) -> int:
    """Return the on-disk size of the database file in bytes."""
    return Path(self._db_path).stat().st_size

get_doc_count() async

Return the total number of document rows in the database.

Source code in src/embedrag/writer/storage.py
331
332
333
334
335
async def get_doc_count(self) -> int:
    """Return the total number of document rows in the database."""
    async with self.read_conn() as conn:
        row = conn.execute("SELECT count(*) FROM documents").fetchone()
        return row[0]

get_doc_ids_by_type(doc_type) async

Return all document IDs matching a given document type.

Parameters:

Name Type Description Default
doc_type str

The document type to filter by.

required

Returns:

Type Description
list[str]

A list of matching document IDs.

Source code in src/embedrag/writer/storage.py
430
431
432
433
434
435
436
437
438
439
440
441
async def get_doc_ids_by_type(self, doc_type: str) -> list[str]:
    """Return all document IDs matching a given document type.

    Args:
        doc_type: The document type to filter by.

    Returns:
        A list of matching document IDs.
    """
    async with self.read_conn() as conn:
        rows = conn.execute("SELECT doc_id FROM documents WHERE doc_type = ?", (doc_type,)).fetchall()
        return [r["doc_id"] for r in rows]

get_embedding_spaces() async

Return all distinct embedding space names in the database.

Returns:

Type Description
list[str]

An alphabetically sorted list of space names.

Source code in src/embedrag/writer/storage.py
315
316
317
318
319
320
321
322
323
async def get_embedding_spaces(self) -> list[str]:
    """Return all distinct embedding space names in the database.

    Returns:
        An alphabetically sorted list of space names.
    """
    async with self.read_conn() as conn:
        rows = conn.execute("SELECT DISTINCT space FROM chunk_embeddings ORDER BY space").fetchall()
        return [r[0] for r in rows]

get_per_space_vector_counts() async

Return a mapping of embedding space to vector count.

Returns:

Type Description
dict[str, int]

A dict like {"text": 1234, "image": 567}.

Source code in src/embedrag/writer/storage.py
337
338
339
340
341
342
343
344
345
346
347
async def get_per_space_vector_counts(self) -> dict[str, int]:
    """Return a mapping of embedding space to vector count.

    Returns:
        A dict like ``{"text": 1234, "image": 567}``.
    """
    async with self.read_conn() as conn:
        rows = conn.execute(
            "SELECT space, count(*) AS cnt FROM chunk_embeddings GROUP BY space ORDER BY space"
        ).fetchall()
        return {r["space"]: r["cnt"] for r in rows}

insert_closure_batch(relations) async

Insert closure table entries: (ancestor_id, descendant_id, depth).

Parameters:

Name Type Description Default
relations list[tuple[str, str, int]]

A list of (ancestor_id, descendant_id, depth) tuples as produced by build_closure_entries().

required
Source code in src/embedrag/writer/storage.py
287
288
289
290
291
292
293
294
295
296
297
async def insert_closure_batch(self, relations: list[tuple[str, str, int]]) -> None:
    """Insert closure table entries: (ancestor_id, descendant_id, depth).

    Args:
        relations: A list of ``(ancestor_id, descendant_id, depth)`` tuples
            as produced by ``build_closure_entries()``.
    """
    async with self.write_conn() as conn:
        sql = "INSERT OR IGNORE INTO chunk_closure (ancestor_id, descendant_id, depth) VALUES (?, ?, ?)"
        conn.executemany(sql, relations)
        conn.commit()

list_documents(limit=50, offset=0, doc_type=None, source=None) async

Return a paginated document list and total count matching optional filters.

Parameters:

Name Type Description Default
limit int

Maximum number of documents per page.

50
offset int

Number of documents to skip.

0
doc_type str | None

Optional document type filter.

None
source str | None

Optional document source filter.

None

Returns:

Type Description
tuple[list[dict], int]

A tuple of (document_dicts, total_count).

Source code in src/embedrag/writer/storage.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
async def list_documents(
    self,
    limit: int = 50,
    offset: int = 0,
    doc_type: str | None = None,
    source: str | None = None,
) -> tuple[list[dict], int]:
    """Return a paginated document list and total count matching optional filters.

    Args:
        limit: Maximum number of documents per page.
        offset: Number of documents to skip.
        doc_type: Optional document type filter.
        source: Optional document source filter.

    Returns:
        A tuple of ``(document_dicts, total_count)``.
    """
    async with self.read_conn() as conn:
        where_parts: list[str] = []
        params: list = []
        if doc_type:
            where_parts.append("doc_type = ?")
            params.append(doc_type)
        if source:
            where_parts.append("source = ?")
            params.append(source)
        where_sql = (" WHERE " + " AND ".join(where_parts)) if where_parts else ""

        total = conn.execute(f"SELECT count(*) FROM documents{where_sql}", params).fetchone()[0]

        rows = conn.execute(
            f"SELECT doc_id, title, source, doc_type, created_at "
            f"FROM documents{where_sql} ORDER BY doc_id LIMIT ? OFFSET ?",
            params + [limit, offset],
        ).fetchall()

        docs = [
            {
                "doc_id": r["doc_id"],
                "title": r["title"],
                "source": r["source"],
                "doc_type": r["doc_type"],
                "created_at": r["created_at"],
            }
            for r in rows
        ]
        return docs, total

read_conn() async

Acquire a read-only connection from the pool.

Yields:

Type Description
AsyncIterator[Connection]

sqlite3.Connection: A read-only SQLite connection.

Source code in src/embedrag/writer/storage.py
106
107
108
109
110
111
112
113
114
115
116
117
@asynccontextmanager
async def read_conn(self) -> AsyncIterator[sqlite3.Connection]:
    """Acquire a read-only connection from the pool.

    Yields:
        sqlite3.Connection: A read-only SQLite connection.
    """
    conn = await self._readers.get()
    try:
        yield conn
    finally:
        self._readers.put_nowait(conn)

write_conn() async

Acquire the exclusive writer connection.

Yields:

Type Description
AsyncIterator[Connection]

sqlite3.Connection: The read-write SQLite connection.

Source code in src/embedrag/writer/storage.py
119
120
121
122
123
124
125
126
127
@asynccontextmanager
async def write_conn(self) -> AsyncIterator[sqlite3.Connection]:
    """Acquire the exclusive writer connection.

    Yields:
        sqlite3.Connection: The read-write SQLite connection.
    """
    async with self._write_lock:
        yield self._writer

write_conn_sync()

Acquire the exclusive writer connection in a synchronous context.

Yields:

Type Description
Connection

sqlite3.Connection: The read-write SQLite connection.

Source code in src/embedrag/writer/storage.py
129
130
131
132
133
134
135
136
@contextmanager
def write_conn_sync(self) -> Iterator[sqlite3.Connection]:
    """Acquire the exclusive writer connection in a synchronous context.

    Yields:
        sqlite3.Connection: The read-write SQLite connection.
    """
    yield self._writer

Query Node

Components specific to the query node, which handles search and retrieval.

Search & Retrieval

The query node's FastAPI application and retrieval logic.

Query node FastAPI application with lifespan for bootstrap and shutdown.

This module defines the web application and runtime state management for the EmbedRAG Query Node. The query node is responsible for the "read" side of the system: serving high-concurrency search requests, performing hybrid retrieval (dense + sparse), and automatically synchronizing with new index snapshots in the background.

QueryState

Holds all runtime state for the query node.

This class serves as a central registry for shared resources such as the GenerationManager (which handles hot-swapping indexes), embedding clients, and the background synchronization task.

Attributes:

Name Type Description
config QueryNodeConfig

The validated configuration for this node.

gen_manager GenerationManager

The component that manages the lifecycle of loaded FAISS shards and SQLite search connections.

embedding_clients dict[str, EmbeddingClient]

A mapping of embedding space names to their respective API clients.

syncer Any

The background task (if enabled) that polls for new snapshots.

Source code in src/embedrag/query/app.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class QueryState:
    """Holds all runtime state for the query node.

    This class serves as a central registry for shared resources such as the
    `GenerationManager` (which handles hot-swapping indexes), embedding clients,
    and the background synchronization task.

    Attributes:
        config (QueryNodeConfig): The validated configuration for this node.
        gen_manager (GenerationManager): The component that manages the lifecycle
            of loaded FAISS shards and SQLite search connections.
        embedding_clients (dict[str, EmbeddingClient]): A mapping of embedding
            space names to their respective API clients.
        syncer (Any): The background task (if enabled) that polls for new snapshots.
    """

    def __init__(self, config: QueryNodeConfig):
        """Initialize the query state.

        Args:
            config (QueryNodeConfig): The query node configuration.
        """
        self.config = config
        self.gen_manager = GenerationManager()
        self.embedding_clients: dict[str, EmbeddingClient] = {}
        for space in config.embedding.get_all_spaces():
            space_cfg = config.embedding.get_space_config(space)
            self.embedding_clients[space] = EmbeddingClient(space_cfg)
        self.syncer: Any = None  # set by lifecycle/bootstrap

    def get_embedding_client(self, space: str = "text") -> EmbeddingClient:
        """Retrieve the embedding client for a specific space.

        Args:
            space (str, optional): The identifier of the embedding space.
                Defaults to "text".

        Returns:
            EmbeddingClient: The client configured for the requested space.

        Raises:
            KeyError: If no client is configured for the given space name.
        """
        if space not in self.embedding_clients:
            available = list(self.embedding_clients.keys())
            raise KeyError(f"No embedding client for space '{space}'. Available: {available}")
        return self.embedding_clients[space]

    async def close(self) -> None:
        """Gracefully shut down the generation manager and network clients.

        This method ensures all index resources (mmap files, DB connections)
        are released and all pending network requests are cancelled.
        """
        await self.gen_manager.close()
        for client in self.embedding_clients.values():
            await client.close()

__init__(config)

Initialize the query state.

Parameters:

Name Type Description Default
config QueryNodeConfig

The query node configuration.

required
Source code in src/embedrag/query/app.py
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(self, config: QueryNodeConfig):
    """Initialize the query state.

    Args:
        config (QueryNodeConfig): The query node configuration.
    """
    self.config = config
    self.gen_manager = GenerationManager()
    self.embedding_clients: dict[str, EmbeddingClient] = {}
    for space in config.embedding.get_all_spaces():
        space_cfg = config.embedding.get_space_config(space)
        self.embedding_clients[space] = EmbeddingClient(space_cfg)
    self.syncer: Any = None  # set by lifecycle/bootstrap

close() async

Gracefully shut down the generation manager and network clients.

This method ensures all index resources (mmap files, DB connections) are released and all pending network requests are cancelled.

Source code in src/embedrag/query/app.py
76
77
78
79
80
81
82
83
84
async def close(self) -> None:
    """Gracefully shut down the generation manager and network clients.

    This method ensures all index resources (mmap files, DB connections)
    are released and all pending network requests are cancelled.
    """
    await self.gen_manager.close()
    for client in self.embedding_clients.values():
        await client.close()

get_embedding_client(space='text')

Retrieve the embedding client for a specific space.

Parameters:

Name Type Description Default
space str

The identifier of the embedding space. Defaults to "text".

'text'

Returns:

Name Type Description
EmbeddingClient EmbeddingClient

The client configured for the requested space.

Raises:

Type Description
KeyError

If no client is configured for the given space name.

Source code in src/embedrag/query/app.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def get_embedding_client(self, space: str = "text") -> EmbeddingClient:
    """Retrieve the embedding client for a specific space.

    Args:
        space (str, optional): The identifier of the embedding space.
            Defaults to "text".

    Returns:
        EmbeddingClient: The client configured for the requested space.

    Raises:
        KeyError: If no client is configured for the given space name.
    """
    if space not in self.embedding_clients:
        available = list(self.embedding_clients.keys())
        raise KeyError(f"No embedding client for space '{space}'. Available: {available}")
    return self.embedding_clients[space]

create_query_app(config_path=None)

Factory function to create and configure the Query FastAPI application.

This function sets up the basic FastAPI app, attaches the lifespan manager, configures CORS, and registers all functional search and admin routes.

Parameters:

Name Type Description Default
config_path str

An optional file path to a YAML configuration file.

None

Returns:

Name Type Description
FastAPI FastAPI

The fully configured web application instance.

Source code in src/embedrag/query/app.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def create_query_app(config_path: str | None = None) -> FastAPI:
    """Factory function to create and configure the Query FastAPI application.

    This function sets up the basic FastAPI app, attaches the lifespan manager,
    configures CORS, and registers all functional search and admin routes.

    Args:
        config_path (str, optional): An optional file path to a YAML
            configuration file.

    Returns:
        FastAPI: The fully configured web application instance.
    """
    app = FastAPI(title="EmbedRAG Query", version="0.5.1", lifespan=query_lifespan)
    app.state.config_path = config_path
    app.add_middleware(RequestContextMiddleware)

    from starlette.middleware.cors import CORSMiddleware

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_methods=["*"],
        allow_headers=["*"],
    )

    from starlette.responses import Response

    @app.get("/sw.js", include_in_schema=False)
    async def _no_service_worker():
        """Prevent service worker interference in simple deployments."""
        return Response(status_code=204)

    @app.get("/metrics", include_in_schema=False)
    async def metrics() -> PlainTextResponse:
        """Prometheus metrics endpoint."""
        from prometheus_client import CONTENT_TYPE_LATEST, generate_latest

        return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)

    from embedrag.query.routes import router

    app.include_router(router)

    # Mount WebUI static files
    from pathlib import Path

    from fastapi.staticfiles import StaticFiles

    webui_dir = Path(__file__).parent.parent / "webui"
    if webui_dir.exists():
        app.mount("/ui", StaticFiles(directory=str(webui_dir), html=True), name="webui")

    # Mount the dedicated cluster visualization page (plotly-powered).
    clusterui_dir = Path(__file__).parent.parent / "clusterui"
    if clusterui_dir.exists():
        app.mount("/cluster", StaticFiles(directory=str(clusterui_dir), html=True), name="clusterui")

    return app

query_lifespan(app) async

Manages the complex lifecycle of the Query FastAPI application.

This context manager handles the critical startup sequence: 1. Loading the node configuration. 2. Initializing the QueryState. 3. Bootstrapping the node (downloading/loading the initial snapshot). 4. Starting the background synchronization process if configured.

On shutdown, it ensures that all search resources are released and the syncer task is stopped cleanly.

Parameters:

Name Type Description Default
app FastAPI

The FastAPI application instance.

required

Yields:

Name Type Description
None AsyncIterator[None]

Control is returned to the FastAPI framework to start serving.

Raises:

Type Description
BootstrapError

If the node fails to load its initial index snapshot.

RuntimeError

If bootstrap completes without a valid index loaded.

Source code in src/embedrag/query/app.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
@asynccontextmanager
async def query_lifespan(app: FastAPI) -> AsyncIterator[None]:
    """Manages the complex lifecycle of the Query FastAPI application.

    This context manager handles the critical startup sequence:
    1. Loading the node configuration.
    2. Initializing the `QueryState`.
    3. Bootstrapping the node (downloading/loading the initial snapshot).
    4. Starting the background synchronization process if configured.

    On shutdown, it ensures that all search resources are released and the
    syncer task is stopped cleanly.

    Args:
        app (FastAPI): The FastAPI application instance.

    Yields:
        None: Control is returned to the FastAPI framework to start serving.

    Raises:
        BootstrapError: If the node fails to load its initial index snapshot.
        RuntimeError: If bootstrap completes without a valid index loaded.
    """
    config_path = app.state.config_path
    config = load_query_config(config_path)
    setup_logging(level=config.logging.level, fmt=config.logging.format, node_type="query")

    state = QueryState(config)
    app.state.query = state

    from embedrag.query.lifecycle.bootstrap import BootstrapError, bootstrap_query_node

    try:
        await bootstrap_query_node(state)
    except BootstrapError as exc:
        logger.critical("startup_aborted", reason=str(exc))
        import sys

        print(f"\n{'=' * 60}", file=sys.stderr)
        print(f"STARTUP FAILED: {exc}", file=sys.stderr)
        print(f"{'=' * 60}\n", file=sys.stderr)
        raise

    if not state.gen_manager.is_loaded:
        msg = "Bootstrap completed but no generation loaded -- this should not happen"
        logger.critical("startup_aborted", reason=msg)
        raise RuntimeError(msg)

    # Start background syncer if configured
    if config.sync.enabled:
        from pathlib import Path as _Path

        from embedrag.query.sync.downloader import SnapshotDownloader
        from embedrag.query.sync.syncer import SnapshotSyncer

        if config.sync.source == "http":
            if not config.sync.http_url:
                logger.warn("sync_disabled_no_url", reason="sync.http_url is empty")
            else:
                from embedrag.shared.http_snapshot_client import HttpSnapshotClient

                h_client = HttpSnapshotClient(
                    config.sync.http_url,
                    timeout=config.sync.download_timeout_seconds,
                )
                staging = str(_Path(config.node.data_dir) / "staging")
                downloader = SnapshotDownloader(
                    h_client,
                    staging,
                    concurrency=config.sync.download_concurrency,
                    timeout=config.sync.download_timeout_seconds,
                )
                state.syncer = SnapshotSyncer(
                    state,
                    h_client,
                    downloader,
                    cron_expr=config.sync.cron,
                    poll_interval=config.sync.poll_interval_seconds,
                )
                state.syncer.start()
        else:
            from embedrag.shared.object_store import ObjectStoreClient

            try:
                o_client = ObjectStoreClient(config.object_store)
                staging = str(_Path(config.node.data_dir) / "staging")
                downloader = SnapshotDownloader(
                    o_client,
                    staging,
                    concurrency=config.sync.download_concurrency,
                    timeout=config.sync.download_timeout_seconds,
                )
                state.syncer = SnapshotSyncer(
                    state,
                    o_client,
                    downloader,
                    cron_expr=config.sync.cron,
                    poll_interval=config.sync.poll_interval_seconds,
                )
                state.syncer.start()
            except Exception:
                logger.exception("sync_init_failed")

    logger.info("query_started", version=state.gen_manager.active_version)
    yield

    # Graceful shutdown
    if state.syncer:
        state.syncer.stop()
    from embedrag.query.lifecycle.shutdown import graceful_shutdown

    await graceful_shutdown(state)
    logger.info("query_stopped")

Dense Retrieval

FAISS-based vector search implementation.

Dense retriever: parallel shard search with result merging.

This module provides the core vector search functionality for the query node. It manages a pool of FAISS shard workers, dispatches queries to them in parallel, and merges the partial results into a final ranked list.

DenseResult dataclass

A single hit from the dense vector search.

Attributes:

Name Type Description
chunk_id str

The unique identifier of the retrieved chunk.

score float

The similarity score (usually inner product/dot product) between the query vector and the chunk's vector. Higher is more similar.

Source code in src/embedrag/query/retrieval/dense.py
23
24
25
26
27
28
29
30
31
32
33
34
@dataclass
class DenseResult:
    """A single hit from the dense vector search.

    Attributes:
        chunk_id (str): The unique identifier of the retrieved chunk.
        score (float): The similarity score (usually inner product/dot product) between
            the query vector and the chunk's vector. Higher is more similar.
    """

    chunk_id: str
    score: float

DenseRetriever

High-level dense retrieval interface.

Wraps the ShardManager to provide a clean search API, handling timing and the filtering of deleted chunks (hotfixes) before returning the final results.

Source code in src/embedrag/query/retrieval/dense.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class DenseRetriever:
    """High-level dense retrieval interface.

    Wraps the `ShardManager` to provide a clean search API, handling timing
    and the filtering of deleted chunks (hotfixes) before returning the final results.
    """

    def __init__(self, shard_manager: ShardManager):
        """Initialize the DenseRetriever.

        Args:
            shard_manager (ShardManager): The active `ShardManager` handling index shards.
        """
        self._shard_manager = shard_manager

    def search(
        self,
        query_vector: np.ndarray,
        top_k: int,
        deleted_ids: set[str] | None = None,
    ) -> tuple[list[DenseResult], float]:
        """Execute a dense search and filter out logically deleted chunks.

        To accommodate filtering without returning fewer results than requested,
        this method queries the underlying shards for `top_k * 2` results, filters
        out any chunk IDs present in `deleted_ids`, and then truncates to `top_k`.

        Args:
            query_vector (np.ndarray): The query embedding vector.
            top_k (int): The final number of desired results.
            deleted_ids (set[str], optional): An optional set of `chunk_id` strings that
                should be excluded from the search results (typically used for hot-swapping
                deletes before the next snapshot).

        Returns:
            tuple[list[DenseResult], float]: A tuple containing:
                - The list of filtered `DenseResult` objects.
                - The elapsed time in milliseconds for the search operation.
        """
        t0 = time.monotonic()
        raw_results = self._shard_manager.search(query_vector, top_k * 2)
        if deleted_ids is not None:
            raw_results = [r for r in raw_results if r.chunk_id not in deleted_ids]
        elapsed = (time.monotonic() - t0) * 1000
        return raw_results[:top_k], elapsed

__init__(shard_manager)

Initialize the DenseRetriever.

Parameters:

Name Type Description Default
shard_manager ShardManager

The active ShardManager handling index shards.

required
Source code in src/embedrag/query/retrieval/dense.py
143
144
145
146
147
148
149
def __init__(self, shard_manager: ShardManager):
    """Initialize the DenseRetriever.

    Args:
        shard_manager (ShardManager): The active `ShardManager` handling index shards.
    """
    self._shard_manager = shard_manager

search(query_vector, top_k, deleted_ids=None)

Execute a dense search and filter out logically deleted chunks.

To accommodate filtering without returning fewer results than requested, this method queries the underlying shards for top_k * 2 results, filters out any chunk IDs present in deleted_ids, and then truncates to top_k.

Parameters:

Name Type Description Default
query_vector ndarray

The query embedding vector.

required
top_k int

The final number of desired results.

required
deleted_ids set[str]

An optional set of chunk_id strings that should be excluded from the search results (typically used for hot-swapping deletes before the next snapshot).

None

Returns:

Type Description
tuple[list[DenseResult], float]

tuple[list[DenseResult], float]: A tuple containing: - The list of filtered DenseResult objects. - The elapsed time in milliseconds for the search operation.

Source code in src/embedrag/query/retrieval/dense.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def search(
    self,
    query_vector: np.ndarray,
    top_k: int,
    deleted_ids: set[str] | None = None,
) -> tuple[list[DenseResult], float]:
    """Execute a dense search and filter out logically deleted chunks.

    To accommodate filtering without returning fewer results than requested,
    this method queries the underlying shards for `top_k * 2` results, filters
    out any chunk IDs present in `deleted_ids`, and then truncates to `top_k`.

    Args:
        query_vector (np.ndarray): The query embedding vector.
        top_k (int): The final number of desired results.
        deleted_ids (set[str], optional): An optional set of `chunk_id` strings that
            should be excluded from the search results (typically used for hot-swapping
            deletes before the next snapshot).

    Returns:
        tuple[list[DenseResult], float]: A tuple containing:
            - The list of filtered `DenseResult` objects.
            - The elapsed time in milliseconds for the search operation.
    """
    t0 = time.monotonic()
    raw_results = self._shard_manager.search(query_vector, top_k * 2)
    if deleted_ids is not None:
        raw_results = [r for r in raw_results if r.chunk_id not in deleted_ids]
    elapsed = (time.monotonic() - t0) * 1000
    return raw_results[:top_k], elapsed

ShardManager

Manages multiple FAISS shard workers and dispatches parallel searches.

The index is split into multiple shards during the build phase. This manager holds references to the loaded ShardWorker instances and uses a thread pool to execute searches across all shards concurrently, minimizing latency.

Source code in src/embedrag/query/retrieval/dense.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class ShardManager:
    """Manages multiple FAISS shard workers and dispatches parallel searches.

    The index is split into multiple shards during the build phase. This manager
    holds references to the loaded `ShardWorker` instances and uses a thread pool
    to execute searches across all shards concurrently, minimizing latency.
    """

    def __init__(self, workers: list[ShardWorker], id_mapper: IDMapper):
        """Initialize the ShardManager.

        Args:
            workers (list[ShardWorker]): A list of loaded `ShardWorker` instances, one for each index shard.
            id_mapper (IDMapper): An `IDMapper` instance used to translate FAISS internal
                integer IDs back to string `chunk_id`s.
        """
        self._workers = workers
        self._id_mapper = id_mapper
        self._executor = ThreadPoolExecutor(max_workers=max(1, len(workers)))

    @property
    def total_vectors(self) -> int:
        """int: The total number of vectors across all managed shards."""
        return sum(w.ntotal for w in self._workers)

    @property
    def num_shards(self) -> int:
        """int: The number of active shards being managed."""
        return len(self._workers)

    def search(self, query_vector: np.ndarray, top_k: int) -> list[DenseResult]:
        """Search all shards in parallel and merge the results.

        This method dispatches the query to all workers via a thread pool. Once all
        workers return their local top-k results, the lists are concatenated,
        sorted globally by score, and truncated to the final `top_k`.

        Args:
            query_vector (np.ndarray): A 1D or 2D float32 numpy array representing the query embedding.
                If 1D, it will be reshaped to (1, dim).
            top_k (int): The maximum number of total results to return.

        Returns:
            list[DenseResult]: A list of `DenseResult` objects, sorted by score in descending order.
        """
        if query_vector.ndim == 1:
            query_vector = query_vector.reshape(1, -1)

        futures = []
        for shard_idx, worker in enumerate(self._workers):
            fut = self._executor.submit(self._search_one, shard_idx, worker, query_vector, top_k)
            futures.append(fut)

        all_results: list[DenseResult] = []
        for fut in futures:
            all_results.extend(fut.result())

        all_results.sort(key=lambda r: r.score, reverse=True)
        return all_results[:top_k]

    def _search_one(self, shard_idx: int, worker: ShardWorker, query: np.ndarray, top_k: int) -> list[DenseResult]:
        """Execute a search on a single shard worker and resolve IDs."""
        distances, indices = worker.search(query, top_k)
        results: list[DenseResult] = []
        for dist, fid in zip(distances[0], indices[0]):
            if fid < 0:
                continue
            chunk_id = self._id_mapper.resolve_single(shard_idx, int(fid))
            if chunk_id:
                results.append(DenseResult(chunk_id=chunk_id, score=float(dist)))
        return results

    def reconstruct_all(self) -> tuple[list[str], np.ndarray]:
        """Reconstruct every stored vector with its chunk id.

        Returns ``(chunk_ids, vectors)``. Exact for Flat/IVF-Flat shards,
        approximate for IVF,PQ. Vectors whose ids cannot be resolved are
        skipped.
        """
        all_ids: list[str] = []
        chunks: list[np.ndarray] = []
        for shard_idx, worker in enumerate(self._workers):
            vecs = worker.reconstruct_all()
            for local_idx in range(vecs.shape[0]):
                chunk_id = self._id_mapper.resolve_single(shard_idx, local_idx)
                if chunk_id:
                    all_ids.append(chunk_id)
                    chunks.append(vecs[local_idx])
        if not chunks:
            return [], np.empty((0, 0), dtype=np.float32)
        return all_ids, np.stack(chunks).astype(np.float32)

    def shutdown(self) -> None:
        """Shut down the thread pool and release all worker resources."""
        self._executor.shutdown(wait=False)
        for w in self._workers:
            w.shutdown()

num_shards property

int: The number of active shards being managed.

total_vectors property

int: The total number of vectors across all managed shards.

__init__(workers, id_mapper)

Initialize the ShardManager.

Parameters:

Name Type Description Default
workers list[ShardWorker]

A list of loaded ShardWorker instances, one for each index shard.

required
id_mapper IDMapper

An IDMapper instance used to translate FAISS internal integer IDs back to string chunk_ids.

required
Source code in src/embedrag/query/retrieval/dense.py
45
46
47
48
49
50
51
52
53
54
55
def __init__(self, workers: list[ShardWorker], id_mapper: IDMapper):
    """Initialize the ShardManager.

    Args:
        workers (list[ShardWorker]): A list of loaded `ShardWorker` instances, one for each index shard.
        id_mapper (IDMapper): An `IDMapper` instance used to translate FAISS internal
            integer IDs back to string `chunk_id`s.
    """
    self._workers = workers
    self._id_mapper = id_mapper
    self._executor = ThreadPoolExecutor(max_workers=max(1, len(workers)))

reconstruct_all()

Reconstruct every stored vector with its chunk id.

Returns (chunk_ids, vectors). Exact for Flat/IVF-Flat shards, approximate for IVF,PQ. Vectors whose ids cannot be resolved are skipped.

Source code in src/embedrag/query/retrieval/dense.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def reconstruct_all(self) -> tuple[list[str], np.ndarray]:
    """Reconstruct every stored vector with its chunk id.

    Returns ``(chunk_ids, vectors)``. Exact for Flat/IVF-Flat shards,
    approximate for IVF,PQ. Vectors whose ids cannot be resolved are
    skipped.
    """
    all_ids: list[str] = []
    chunks: list[np.ndarray] = []
    for shard_idx, worker in enumerate(self._workers):
        vecs = worker.reconstruct_all()
        for local_idx in range(vecs.shape[0]):
            chunk_id = self._id_mapper.resolve_single(shard_idx, local_idx)
            if chunk_id:
                all_ids.append(chunk_id)
                chunks.append(vecs[local_idx])
    if not chunks:
        return [], np.empty((0, 0), dtype=np.float32)
    return all_ids, np.stack(chunks).astype(np.float32)

search(query_vector, top_k)

Search all shards in parallel and merge the results.

This method dispatches the query to all workers via a thread pool. Once all workers return their local top-k results, the lists are concatenated, sorted globally by score, and truncated to the final top_k.

Parameters:

Name Type Description Default
query_vector ndarray

A 1D or 2D float32 numpy array representing the query embedding. If 1D, it will be reshaped to (1, dim).

required
top_k int

The maximum number of total results to return.

required

Returns:

Type Description
list[DenseResult]

list[DenseResult]: A list of DenseResult objects, sorted by score in descending order.

Source code in src/embedrag/query/retrieval/dense.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def search(self, query_vector: np.ndarray, top_k: int) -> list[DenseResult]:
    """Search all shards in parallel and merge the results.

    This method dispatches the query to all workers via a thread pool. Once all
    workers return their local top-k results, the lists are concatenated,
    sorted globally by score, and truncated to the final `top_k`.

    Args:
        query_vector (np.ndarray): A 1D or 2D float32 numpy array representing the query embedding.
            If 1D, it will be reshaped to (1, dim).
        top_k (int): The maximum number of total results to return.

    Returns:
        list[DenseResult]: A list of `DenseResult` objects, sorted by score in descending order.
    """
    if query_vector.ndim == 1:
        query_vector = query_vector.reshape(1, -1)

    futures = []
    for shard_idx, worker in enumerate(self._workers):
        fut = self._executor.submit(self._search_one, shard_idx, worker, query_vector, top_k)
        futures.append(fut)

    all_results: list[DenseResult] = []
    for fut in futures:
        all_results.extend(fut.result())

    all_results.sort(key=lambda r: r.score, reverse=True)
    return all_results[:top_k]

shutdown()

Shut down the thread pool and release all worker resources.

Source code in src/embedrag/query/retrieval/dense.py
129
130
131
132
133
def shutdown(self) -> None:
    """Shut down the thread pool and release all worker resources."""
    self._executor.shutdown(wait=False)
    for w in self._workers:
        w.shutdown()

Sparse Retrieval

SQLite FTS5-based keyword search implementation.

Sparse retrieval via SQLite FTS5 trigram index.

This module implements a hybrid keyword search strategy designed for both space-delimited languages (e.g., English) and scriptio-continua languages (e.g., Chinese, Japanese). It uses a tiered approach combining SQLite's FTS5 trigram index for fast BM25-ranked matches and a LIKE-based fallback for short terms and bigrams.

Tiered retrieval strategy
  1. FTS5 MATCH (primary, fast, BM25-ranked): Uses trigram-based indexing. For scriptio-continua segments, the query is decomposed into sliding windows to handle punctuation breaks.
  2. LIKE fallback (secondary, slower, length-ranked): Activated for short segments (< 3 characters) and bigrams extracted from long segments to bridge punctuation boundaries.

Requires schema v3 which includes the text_norm column in the FTS table.

SparseResult dataclass

A single hit from the sparse keyword search.

Attributes:

Name Type Description
chunk_id str

The unique identifier of the retrieved chunk.

score float

The relevance score. For FTS matches, this is the negative BM25 rank. For LIKE matches, it's a length-based heuristic.

Source code in src/embedrag/query/retrieval/sparse.py
58
59
60
61
62
63
64
65
66
67
68
69
@dataclass
class SparseResult:
    """A single hit from the sparse keyword search.

    Attributes:
        chunk_id (str): The unique identifier of the retrieved chunk.
        score (float): The relevance score. For FTS matches, this is the
            negative BM25 rank. For LIKE matches, it's a length-based heuristic.
    """

    chunk_id: str
    score: float

SparseRetriever

Keyword search via SQLite FTS5 trigram index with optional metadata filters.

This retriever handles the complexities of multilingual keyword search by splitting queries into FTS-eligible segments and short fallback segments.

Source code in src/embedrag/query/retrieval/sparse.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
class SparseRetriever:
    """Keyword search via SQLite FTS5 trigram index with optional metadata filters.

    This retriever handles the complexities of multilingual keyword search
    by splitting queries into FTS-eligible segments and short fallback segments.
    """

    def __init__(self, pool: ReadOnlySQLitePool):
        """Initialize the SparseRetriever.

        Args:
            pool (ReadOnlySQLitePool): The connection pool to the query node's
                read-only SQLite database.
        """
        self._pool = pool

    def search(
        self,
        query_text: str,
        top_k: int,
        filters: dict | None = None,
    ) -> tuple[list[SparseResult], float]:
        """Search for chunks using a combination of FTS5 and LIKE fallback.

        Args:
            query_text (str): The raw keyword query string.
            top_k (int): The maximum number of results to return.
            filters (dict, optional): Metadata filters to apply (e.g., `doc_type`, `doc_id`).

        Returns:
            tuple[list[SparseResult], float]: A tuple containing:
                - list[SparseResult]: The merged and ranked search results.
                - float: The elapsed time in milliseconds.
        """
        if not query_text or not query_text.strip():
            return [], 0.0

        t0 = time.monotonic()
        fts_segs, short_segs = self._split_segments(query_text)

        with self._pool.connection() as conn:
            results: list[SparseResult] = []
            try:
                if fts_segs:
                    fts_query = self._segments_to_fts(fts_segs)
                    if filters:
                        results = self._search_with_filters(conn, fts_query, top_k, filters)
                    else:
                        results = self._search_simple(conn, fts_query, top_k)

                if short_segs:
                    like_results = self._search_like_fallback(conn, short_segs, top_k, filters)
                    seen = {r.chunk_id for r in results}
                    for r in like_results:
                        if r.chunk_id not in seen:
                            results.append(r)
                            seen.add(r.chunk_id)
            except Exception:
                logger.warn(
                    "sparse_query_error",
                    fts_segs=fts_segs[:3],
                    short_segs=short_segs[:3],
                    exc_info=True,
                )

        elapsed = (time.monotonic() - t0) * 1000
        return results, elapsed

    def _search_simple(self, conn, fts_query: str, top_k: int) -> list[SparseResult]:
        """Execute a simple FTS5 MATCH query."""
        rows = conn.execute(
            "SELECT chunk_id, rank FROM chunks_fts WHERE text_norm MATCH ? ORDER BY rank LIMIT ?",
            (fts_query, top_k),
        ).fetchall()
        return [SparseResult(chunk_id=r["chunk_id"], score=-r["rank"]) for r in rows]

    def _search_with_filters(
        self,
        conn,
        fts_query: str,
        top_k: int,
        filters: dict,
    ) -> list[SparseResult]:
        """Execute an FTS5 MATCH query joined with the chunks table for filtering."""
        where_clauses: list[str] = []
        params: list = [fts_query]

        if "doc_type" in filters:
            where_clauses.append("c.metadata_json LIKE ?")
            params.append(f'%"doc_type": "{filters["doc_type"]}"%')
        if "doc_id" in filters:
            where_clauses.append("c.doc_id = ?")
            params.append(filters["doc_id"])

        filter_sql = (" AND " + " AND ".join(where_clauses)) if where_clauses else ""
        params.append(top_k)

        rows = conn.execute(
            f"SELECT f.chunk_id, f.rank "
            f"FROM chunks_fts f "
            f"JOIN chunks c ON f.chunk_id = c.chunk_id "
            f"WHERE f.text_norm MATCH ? {filter_sql} "
            f"ORDER BY f.rank "
            f"LIMIT ?",
            params,
        ).fetchall()
        return [SparseResult(chunk_id=r["chunk_id"], score=-r["rank"]) for r in rows]

    def _search_like_fallback(
        self,
        conn,
        terms: list[str],
        top_k: int,
        filters: dict | None = None,
    ) -> list[SparseResult]:
        """LIKE-based fallback for terms shorter than the trigram minimum.

        Queries the FTS5 content backing table's `c2` column (text_norm).

        Args:
            conn: SQLite connection.
            terms (list[str]): List of short terms or bigrams.
            top_k (int): Result limit.
            filters (dict, optional): Metadata filters.

        Returns:
            list[SparseResult]: Ranked hits.
        """
        if not terms:
            return []

        if len(terms) > MAX_LIKE_TERMS:
            logger.warning(
                "like_terms_capped",
                original=len(terms),
                cap=MAX_LIKE_TERMS,
            )
            terms = terms[:MAX_LIKE_TERMS]

        where_parts = ["fc.c2 LIKE ?" for _ in terms]
        params: list = [f"%{t}%" for t in terms]

        text_clause = " OR ".join(where_parts)
        filter_clauses: list[str] = []
        if filters:
            if "doc_type" in filters:
                filter_clauses.append("c.metadata_json LIKE ?")
                params.append(f'%"doc_type": "{filters["doc_type"]}"%')
            if "doc_id" in filters:
                filter_clauses.append("c.doc_id = ?")
                params.append(filters["doc_id"])
        filter_sql = (" AND " + " AND ".join(filter_clauses)) if filter_clauses else ""
        params.append(top_k)

        rows = conn.execute(
            f"SELECT fc.c0 AS chunk_id, length(fc.c2) AS tlen "
            f"FROM chunks_fts_content fc "
            f"JOIN chunks c ON fc.c0 = c.chunk_id "
            f"WHERE ({text_clause}) {filter_sql} "
            f"ORDER BY tlen ASC "
            f"LIMIT ?",
            params,
        ).fetchall()
        return [SparseResult(chunk_id=r[0], score=1.0 / max(r[1], 1)) for r in rows]

    def _split_segments(self, query_text: str) -> tuple[list[str], list[str]]:
        """Split query into FTS-eligible segments and short segments.

        Normalization (NFKC + casefold + trad->simp) is applied. For
        scriptio-continua segments, 2-char bigrams are extracted for fallback.

        Args:
            query_text (str): The raw input query.

        Returns:
            tuple[list[str], list[str]]: (fts_eligible_segments, short_fallback_segments).
        """
        text = normalize_query(query_text.strip())
        text = _FTS5_SPECIAL.sub(" ", text)

        fts_segs: list[str] = []
        short_segs: list[str] = []

        segments = text.split()
        if not segments:
            merged = text.replace(" ", "")
            if merged:
                segments = [merged]

        for seg in segments:
            if len(seg) >= TRIGRAM_MIN_LEN:
                fts_segs.append(seg)
                if _SCRIPTIO_CONTINUA.search(seg):
                    bigrams: list[str] = []
                    for i in range(len(seg) - 1):
                        a, b = seg[i], seg[i + 1]
                        if _SCRIPTIO_CONTINUA.match(a) and _SCRIPTIO_CONTINUA.match(b):
                            bigrams.append(a + b)
                    if bigrams:
                        remaining = MAX_LIKE_TERMS - len(short_segs)
                        if remaining <= 0:
                            pass
                        elif len(bigrams) <= remaining:
                            short_segs.extend(bigrams)
                        else:
                            head = remaining // 2 or 1
                            tail = remaining - head
                            short_segs.extend(bigrams[:head])
                            if tail:
                                short_segs.extend(bigrams[-tail:])
            elif seg:
                short_segs.append(seg)

        if short_segs:
            short_segs = list(dict.fromkeys(short_segs))

        return fts_segs, short_segs

    @staticmethod
    def _segments_to_fts(segments: list[str]) -> str:
        """Build an FTS5 MATCH expression from normalized segments.

        For CJK segments, it emits the full phrase AND overlapping 3-char
        sliding windows to improve recall across punctuation boundaries.

        Args:
            segments (list[str]): List of normalized query segments.

        Returns:
            str: An FTS5 MATCH expression string.
        """
        terms: list[str] = []
        for seg in segments:
            seg = seg.replace('"', '""')
            terms.append(f'"{seg}"')
            if len(seg) <= TRIGRAM_MIN_LEN:
                continue
            sc_count = sum(1 for c in seg if _SCRIPTIO_CONTINUA.match(c))
            if sc_count < 2:
                continue
            for i in range(len(seg) - TRIGRAM_MIN_LEN + 1):
                window = seg[i : i + TRIGRAM_MIN_LEN]
                if any(_SCRIPTIO_CONTINUA.match(c) for c in window):
                    terms.append(f'"{window}"')
        return " OR ".join(terms) if terms else '""'

__init__(pool)

Initialize the SparseRetriever.

Parameters:

Name Type Description Default
pool ReadOnlySQLitePool

The connection pool to the query node's read-only SQLite database.

required
Source code in src/embedrag/query/retrieval/sparse.py
79
80
81
82
83
84
85
86
def __init__(self, pool: ReadOnlySQLitePool):
    """Initialize the SparseRetriever.

    Args:
        pool (ReadOnlySQLitePool): The connection pool to the query node's
            read-only SQLite database.
    """
    self._pool = pool

search(query_text, top_k, filters=None)

Search for chunks using a combination of FTS5 and LIKE fallback.

Parameters:

Name Type Description Default
query_text str

The raw keyword query string.

required
top_k int

The maximum number of results to return.

required
filters dict

Metadata filters to apply (e.g., doc_type, doc_id).

None

Returns:

Type Description
tuple[list[SparseResult], float]

tuple[list[SparseResult], float]: A tuple containing: - list[SparseResult]: The merged and ranked search results. - float: The elapsed time in milliseconds.

Source code in src/embedrag/query/retrieval/sparse.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def search(
    self,
    query_text: str,
    top_k: int,
    filters: dict | None = None,
) -> tuple[list[SparseResult], float]:
    """Search for chunks using a combination of FTS5 and LIKE fallback.

    Args:
        query_text (str): The raw keyword query string.
        top_k (int): The maximum number of results to return.
        filters (dict, optional): Metadata filters to apply (e.g., `doc_type`, `doc_id`).

    Returns:
        tuple[list[SparseResult], float]: A tuple containing:
            - list[SparseResult]: The merged and ranked search results.
            - float: The elapsed time in milliseconds.
    """
    if not query_text or not query_text.strip():
        return [], 0.0

    t0 = time.monotonic()
    fts_segs, short_segs = self._split_segments(query_text)

    with self._pool.connection() as conn:
        results: list[SparseResult] = []
        try:
            if fts_segs:
                fts_query = self._segments_to_fts(fts_segs)
                if filters:
                    results = self._search_with_filters(conn, fts_query, top_k, filters)
                else:
                    results = self._search_simple(conn, fts_query, top_k)

            if short_segs:
                like_results = self._search_like_fallback(conn, short_segs, top_k, filters)
                seen = {r.chunk_id for r in results}
                for r in like_results:
                    if r.chunk_id not in seen:
                        results.append(r)
                        seen.add(r.chunk_id)
        except Exception:
            logger.warn(
                "sparse_query_error",
                fts_segs=fts_segs[:3],
                short_segs=short_segs[:3],
                exc_info=True,
            )

    elapsed = (time.monotonic() - t0) * 1000
    return results, elapsed

Fusion

Reciprocal Rank Fusion (RRF) for combining dense and sparse results.

Reciprocal Rank Fusion (RRF) for merging dense and sparse results.

This module provides an implementation of the Reciprocal Rank Fusion algorithm, which is used to combine multiple ranked result lists into a single, unified ranking without requiring score normalization.

FusedResult dataclass

A single hit from the fused search results.

Attributes:

Name Type Description
chunk_id str

The unique identifier of the retrieved chunk.

rrf_score float

The calculated RRF score for this chunk.

dense_score float

The original score from the dense retriever.

sparse_score float

The original score from the sparse retriever.

Source code in src/embedrag/query/retrieval/fusion.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@dataclass
class FusedResult:
    """A single hit from the fused search results.

    Attributes:
        chunk_id (str): The unique identifier of the retrieved chunk.
        rrf_score (float): The calculated RRF score for this chunk.
        dense_score (float): The original score from the dense retriever.
        sparse_score (float): The original score from the sparse retriever.
    """

    chunk_id: str
    rrf_score: float
    dense_score: float
    sparse_score: float

rrf_fuse(dense_results, sparse_results, top_k, k=60, dense_weight=1.0, sparse_weight=1.0)

Merge dense and sparse results using Reciprocal Rank Fusion.

The RRF score for a document is calculated as

RRFscore(d) = sum( weight / (k + rank_i(d)) )

where rank_i(d) is the rank of document d in the i-th ranking list.

RRF is highly effective because it does not require the underlying scores (e.g., dot product for dense and BM25 for sparse) to be on the same scale.

Parameters:

Name Type Description Default
dense_results list[DenseResult]

Ranked results from the dense retriever.

required
sparse_results list[SparseResult]

Ranked results from the sparse retriever.

required
top_k int

The number of final fused results to return.

required
k int

The smoothing constant used in the RRF formula. Defaults to 60, which is the value recommended in the original RRF paper.

60
dense_weight float

A multiplier for the dense ranking's contribution to the final score. Defaults to 1.0.

1.0
sparse_weight float

A multiplier for the sparse ranking's contribution to the final score. Defaults to 1.0.

1.0

Returns:

Type Description
list[FusedResult]

list[FusedResult]: A list of FusedResult objects, sorted by rrf_score in descending order.

Source code in src/embedrag/query/retrieval/fusion.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def rrf_fuse(
    dense_results: list[DenseResult],
    sparse_results: list[SparseResult],
    top_k: int,
    k: int = 60,
    dense_weight: float = 1.0,
    sparse_weight: float = 1.0,
) -> list[FusedResult]:
    """Merge dense and sparse results using Reciprocal Rank Fusion.

    The RRF score for a document is calculated as:
        RRFscore(d) = sum( weight / (k + rank_i(d)) )
    where `rank_i(d)` is the rank of document `d` in the i-th ranking list.

    RRF is highly effective because it does not require the underlying scores
    (e.g., dot product for dense and BM25 for sparse) to be on the same scale.

    Args:
        dense_results (list[DenseResult]): Ranked results from the dense retriever.
        sparse_results (list[SparseResult]): Ranked results from the sparse retriever.
        top_k (int): The number of final fused results to return.
        k (int, optional): The smoothing constant used in the RRF formula.
            Defaults to 60, which is the value recommended in the original RRF paper.
        dense_weight (float, optional): A multiplier for the dense ranking's
            contribution to the final score. Defaults to 1.0.
        sparse_weight (float, optional): A multiplier for the sparse ranking's
            contribution to the final score. Defaults to 1.0.

    Returns:
        list[FusedResult]: A list of `FusedResult` objects, sorted by `rrf_score`
            in descending order.
    """
    scores: dict[str, dict] = {}

    for rank, dr in enumerate(dense_results):
        if dr.chunk_id not in scores:
            scores[dr.chunk_id] = {"rrf": 0.0, "dense": dr.score, "sparse": 0.0}
        scores[dr.chunk_id]["rrf"] += dense_weight / (k + rank + 1)
        scores[dr.chunk_id]["dense"] = max(scores[dr.chunk_id]["dense"], dr.score)

    for rank, sr in enumerate(sparse_results):
        if sr.chunk_id not in scores:
            scores[sr.chunk_id] = {"rrf": 0.0, "dense": 0.0, "sparse": sr.score}
        scores[sr.chunk_id]["rrf"] += sparse_weight / (k + rank + 1)
        scores[sr.chunk_id]["sparse"] = max(scores[sr.chunk_id]["sparse"], sr.score)

    fused = [
        FusedResult(
            chunk_id=cid,
            rrf_score=s["rrf"],
            dense_score=s["dense"],
            sparse_score=s["sparse"],
        )
        for cid, s in scores.items()
    ]
    fused.sort(key=lambda x: x.rrf_score, reverse=True)
    return fused[:top_k]

Clustering

The standalone, reusable clustering module that powers the embedrag cluster CLI and the integrated query-node cluster API.

Pipeline & Library API

Orchestration and the public cluster_vectors / cluster_items entry points.

End-to-end clustering pipeline.

Wires the stages together: normalize -> (optional) reduce -> select params / cluster -> evaluate -> explain -> label -> visualize, returning a ClusterResult. The sync core (cluster_vectors) needs no network; text vectorization and LLM labeling are layered on top by callers (CLI / HTTP).

apply_llm_labels(result, chat_url, model='', api_key='', language='auto') async

Replace keyword labels with LLM-generated topic names (in place).

Source code in src/embedrag/cluster/pipeline.py
102
103
104
105
106
107
108
109
110
111
112
async def apply_llm_labels(
    result: ClusterResult,
    chat_url: str,
    model: str = "",
    api_key: str = "",
    language: str = "auto",
) -> None:
    """Replace keyword labels with LLM-generated topic names (in place)."""
    await label.label_clusters_llm(
        result.clusters, chat_url=chat_url, model=model, api_key=api_key, language=language
    )

cluster_vectors(vectors, items, *, algorithm='auto', reduce='auto', n_components=0, auto=True, params=None, top_keywords=10, top_reps=5, ground_truth=None, run_id=None, space='text', source='')

Cluster a matrix of vectors and return a fully explained result.

This is the pure, synchronous core (no embedding service, no LLM calls).

Source code in src/embedrag/cluster/pipeline.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def cluster_vectors(
    vectors: np.ndarray,
    items: list[Item],
    *,
    algorithm: str = "auto",
    reduce: str = "auto",
    n_components: int = 0,
    auto: bool = True,
    params: dict | None = None,
    top_keywords: int = 10,
    top_reps: int = 5,
    ground_truth: list | None = None,
    run_id: str | None = None,
    space: str = "text",
    source: str = "",
) -> ClusterResult:
    """Cluster a matrix of vectors and return a fully explained result.

    This is the pure, synchronous core (no embedding service, no LLM calls).
    """
    vectors = np.asarray(vectors, dtype=np.float32)
    if vectors.ndim != 2:
        raise ValueError(f"vectors must be 2D, got shape {vectors.shape}")
    if vectors.shape[0] != len(items):
        raise ValueError(f"vectors ({vectors.shape[0]}) and items ({len(items)}) length mismatch")

    run_id = run_id or store.make_run_id()
    normalized = l2_normalize(vectors)
    clustering_vecs, reduce_used = reduce_dims(normalized, method=reduce, n_components=n_components)

    algo_name, assignment, chosen_params, sweep = evaluate.select_params(
        clustering_vecs, algorithm, overrides=params or {}, auto=auto
    )
    labels = assignment.labels

    clusters, members, _centroids, sim_matrix = explain.explain(
        normalized, labels, assignment.probabilities, items, top_keywords=top_keywords, top_reps=top_reps
    )
    label.apply_keyword_labels(clusters)

    metrics = evaluate.internal_metrics(clustering_vecs, labels)
    if ground_truth is not None:
        metrics["external"] = evaluate.external_metrics(labels, ground_truth)

    backend = make_backend(algo_name, **chosen_params)
    projection = visualize.build_projection(clustering_vecs, labels, assignment.probabilities, items)
    projection["method"] = reduce_used if reduce_used != "none" else "raw"
    panels = visualize.build_panels(backend.panels, clustering_vecs, assignment, clusters, sim_matrix, sweep)

    noise_count = int(np.sum(labels == -1))
    result = ClusterResult(
        run_id=run_id,
        algorithm=algo_name,
        params=chosen_params,
        space=space,
        created_at=datetime.now(UTC).isoformat(),
        source=source,
        n_items=len(items),
        n_clusters=len(clusters),
        noise_count=noise_count,
        clusters=clusters,
        members=members,
        metrics=metrics,
        sweep=sweep,
        projection=projection,
        viz=panels,
    )
    logger.info(
        "cluster_done",
        run_id=run_id,
        algorithm=algo_name,
        n_clusters=len(clusters),
        noise=noise_count,
        silhouette=metrics.get("silhouette"),
    )
    return result

Sources

Loading vectors and items from files, .npy, a writer DB, or TF-IDF.

Vector + item acquisition for clustering.

Supports several co-equal sources so the tool works standalone or against the embedRAG vector store:

  • Files: .jsonl / .csv of {id, text, [embedding]}, or a .npy matrix of precomputed vectors.
  • Passed-in python objects: a list of texts, or a numpy/array of vectors.
  • A writer SQLite DB: exact vectors from chunk_embeddings + text from chunks (with optional filters).
  • A loaded query-node generation: vectors reconstructed from the FAISS index (exact for Flat/IVF-Flat, approximate for IVF,PQ).

When no vectors are available, callers fall back to embedding the text via an embedding service, or to a local TF-IDF representation (no service needed).

items_from_texts(texts, ids=None)

Build items from a list of texts.

Source code in src/embedrag/cluster/source.py
120
121
122
123
124
def items_from_texts(texts: list[str], ids: list[str] | None = None) -> list[Item]:
    """Build items from a list of texts."""
    if ids is None:
        ids = [str(i) for i in range(len(texts))]
    return [Item(id=str(i), text=str(t or "")) for i, t in zip(ids, texts)]

load_items_from_file(path, text_field='text', id_field='id', embedding_field='embedding')

Load items (and optional inline embeddings) from a .jsonl or .csv file.

Returns (items, vectors_or_None). vectors is returned only when every row carries an embedding under embedding_field.

Source code in src/embedrag/cluster/source.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def load_items_from_file(
    path: str,
    text_field: str = "text",
    id_field: str = "id",
    embedding_field: str = "embedding",
) -> tuple[list[Item], np.ndarray | None]:
    """Load items (and optional inline embeddings) from a .jsonl or .csv file.

    Returns ``(items, vectors_or_None)``. ``vectors`` is returned only when
    every row carries an embedding under ``embedding_field``.
    """
    p = Path(path)
    suffix = p.suffix.lower()
    if suffix in (".jsonl", ".ndjson"):
        rows = _read_jsonl(p)
    elif suffix == ".json":
        rows = _read_json_array(p)
    elif suffix == ".csv":
        rows = _read_csv(p)
    else:
        raise ValueError(f"Unsupported input file type: {suffix} (use .jsonl, .json, or .csv)")

    items: list[Item] = []
    vectors: list[list[float]] = []
    have_all_vectors = True
    for i, row in enumerate(rows):
        rid = str(row.get(id_field, i))
        text = str(row.get(text_field, "") or "")
        items.append(Item(id=rid, text=text))
        emb = row.get(embedding_field)
        if emb is None:
            have_all_vectors = False
        elif have_all_vectors:
            if isinstance(emb, str):
                emb = json.loads(emb)
            vectors.append([float(x) for x in emb])

    if have_all_vectors and vectors:
        arr = np.asarray(vectors, dtype=np.float32)
        logger.info("cluster_source_file", path=str(p), items=len(items), with_embeddings=True, dim=arr.shape[1])
        return items, arr

    logger.info("cluster_source_file", path=str(p), items=len(items), with_embeddings=False)
    return items, None

load_vectors_npy(path, items=None)

Load a .npy matrix of vectors; synthesize ids if no items given.

Source code in src/embedrag/cluster/source.py
108
109
110
111
112
113
114
115
116
117
def load_vectors_npy(path: str, items: list[Item] | None = None) -> tuple[list[Item], np.ndarray]:
    """Load a ``.npy`` matrix of vectors; synthesize ids if no items given."""
    arr = np.load(path).astype(np.float32, copy=False)
    if arr.ndim != 2:
        raise ValueError(f"Expected a 2D array in {path}, got shape {arr.shape}")
    if items is None:
        items = [Item(id=str(i)) for i in range(arr.shape[0])]
    elif len(items) != arr.shape[0]:
        raise ValueError(f"items ({len(items)}) and vectors ({arr.shape[0]}) length mismatch")
    return items, arr

read_writer_db(db_path, space='text', filters=None, limit=None)

Read exact vectors + text from a writer DB's chunk_embeddings table.

filters may contain doc_type and/or doc_id to restrict the set. Raises if the DB has no populated chunk_embeddings table.

Source code in src/embedrag/cluster/source.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def read_writer_db(
    db_path: str,
    space: str = "text",
    filters: dict | None = None,
    limit: int | None = None,
) -> tuple[list[Item], np.ndarray]:
    """Read exact vectors + text from a writer DB's ``chunk_embeddings`` table.

    ``filters`` may contain ``doc_type`` and/or ``doc_id`` to restrict the set.
    Raises if the DB has no populated ``chunk_embeddings`` table.
    """
    conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
    conn.row_factory = sqlite3.Row
    try:
        if not _table_exists(conn, "chunk_embeddings"):
            raise ValueError(
                f"{db_path} has no chunk_embeddings table (a query-node export drops it). "
                "Use a writer DB, or cluster a loaded generation via FAISS reconstruction."
            )
        where = ["ce.space = ?"]
        params: list = [space]
        filters = filters or {}
        if filters.get("doc_type"):
            where.append("d.doc_type = ?")
            params.append(filters["doc_type"])
        if filters.get("doc_id"):
            where.append("c.doc_id = ?")
            params.append(filters["doc_id"])
        sql = (
            "SELECT c.chunk_id AS id, c.text AS text, ce.embedding AS embedding "
            "FROM chunk_embeddings ce "
            "JOIN chunks c ON c.chunk_id = ce.chunk_id "
            "LEFT JOIN documents d ON c.doc_id = d.doc_id "
            f"WHERE {' AND '.join(where)} "
            "ORDER BY c.chunk_id"
        )
        if limit:
            sql += f" LIMIT {int(limit)}"
        rows = conn.execute(sql, params).fetchall()
    finally:
        conn.close()

    if not rows:
        raise ValueError(f"No vectors found in {db_path} for space '{space}' with the given filters")

    items = [Item(id=r["id"], text=r["text"] or "") for r in rows]
    vectors = np.stack([np.frombuffer(r["embedding"], dtype=np.float32) for r in rows])
    logger.info("cluster_source_writer_db", db=db_path, space=space, items=len(items), dim=vectors.shape[1])
    return items, vectors

tfidf_vectors(texts, max_features=4096)

Build a dense TF-IDF representation for the no-embedding-service path.

Uses char n-grams so it works for CJK text without word segmentation.

Source code in src/embedrag/cluster/source.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def tfidf_vectors(texts: list[str], max_features: int = 4096) -> np.ndarray:
    """Build a dense TF-IDF representation for the no-embedding-service path.

    Uses char n-grams so it works for CJK text without word segmentation.
    """
    from sklearn.feature_extraction.text import TfidfVectorizer

    has_cjk = any(_contains_cjk(t) for t in texts)
    if has_cjk:
        vec = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 3), max_features=max_features)
    else:
        vec = TfidfVectorizer(ngram_range=(1, 2), max_features=max_features, stop_words="english")
    matrix = vec.fit_transform(texts)
    return matrix.toarray().astype(np.float32)

Algorithms

Pluggable clustering backends behind a single interface.

Pluggable clustering backends behind a single interface.

Each backend takes preprocessed (normalized, optionally reduced) vectors and produces integer labels (-1 == noise) plus optional per-point membership probabilities. Backends also declare which visualization panels make sense for them, so the UI/exports can adapt per algorithm.

AgglomerativeBackend

Bases: ClusterBackend

Hierarchical (Ward); supports a dendrogram and threshold/K cut.

Source code in src/embedrag/cluster/algorithms.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class AgglomerativeBackend(ClusterBackend):
    """Hierarchical (Ward); supports a dendrogram and threshold/K cut."""

    name = "agglomerative"

    def fit(self, vectors: np.ndarray) -> ClusterAssignment:
        from sklearn.cluster import AgglomerativeClustering

        n = vectors.shape[0]
        n_clusters = self.params.get("n_clusters") or self.params.get("k")
        distance_threshold = self.params.get("distance_threshold")
        if n_clusters is None and distance_threshold is None:
            n_clusters = max(2, int(round((n / 2) ** 0.5)))
        kwargs: dict = {"linkage": self.params.get("linkage", "ward")}
        if distance_threshold is not None:
            kwargs["n_clusters"] = None
            kwargs["distance_threshold"] = float(distance_threshold)
        else:
            nc = int(n_clusters) if n_clusters is not None else max(2, int(round((n / 2) ** 0.5)))
            kwargs["n_clusters"] = max(2, min(nc, n))
        model = AgglomerativeClustering(**kwargs)
        labels = model.fit_predict(vectors)
        extra = {}
        # Linkage matrix for the dendrogram (only feasible at small/medium scale).
        if n <= 4000:
            try:
                from scipy.cluster.hierarchy import linkage

                extra["linkage"] = linkage(vectors, method=self.params.get("linkage", "ward")).tolist()
            except Exception as exc:  # scipy optional
                logger.warn("dendrogram_unavailable", error=str(exc))
        self.params["n_clusters"] = int(len(set(labels)))
        return ClusterAssignment(labels=labels.astype(int), extra=extra)

ClusterAssignment dataclass

Output of a clustering backend.

Source code in src/embedrag/cluster/algorithms.py
29
30
31
32
33
34
35
@dataclass
class ClusterAssignment:
    """Output of a clustering backend."""

    labels: np.ndarray
    probabilities: np.ndarray | None = None
    extra: dict = field(default_factory=dict)  # algo-specific artifacts (e.g. linkage matrix)

ClusterBackend

Base interface for a clustering algorithm.

Source code in src/embedrag/cluster/algorithms.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class ClusterBackend:
    """Base interface for a clustering algorithm."""

    name: str = "base"

    def __init__(self, **params):
        self.params = params

    def fit(self, vectors: np.ndarray) -> ClusterAssignment:  # pragma: no cover - abstract
        raise NotImplementedError

    @property
    def panels(self) -> list[str]:
        return ALGORITHM_PANELS.get(self.name, ["scatter", "size_bar"])

DBSCANBackend

Bases: ClusterBackend

Density-based with explicit eps; good for the no-embedding/TF-IDF path.

Source code in src/embedrag/cluster/algorithms.py
78
79
80
81
82
83
84
85
86
87
88
89
90
class DBSCANBackend(ClusterBackend):
    """Density-based with explicit eps; good for the no-embedding/TF-IDF path."""

    name = "dbscan"

    def fit(self, vectors: np.ndarray) -> ClusterAssignment:
        from sklearn.cluster import DBSCAN

        eps = float(self.params.get("eps") or _estimate_eps(vectors, int(self.params.get("min_samples", 5))))
        min_samples = int(self.params.get("min_samples", 5))
        labels = DBSCAN(eps=eps, min_samples=min_samples, metric="euclidean").fit_predict(vectors)
        self.params["eps"] = round(eps, 4)
        return ClusterAssignment(labels=labels)

HDBSCANBackend

Bases: ClusterBackend

Density-based, auto cluster count, native noise + membership probability.

Source code in src/embedrag/cluster/algorithms.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class HDBSCANBackend(ClusterBackend):
    """Density-based, auto cluster count, native noise + membership probability."""

    name = "hdbscan"

    def fit(self, vectors: np.ndarray) -> ClusterAssignment:
        from sklearn.cluster import HDBSCAN

        n = vectors.shape[0]
        min_cluster_size = int(self.params.get("min_cluster_size") or max(5, int(round(n**0.5 / 2)) or 5))
        min_cluster_size = max(2, min(min_cluster_size, max(2, n // 2)))
        min_samples = self.params.get("min_samples")
        model = HDBSCAN(
            min_cluster_size=min_cluster_size,
            min_samples=int(min_samples) if min_samples else None,
            metric="euclidean",  # vectors are L2-normalized => euclidean ~ cosine
            cluster_selection_method=self.params.get("cluster_selection_method", "eom"),
        )
        labels = model.fit_predict(vectors)
        probs = getattr(model, "probabilities_", None)
        self.params["min_cluster_size"] = min_cluster_size
        return ClusterAssignment(labels=labels, probabilities=probs)

KMeansBackend

Bases: ClusterBackend

Centroid-based (spherical via normalized inputs). Scales via MiniBatch.

Source code in src/embedrag/cluster/algorithms.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class KMeansBackend(ClusterBackend):
    """Centroid-based (spherical via normalized inputs). Scales via MiniBatch."""

    name = "kmeans"

    def fit(self, vectors: np.ndarray) -> ClusterAssignment:
        n = vectors.shape[0]
        k = int(self.params.get("k") or self.params.get("n_clusters") or max(2, int(round((n / 2) ** 0.5))))
        k = max(2, min(k, n))
        use_faiss = self.params.get("use_faiss", n > 200_000)
        if use_faiss:
            labels, dist = _faiss_kmeans(vectors, k)
        else:
            from sklearn.cluster import KMeans, MiniBatchKMeans

            if n > 20_000:
                model = MiniBatchKMeans(n_clusters=k, random_state=42, n_init=3, batch_size=2048)
            else:
                model = KMeans(n_clusters=k, random_state=42, n_init=10)
            labels = model.fit_predict(vectors)
            dist = None
        self.params["k"] = k
        probs = None
        if dist is not None:
            # convert nearest-centroid distance to a soft confidence in (0,1]
            probs = 1.0 / (1.0 + dist)
        return ClusterAssignment(labels=labels.astype(int), probabilities=probs)

LeidenBackend

Bases: ClusterBackend

Community detection on a FAISS kNN graph (optional deps).

Source code in src/embedrag/cluster/algorithms.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class LeidenBackend(ClusterBackend):
    """Community detection on a FAISS kNN graph (optional deps)."""

    name = "leiden"

    def fit(self, vectors: np.ndarray) -> ClusterAssignment:
        try:
            import igraph as ig
            import leidenalg
        except ImportError as exc:  # pragma: no cover - optional
            raise RuntimeError("leiden backend needs `igraph` and `leidenalg` installed") from exc

        k = int(self.params.get("knn", 15))
        edges, weights = _knn_graph(vectors, k)
        g = ig.Graph(n=vectors.shape[0], edges=edges, edge_attrs={"weight": weights})
        part = leidenalg.find_partition(
            g, leidenalg.RBConfigurationVertexPartition, weights="weight", seed=42
        )
        labels = np.asarray(part.membership, dtype=int)
        return ClusterAssignment(labels=labels)

available_algorithms()

Names of all registered clustering backends.

Source code in src/embedrag/cluster/algorithms.py
188
189
190
def available_algorithms() -> list[str]:
    """Names of all registered clustering backends."""
    return list(_BACKENDS.keys())

make_backend(name, **params)

Instantiate a clustering backend by name.

Source code in src/embedrag/cluster/algorithms.py
193
194
195
196
197
def make_backend(name: str, **params) -> ClusterBackend:
    """Instantiate a clustering backend by name."""
    if name not in _BACKENDS:
        raise ValueError(f"Unknown algorithm '{name}'. Available: {available_algorithms()}")
    return _BACKENDS[name](**params)

Evaluation

Internal/external metrics and the automatic parameter sweep.

Evaluation harness and automatic parameter selection.

This is the "honesty layer": every run reports internal quality metrics so the result can be judged rather than blindly trusted. --auto sweeps the key parameter for an algorithm and returns the full score curve, and the auto algorithm compares backends on a composite score.

composite_score(metrics)

Single objective for model selection: silhouette penalized by noise.

Source code in src/embedrag/cluster/evaluate.py
83
84
85
86
87
88
def composite_score(metrics: dict) -> float:
    """Single objective for model selection: silhouette penalized by noise."""
    sil = metrics.get("silhouette")
    if sil is None or metrics.get("n_clusters", 0) < 2:
        return -1.0
    return float(sil) - 0.5 * float(metrics.get("noise_ratio", 0.0))

external_metrics(labels, ground_truth)

Compare predicted labels to ground-truth labels (when available).

Source code in src/embedrag/cluster/evaluate.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def external_metrics(labels: np.ndarray, ground_truth: list) -> dict:
    """Compare predicted labels to ground-truth labels (when available)."""
    from sklearn.metrics import (
        adjusted_rand_score,
        normalized_mutual_info_score,
        v_measure_score,
    )

    gt = np.asarray(ground_truth)
    return {
        "ari": round(float(adjusted_rand_score(gt, labels)), 4),
        "nmi": round(float(normalized_mutual_info_score(gt, labels)), 4),
        "v_measure": round(float(v_measure_score(gt, labels)), 4),
    }

internal_metrics(vectors, labels, sample=5000)

Compute clustering quality metrics that need no ground truth.

Silhouette / Davies-Bouldin / Calinski-Harabasz are computed on non-noise points only. Returns a dict including n_clusters and noise_ratio.

Source code in src/embedrag/cluster/evaluate.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def internal_metrics(vectors: np.ndarray, labels: np.ndarray, sample: int = 5000) -> dict:
    """Compute clustering quality metrics that need no ground truth.

    Silhouette / Davies-Bouldin / Calinski-Harabasz are computed on non-noise
    points only. Returns a dict including ``n_clusters`` and ``noise_ratio``.
    """
    from sklearn.metrics import (
        calinski_harabasz_score,
        davies_bouldin_score,
        silhouette_score,
    )

    labels = np.asarray(labels)
    n = labels.shape[0]
    noise_mask = labels == -1
    noise_ratio = float(noise_mask.mean()) if n else 0.0
    core = ~noise_mask
    core_labels = labels[core]
    unique = np.unique(core_labels)
    n_clusters = int(len(unique))

    out: dict = {
        "n_clusters": n_clusters,
        "noise_ratio": round(noise_ratio, 4),
        "silhouette": None,
        "davies_bouldin": None,
        "calinski_harabasz": None,
    }
    if n_clusters < 2 or core.sum() <= n_clusters:
        return out

    feats = vectors[core]
    y = core_labels
    if feats.shape[0] > sample:
        rng = np.random.RandomState(42)
        idx = rng.choice(feats.shape[0], sample, replace=False)
        feats, y = feats[idx], y[idx]
        if len(np.unique(y)) < 2:
            return out
    try:
        out["silhouette"] = round(float(silhouette_score(feats, y, metric="euclidean")), 4)
        out["davies_bouldin"] = round(float(davies_bouldin_score(feats, y)), 4)
        out["calinski_harabasz"] = round(float(calinski_harabasz_score(feats, y)), 2)
    except Exception as exc:
        logger.warn("internal_metrics_failed", error=str(exc))
    return out

select_params(vectors, algorithm, overrides=None, auto=True)

Pick the best parameters for an algorithm (or compare backends for 'auto').

Returns (chosen_algorithm, assignment, chosen_params, sweep) where sweep is the list of evaluated candidates (the score curve).

Source code in src/embedrag/cluster/evaluate.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def select_params(
    vectors: np.ndarray,
    algorithm: str,
    overrides: dict | None = None,
    auto: bool = True,
) -> tuple[str, ClusterAssignment, dict, list[dict]]:
    """Pick the best parameters for an algorithm (or compare backends for 'auto').

    Returns ``(chosen_algorithm, assignment, chosen_params, sweep)`` where
    ``sweep`` is the list of evaluated candidates (the score curve).
    """
    overrides = dict(overrides or {})

    if algorithm == "auto":
        return _auto_algorithm(vectors, overrides)

    # If the user pinned the controlling parameter, skip the sweep.
    if not auto or _has_pinned_param(algorithm, overrides):
        backend = make_backend(algorithm, **overrides)
        assignment = backend.fit(vectors)
        metrics = internal_metrics(vectors, assignment.labels)
        return algorithm, assignment, backend.params, [{"params": dict(backend.params), "metrics": metrics}]

    grid = _param_grid(algorithm, vectors.shape[0])
    sweep: list[dict] = []
    best = None  # (score, assignment, params)
    for value in grid:
        params = {**overrides, **value}
        backend = make_backend(algorithm, **params)
        try:
            assignment = backend.fit(vectors)
        except Exception as exc:
            logger.warn("sweep_candidate_failed", algorithm=algorithm, params=value, error=str(exc))
            continue
        metrics = internal_metrics(vectors, assignment.labels)
        score = composite_score(metrics)
        sweep.append({"params": dict(backend.params), "metrics": metrics, "score": round(score, 4)})
        if best is None or score > best[0]:
            best = (score, assignment, dict(backend.params))

    if best is None:
        backend = make_backend(algorithm, **overrides)
        assignment = backend.fit(vectors)
        return algorithm, assignment, backend.params, sweep

    return algorithm, best[1], best[2], sweep

Explainability

c-TF-IDF keywords, medoids, cohesion/separation, and attribution.

Cluster explainability: keywords, representatives, stats, attribution.

Produces the human-facing description of each cluster: - distinctive keywords via class-based TF-IDF (c-TF-IDF), - medoid example texts (points nearest the centroid), - cohesion (mean cosine to centroid) and separation (nearest other centroid), - an inter-cluster similarity matrix, - and per-point "why this cluster" attribution.

compute_centroids(vectors, labels)

Mean (then renormalized) vector per non-noise cluster.

Source code in src/embedrag/cluster/explain.py
26
27
28
29
30
31
32
33
34
35
36
37
38
def compute_centroids(vectors: np.ndarray, labels: np.ndarray) -> dict[int, np.ndarray]:
    """Mean (then renormalized) vector per non-noise cluster."""
    centroids: dict[int, np.ndarray] = {}
    for cid in sorted(set(int(x) for x in labels)):
        if cid == -1:
            continue
        members = vectors[labels == cid]
        if members.shape[0] == 0:
            continue
        c = members.mean(axis=0)
        norm = np.linalg.norm(c)
        centroids[cid] = c / norm if norm > 0 else c
    return centroids

ctfidf_keywords(texts, labels, top_n=10)

Distinctive keywords per cluster via class-based TF-IDF (c-TF-IDF).

Each cluster is treated as a single document. Char bigrams are used for CJK-heavy corpora (no word segmentation needed), words otherwise.

Source code in src/embedrag/cluster/explain.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def ctfidf_keywords(texts: list[str], labels: np.ndarray, top_n: int = 10) -> dict[int, list[str]]:
    """Distinctive keywords per cluster via class-based TF-IDF (c-TF-IDF).

    Each cluster is treated as a single document. Char bigrams are used for
    CJK-heavy corpora (no word segmentation needed), words otherwise.
    """
    from sklearn.feature_extraction.text import CountVectorizer

    cluster_ids = sorted(c for c in set(int(x) for x in labels) if c != -1)
    if not cluster_ids:
        return {}

    docs = []
    for cid in cluster_ids:
        joined = " ".join(texts[i] for i in range(len(texts)) if int(labels[i]) == cid)
        docs.append(joined)

    if not any(d.strip() for d in docs):
        return {cid: [] for cid in cluster_ids}

    cjk = sum(1 for d in docs if _CJK_RE.search(d)) > len(docs) / 2
    if cjk:
        vec = CountVectorizer(analyzer="char_wb", ngram_range=(2, 2), max_features=5000)
    else:
        vec = CountVectorizer(ngram_range=(1, 2), max_features=5000, stop_words="english")
    try:
        counts = vec.fit_transform(docs).toarray().astype(np.float64)
    except ValueError:
        return {cid: [] for cid in cluster_ids}
    vocab = np.array(vec.get_feature_names_out())

    # c-TF-IDF: tf within class * log(1 + A / f_t)
    tf = counts / np.maximum(counts.sum(axis=1, keepdims=True), 1)
    f_t = counts.sum(axis=0)
    avg_words = counts.sum(axis=1).mean()
    idf = np.log(1.0 + avg_words / np.maximum(f_t, 1e-9))
    weights = tf * idf

    result: dict[int, list[str]] = {}
    for row, cid in enumerate(cluster_ids):
        top_idx = np.argsort(-weights[row])[:top_n]
        kws = [str(vocab[j]).strip() for j in top_idx if weights[row, j] > 0]
        result[cid] = [k for k in kws if k]
    return result

explain(vectors, labels, probabilities, items, top_keywords=10, top_reps=5)

Build per-cluster info, per-member attribution, and similarity matrix.

Source code in src/embedrag/cluster/explain.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def explain(
    vectors: np.ndarray,
    labels: np.ndarray,
    probabilities: np.ndarray | None,
    items: list[Item],
    top_keywords: int = 10,
    top_reps: int = 5,
) -> tuple[list[ClusterInfo], list[ClusterMember], dict[int, np.ndarray], dict]:
    """Build per-cluster info, per-member attribution, and similarity matrix."""
    norm_vecs = l2_normalize(vectors)
    centroids = compute_centroids(norm_vecs, labels)
    texts = [it.text for it in items]
    keywords_by_cluster = ctfidf_keywords(texts, labels, top_keywords)

    cluster_ids = sorted(centroids.keys())
    cmatrix = np.stack([centroids[c] for c in cluster_ids]) if cluster_ids else np.empty((0, 0))

    # similarity matrix between centroids
    sim_matrix: dict = {"cluster_ids": cluster_ids, "matrix": []}
    if len(cluster_ids) > 0:
        sims = cmatrix @ cmatrix.T
        sim_matrix["matrix"] = np.round(sims, 4).tolist()

    clusters: list[ClusterInfo] = []
    for cid in cluster_ids:
        idx = np.where(labels == cid)[0]
        member_vecs = norm_vecs[idx]
        centroid = centroids[cid]
        sims_to_centroid = member_vecs @ centroid
        order = np.argsort(-sims_to_centroid)
        rep_local = order[:top_reps]
        rep_ids = [items[idx[i]].id for i in rep_local]
        rep_texts = [items[idx[i]].text for i in rep_local]

        # separation: highest cosine to any other centroid
        separation = 0.0
        if len(cluster_ids) > 1:
            others = [centroids[o] for o in cluster_ids if o != cid]
            separation = float(max(centroid @ o for o in others))

        clusters.append(
            ClusterInfo(
                cluster_id=int(cid),
                size=int(idx.shape[0]),
                keywords=keywords_by_cluster.get(cid, []),
                cohesion=round(float(sims_to_centroid.mean()), 4),
                separation=round(separation, 4),
                representatives=rep_ids,
                representative_texts=rep_texts,
                member_ids=[items[i].id for i in idx],
                centroid=centroid.astype(float).round(6).tolist(),
            )
        )

    members = _attribute_members(norm_vecs, labels, probabilities, items, centroids, cluster_ids, cmatrix)
    return clusters, members, centroids, sim_matrix

Persistence

Side-file storage for cluster runs.

Side-file persistence for cluster runs.

A cluster run is stored as a single JSON file under <data_dir>/cluster_runs/<run_id>.json. This deliberately avoids any snapshot DB schema change: runs can be created, listed, read, and deleted independently of the index lifecycle.

delete_run(data_dir, run_id)

Delete a run file; returns True if it existed.

Source code in src/embedrag/cluster/store.py
53
54
55
56
57
58
59
def delete_run(data_dir: str, run_id: str) -> bool:
    """Delete a run file; returns True if it existed."""
    path = runs_dir(data_dir) / f"{run_id}.json"
    if path.exists():
        path.unlink()
        return True
    return False

list_runs(data_dir)

List run summaries (without the full member/projection payload).

Source code in src/embedrag/cluster/store.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def list_runs(data_dir: str) -> list[dict]:
    """List run summaries (without the full member/projection payload)."""
    d = runs_dir(data_dir)
    summaries: list[dict] = []
    for path in sorted(d.glob("*.json"), reverse=True):
        try:
            with open(path, encoding="utf-8") as f:
                data = json.load(f)
        except (json.JSONDecodeError, OSError):
            continue
        summaries.append(
            {
                "run_id": data.get("run_id", path.stem),
                "algorithm": data.get("algorithm", ""),
                "space": data.get("space", "text"),
                "created_at": data.get("created_at", ""),
                "n_items": data.get("n_items", 0),
                "n_clusters": data.get("n_clusters", 0),
                "noise_count": data.get("noise_count", 0),
                "source": data.get("source", ""),
                "metrics": data.get("metrics", {}),
                "clusters": [
                    {
                        "cluster_id": c.get("cluster_id"),
                        "label": c.get("label", ""),
                        "size": c.get("size", 0),
                        "keywords": c.get("keywords", []),
                    }
                    for c in data.get("clusters", [])
                ],
            }
        )
    return summaries

load_run(data_dir, run_id)

Load a single run by id, or None if missing.

Source code in src/embedrag/cluster/store.py
44
45
46
47
48
49
50
def load_run(data_dir: str, run_id: str) -> ClusterResult | None:
    """Load a single run by id, or None if missing."""
    path = runs_dir(data_dir) / f"{run_id}.json"
    if not path.exists():
        return None
    with open(path, encoding="utf-8") as f:
        return ClusterResult.from_dict(json.load(f))

make_run_id(prefix='run')

Generate a sortable, unique run id.

Source code in src/embedrag/cluster/store.py
30
31
32
def make_run_id(prefix: str = "run") -> str:
    """Generate a sortable, unique run id."""
    return f"{prefix}-{int(time.time() * 1000)}"

runs_dir(data_dir)

Return (and create) the cluster-runs directory under data_dir.

Source code in src/embedrag/cluster/store.py
23
24
25
26
27
def runs_dir(data_dir: str) -> Path:
    """Return (and create) the cluster-runs directory under ``data_dir``."""
    p = Path(data_dir) / RUNS_DIRNAME
    p.mkdir(parents=True, exist_ok=True)
    return p

save_run(data_dir, result)

Persist a cluster run; returns the file path.

Source code in src/embedrag/cluster/store.py
35
36
37
38
39
40
41
def save_run(data_dir: str, result: ClusterResult) -> str:
    """Persist a cluster run; returns the file path."""
    path = runs_dir(data_dir) / f"{result.run_id}.json"
    with open(path, "w", encoding="utf-8") as f:
        json.dump(result.to_dict(), f, ensure_ascii=False)
    logger.info("cluster_run_saved", run_id=result.run_id, path=str(path), clusters=result.n_clusters)
    return str(path)