Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_explore_networks.py: 100%

288 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5import networkx as nx 

6import polars as pl 

7import pytest 

8from networkx import Graph 

9 

10from intelligence_toolkit.detect_entity_networks.explore_networks import ( 

11 _build_fuzzy_neighbors, 

12 _integrate_flags, 

13 _merge_condition, 

14 _merge_node_list, 

15 _merge_nodes, 

16 get_entity_graph, 

17 get_type_color, 

18 hsl_to_hex, 

19 simplify_entities_graph, 

20) 

21 

22 

23@pytest.fixture() 

24def graph() -> Graph: 

25 G = nx.Graph() 

26 G.add_node("A", type="TypeA", flags=1) 

27 

28 G.add_node("B", type="TypeB", flags=0) 

29 G.add_node("C", type="TypeC", flags=1) 

30 G.add_node("D", type="TypeC", flags=1) 

31 G.add_node("E", type="TypeC", flags=0) 

32 G.add_node("F", type="TypeC", flags=1) 

33 G.add_edge("A", "B") 

34 G.add_edge("B", "F") 

35 G.add_edge("B", "C") 

36 G.add_edge("E", "F") 

37 G.add_edge("D", "F") 

38 G.add_edge("A", "F") 

39 return G 

40 

41 

42@pytest.fixture() 

43def simple_graph() -> Graph: 

44 G = nx.Graph() 

45 G.add_node("ENTITY==1") 

46 G.add_node("ENTITY==2") 

47 G.add_node("ENTITY==3") 

48 G.add_node("ENTITY==4") 

49 G.add_node("ENTITY==5") 

50 G.add_node("Attr==Type1") 

51 G.add_node("Attr==Type2") 

52 G.add_node("Attr==Type35") 

53 G.add_node("AttributeABCD==Type35") 

54 G.add_node("AttributeABCD==Type37") 

55 G.add_node("AttributeABCD==Type38") 

56 G.add_node("AttributeABCD==Type47") 

57 

58 G.add_edge("ENTITY==1", "ENTITY==2") 

59 G.add_edge("Attr==Type1", "ENTITY==1") 

60 G.add_edge("AttributeABCD==Type37", "Attr==Type1") 

61 G.add_edge("AttributeABCD==Type47", "Attr==Type2") 

62 G.add_edge("ENTITY==3", "Attr==Type1") 

63 G.add_edge("ENTITY==3", "AttributeABCD==Type35") 

64 G.add_edge("AttributeABCD==Type35", "AttributeABCD==Type38") 

65 G.add_edge("ENTITY==3", "ENTITY==5") 

66 return G 

67 

68 

69class TestFuzzyNeighbors: 

70 @pytest.fixture() 

71 def existing_network_graph(self): 

72 G = nx.Graph() 

73 G.add_node("ENTITY==47") 

74 G.add_node("Attr==Type108") 

75 G.add_node("Attr==Type222") 

76 

77 G.add_edge("Attr==Type108", "ENTITY==47") 

78 G.add_edge("Attr==Type108", "Attr==Type222") 

79 

80 return G 

81 

82 def test_empty_graph(self): 

83 result = _build_fuzzy_neighbors( 

84 nx.Graph(), nx.Graph(), "ENTITY==1", set(), dict 

85 ) 

86 assert result.nodes == nx.Graph().nodes 

87 

88 def test_node_inexistent(self, simple_graph): 

89 with pytest.raises(ValueError, match="Node ENTITY==78 not in graph"): 

90 _build_fuzzy_neighbors(simple_graph, nx.Graph(), "ENTITY==78", set(), dict) 

91 

92 def test_fuzzy_neighbor_basic(self, simple_graph): 

93 network_graph = nx.Graph() 

94 att_neighbor = "Attr==Type1" 

95 trimmed_nodeset = set() 

96 inferred_links = {} 

97 result = _build_fuzzy_neighbors( 

98 simple_graph, network_graph, att_neighbor, trimmed_nodeset, inferred_links 

99 ) 

100 

101 assert list(result.nodes()) == ["AttributeABCD==Type37", "Attr==Type1"] 

102 assert list(result.edges()) == [("AttributeABCD==Type37", "Attr==Type1")] 

103 

104 def test_fuzzy_neighbor_inferred(self, simple_graph): 

105 network_graph = nx.Graph() 

106 att_neighbor = "Attr==Type1" 

107 trimmed_nodeset = set() 

108 inf = set() 

109 inf.add("AttributeABCD==Type47") 

110 inferred_links = {"Attr==Type1": inf} 

