Coverage for intelligence_toolkit/AI/base_embedder.py: 30%
115 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4import asyncio
5import json
6import logging
7from abc import ABC, abstractmethod
8from typing import Any
10import numpy as np
11import pyarrow as pa
12from tqdm.asyncio import tqdm_asyncio
14from intelligence_toolkit.AI.base_batch_async import BaseBatchAsync
15from intelligence_toolkit.AI.classes import VectorData
16from intelligence_toolkit.AI.defaults import (
17 DEFAULT_CONCURRENT_COROUTINES,
18 DEFAULT_LLM_MAX_TOKENS,
19 EMBEDDING_BATCHES_NUMBER,
20)
21from intelligence_toolkit.AI.utils import get_token_count, hash_text
22from intelligence_toolkit.AI.vector_store import VectorStore
23from intelligence_toolkit.helpers.constants import CACHE_PATH
24from intelligence_toolkit.helpers.decorators import retry_with_backoff
25from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
27logger = logging.getLogger(__name__)
29schema = pa.schema(
30 [
31 pa.field("hash", pa.string()),
32 pa.field("text", pa.string()),
33 pa.field("vector", pa.list_(pa.float64())),
34 pa.field("additional_details", pa.string()),
35 ]
36)
39class BaseEmbedder(ABC, BaseBatchAsync):
40 def __init__(
41 self,
42 db_name: str = "embeddings",
43 db_path=CACHE_PATH,
44 max_tokens=DEFAULT_LLM_MAX_TOKENS,
45 concurrent_coroutines=DEFAULT_CONCURRENT_COROUTINES,
46 check_token_count=True,
47 ) -> None:
48 self.vector_store = VectorStore(db_name, db_path, schema)
49 self.max_tokens = max_tokens
50 self.semaphore = asyncio.Semaphore(concurrent_coroutines)
51 self.check_token_count = check_token_count
53 @retry_with_backoff()
54 async def embed_one_async(
55 self,
56 data: VectorData,
57 has_callback=False,
58 ) -> Any | list[float]:
59 async with self.semaphore:
60 if not data["hash"]:
61 text_hashed = hash_text(data["text"])
62 data["hash"] = text_hashed
63 if self.check_token_count:
64 try:
65 tokens = get_token_count(data["text"])
66 if tokens > self.max_tokens:
67 text = data["text"][: self.max_tokens]
68 data["text"] = text
69 logger.info("Truncated text to max tokens")
70 except Exception:
71 pass
72 try:
73 embedding = await asyncio.wait_for(
74 self._generate_embedding_async(data["text"]), timeout=90
75 )
76 data["additional_details"] = json.dumps(
77 data["additional_details"] if "additional_details" in data else {}
78 )
79 data["vector"] = embedding
80 except Exception as e:
81 msg = f"Timeout in embedding generation. {e} Please try again."
82 raise Exception(msg)
84 if has_callback:
85 self.progress_callback()
86 return embedding, data
88 @retry_with_backoff()
89 def embed_store_one(
90 self, text: str, cache_data=True, additional_detail: Any = "{}"
91 ) -> Any | list[float]:
92 cache_data = False # disable for now
93 text_hashed = hash_text(text)
94 if cache_data:
95 existing_embedding = (
96 self.vector_store.search_by_column(text_hashed, "hash")
97 if cache_data
98 else []
99 )
100 if len(existing_embedding) > 0:
101 return existing_embedding.get("vector")[0]
103 # error when local
104 if self.check_token_count:
105 try:
106 tokens = get_token_count(text)
107 if tokens > self.max_tokens:
108 text = text[: self.max_tokens]
109 logger.info("Truncated text to max tokens")
110 except:
111 pass
113 try:
114 embedding = self._generate_embedding(text)
115 data = {
116 "hash": text_hashed,
117 "text": text,
118 "vector": embedding,
119 "additional_details": json.dumps(additional_detail),
120 }
121 self.vector_store.save([data]) if cache_data else None
122 except Exception as e:
123 msg = f"Problem in embedding generation. {e}"
124 raise Exception(msg)
125 return embedding
127 @retry_with_backoff()
128 async def embed_store_many(
129 self,
130 data: list[VectorData],
131 callbacks: list[ProgressBatchCallback] | None = None,
132 cache_data=True,
133 ) -> np.ndarray[Any, np.dtype[Any]]:
134 cache_data = False # disable for now
135 self.total_tasks = len(data)
136 final_embeddings = []
137 loaded_texts = []
138 all_data = []
140 for i in range(0, len(data), (EMBEDDING_BATCHES_NUMBER)):
141 batch_data = data[i : i + (EMBEDDING_BATCHES_NUMBER)]
143 if cache_data:
144 hash_all_texts = [hash_text(item["text"]) for item in batch_data]
145 existing = self.vector_store.search_by_column(hash_all_texts, "hash")
147 if len(existing.get("vector")) > 0:
148 existing_texts = existing.sort_values("text")
149 for item in existing_texts.to_numpy():
150 all_data.append(
151 {
152 "hash": item[0],
153 "text": item[1],
154 "vector": item[2],
155 "additional_details": item[3]
156 if len(item) > 3
157 else "{}",
158 }
159 )
160 loaded_texts.append(item[1])
161 final_embeddings.append(item[2])
163 new_items = [
164 item for item in batch_data if item["text"] not in loaded_texts
165 ]
166 if len(new_items) > 0:
167 tasks = [
168 asyncio.create_task(self.embed_one_async(item, callbacks))
169 for item in new_items
170 ]
171 if callbacks:
172 progress_task = asyncio.create_task(
173 self.track_progress(tasks, callbacks)
174 )
175 result = await tqdm_asyncio.gather(*tasks)
176 if callbacks:
177 await progress_task
178 embeddings = [embedding[0] for embedding in result]
179 new_data = [embedding[1] for embedding in result]
180 all_data.extend(new_data)
182 final_embeddings.extend(embeddings)
183 if cache_data:
184 self.vector_store.save(new_data)
185 self.vector_store.update_duckdb_data()
187 print(f"Got {len(loaded_texts)} existing texts")
188 logger.info("Got %s existing texts", len(loaded_texts))
189 print(f"Got {len(final_embeddings) - len(loaded_texts)} new texts")
190 logger.info("Got %s new texts", len(final_embeddings) - len(loaded_texts))
192 return all_data
194 @abstractmethod
195 def _generate_embedding(self, text: str) -> list[float]:
196 """Generate an embedding for a single text"""
198 @abstractmethod
199 async def _generate_embedding_async(self, text: str) -> list:
200 """Generate async embeddings for text"""