Coverage for intelligence_toolkit/query_text_data/relevance_assessor.py: 0%
142 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
4import asyncio
5from json import loads
6from collections import defaultdict
7import numpy as np
8import scipy.spatial.distance
9import tiktoken
10import intelligence_toolkit.AI.utils as utils
11import intelligence_toolkit.query_text_data.helper_functions as helper_functions
12import intelligence_toolkit.query_text_data.prompts as prompts
14async def assess_relevance(
15 ai_configuration,
16 search_label,
17 search_cids,
18 cid_to_text,
19 query,
20 logit_bias,
21 relevance_test_budget,
22 num_adjacent,
23 relevance_test_batch_size,
24 test_history,
25 progress_callback,
26 chunk_callback,
27 commentary
28):
29 batched_cids = [
30 search_cids[i : i + relevance_test_batch_size]
31 for i in range(0, len(search_cids), relevance_test_batch_size)
32 ]
33 batched_texts = [[cid_to_text[cid] for cid in batch] for batch in batched_cids]
34 batched_messages = [
35 [
36 utils.prepare_messages(
37 prompts.chunk_relevance_prompt, {"chunk": chunk, "query": query}
38 )
39 for chunk in batch
40 ]
41 for batch in batched_texts
42 ]
43 is_relevant = False
44 for mx, mapped_messages in enumerate(batched_messages):
45 cid_batch = batched_cids[mx]
46 if (
47 len(test_history) + len(mapped_messages) + num_adjacent
48 > relevance_test_budget
49 ):
50 mapped_messages = mapped_messages[
51 : relevance_test_budget - len(test_history)
52 ]
53 mapped_responses = await utils.map_generate_text(
54 ai_configuration, mapped_messages, logit_bias=logit_bias, max_tokens=1
55 )
56 num_relevant = process_relevance_responses(
57 search_label,
58 cid_batch,
59 cid_to_text,
60 mapped_responses,
61 test_history,
62 progress_callback,
63 chunk_callback,
64 commentary,
65 )
66 is_relevant = num_relevant > 0
67 if not is_relevant: # No relevant chunks found in this batch; terminate early
68 break
69 return is_relevant
72def process_relevance_responses(
73 search_label,
74 search_cids,
75 cid_to_text,
76 mapped_responses,
77 test_history,
78 progress_callback,
79 chunk_callback,
80 commentary,
81):
82 tested_relevant = []
83 for r, c in zip(mapped_responses, search_cids):
84 if c not in [x[1] for x in test_history]:
85 test_history.append((search_label, c, r))
86 if r == "Yes":
87 tested_relevant.append(c)
88 if progress_callback is not None:
89 progress_callback(test_history)
90 relevant_list = [x[1] for x in test_history if x[2] == "Yes"]
91 if chunk_callback is not None:
92 chunk_callback([cid_to_text[cid] for cid in relevant_list])
94 if commentary is not None and len(tested_relevant) > 0:
95 relevant_texts = {cid: cid_to_text[cid] for cid in tested_relevant}
96 commentary.add_chunks(relevant_texts)
97 return len(tested_relevant)
100async def detect_relevant_chunks(
101 ai_configuration,
102 query,
103 processed_chunks,
104 cid_to_vector,
105 embedder,
106 embedding_cache,
107 chunk_search_config,
108 chunk_progress_callback=None,
109 chunk_callback=None,
110 commentary=None
111):
113 test_history = []
114 all_units = sorted(
115 [(cid, vector) for cid, vector in (cid_to_vector.items())], key=lambda x: x[0]
116 )
118 yes_id = tiktoken.get_encoding("o200k_base").encode("Yes")[0]
119 no_id = tiktoken.get_encoding("o200k_base").encode("No")[0]
120 select_logit_bias = 5
121 logit_bias = {yes_id: select_logit_bias, no_id: select_logit_bias}
123 if chunk_progress_callback is not None:
124 chunk_progress_callback(test_history)
126 aq_embedding = np.array(embedder.embed_store_one(query, embedding_cache))
127 relevant, seen, adjacent = helper_functions.test_history_elements(
128 test_history,
129 processed_chunks.previous_cid,
130 processed_chunks.next_cid,
131 chunk_search_config.adjacent_test_steps,
132 )
133 cosine_distances = sorted(
134 [
135 (cid, scipy.spatial.distance.cosine(aq_embedding, vector))
136 for (cid, vector) in all_units
137 if cid not in seen
138 ],
139 key=lambda x: x[1],
140 reverse=False,
141 )
142 semantic_search_cids = [x[0] for x in cosine_distances]
143 # print(f"Top semantic search cids: {semantic_search_cids[:100]}")
144 level_to_community_sequence = {}
145 max_level = max([hc.level for hc in processed_chunks.hierarchical_communities])
146 concept_to_level_to_community = defaultdict(dict)
147 level_to_community_to_candidate_cids = defaultdict(lambda: defaultdict(set))
148 level_to_community_to_cids = defaultdict(lambda: defaultdict(list))
149 level_to_cid_to_communities = defaultdict(lambda: defaultdict(set))
150 community_to_parent = {}
151 for hc in processed_chunks.hierarchical_communities:
152 concept_to_level_to_community[hc.node][hc.level] = (
153 processed_chunks.community_to_label[hc.cluster]
154 )
155 if hc.parent_cluster is not None:
156 community_to_parent[processed_chunks.community_to_label[hc.cluster]] = (
157 processed_chunks.community_to_label[hc.parent_cluster]
158 )
159 cid_to_level_to_communities = defaultdict(lambda: defaultdict(set))
160 for level in range(0, max_level + 1):
161 for cid, concepts in processed_chunks.cid_to_concepts.items():
162 for concept in concepts:
163 if concept in concept_to_level_to_community.keys():
164 if level in concept_to_level_to_community[concept].keys():
165 community = concept_to_level_to_community[concept][level]
166 cid_to_level_to_communities[cid][level].add(community)
167 level_to_cid_to_communities[level][cid].add(community)
168 level_to_community_to_candidate_cids[level][community].add(cid)
169 else:
170 # use the community from the previous level
171 if level - 1 in concept_to_level_to_community[concept].keys():
172 community = concept_to_level_to_community[concept][
173 level - 1
174 ]
175 cid_to_level_to_communities[cid][level].add(community)
176 level_to_cid_to_communities[level][cid].add(community)
177 level_to_community_to_candidate_cids[level][community].add(
178 cid
179 )
181 community_sequence = []
182 community_mean_rank = []
184 for community, cids in level_to_community_to_candidate_cids[level].items():
185 filtered_cids = [c for c in cids if c in semantic_search_cids]
186 mean_rank = np.mean(
187 sorted([semantic_search_cids.index(c) for c in filtered_cids])[
188 : chunk_search_config.community_ranking_chunks
189 ]
190 )
191 community_mean_rank.append((community, mean_rank))
192 community_sequence = [
193 x[0] for x in sorted(community_mean_rank, key=lambda x: x[1])
194 ]
195 # print(f"Level {level} community sequence: {community_sequence}")
196 level_to_community_sequence[level] = community_sequence
198 for cid in semantic_search_cids:
199 chunk_communities = cid_to_level_to_communities[cid][level]
200 if len(chunk_communities) > 0:
201 assigned_community = sorted(
202 chunk_communities, key=lambda x: community_sequence.index(x)
203 )[0]
204 if cid not in level_to_community_to_cids[level][assigned_community]:
205 level_to_community_to_cids[level][assigned_community].append(cid)
207 for level, community_to_cids in level_to_community_to_cids.items():
208 for community, cids in community_to_cids.items():
209 cids.sort(key=lambda x: semantic_search_cids.index(x))
211 # Set level -1 as everything in the dataset
212 level_to_community_sequence[-1] = ["1"]
213 level_to_community_to_cids[-1]["1"] = semantic_search_cids
214 for concept, level_to_community in concept_to_level_to_community.items():
215 level_to_community[-1] = "1"
217 for cid, level_to_community in cid_to_level_to_communities.items():
218 level_to_community[-1] = "1"
220 successive_irrelevant = 0
221 eliminated_communities = set()
222 current_level = -1
224 while len(test_history) + len(adjacent) < chunk_search_config.relevance_test_budget:
225 # print(f"New level {current_level} loop after {len(test_history)} tests")
226 relevant_this_loop = False
228 community_sequence = []
229 for community in level_to_community_sequence[current_level]:
230 if community in community_to_parent.keys():
231 parent = community_to_parent[community]
232 if parent not in eliminated_communities:
233 community_sequence.append(community)
234 else:
235 eliminated_communities.add(community)
236 # print(f"Eliminated community {community} due to parent {parent}")
237 else:
238 community_sequence.append(community)
239 # print(f"Community sequence: {community_sequence}")
240 community_to_cids = level_to_community_to_cids[current_level]
241 for community in community_sequence:
242 relevant, seen, adjacent = helper_functions.test_history_elements(
243 test_history,
244 processed_chunks.previous_cid,
245 processed_chunks.next_cid,
246 chunk_search_config.adjacent_test_steps,
247 )
248 unseen_cids = [c for c in community_to_cids[community] if c not in seen][
249 : chunk_search_config.community_relevance_tests
250 ]
251 if len(unseen_cids) > 0:
252 # print(
253 # f"Assessing relevance for community {community} with chunks {unseen_cids}"
254 # )
255 is_relevant = await assess_relevance(
256 ai_configuration=ai_configuration,
257 search_label=f"topic {community}",
258 search_cids=unseen_cids,
259 cid_to_text=processed_chunks.cid_to_text,
260 query=query,
261 logit_bias=logit_bias,
262 relevance_test_budget=chunk_search_config.relevance_test_budget,
263 num_adjacent=len(adjacent),
264 relevance_test_batch_size=chunk_search_config.relevance_test_batch_size,
265 test_history=test_history,
266 progress_callback=chunk_progress_callback,
267 chunk_callback=chunk_callback,
268 commentary=commentary
269 )
270 if len(test_history) + len(adjacent) >= chunk_search_config.relevance_test_budget:
271 break
272 relevant_this_loop |= is_relevant
273 # print(f"Community {community} relevant? {is_relevant}")
274 if (
275 current_level > -1 and not is_relevant
276 ): # don't stop after failure at the root level
277 eliminated_communities.add(community)
278 successive_irrelevant += 1
279 if (
280 successive_irrelevant
281 == chunk_search_config.irrelevant_community_restart
282 ):
283 # print(
284 # f"{successive_irrelevant} successive irrelevant communities; restarting"
285 # )
286 successive_irrelevant = 0
287 break
288 else:
289 successive_irrelevant = 0
290 if (
291 current_level > -1 and not relevant_this_loop
292 ): # don't stop after failure at the root level
293 # print("Nothing relevant this loop")
294 break
295 if current_level + 1 in level_to_community_sequence.keys():
296 # print("Incrementing level")
297 current_level += 1
298 else:
299 # print("Reached final level")
300 pass
302 relevant, seen, adjacent = helper_functions.test_history_elements(
303 test_history,
304 processed_chunks.previous_cid,
305 processed_chunks.next_cid,
306 chunk_search_config.adjacent_test_steps,
307 )
309 await assess_relevance(
310 ai_configuration=ai_configuration,
311 search_label="neighbours",
312 search_cids=adjacent,
313 cid_to_text=processed_chunks.cid_to_text,
314 query=query,
315 logit_bias=logit_bias,
316 relevance_test_budget=chunk_search_config.relevance_test_budget,
317 num_adjacent=len(adjacent),
318 relevance_test_batch_size=chunk_search_config.relevance_test_batch_size,
319 test_history=test_history,
320 progress_callback=chunk_progress_callback,
321 chunk_callback=chunk_callback,
322 commentary=commentary
323 )
324 relevant, seen, adjacent = helper_functions.test_history_elements(
325 test_history,
326 processed_chunks.previous_cid,
327 processed_chunks.next_cid,
328 chunk_search_config.adjacent_test_steps,
329 )
330 relevant.sort()
331 commentary.complete_analysis()
332 return relevant, helper_functions.get_test_progress(test_history)