111 result = _build_fuzzy_neighbors( 

112 simple_graph, network_graph, att_neighbor, trimmed_nodeset, inferred_links 

113 ) 

114 

115 assert len(result.edges()) == 2 

116 assert ("AttributeABCD==Type47", "Attr==Type1") in result.edges() 

117 assert ("AttributeABCD==Type37", "Attr==Type1") in result.edges() 

118 

119 def test_fuzzy_neighbor_trimmed(self, simple_graph): 

120 network_graph = nx.Graph() 

121 att_neighbor = "AttributeABCD==Type35" 

122 trimmed_nodeset = set() 

123 trimmed_nodeset.add("AttributeABCD==Type38") 

124 inferred_links = {} 

125 result = _build_fuzzy_neighbors( 

126 simple_graph, network_graph, att_neighbor, trimmed_nodeset, inferred_links 

127 ) 

128 

129 assert len(result.nodes()) == 0 

130 assert len(result.edges()) == 0 

131 

132 def test_fuzzy_neighbor_graph_existent(self, simple_graph, existing_network_graph): 

133 att_neighbor = "Attr==Type1" 

134 trimmed_nodeset = set() 

135 inferred_links = {} 

136 result = _build_fuzzy_neighbors( 

137 simple_graph, 

138 existing_network_graph, 

139 att_neighbor, 

140 trimmed_nodeset, 

141 inferred_links, 

142 ) 

143 

144 assert len(result.edges()) == 3 

145 assert len(result.nodes()) == 5 

146 assert ("AttributeABCD==Type37", "Attr==Type1") in result.edges() 

147 assert ("Attr==Type108", "Attr==Type222") in result.edges() 

148 

149 

150class TestIntegrateFlags: 

151 @pytest.fixture() 

152 def graph_flags(self): 

153 G = nx.Graph() 

154 G.add_node("A") 

155 

156 G.add_node("B") 

157 G.add_node("C") 

158 G.add_node("D") 

159 G.add_node("E") 

160 G.add_node("F") 

161 return G 

162 

163 def test_empty_graph(self): 

164 result = _integrate_flags(nx.Graph(), pl.DataFrame()) 

165 

166 assert len(result.nodes()) == 0 

167 

168 def test_empty_flags(self, graph_flags): 

169 result = _integrate_flags(graph_flags, pl.DataFrame()) 

170 

171 assert len(result.nodes()) == 0 

172 

173 def test_integration(self, graph_flags): 

174 flags = pl.DataFrame( 

175 { 

176 "qualified_entity": ["A", "C", "D", "F"], 

177 "count": [1, 2, 3, 0], 

178 } 

179 ) 

180 

181 result = _integrate_flags(graph_flags, flags) 

182 

183 assert result.nodes["A"]["flags"] == 1 

184 assert result.nodes["C"]["flags"] == 2 

185 assert result.nodes["D"]["flags"] == 3 

186 assert "flags" not in result.nodes["B"] 

187 assert "flags" not in result.nodes["E"] 

188 assert "flags" not in result.nodes["F"] 

189 

190 def test_sum(self, graph_flags): 

191 flags = pl.DataFrame( 

192 { 

193 "qualified_entity": ["A", "C", "D", "F", "A"], 

194 "count": [1, 2, 3, 0, 5], 

195 } 

196 ) 

197 

198 result = _integrate_flags(graph_flags, flags) 

199 

200 assert result.nodes["A"]["flags"] == 6 

201 assert result.nodes["C"]["flags"] == 2 

202 assert result.nodes["D"]["flags"] == 3 

203 assert "flags" not in result.nodes["B"] 

204 assert "flags" not in result.nodes["E"] 

205 assert "flags" not in result.nodes["F"] 

206 

207 def test_node_not_in_graph(self, graph_flags): 

208 flags = pl.DataFrame( 

209 { 

210 "qualified_entity": ["A", "C", "D", "F", "Z"], 

211 "count": [1, 2, 3, 0, 3], 

212 } 

213 ) 

214 

215 result = _integrate_flags(graph_flags, flags) 

216 

217 assert result.nodes["A"]["flags"] == 1 

218 assert result.nodes["C"]["flags"] == 2 

219 assert result.nodes["D"]["flags"] == 3 

220 assert "flags" not in result.nodes["B"] 

221 assert "flags" not in result.nodes["E"] 

222 assert "flags" not in result.nodes["F"] 

223 assert "Z" not in result.nodes() 

224 assert "flags" not in result.nodes["B"] 

225 assert "flags" not in result.nodes["E"] 

