Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_index_and_infer.py: 100%
175 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5from collections import defaultdict
6from unittest.mock import Mock, patch
8import networkx as nx
9import pytest
11from intelligence_toolkit.detect_entity_networks.config import (
12 SIMILARITY_THRESHOLD_MAX,
13 SIMILARITY_THRESHOLD_MIN,
14)
15from intelligence_toolkit.detect_entity_networks.index_and_infer import (
16 create_inferred_links,
17 index_and_infer,
18 index_nodes,
19 infer_nodes,
20)
21from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
24class TestIndexNodes:
25 @pytest.fixture()
26 def overall_graph_small(self):
27 G = nx.Graph()
28 G.add_node("Entity1", type="TypeA")
29 G.add_node("Entity2", type="TypeB")
30 G.add_node("Entity3", type="TypeA")
31 G.add_node("Entity4", type="TypeC")
32 G.add_edge("Entity1", "Entity2")
33 G.add_edge("Entity3", "Entity4")
34 return G
36 @pytest.fixture()
37 def overall_graph(self):
38 G = nx.Graph()
39 # Adding more nodes and edges to the graph
40 for i in range(1, 31):
41 G.add_node(
42 f"Entity{i}", type=f"Type{chr(65 + (i % 4))}"
43 ) # Types will be TypeA, TypeB, TypeC
44 for i in range(1, 31, 2):
45 G.add_edge(f"Entity{i}", f"Entity{i + 1}")
46 return G
48 async def test_index_nodes_empty_types(self, overall_graph):
49 with pytest.raises(
50 ValueError,
51 match="No node types to index",
52 ):
53 await index_nodes([], overall_graph)
55 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.OpenAIEmbedder")
56 async def test_index_nodes_small_samples(self, mock_embedder, overall_graph_small):
57 async def embed_store_many(*args) -> list[list[float]]:
58 return [
59 {"vector": [0.1, 0.3], "text": "A"},
60 {"vector": [0.3, 0.4], "text": "B"},
61 ]
63 mock_instance = mock_embedder.return_value
64 mock_instance.embed_store_many.side_effect = embed_store_many
66 indexed_node_types = ["TypeA", "TypeB"]
67 # Expect ValueError
68 with pytest.raises(
69 ValueError,
70 match="Expected n_neighbors <= n_samples_fit, but n_neighbors = 20, n_samples_fit = 2, n_samples = 2",
71 ):
72 await index_nodes(indexed_node_types, overall_graph_small)
74 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.OpenAIEmbedder")
75 async def test_index_nodes(self, mock_embedder, overall_graph):
76 async def embed_store_many(*args) -> list[list[float]]:
77 return [
78 {"vector": [0.1, 0.3], "text": "A"},
79 ] * 23
81 mock_instance = mock_embedder.return_value
82 mock_instance.embed_store_many.side_effect = embed_store_many
84 indexed_node_types = ["TypeA", "TypeB", "TypeC"]
85 (
86 embedded_texts,
87 nearest_text_distances,
88 nearest_text_indices,
89 ) = await index_nodes(indexed_node_types, overall_graph)
91 # Check if the embedded texts are correct
92 expected_texts = [
93 f"Entity{i}"
94 for i in range(1, 31)
95 if f"Type{chr(65 + (i % 4))}" in indexed_node_types
96 ]
97 expected_texts.sort()
99 # Check the shape of the distances and indices
100 expected_shape = (len(expected_texts), 20)
101 assert embedded_texts == expected_texts
102 assert nearest_text_distances.shape == expected_shape
103 assert nearest_text_indices.shape == expected_shape
106class TestInferNodes:
107 def test_infer_nodes_min_threshold(self):
108 with pytest.raises(
109 ValueError,
110 match=f"Similarity threshold must be between {SIMILARITY_THRESHOLD_MIN} and {SIMILARITY_THRESHOLD_MAX}",
111 ):
112 infer_nodes(-0.1, [], [], [])
114 def test_infer_nodes_max_threshold(self):
115 with pytest.raises(
116 ValueError,
117 match=f"Similarity threshold must be between {SIMILARITY_THRESHOLD_MIN} and {SIMILARITY_THRESHOLD_MAX}",
118 ):
119 infer_nodes(2, [], [], [])
121 def test_infer_nodes_005(self):
122 similarity_threshold = 0.05
123 embedded_texts = [
124 "Entity==ABCDE",
125 "Entity==PLUS_ONE",
126 "Entity==PLUS ONE",
127 ]
128 nearest_text_indices = [[1, 2], [2, 1], [1, 2]]
129 nearest_text_distances = [
130 [0.1, 0.1, 0.3],
131 [0.02, 0.1, 0.3],
132 [0.02, 0.1, 0.3],
133 ]
135 inferred_links = infer_nodes(
136 similarity_threshold,
137 embedded_texts,
138 nearest_text_indices,
139 nearest_text_distances,
140 )
142 expected_links = defaultdict(set)
143 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE")
144 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE")
146 assert inferred_links == expected_links
148 def test_infer_nodes_1(self):
149 similarity_threshold = 1
150 embedded_texts = [
151 "Entity==ABCDE",
152 "Entity==PLUS_ONE",
153 "Entity==PLUS ONE",
154 ]
155 nearest_text_indices = [[1, 2], [2, 1], [1, 2]]
156 nearest_text_distances = [
157 [0.1, 0.1, 0.3],
158 [0.02, 0.1, 0.3],
159 [0.02, 0.1, 0.3],
160 ]
162 inferred_links = infer_nodes(
163 similarity_threshold,
164 embedded_texts,
165 nearest_text_indices,
166 nearest_text_distances,
167 )
169 expected_links = defaultdict(set)
170 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE")
171 expected_links["Entity==PLUS ONE"].add("Entity==ABCDE")
172 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE")
173 expected_links["Entity==PLUS_ONE"].add("Entity==ABCDE")
174 expected_links["Entity==ABCDE"].add("Entity==PLUS ONE")
175 expected_links["Entity==ABCDE"].add("Entity==PLUS_ONE")
177 assert inferred_links == expected_links
179 def test_infer_nodes_07(self):
180 similarity_threshold = 0.7
181 embedded_texts = [
182 "Entity==ABCDE",
183 "Entity==PLUS_ONE",
184 "Entity==PLUS ONE",
185 ]
186 nearest_text_indices = [[1, 2], [2, 1], [1, 2]]
187 nearest_text_distances = [
188 [0.8, 0.1, 0.3],
189 [0.02, 0.1, 0.3],
190 [0.02, 0.1, 0.3],
191 ]
193 inferred_links = infer_nodes(
194 similarity_threshold,
195 embedded_texts,
196 nearest_text_indices,
197 nearest_text_distances,
198 )
200 expected_links = defaultdict(set)
201 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE")
202 expected_links["Entity==PLUS ONE"].add("Entity==ABCDE")
203 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE")
204 expected_links["Entity==ABCDE"].add("Entity==PLUS ONE")
206 assert inferred_links == expected_links
208 def test_infer_nodes_progress_callbacks_empty(self):
209 similarity_threshold = 0.7
210 embedded_texts = [
211 "Entity==ABCDE",
212 "Entity==PLUS_ONE",
213 "Entity==PLUS ONE",
214 ]
215 nearest_text_indices = [[1, 2], [2, 1], [1, 2]]
216 nearest_text_distances = [
217 [0.8, 0.1, 0.3],
218 [0.02, 0.1, 0.3],
219 [0.02, 0.1, 0.3],
220 ]
222 inferred_links = infer_nodes(
223 similarity_threshold,
224 embedded_texts,
225 nearest_text_indices,
226 nearest_text_distances,
227 progress_callbacks=[],
228 )
230 expected_links = defaultdict(set)
231 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE")
232 expected_links["Entity==PLUS ONE"].add("Entity==ABCDE")
233 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE")
234 expected_links["Entity==ABCDE"].add("Entity==PLUS ONE")
236 assert inferred_links == expected_links
238 def test_infer_nodes_one_progress_callback(self):
239 similarity_threshold = 0.7
240 embedded_texts = [
241 "Entity==ABCDE",
242 "Entity==PLUS_ONE",
243 "Entity==PLUS ONE",
244 ]
245 nearest_text_indices = [[1, 2], [2, 1], [1, 2]]
246 nearest_text_distances = [
247 [0.8, 0.1, 0.3],
248 [0.02, 0.1, 0.3],
249 [0.02, 0.1, 0.3],
250 ]
252 callb1 = ProgressBatchCallback()
253 progress_callback = Mock()
254 callb1.on_batch_change = progress_callback
256 infer_nodes(
257 similarity_threshold,
258 embedded_texts,
259 nearest_text_indices,
260 nearest_text_distances,
261 progress_callbacks=[callb1],
262 )
263 progress_callback.assert_called_with(2, 3, "Infering links...")
265 def test_infer_nodes_two_progress_callback(self):
266 similarity_threshold = 0.7
267 embedded_texts = [
268 "Entity==ABCDE",
269 "Entity==PLUS_ONE",
270 "Entity==PLUS ONE",
271 ]
272 nearest_text_indices = [[1, 2], [2, 1], [1, 2]]
273 nearest_text_distances = [
274 [0.8, 0.1, 0.3],
275 [0.02, 0.1, 0.3],
276 [0.02, 0.1, 0.3],
277 ]
279 callb1 = ProgressBatchCallback()
280 progress_callback = Mock()
281 callb1.on_batch_change = progress_callback
283 callb2 = ProgressBatchCallback()
284 progress_callback2 = Mock()
285 callb2.on_batch_change = progress_callback2
287 infer_nodes(
288 similarity_threshold,
289 embedded_texts,
290 nearest_text_indices,
291 nearest_text_distances,
292 progress_callbacks=[callb1, callb2],
293 )
294 progress_callback.assert_called_with(2, 3, "Infering links...")
295 progress_callback2.assert_called_with(2, 3, "Infering links...")
298class TestCreateInferredLinks:
299 def test_create_links(self):
300 inferred_links = defaultdict(set)
301 inferred_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE")
302 inferred_links["Entity==PLUS ONE"].add("Entity==ABCDE")
303 inferred_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE")
304 inferred_links["Entity==ABCDE"].add("Entity==PLUS ONE")
306 created_links = create_inferred_links(inferred_links)
308 expected_links = [
309 ("Entity==PLUS ONE", "Entity==PLUS_ONE"),
310 ("Entity==ABCDE", "Entity==PLUS ONE"),
311 ]
312 assert created_links == expected_links
314 def test_create_links_return_empty(self):
315 inferred_links = defaultdict(set)
317 created_links = create_inferred_links(inferred_links)
319 assert created_links == []
321 created_links = create_inferred_links(inferred_links)
323 assert created_links == []
326class TestIndexAndInfer:
327 @pytest.fixture()
328 def overall_graph(self):
329 G = nx.Graph()
330 # Adding more nodes and edges to the graph
331 for i in range(1, 31):
332 G.add_node(f"Entity{i}", type=f"Type{chr(65 + (i % 4))}")
333 for i in range(1, 31, 2):
334 G.add_edge(f"Entity{i}", f"Entity{i + 1}")
335 return G
337 async def test_empty_graph(self):
338 indexed_node_types = ["TypeA", "TypeB", "TypeC"]
339 with pytest.raises(ValueError, match="Graph is empty"):
340 await index_and_infer(indexed_node_types, nx.Graph(), 0)
342 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.index_nodes")
343 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.infer_nodes")
344 async def test_index_and_infer(
345 self, mock_infer_nodes, mock_index_nodes, overall_graph
346 ) -> None:
347 indexed_node_types = ["TypeA", "TypeB", "TypeC"]
348 embedded_texts = ["Entity1", "Entity2", "Entity3"]
349 nearest_text_distances = [[0.1, 0.3], [0.3, 0.4]]
350 nearest_text_indices = [[0, 1], [1, 0]]
352 mock_index_nodes.return_value = (
353 embedded_texts,
354 nearest_text_distances,
355 nearest_text_indices,
356 )
358 inferred_links = defaultdict(set)
359 inferred_links["Entity1"].add("Entity2")
360 mock_infer_nodes.return_value = inferred_links
362 link_list, _ = await index_and_infer(indexed_node_types, overall_graph, 0.5)
364 assert link_list == inferred_links
366 mock_index_nodes.assert_called_once_with(
367 indexed_node_types, overall_graph, None, None, None, True
368 )
369 mock_infer_nodes.assert_called_once_with(
370 0.5, embedded_texts, nearest_text_indices, nearest_text_distances, None
371 )