Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_index_and_infer.py: 100%

175 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5from collections import defaultdict 

6from unittest.mock import Mock, patch 

7 

8import networkx as nx 

9import pytest 

10 

11from intelligence_toolkit.detect_entity_networks.config import ( 

12 SIMILARITY_THRESHOLD_MAX, 

13 SIMILARITY_THRESHOLD_MIN, 

14) 

15from intelligence_toolkit.detect_entity_networks.index_and_infer import ( 

16 create_inferred_links, 

17 index_and_infer, 

18 index_nodes, 

19 infer_nodes, 

20) 

21from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

22 

23 

24class TestIndexNodes: 

25 @pytest.fixture() 

26 def overall_graph_small(self): 

27 G = nx.Graph() 

28 G.add_node("Entity1", type="TypeA") 

29 G.add_node("Entity2", type="TypeB") 

30 G.add_node("Entity3", type="TypeA") 

31 G.add_node("Entity4", type="TypeC") 

32 G.add_edge("Entity1", "Entity2") 

33 G.add_edge("Entity3", "Entity4") 

34 return G 

35 

36 @pytest.fixture() 

37 def overall_graph(self): 

38 G = nx.Graph() 

39 # Adding more nodes and edges to the graph 

40 for i in range(1, 31): 

41 G.add_node( 

42 f"Entity{i}", type=f"Type{chr(65 + (i % 4))}" 

43 ) # Types will be TypeA, TypeB, TypeC 

44 for i in range(1, 31, 2): 

45 G.add_edge(f"Entity{i}", f"Entity{i + 1}") 

46 return G 

47 

48 async def test_index_nodes_empty_types(self, overall_graph): 

49 with pytest.raises( 

50 ValueError, 

51 match="No node types to index", 

52 ): 

53 await index_nodes([], overall_graph) 

54 

55 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.OpenAIEmbedder") 

56 async def test_index_nodes_small_samples(self, mock_embedder, overall_graph_small): 

57 async def embed_store_many(*args) -> list[list[float]]: 

58 return [ 

59 {"vector": [0.1, 0.3], "text": "A"}, 

60 {"vector": [0.3, 0.4], "text": "B"}, 

61 ] 

62 

63 mock_instance = mock_embedder.return_value 

64 mock_instance.embed_store_many.side_effect = embed_store_many 

65 

66 indexed_node_types = ["TypeA", "TypeB"] 

67 # Expect ValueError 

68 with pytest.raises( 

69 ValueError, 

70 match="Expected n_neighbors <= n_samples_fit, but n_neighbors = 20, n_samples_fit = 2, n_samples = 2", 

71 ): 

72 await index_nodes(indexed_node_types, overall_graph_small) 

73 

74 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.OpenAIEmbedder") 

75 async def test_index_nodes(self, mock_embedder, overall_graph): 

76 async def embed_store_many(*args) -> list[list[float]]: 

77 return [ 

78 {"vector": [0.1, 0.3], "text": "A"}, 

79 ] * 23 

80 

81 mock_instance = mock_embedder.return_value 

82 mock_instance.embed_store_many.side_effect = embed_store_many 

83 

84 indexed_node_types = ["TypeA", "TypeB", "TypeC"] 

85 ( 

86 embedded_texts, 

87 nearest_text_distances, 

88 nearest_text_indices, 

89 ) = await index_nodes(indexed_node_types, overall_graph) 

90 

91 # Check if the embedded texts are correct 

92 expected_texts = [ 

93 f"Entity{i}" 

94 for i in range(1, 31) 

95 if f"Type{chr(65 + (i % 4))}" in indexed_node_types 

96 ] 

97 expected_texts.sort() 

98 

99 # Check the shape of the distances and indices 

100 expected_shape = (len(expected_texts), 20) 

101 assert embedded_texts == expected_texts 

102 assert nearest_text_distances.shape == expected_shape 

103 assert nearest_text_indices.shape == expected_shape 

104 

105 

106class TestInferNodes: 

107 def test_infer_nodes_min_threshold(self): 

108 with pytest.raises( 

109 ValueError, 

110 match=f"Similarity threshold must be between {SIMILARITY_THRESHOLD_MIN} and {SIMILARITY_THRESHOLD_MAX}", 

111 ): 

112 infer_nodes(-0.1, [], [], []) 

113 

114 def test_infer_nodes_max_threshold(self): 

115 with pytest.raises( 

116 ValueError, 

117 match=f"Similarity threshold must be between {SIMILARITY_THRESHOLD_MIN} and {SIMILARITY_THRESHOLD_MAX}", 

118 ): 

119 infer_nodes(2, [], [], []) 

120 

121 def test_infer_nodes_005(self): 

122 similarity_threshold = 0.05 