226 assert "flags" not in result.nodes["F"] 

227 assert "Z" not in result.nodes() 

228 

229 

230class TestMergeCondition: 

231 def test_merge(self) -> None: 

232 x = "A==1;B==2;C==3" 

233 y = "D==4;E==1;F==5" 

234 assert _merge_condition(x, y) is True 

235 

236 def test_merge_common_attr(self) -> None: 

237 x = "A==1;B==2;C==3" 

238 y = "D==4;E==5;F==2" 

239 assert _merge_condition(x, y) is True 

240 

241 def test_merge_not_common(self) -> None: 

242 x = "A==1;B==2;C==3" 

243 y = "D==4;E==5;F==6" 

244 assert _merge_condition(x, y) is False 

245 

246 def test_merge_multiple_common(self) -> None: 

247 x = "A==1;B==2;C==3" 

248 y = "D==4;E==1;F==2" 

249 assert _merge_condition(x, y) is True 

250 

251 

252class TestMergeNodeList: 

253 def test_merge_node_list(self, graph): 

254 merge_list = ["A", "C"] 

255 merged_graph = _merge_node_list(graph, merge_list) 

256 

257 assert merged_graph.has_node("A;C") 

258 assert merged_graph.has_node("B") 

259 assert merged_graph.nodes["A;C"]["type"] == "TypeA;TypeC" 

260 assert merged_graph.nodes["A;C"]["flags"] == 1 

261 assert merged_graph.has_edge("B", "A;C") 

262 assert not merged_graph.has_node("A") 

263 assert not merged_graph.has_node("C") 

264 assert not merged_graph.has_edge("A", "B") 

265 assert not merged_graph.has_edge("B", "C") 

266 assert not merged_graph.has_edge("A", "B") 

267 assert not merged_graph.has_edge("B", "C") 

268 assert not merged_graph.has_edge("B", "C") 

269 

270 

271class TestMergeNodes: 

272 def test_merge_all_nodes(self, graph): 

273 merged_graph = _merge_nodes(graph, lambda _x, _y: True) 

274 assert len(merged_graph.nodes()) < len(graph.nodes()) 

275 

276 def test_merge_no_nodes(self, graph): 

277 merged_graph = _merge_nodes(graph, lambda _x, _y: False) 

278 assert len(merged_graph.nodes()) == len(graph.nodes()) 

279 

280 def test_empty_graph(self): 

281 G = nx.Graph() 

282 merged_graph = _merge_nodes(G, lambda _x, _y: True) 

283 assert len(merged_graph.nodes()) == 0 

284 

285 

286class TestSimplifyEntitiesGraph: 

287 def test_simplify_condition_false(self, mocker, graph): 

288 mocker.patch( 

289 "intelligence_toolkit.detect_entity_networks.explore_networks._merge_nodes" 

290 ).return_value = graph 

291 

292 aba = simplify_entities_graph(graph) 

293 assert len(aba.nodes()) == 3 

294 

295 def test_simplify_condition_true(self, mocker, graph): 

296 G = nx.Graph() 

297 mocker.patch( 

298 "intelligence_toolkit.detect_entity_networks.explore_networks._merge_nodes" 

299 ).return_value = G 

300 

301 aba = simplify_entities_graph(graph) 

302 assert len(aba.nodes()) == 0 

303 assert len(aba.nodes()) == 0 

304 

305 

306class TestHslToHex: 

307 def test_colors(self): 

308 assert hsl_to_hex(0, 0, 0) == "#000000" 

309 assert hsl_to_hex(0, 0, 100) == "#ffffff" 

310 assert hsl_to_hex(0, 100, 50) == "#ff0000" 

311 assert hsl_to_hex(120, 100, 50) == "#00ff00" 

312 assert hsl_to_hex(240, 100, 50) == "#0000ff" 

313 

314 

315class TestGetTypeColor: 

316 @pytest.fixture() 

317 def attribute_types(self) -> list[str]: 

318 return ["Type1", "Type2", "Type3"] 

319 

320 def test_not_flagged(self, attribute_types) -> None: 

321 node_type = "Type1" 

322 is_flagged = False 

323 result = get_type_color(node_type, is_flagged, attribute_types) 

324 

325 assert result == "#a8b4ef" 

326 

327 def test_flagged(self, attribute_types) -> None: 

328 node_type = "Type1" 

329 is_flagged = True 

330 result = get_type_color(node_type, is_flagged, attribute_types) 

331 

332 assert result == "#efa8a8" 

333 

334 def test_type_2_3_types(self, attribute_types) -> None: 

