Coverage for intelligence_toolkit/AI/base_embedder.py: 30%

115 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4import asyncio 

5import json 

6import logging 

7from abc import ABC, abstractmethod 

8from typing import Any 

9 

10import numpy as np 

11import pyarrow as pa 

12from tqdm.asyncio import tqdm_asyncio 

13 

14from intelligence_toolkit.AI.base_batch_async import BaseBatchAsync 

15from intelligence_toolkit.AI.classes import VectorData 

16from intelligence_toolkit.AI.defaults import ( 

17 DEFAULT_CONCURRENT_COROUTINES, 

18 DEFAULT_LLM_MAX_TOKENS, 

19 EMBEDDING_BATCHES_NUMBER, 

20) 

21from intelligence_toolkit.AI.utils import get_token_count, hash_text 

22from intelligence_toolkit.AI.vector_store import VectorStore 

23from intelligence_toolkit.helpers.constants import CACHE_PATH 

24from intelligence_toolkit.helpers.decorators import retry_with_backoff 

25from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

26 

27logger = logging.getLogger(__name__) 

28 

29schema = pa.schema( 

30 [ 

31 pa.field("hash", pa.string()), 

32 pa.field("text", pa.string()), 

33 pa.field("vector", pa.list_(pa.float64())), 

34 pa.field("additional_details", pa.string()), 

35 ] 

36) 

37 

38 

39class BaseEmbedder(ABC, BaseBatchAsync): 

40 def __init__( 

41 self, 

42 db_name: str = "embeddings", 

43 db_path=CACHE_PATH, 

44 max_tokens=DEFAULT_LLM_MAX_TOKENS, 

45 concurrent_coroutines=DEFAULT_CONCURRENT_COROUTINES, 

46 check_token_count=True, 

47 ) -> None: 

48 self.vector_store = VectorStore(db_name, db_path, schema) 

49 self.max_tokens = max_tokens 

50 self.semaphore = asyncio.Semaphore(concurrent_coroutines) 

51 self.check_token_count = check_token_count 

52 

53 @retry_with_backoff() 

54 async def embed_one_async( 

55 self, 

56 data: VectorData, 

57 has_callback=False, 

58 ) -> Any | list[float]: 

59 async with self.semaphore: 

60 if not data["hash"]: 

61 text_hashed = hash_text(data["text"]) 

62 data["hash"] = text_hashed 

63 if self.check_token_count: 

64 try: 

65 tokens = get_token_count(data["text"]) 

66 if tokens > self.max_tokens: 

67 text = data["text"][: self.max_tokens] 

68 data["text"] = text 

69 logger.info("Truncated text to max tokens") 

70 except Exception: 

71 pass 

72 try: 

73 embedding = await asyncio.wait_for( 

74 self._generate_embedding_async(data["text"]), timeout=90 

75 ) 

76 data["additional_details"] = json.dumps( 

77 data["additional_details"] if "additional_details" in data else {} 

78 ) 

79 data["vector"] = embedding 

80 except Exception as e: 

81 msg = f"Timeout in embedding generation. {e} Please try again." 

82 raise Exception(msg) 

83 

84 if has_callback: 

85 self.progress_callback() 

86 return embedding, data 

87 

88 @retry_with_backoff() 

89 def embed_store_one( 

90 self, text: str, cache_data=True, additional_detail: Any = "{}" 

91 ) -> Any | list[float]: 

92 cache_data = False # disable for now 

93 text_hashed = hash_text(text) 

94 if cache_data: 

95 existing_embedding = ( 

96 self.vector_store.search_by_column(text_hashed, "hash") 

97 if cache_data 

98 else [] 

99 ) 

100 if len(existing_embedding) > 0: 

101 return existing_embedding.get("vector")[0] 

102 

103 # error when local 

104 if self.check_token_count: 

105 try: 

106 tokens = get_token_count(text) 

107 if tokens > self.max_tokens: 

108 text = text[: self.max_tokens] 

109 logger.info("Truncated text to max tokens") 

110 except: 

111 pass 

112 

113 try: 

114 embedding = self._generate_embedding(text) 

115 data = { 

116 "hash": text_hashed, 

117 "text": text, 

118 "vector": embedding, 

119 "additional_details": json.dumps(additional_detail), 

120 } 

121 self.vector_store.save([data]) if cache_data else None 

122 except Exception as e: 

123 msg = f"Problem in embedding generation. {e}" 

124 raise Exception(msg) 

125 return embedding 

126 

127 @retry_with_backoff() 

128 async def embed_store_many( 

129 self, 

130 data: list[VectorData], 

131 callbacks: list[ProgressBatchCallback] | None = None, 

132 cache_data=True, 

133 ) -> np.ndarray[Any, np.dtype[Any]]: 

134 cache_data = False # disable for now 

135 self.total_tasks = len(data) 

136 final_embeddings = [] 

137 loaded_texts = [] 

138 all_data = [] 

139 

140 for i in range(0, len(data), (EMBEDDING_BATCHES_NUMBER)): 

141 batch_data = data[i : i + (EMBEDDING_BATCHES_NUMBER)] 

142 

143 if cache_data: 

144 hash_all_texts = [hash_text(item["text"]) for item in batch_data] 

145 existing = self.vector_store.search_by_column(hash_all_texts, "hash") 

146 

147 if len(existing.get("vector")) > 0: 

148 existing_texts = existing.sort_values("text") 

149 for item in existing_texts.to_numpy(): 

150 all_data.append( 

151 { 

152 "hash": item[0], 

153 "text": item[1], 

154 "vector": item[2], 

155 "additional_details": item[3] 

156 if len(item) > 3 

157 else "{}", 

158 } 

159 ) 

160 loaded_texts.append(item[1]) 

161 final_embeddings.append(item[2]) 

162 

163 new_items = [ 

164 item for item in batch_data if item["text"] not in loaded_texts 

165 ] 

166 if len(new_items) > 0: 

167 tasks = [ 

168 asyncio.create_task(self.embed_one_async(item, callbacks)) 

169 for item in new_items 

170 ] 

171 if callbacks: 

172 progress_task = asyncio.create_task( 

173 self.track_progress(tasks, callbacks) 

174 ) 

175 result = await tqdm_asyncio.gather(*tasks) 

176 if callbacks: 

177 await progress_task 

178 embeddings = [embedding[0] for embedding in result] 

179 new_data = [embedding[1] for embedding in result] 

180 all_data.extend(new_data) 

181 

182 final_embeddings.extend(embeddings) 

183 if cache_data: 

184 self.vector_store.save(new_data) 

185 self.vector_store.update_duckdb_data() 

186 

187 print(f"Got {len(loaded_texts)} existing texts") 

188 logger.info("Got %s existing texts", len(loaded_texts)) 

189 print(f"Got {len(final_embeddings) - len(loaded_texts)} new texts") 

190 logger.info("Got %s new texts", len(final_embeddings) - len(loaded_texts)) 

191 

192 return all_data 

193 

194 @abstractmethod 

195 def _generate_embedding(self, text: str) -> list[float]: 

196 """Generate an embedding for a single text""" 

197 

198 @abstractmethod 

199 async def _generate_embedding_async(self, text: str) -> list: 

200 """Generate async embeddings for text"""