123 embedded_texts = [ 

124 "Entity==ABCDE", 

125 "Entity==PLUS_ONE", 

126 "Entity==PLUS ONE", 

127 ] 

128 nearest_text_indices = [[1, 2], [2, 1], [1, 2]] 

129 nearest_text_distances = [ 

130 [0.1, 0.1, 0.3], 

131 [0.02, 0.1, 0.3], 

132 [0.02, 0.1, 0.3], 

133 ] 

134 

135 inferred_links = infer_nodes( 

136 similarity_threshold, 

137 embedded_texts, 

138 nearest_text_indices, 

139 nearest_text_distances, 

140 ) 

141 

142 expected_links = defaultdict(set) 

143 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE") 

144 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE") 

145 

146 assert inferred_links == expected_links 

147 

148 def test_infer_nodes_1(self): 

149 similarity_threshold = 1 

150 embedded_texts = [ 

151 "Entity==ABCDE", 

152 "Entity==PLUS_ONE", 

153 "Entity==PLUS ONE", 

154 ] 

155 nearest_text_indices = [[1, 2], [2, 1], [1, 2]] 

156 nearest_text_distances = [ 

157 [0.1, 0.1, 0.3], 

158 [0.02, 0.1, 0.3], 

159 [0.02, 0.1, 0.3], 

160 ] 

161 

162 inferred_links = infer_nodes( 

163 similarity_threshold, 

164 embedded_texts, 

165 nearest_text_indices, 

166 nearest_text_distances, 

167 ) 

168 

169 expected_links = defaultdict(set) 

170 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE") 

171 expected_links["Entity==PLUS ONE"].add("Entity==ABCDE") 

172 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE") 

173 expected_links["Entity==PLUS_ONE"].add("Entity==ABCDE") 

174 expected_links["Entity==ABCDE"].add("Entity==PLUS ONE") 

175 expected_links["Entity==ABCDE"].add("Entity==PLUS_ONE") 

176 

177 assert inferred_links == expected_links 

178 

179 def test_infer_nodes_07(self): 

180 similarity_threshold = 0.7 

181 embedded_texts = [ 

182 "Entity==ABCDE", 

183 "Entity==PLUS_ONE", 

184 "Entity==PLUS ONE", 

185 ] 

186 nearest_text_indices = [[1, 2], [2, 1], [1, 2]] 

187 nearest_text_distances = [ 

188 [0.8, 0.1, 0.3], 

189 [0.02, 0.1, 0.3], 

190 [0.02, 0.1, 0.3], 

191 ] 

192 

193 inferred_links = infer_nodes( 

194 similarity_threshold, 

195 embedded_texts, 

196 nearest_text_indices, 

197 nearest_text_distances, 

198 ) 

199 

200 expected_links = defaultdict(set) 

201 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE") 

202 expected_links["Entity==PLUS ONE"].add("Entity==ABCDE") 

203 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE") 

204 expected_links["Entity==ABCDE"].add("Entity==PLUS ONE") 

205 

206 assert inferred_links == expected_links 

207 

208 def test_infer_nodes_progress_callbacks_empty(self): 

209 similarity_threshold = 0.7 

210 embedded_texts = [ 

211 "Entity==ABCDE", 

212 "Entity==PLUS_ONE", 

213 "Entity==PLUS ONE", 

214 ] 

215 nearest_text_indices = [[1, 2], [2, 1], [1, 2]] 

216 nearest_text_distances = [ 

217 [0.8, 0.1, 0.3], 

218 [0.02, 0.1, 0.3], 

219 [0.02, 0.1, 0.3], 

220 ] 

221 

222 inferred_links = infer_nodes( 

223 similarity_threshold, 

224 embedded_texts, 

225 nearest_text_indices, 

226 nearest_text_distances, 

227 progress_callbacks=[], 

228 ) 

229 

230 expected_links = defaultdict(set) 

231 expected_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE") 

232 expected_links["Entity==PLUS ONE"].add("Entity==ABCDE") 

233 expected_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE") 

234 expected_links["Entity==ABCDE"].add("Entity==PLUS ONE") 

235 

236 assert inferred_links == expected_links 

237 

238 def test_infer_nodes_one_progress_callback(self): 

239 similarity_threshold = 0.7 

240 embedded_texts = [ 

241 "Entity==ABCDE", 

242 "Entity==PLUS_ONE", 

243 "Entity==PLUS ONE", 

244 ] 

245 nearest_text_indices = [[1, 2], [2, 1], [1, 2]] 

246 nearest_text_distances = [ 

247 [0.8, 0.1, 0.3], 

248 [0.02, 0.1, 0.3], 

249 [0.02, 0.1, 0.3], 

250 ] 

251 

252 callb1 = ProgressBatchCallback() 

253 progress_callback = Mock() 