335 node_type = "Type2" 

336 is_flagged = False 

337 result = get_type_color(node_type, is_flagged, attribute_types) 

338 

339 assert result == "#efd3a8" 

340 

341 def test_type_2_2_types(self, attribute_types) -> None: 

342 node_type = "Type2" 

343 is_flagged = False 

344 attribute_types = attribute_types[:-1] 

345 result = get_type_color(node_type, is_flagged, attribute_types) 

346 

347 assert result == "#d1efa8" 

348 

349 

350class TestGetEntityGraph: 

351 @pytest.fixture() 

352 def simple_graph(self) -> Graph: 

353 G = nx.Graph() 

354 G.add_node("ENTITY==1") 

355 G.add_node("ENTITY==2") 

356 G.add_node("ENTITY==3") 

357 G.add_node("Attr==Type1") 

358 G.add_node("Attr==Type2") 

359 G.add_node("AttributeABCD==Type35") 

360 G.add_node("AttributeABCD==Type37") 

361 

362 G.add_edge("ENTITY==1", "ENTITY==2") 

363 G.add_edge("Attr==Type1", "ENTITY==1") 

364 G.add_edge("AttributeABCD==Type37", "Attr==Type1") 

365 G.add_edge("ENTITY==3", "Attr==Type1") 

366 G.add_edge("ENTITY==3", "AttributeABCD==Type35") 

367 return G 

368 

369 def test_empty(self) -> None: 

370 G = nx.Graph() 

371 selected = "" 

372 attribute_types = [] 

373 

374 nodes, edges = get_entity_graph(G, selected, attribute_types) 

375 assert len(nodes) == 0 

376 assert len(edges) == 0 

377 

378 def test_none_selected_nodes(self, simple_graph) -> None: 

379 attribute_types = ["ENTITY", "AttributeABCD", "Attr"] 

380 selected = "" 

381 nodes, _ = get_entity_graph(simple_graph, selected, attribute_types) 

382 

383 expected_nodes = [ 

384 { 

385 "title": "ENTITY==1\nFlags: 0", 

386 "id": "ENTITY==1", 

387 "label": "1\n(ENTITY)", 

388 "size": 12, 

389 "color": "#a8b4ef", 

390 "font": {"vadjust": -22, "size": 5}, 

391 }, 

392 { 

393 "title": "ENTITY==3\nFlags: 0", 

394 "id": "ENTITY==3", 

395 "label": "3\n(ENTITY)", 

396 "size": 12, 

397 "color": "#a8b4ef", 

398 "font": {"vadjust": -22, "size": 5}, 

399 }, 

400 { 

401 "title": "AttributeABCD==Type37\nFlags: 0", 

402 "id": "AttributeABCD==Type37", 

403 "label": "Type37\n(AttributeABCD)", 

404 "size": 8, 

405 "color": "#efd3a8", 

406 "font": {"vadjust": -18, "size": 5}, 

407 }, 

408 { 

409 "title": "ENTITY==2\nFlags: 0", 

410 "id": "ENTITY==2", 

411 "label": "2\n(ENTITY)", 

412 "size": 12, 

413 "color": "#a8b4ef", 

414 "font": {"vadjust": -22, "size": 5}, 

415 }, 

416 { 

417 "title": "Attr==Type1\nFlags: 0", 

418 "id": "Attr==Type1", 

419 "label": "Type1\n(Attr)", 

420 "size": 8, 

421 "color": "#ebefa8", 

422 "font": {"vadjust": -18, "size": 5}, 

423 }, 

424 { 

425 "title": "AttributeABCD==Type35\nFlags: 0", 

426 "id": "AttributeABCD==Type35", 

427 "label": "Type35\n(AttributeABCD)", 

428 "size": 8, 

429 "color": "#efd3a8", 

430 "font": {"vadjust": -18, "size": 5}, 

431 }, 

432 ] 

433 

434 expected_nodes = sorted(expected_nodes, key=lambda x: x["id"]) 

435 nodes = sorted(nodes, key=lambda x: x["id"]) 

436 assert expected_nodes == nodes 

437 

438 def test_none_selected_edges(self, simple_graph): 

439 attribute_types = ["ENTITY", "AttributeABCD", "Attr"] 

440 selected = "" 

441 _, edges = get_entity_graph(simple_graph, selected, attribute_types) 

442 

