Coverage for intelligence_toolkit/AI/local_embedder.py: 100%
20 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5import asyncio
6from typing import Any
8from sentence_transformers import SentenceTransformer
10from intelligence_toolkit.AI.base_embedder import BaseEmbedder
11from intelligence_toolkit.AI.defaults import (
12 DEFAULT_CONCURRENT_COROUTINES,
13 DEFAULT_LLM_MAX_TOKENS,
14 DEFAULT_LOCAL_EMBEDDING_MODEL,
15)
16from intelligence_toolkit.helpers.constants import CACHE_PATH
20class LocalEmbedder(BaseEmbedder):
21 def __init__(
22 self,
23 db_name: str = "embeddings",
24 db_path=CACHE_PATH,
25 max_tokens=DEFAULT_LLM_MAX_TOKENS,
26 concurrent_coroutines: int | None = DEFAULT_CONCURRENT_COROUTINES + 100,
27 model: str | None = DEFAULT_LOCAL_EMBEDDING_MODEL,
28 ):
29 super().__init__(db_name, db_path, max_tokens, concurrent_coroutines, False)
30 # Use default model if None is passed
31 if model is None:
32 model = DEFAULT_LOCAL_EMBEDDING_MODEL
33 try:
34 self.local_client = SentenceTransformer(model)
35 except Exception as e:
36 raise Exception(f"Failed to load local embedding model '{model}': {e}. Please ensure the model is available or check your internet connection for download.")
38 def _generate_embedding(self, text: str | list[str]) -> list | Any:
39 return self.local_client.encode(text).tolist()
41 async def _generate_embedding_async(self, text: str) -> list | Any:
42 await asyncio.sleep(0)
44 return self._generate_embedding(text)