254 callb1.on_batch_change = progress_callback 

255 

256 infer_nodes( 

257 similarity_threshold, 

258 embedded_texts, 

259 nearest_text_indices, 

260 nearest_text_distances, 

261 progress_callbacks=[callb1], 

262 ) 

263 progress_callback.assert_called_with(2, 3, "Infering links...") 

264 

265 def test_infer_nodes_two_progress_callback(self): 

266 similarity_threshold = 0.7 

267 embedded_texts = [ 

268 "Entity==ABCDE", 

269 "Entity==PLUS_ONE", 

270 "Entity==PLUS ONE", 

271 ] 

272 nearest_text_indices = [[1, 2], [2, 1], [1, 2]] 

273 nearest_text_distances = [ 

274 [0.8, 0.1, 0.3], 

275 [0.02, 0.1, 0.3], 

276 [0.02, 0.1, 0.3], 

277 ] 

278 

279 callb1 = ProgressBatchCallback() 

280 progress_callback = Mock() 

281 callb1.on_batch_change = progress_callback 

282 

283 callb2 = ProgressBatchCallback() 

284 progress_callback2 = Mock() 

285 callb2.on_batch_change = progress_callback2 

286 

287 infer_nodes( 

288 similarity_threshold, 

289 embedded_texts, 

290 nearest_text_indices, 

291 nearest_text_distances, 

292 progress_callbacks=[callb1, callb2], 

293 ) 

294 progress_callback.assert_called_with(2, 3, "Infering links...") 

295 progress_callback2.assert_called_with(2, 3, "Infering links...") 

296 

297 

298class TestCreateInferredLinks: 

299 def test_create_links(self): 

300 inferred_links = defaultdict(set) 

301 inferred_links["Entity==PLUS ONE"].add("Entity==PLUS_ONE") 

302 inferred_links["Entity==PLUS ONE"].add("Entity==ABCDE") 

303 inferred_links["Entity==PLUS_ONE"].add("Entity==PLUS ONE") 

304 inferred_links["Entity==ABCDE"].add("Entity==PLUS ONE") 

305 

306 created_links = create_inferred_links(inferred_links) 

307 

308 expected_links = [ 

309 ("Entity==PLUS ONE", "Entity==PLUS_ONE"), 

310 ("Entity==ABCDE", "Entity==PLUS ONE"), 

311 ] 

312 assert created_links == expected_links 

313 

314 def test_create_links_return_empty(self): 

315 inferred_links = defaultdict(set) 

316 

317 created_links = create_inferred_links(inferred_links) 

318 

319 assert created_links == [] 

320 

321 created_links = create_inferred_links(inferred_links) 

322 

323 assert created_links == [] 

324 

325 

326class TestIndexAndInfer: 

327 @pytest.fixture() 

328 def overall_graph(self): 

329 G = nx.Graph() 

330 # Adding more nodes and edges to the graph 

331 for i in range(1, 31): 

332 G.add_node(f"Entity{i}", type=f"Type{chr(65 + (i % 4))}") 

333 for i in range(1, 31, 2): 

334 G.add_edge(f"Entity{i}", f"Entity{i + 1}") 

335 return G 

336 

337 async def test_empty_graph(self): 

338 indexed_node_types = ["TypeA", "TypeB", "TypeC"] 

339 with pytest.raises(ValueError, match="Graph is empty"): 

340 await index_and_infer(indexed_node_types, nx.Graph(), 0) 

341 

342 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.index_nodes") 

343 @patch("intelligence_toolkit.detect_entity_networks.index_and_infer.infer_nodes") 

344 async def test_index_and_infer( 

345 self, mock_infer_nodes, mock_index_nodes, overall_graph 

346 ) -> None: 

347 indexed_node_types = ["TypeA", "TypeB", "TypeC"] 

348 embedded_texts = ["Entity1", "Entity2", "Entity3"] 

349 nearest_text_distances = [[0.1, 0.3], [0.3, 0.4]] 

350 nearest_text_indices = [[0, 1], [1, 0]] 

351 

352 mock_index_nodes.return_value = ( 

353 embedded_texts, 

354 nearest_text_distances, 

355 nearest_text_indices, 

356 ) 

357 

358 inferred_links = defaultdict(set) 

359 inferred_links["Entity1"].add("Entity2") 

360 mock_infer_nodes.return_value = inferred_links 

361 

362 link_list, _ = await index_and_infer(indexed_node_types, overall_graph, 0.5) 

363 

364 assert link_list == inferred_links 

365 

366 mock_index_nodes.assert_called_once_with( 

367 indexed_node_types, overall_graph, None, None, None, True 

368 ) 

369 mock_infer_nodes.assert_called_once_with( 

370 0.5, embedded_texts, nearest_text_indices, nearest_text_distances, None 

371 )