Coverage for intelligence_toolkit/AI/local_embedder.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5import asyncio 

6from typing import Any 

7 

8from sentence_transformers import SentenceTransformer 

9 

10from intelligence_toolkit.AI.base_embedder import BaseEmbedder 

11from intelligence_toolkit.AI.defaults import ( 

12 DEFAULT_CONCURRENT_COROUTINES, 

13 DEFAULT_LLM_MAX_TOKENS, 

14 DEFAULT_LOCAL_EMBEDDING_MODEL, 

15) 

16from intelligence_toolkit.helpers.constants import CACHE_PATH 

17 

18 

19 

20class LocalEmbedder(BaseEmbedder): 

21 def __init__( 

22 self, 

23 db_name: str = "embeddings", 

24 db_path=CACHE_PATH, 

25 max_tokens=DEFAULT_LLM_MAX_TOKENS, 

26 concurrent_coroutines: int | None = DEFAULT_CONCURRENT_COROUTINES + 100, 

27 model: str | None = DEFAULT_LOCAL_EMBEDDING_MODEL, 

28 ): 

29 super().__init__(db_name, db_path, max_tokens, concurrent_coroutines, False) 

30 # Use default model if None is passed 

31 if model is None: 

32 model = DEFAULT_LOCAL_EMBEDDING_MODEL 

33 try: 

34 self.local_client = SentenceTransformer(model) 

35 except Exception as e: 

36 raise Exception(f"Failed to load local embedding model '{model}': {e}. Please ensure the model is available or check your internet connection for download.") 

37 

38 def _generate_embedding(self, text: str | list[str]) -> list | Any: 

39 return self.local_client.encode(text).tolist() 

40 

41 async def _generate_embedding_async(self, text: str) -> list | Any: 

42 await asyncio.sleep(0) 

43 

44 return self._generate_embedding(text)