Coverage for intelligence_toolkit/query_text_data/relevance_assessor.py: 0%

142 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3 

4import asyncio 

5from json import loads 

6from collections import defaultdict 

7import numpy as np 

8import scipy.spatial.distance 

9import tiktoken 

10import intelligence_toolkit.AI.utils as utils 

11import intelligence_toolkit.query_text_data.helper_functions as helper_functions 

12import intelligence_toolkit.query_text_data.prompts as prompts 

13 

14async def assess_relevance( 

15 ai_configuration, 

16 search_label, 

17 search_cids, 

18 cid_to_text, 

19 query, 

20 logit_bias, 

21 relevance_test_budget, 

22 num_adjacent, 

23 relevance_test_batch_size, 

24 test_history, 

25 progress_callback, 

26 chunk_callback, 

27 commentary 

28): 

29 batched_cids = [ 

30 search_cids[i : i + relevance_test_batch_size] 

31 for i in range(0, len(search_cids), relevance_test_batch_size) 

32 ] 

33 batched_texts = [[cid_to_text[cid] for cid in batch] for batch in batched_cids] 

34 batched_messages = [ 

35 [ 

36 utils.prepare_messages( 

37 prompts.chunk_relevance_prompt, {"chunk": chunk, "query": query} 

38 ) 

39 for chunk in batch 

40 ] 

41 for batch in batched_texts 

42 ] 

43 is_relevant = False 

44 for mx, mapped_messages in enumerate(batched_messages): 

45 cid_batch = batched_cids[mx] 

46 if ( 

47 len(test_history) + len(mapped_messages) + num_adjacent 

48 > relevance_test_budget 

49 ): 

50 mapped_messages = mapped_messages[ 

51 : relevance_test_budget - len(test_history) 

52 ] 

53 mapped_responses = await utils.map_generate_text( 

54 ai_configuration, mapped_messages, logit_bias=logit_bias, max_tokens=1 

55 ) 

56 num_relevant = process_relevance_responses( 

57 search_label, 

58 cid_batch, 

59 cid_to_text, 

60 mapped_responses, 

61 test_history, 

62 progress_callback, 

63 chunk_callback, 

64 commentary, 

65 ) 

66 is_relevant = num_relevant > 0 

67 if not is_relevant: # No relevant chunks found in this batch; terminate early 

68 break 

69 return is_relevant 

70 

71 

72def process_relevance_responses( 

73 search_label, 

74 search_cids, 

75 cid_to_text, 

76 mapped_responses, 

77 test_history, 

78 progress_callback, 

79 chunk_callback, 

80 commentary, 

81): 

82 tested_relevant = [] 

83 for r, c in zip(mapped_responses, search_cids): 

84 if c not in [x[1] for x in test_history]: 

85 test_history.append((search_label, c, r)) 

86 if r == "Yes": 

87 tested_relevant.append(c) 

88 if progress_callback is not None: 

89 progress_callback(test_history) 

90 relevant_list = [x[1] for x in test_history if x[2] == "Yes"] 

91 if chunk_callback is not None: 

92 chunk_callback([cid_to_text[cid] for cid in relevant_list]) 

93 

94 if commentary is not None and len(tested_relevant) > 0: 

95 relevant_texts = {cid: cid_to_text[cid] for cid in tested_relevant} 

96 commentary.add_chunks(relevant_texts) 

97 return len(tested_relevant) 

98 

99 

100async def detect_relevant_chunks( 

101 ai_configuration, 

102 query, 

103 processed_chunks, 

104 cid_to_vector, 

105 embedder, 

106 embedding_cache, 

107 chunk_search_config, 

108 chunk_progress_callback=None, 

109 chunk_callback=None, 

110 commentary=None 

111): 

112 

113 test_history = [] 

114 all_units = sorted( 

115 [(cid, vector) for cid, vector in (cid_to_vector.items())], key=lambda x: x[0] 

116 ) 

117 

118 yes_id = tiktoken.get_encoding("o200k_base").encode("Yes")[0] 

119 no_id = tiktoken.get_encoding("o200k_base").encode("No")[0] 

120 select_logit_bias = 5 

121 logit_bias = {yes_id: select_logit_bias, no_id: select_logit_bias} 

122 

123 if chunk_progress_callback is not None: 

124 chunk_progress_callback(test_history) 

125 

126 aq_embedding = np.array(embedder.embed_store_one(query, embedding_cache)) 