443 expected_edges = [ 

444 { 

445 "source": "Attr==Type1", 

446 "target": "AttributeABCD==Type37", 

447 "color": "mediumgray", 

448 "size": 1, 

449 }, 

450 { 

451 "source": "ENTITY==1", 

452 "target": "Attr==Type1", 

453 "color": "mediumgray", 

454 "size": 1, 

455 }, 

456 { 

457 "source": "ENTITY==1", 

458 "target": "ENTITY==2", 

459 "color": "mediumgray", 

460 "size": 1, 

461 }, 

462 { 

463 "source": "ENTITY==3", 

464 "target": "Attr==Type1", 

465 "color": "mediumgray", 

466 "size": 1, 

467 }, 

468 { 

469 "source": "ENTITY==3", 

470 "target": "AttributeABCD==Type35", 

471 "color": "mediumgray", 

472 "size": 1, 

473 }, 

474 ] 

475 

476 edges = sorted(edges, key=lambda x: (x["source"], x["target"])) 

477 assert edges == expected_edges 

478 

479 def test_selected_nodes(self, simple_graph): 

480 attribute_types = ["ENTITY", "AttributeABCD", "Attr"] 

481 selected = "ENTITY==3" 

482 nodes, _ = get_entity_graph(simple_graph, selected, attribute_types) 

483 

484 expected_nodes = [ 

485 { 

486 "title": "ENTITY==1\nFlags: 0", 

487 "id": "ENTITY==1", 

488 "label": "1\n(ENTITY)", 

489 "size": 12, 

490 "color": "#a8b4ef", 

491 "font": {"vadjust": -22, "size": 5}, 

492 }, 

493 { 

494 "title": "ENTITY==3\nFlags: 0", 

495 "id": "ENTITY==3", 

496 "label": "3\n(ENTITY)", 

497 "size": 20, 

498 "color": "#a8b4ef", 

499 "font": {"vadjust": -30, "size": 5}, 

500 }, 

501 { 

502 "title": "AttributeABCD==Type37\nFlags: 0", 

503 "id": "AttributeABCD==Type37", 

504 "label": "Type37\n(AttributeABCD)", 

505 "size": 8, 

506 "color": "#efd3a8", 

507 "font": {"vadjust": -18, "size": 5}, 

508 }, 

509 { 

510 "title": "ENTITY==2\nFlags: 0", 

511 "id": "ENTITY==2", 

512 "label": "2\n(ENTITY)", 

513 "size": 12, 

514 "color": "#a8b4ef", 

515 "font": {"vadjust": -22, "size": 5}, 

516 }, 

517 { 

518 "title": "Attr==Type1\nFlags: 0", 

519 "id": "Attr==Type1", 

520 "label": "Type1\n(Attr)", 

521 "size": 8, 

522 "color": "#ebefa8", 

523 "font": {"vadjust": -18, "size": 5}, 

524 }, 

525 { 

526 "title": "AttributeABCD==Type35\nFlags: 0", 

527 "id": "AttributeABCD==Type35", 

528 "label": "Type35\n(AttributeABCD)", 

529 "size": 8, 

530 "color": "#efd3a8", 

531 "font": {"vadjust": -18, "size": 5}, 

532 }, 

533 ] 

534 

535 nodes = sorted(nodes, key=lambda x: x["id"]) 

536 expected_nodes = sorted(expected_nodes, key=lambda x: x["id"]) 

537 assert nodes == expected_nodes 

538 

539 def test_selected_edges(self, simple_graph): 

540 attribute_types = ["ENTITY", "AttributeABCD", "Attr"] 

541 selected = "ENTITY==3" 

542 _, edges = get_entity_graph(simple_graph, selected, attribute_types) 

543 

544 expected_edges = [ 

545 { 

546 "source": "Attr==Type1", 

547 "target": "AttributeABCD==Type37", 

548 "color": "mediumgray", 

549 "size": 1, 

550 }, 

551 { 

552 "source": "ENTITY==1", 

553 "target": "Attr==Type1", 

554 "color": "mediumgray", 

555 "size": 1, 

556 }, 

557 { 

558 "source": "ENTITY==1", 

559 "target": "ENTITY==2", 

560 "color": "mediumgray", 

561 "size": 1, 

562 }, 

563 { 

564 "source": "ENTITY==3", 

565 "target": "Attr==Type1", 

566 "color": "mediumgray", 

567 "size": 1, 

568 }, 

569 { 

570 "source": "ENTITY==3", 

571 "target": "AttributeABCD==Type35", 

572 "color": "mediumgray", 

573 "size": 1, 

574 }, 

575 ] 

576 

577 edges = sorted(edges, key=lambda x: (x["source"], x["target"])) 

578 assert edges == expected_edges