127 relevant, seen, adjacent = helper_functions.test_history_elements( 

128 test_history, 

129 processed_chunks.previous_cid, 

130 processed_chunks.next_cid, 

131 chunk_search_config.adjacent_test_steps, 

132 ) 

133 cosine_distances = sorted( 

134 [ 

135 (cid, scipy.spatial.distance.cosine(aq_embedding, vector)) 

136 for (cid, vector) in all_units 

137 if cid not in seen 

138 ], 

139 key=lambda x: x[1], 

140 reverse=False, 

141 ) 

142 semantic_search_cids = [x[0] for x in cosine_distances] 

143 # print(f"Top semantic search cids: {semantic_search_cids[:100]}") 

144 level_to_community_sequence = {} 

145 max_level = max([hc.level for hc in processed_chunks.hierarchical_communities]) 

146 concept_to_level_to_community = defaultdict(dict) 

147 level_to_community_to_candidate_cids = defaultdict(lambda: defaultdict(set)) 

148 level_to_community_to_cids = defaultdict(lambda: defaultdict(list)) 

149 level_to_cid_to_communities = defaultdict(lambda: defaultdict(set)) 

150 community_to_parent = {} 

151 for hc in processed_chunks.hierarchical_communities: 

152 concept_to_level_to_community[hc.node][hc.level] = ( 

153 processed_chunks.community_to_label[hc.cluster] 

154 ) 

155 if hc.parent_cluster is not None: 

156 community_to_parent[processed_chunks.community_to_label[hc.cluster]] = ( 

157 processed_chunks.community_to_label[hc.parent_cluster] 

158 ) 

159 cid_to_level_to_communities = defaultdict(lambda: defaultdict(set)) 

160 for level in range(0, max_level + 1): 

161 for cid, concepts in processed_chunks.cid_to_concepts.items(): 

162 for concept in concepts: 

163 if concept in concept_to_level_to_community.keys(): 

164 if level in concept_to_level_to_community[concept].keys(): 

165 community = concept_to_level_to_community[concept][level] 

166 cid_to_level_to_communities[cid][level].add(community) 

167 level_to_cid_to_communities[level][cid].add(community) 

168 level_to_community_to_candidate_cids[level][community].add(cid) 

169 else: 

170 # use the community from the previous level 

171 if level - 1 in concept_to_level_to_community[concept].keys(): 

172 community = concept_to_level_to_community[concept][ 

173 level - 1 

174 ] 

175 cid_to_level_to_communities[cid][level].add(community) 

176 level_to_cid_to_communities[level][cid].add(community) 

177 level_to_community_to_candidate_cids[level][community].add( 

178 cid 

179 ) 

180 

181 community_sequence = [] 

182 community_mean_rank = [] 

183 

184 for community, cids in level_to_community_to_candidate_cids[level].items(): 

185 filtered_cids = [c for c in cids if c in semantic_search_cids] 

186 mean_rank = np.mean( 

187 sorted([semantic_search_cids.index(c) for c in filtered_cids])[ 

188 : chunk_search_config.community_ranking_chunks 

189 ] 

190 ) 

191 community_mean_rank.append((community, mean_rank)) 

192 community_sequence = [ 

193 x[0] for x in sorted(community_mean_rank, key=lambda x: x[1]) 

194 ] 

195 # print(f"Level {level} community sequence: {community_sequence}") 

196 level_to_community_sequence[level] = community_sequence 

197 

198 for cid in semantic_search_cids: 

199 chunk_communities = cid_to_level_to_communities[cid][level] 

200 if len(chunk_communities) > 0: 

201 assigned_community = sorted( 

202 chunk_communities, key=lambda x: community_sequence.index(x) 

203 )[0] 

204 if cid not in level_to_community_to_cids[level][assigned_community]: 

205 level_to_community_to_cids[level][assigned_community].append(cid) 

206 

207 for level, community_to_cids in level_to_community_to_cids.items(): 

208 for community, cids in community_to_cids.items(): 

209 cids.sort(key=lambda x: semantic_search_cids.index(x)) 

210 

211 # Set level -1 as everything in the dataset 

212 level_to_community_sequence[-1] = ["1"] 

213 level_to_community_to_cids[-1]["1"] = semantic_search_cids 

214 for concept, level_to_community in concept_to_level_to_community.items(): 

215 level_to_community[-1] = "1" 

216 

217 for cid, level_to_community in cid_to_level_to_communities.items(): 

218 level_to_community[-1] = "1" 

219 

220 successive_irrelevant = 0 

221 eliminated_communities = set() 

222 current_level = -1 

223 

224 while len(test_history) + len(adjacent) < chunk_search_config.relevance_test_budget: 

225 # print(f"New level {current_level} loop after {len(test_history)} tests") 

226 relevant_this_loop = False 

227 

228 community_sequence = [] 

229 for community in level_to_community_sequence[current_level]: 

230 if community in community_to_parent.keys(): 

231 parent = community_to_parent[community] 

232 if parent not in eliminated_communities: 

233 community_sequence.append(community) 

234 else: 

235 eliminated_communities.add(community) 

236 # print(f"Eliminated community {community} due to parent {parent}") 

237 else: 

238 community_sequence.append(community) 

239 # print(f"Community sequence: {community_sequence}") 

240 community_to_cids = level_to_community_to_cids[current_level] 

241 for community in community_sequence: 

242 relevant, seen, adjacent = helper_functions.test_history_elements( 

243 test_history, 

244 processed_chunks.previous_cid, 

245 processed_chunks.next_cid, 

246 chunk_search_config.adjacent_test_steps, 

247 ) 

248 unseen_cids = [c for c in community_to_cids[community] if c not in seen][ 

249 : chunk_search_config.community_relevance_tests 

250 ] 

251 if len(unseen_cids) > 0: 

252 # print( 

253 # f"Assessing relevance for community {community} with chunks {unseen_cids}" 

254 # ) 

255 is_relevant = await assess_relevance( 

256 ai_configuration=ai_configuration, 

257 search_label=f"topic {community}", 

258 search_cids=unseen_cids, 

259 cid_to_text=processed_chunks.cid_to_text, 

260 query=query, 

261 logit_bias=logit_bias, 

262 relevance_test_budget=chunk_search_config.relevance_test_budget, 

263 num_adjacent=len(adjacent), 

264 relevance_test_batch_size=chunk_search_config.relevance_test_batch_size, 

265 test_history=test_history, 

266 progress_callback=chunk_progress_callback, 

267 chunk_callback=chunk_callback, 

268 commentary=commentary 

269 ) 

270 if len(test_history) + len(adjacent) >= chunk_search_config.relevance_test_budget: 

271 break 

272 relevant_this_loop |= is_relevant 

273 # print(f"Community {community} relevant? {is_relevant}") 

274 if ( 

275 current_level > -1 and not is_relevant 

276 ): # don't stop after failure at the root level 

277 eliminated_communities.add(community) 

278 successive_irrelevant += 1 

279 if ( 

280 successive_irrelevant 

281 == chunk_search_config.irrelevant_community_restart 

282 ): 

283 # print( 

284 # f"{successive_irrelevant} successive irrelevant communities; restarting" 

285 # ) 

286 successive_irrelevant = 0 

287 break 

288 else: 

289 successive_irrelevant = 0 

290 if ( 

291 current_level > -1 and not relevant_this_loop 

292 ): # don't stop after failure at the root level 

293 # print("Nothing relevant this loop") 

294 break 

295 if current_level + 1 in level_to_community_sequence.keys(): 

296 # print("Incrementing level") 

297 current_level += 1 

298 else: 

299 # print("Reached final level") 

300 pass 

301 

302 relevant, seen, adjacent = helper_functions.test_history_elements( 

303 test_history, 

304 processed_chunks.previous_cid, 

305 processed_chunks.next_cid, 

306 chunk_search_config.adjacent_test_steps, 

307 ) 

308 

309 await assess_relevance( 

310 ai_configuration=ai_configuration, 

311 search_label="neighbours", 

312 search_cids=adjacent, 

313 cid_to_text=processed_chunks.cid_to_text, 

314 query=query, 

315 logit_bias=logit_bias, 

316 relevance_test_budget=chunk_search_config.relevance_test_budget, 

317 num_adjacent=len(adjacent), 

318 relevance_test_batch_size=chunk_search_config.relevance_test_batch_size, 

319 test_history=test_history, 

320 progress_callback=chunk_progress_callback, 

321 chunk_callback=chunk_callback, 

322 commentary=commentary 

323 ) 

324 relevant, seen, adjacent = helper_functions.test_history_elements( 

325 test_history, 

326 processed_chunks.previous_cid, 

327 processed_chunks.next_cid, 

328 chunk_search_config.adjacent_test_steps, 

329 ) 

330 relevant.sort() 

331 commentary.complete_analysis() 

332 return relevant, helper_functions.get_test_progress(test_history)