Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_identify_networks.py: 100%

400 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from collections import defaultdict 

5 

6import networkx as nx 

7import polars as pl 

8import pytest 

9 

10from intelligence_toolkit.detect_entity_networks.identify_networks import ( 

11 build_entity_records, 

12 get_community_nodes, 

13 get_entity_neighbors, 

14 get_integrated_flags, 

15 get_subgraph, 

16 neighbor_is_valid, 

17 project_entity_graph, 

18 trim_nodeset, 

19) 

20 

21 

22class TestTrimNodeset: 

23 @pytest.fixture() 

24 def overall_graph(self): 

25 G = nx.Graph() 

26 # Adding more nodes and edges to the graph 

27 for i in range(1, 10): 

28 G.add_node( 

29 f"Entity{i}", type=f"Type{chr(65 + (i % 3))}" 

30 ) # Types will be TypeA, TypeB, TypeC 

31 for i in range(1, 10, 2): 

32 G.add_edge(f"Entity{i}", f"Entity{i + 1}") 

33 G.add_edge(f"Entity{i}", f"Entity{i + 2}") 

34 return G 

35 

36 def test_trim_nodeset_additional_empty(self, overall_graph): 

37 max_attribute_degree = 1 

38 additional_trimmed_attributes = set() 

39 (trimmed_degrees, trimmed_nodes) = trim_nodeset( 

40 overall_graph, max_attribute_degree, additional_trimmed_attributes 

41 ) 

42 

43 trimmed_nodes_expected = { 

44 "Entity1", 

45 "Entity3", 

46 "Entity5", 

47 "Entity7", 

48 "Entity9", 

49 } 

50 trimmed_degrees_expected = { 

51 ("Entity1", 2), 

52 ("Entity3", 3), 

53 ("Entity5", 3), 

54 ("Entity7", 3), 

55 ("Entity9", 3), 

56 } 

57 

58 assert trimmed_nodes == trimmed_nodes_expected 

59 assert trimmed_degrees == trimmed_degrees_expected 

60 

61 def test_trim_nodeset_additional(self, overall_graph): 

62 max_attribute_degree = 1 

63 additional_trimmed_attributes = {"Entity2"} 

64 (trimmed_degrees, trimmed_nodes) = trim_nodeset( 

65 overall_graph, 

66 max_attribute_degree, 

67 additional_trimmed_attributes, 

68 ) 

69 

70 trimmed_nodes_expected = { 

71 "Entity1", 

72 "Entity2", 

73 "Entity3", 

74 "Entity5", 

75 "Entity7", 

76 "Entity9", 

77 } 

78 

79 trimmed_degrees_expected = { 

80 ("Entity1", 2), 

81 ("Entity3", 3), 

82 ("Entity5", 3), 

83 ("Entity7", 3), 

84 ("Entity9", 3), 

85 } 

86 

87 assert trimmed_nodes == trimmed_nodes_expected 

88 assert trimmed_degrees == trimmed_degrees_expected 

89 

90 

91class TestProjectEntityGraph: 

92 @pytest.fixture() 

93 def simple_graph(self): 

94 G = nx.Graph() 

95 G.add_node("ENTITY==1") 

96 G.add_node("ENTITY==2") 

97 G.add_node("ENTITY==3") 

98 G.add_node("ENTITY==4") 

99 G.add_node("ENTITY==5") 

100 G.add_node("Attr==Type1") 

101 G.add_node("Attr==Type2") 

102 G.add_node("Attr==Type35") 

103 G.add_node("AttributeABCD==Type35") 

104 

105 G.add_edge("ENTITY==1", "ENTITY==2") 

106 G.add_edge("Attr==Type1", "ENTITY==1") 

107 G.add_edge("ENTITY==3", "Attr==Type1") 

108 G.add_edge("ENTITY==3", "AttributeABCD==Type35") 

109 G.add_edge("ENTITY==3", "ENTITY==5") 

110 return G 

111 

112 @pytest.fixture() 

113 def trimmed_nodeset(self): 

114 return {"ENTITY==5"} 

115 

116 @pytest.fixture() 

117 def inferred_links(self): 

118 expected_links = defaultdict(set) 

119 # Inferred links vira um link entre eles 

120 expected_links["ENTITY==1"].add("ENTITY==5") 

121 return expected_links 

122 

123 @pytest.fixture() 

124 def inferred_links_empty(self): 

125 return defaultdict(set) 

126 

127 @pytest.fixture() 

128 def supporting_attribute_types(self): 

129 # Liga um node a um edge pelo atributo, se os dois estao ligados ao mesmo atributo 

130 return ["Attr"] # If here, aparece em edges e nodes! Se nao, nao 

131 

132 def test_empty_graph( 

133 self, trimmed_nodeset, inferred_links, supporting_attribute_types 

134 ): 

135 empty_graph = nx.Graph() 

136 projected = project_entity_graph( 

137 empty_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

138 ) 

139 assert len(projected.nodes()) == 0 

140 assert len(projected.edges()) == 0 

141 

142 def test_edges_no_inferred( 

143 self, 

144 simple_graph, 

145 trimmed_nodeset, 

146 inferred_links_empty, 

147 supporting_attribute_types, 

148 ): 

149 projected = project_entity_graph( 

150 simple_graph, 

151 trimmed_nodeset, 

152 inferred_links_empty, 

153 supporting_attribute_types, 

154 ) 

155 assert ("ENTITY==1", "ENTITY==2") in projected.edges() 

156 assert ("ENTITY==5", "ENTITY==3") in projected.edges() 

157 

158 def test_edges_no_trimmed( 

159 self, 

160 simple_graph, 

161 inferred_links_empty, 

162 supporting_attribute_types, 

163 ): 

164 projected = project_entity_graph( 

165 simple_graph, 

166 [], 

167 inferred_links_empty, 

168 supporting_attribute_types, 

169 ) 

170 assert ("ENTITY==1", "ENTITY==2") in projected.edges() 

171 assert ("ENTITY==3", "ENTITY==5") in projected.edges() 

172 

173 def test_edges_inferred_trimmed( 

174 self, simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

175 ): 

176 projected = project_entity_graph( 

177 simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

178 ) 

179 assert ("ENTITY==1", "ENTITY==2") in projected.edges() 

180 assert ("ENTITY==1", "ENTITY==5") in projected.edges() 

181 assert ("ENTITY==5", "ENTITY==3") in projected.edges() 

182 

183 def test_edges_inferred_no_trimmed( 

184 self, simple_graph, inferred_links, supporting_attribute_types 

185 ): 

186 projected = project_entity_graph( 

187 simple_graph, [], inferred_links, supporting_attribute_types 

188 ) 

189 assert ("ENTITY==1", "ENTITY==2") in projected.edges() 

190 assert ("ENTITY==1", "ENTITY==5") in projected.edges() 

191 assert ("ENTITY==5", "ENTITY==3") in projected.edges() 

192 

193 def test_nodes_no_inferred( 

194 self, 

195 simple_graph, 

196 trimmed_nodeset, 

197 inferred_links_empty, 

198 supporting_attribute_types, 

199 ): 

200 projected = project_entity_graph( 

201 simple_graph, 

202 trimmed_nodeset, 

203 inferred_links_empty, 

204 supporting_attribute_types, 

205 ) 

206 assert ("ENTITY==1") in projected.nodes() 

207 assert ("ENTITY==2") in projected.nodes() 

208 assert ("ENTITY==3") in projected.nodes() 

209 assert ("ENTITY==4") not in projected.nodes() 

210 

211 def test_nodes_inferred( 

212 self, simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

213 ): 

214 projected = project_entity_graph( 

215 simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

216 ) 

217 assert ("ENTITY==1") in projected.nodes() 

218 assert ("ENTITY==2") in projected.nodes() 

219 assert ("ENTITY==3") in projected.nodes() 

220 assert ("ENTITY==5") in projected.nodes() 

221 assert ("ENTITY==4") not in projected.nodes() 

222 

223 

224class TestValidNeighbor: 

225 @pytest.fixture() 

226 def supporting_attribute_types(self): 

227 return ["Node1"] 

228 

229 @pytest.fixture() 

230 def trimmed_nodeset(self): 

231 return ["Node2==1"] 

232 

233 @pytest.fixture() 

234 def node1(self): 

235 return "Node1==2" 

236 

237 @pytest.fixture() 

238 def node2(self): 

239 return "Node2==1" 

240 

241 def test_empty(self): 

242 result = neighbor_is_valid("", [], []) 

243 assert result is False 

244 

245 def test_is_supported(self, node1, supporting_attribute_types): 

246 result = neighbor_is_valid(node1, supporting_attribute_types, []) 

247 assert result is False 

248 

249 def test_is_not_supported(self, node2, supporting_attribute_types): 

250 result = neighbor_is_valid(node2, supporting_attribute_types, []) 

251 assert result is True 

252 

253 def test_is_trimmed(self, node2, supporting_attribute_types, trimmed_nodeset): 

254 result = neighbor_is_valid(node2, supporting_attribute_types, trimmed_nodeset) 

255 assert result is False 

256 

257 

258class TestGetEntityNeighbors: 

259 @pytest.fixture() 

260 def graph(self): 

261 G = nx.Graph() 

262 G.add_node("node0") 

263 G.add_node("node1") 

264 G.add_node("node2") 

265 G.add_node("node3") 

266 G.add_node("node4") 

267 G.add_node("node5") 

268 G.add_node("node6") 

269 G.add_node("node7") 

270 G.add_edge("node0", "node1") 

271 G.add_edge("node1", "node2") 

272 G.add_edge("node1", "node1") 

273 G.add_edge("node1", "node3") 

274 G.add_edge("node2", "node3") 

275 G.add_edge("node3", "node4") 

276 G.add_edge("node4", "node5") 

277 return G 

278 

279 def test_empty_graph(self): 

280 result = get_entity_neighbors(nx.Graph(), [], [], "") 

281 assert result == [] 

282 

283 def test_node_not_int_graph(self, graph): 

284 with pytest.raises( 

285 ValueError, 

286 match="Node node77 not in graph", 

287 ): 

288 get_entity_neighbors(graph, [], [], "node77") 

289 

290 def test_no_inferred(self, graph): 

291 result = get_entity_neighbors(graph, [], [], "node5") 

292 assert result == ["node4"] 

293 

294 def test_inferred(self, graph): 

295 inferred_links = defaultdict(set) 

296 inferred_links["node5"].add("node2") 

297 inferred_links["node12"].add("node127") 

298 result = get_entity_neighbors(graph, inferred_links, [], "node5") 

299 assert result == ["node2", "node4"] 

300 

301 def test_trimmed(self, graph): 

302 trimmed = ["node2"] 

303 result = get_entity_neighbors(graph, [], trimmed, "node1") 

304 assert result == ["node0", "node3"] 

305 

306 def test_node_equals(self, graph): 

307 result = get_entity_neighbors(graph, [], [], "node1") 

308 assert result == ["node0", "node2", "node3"] 

309 

310 def test_inferred_mixed(self, graph) -> None: 

311 inferred_links = defaultdict(set) 

312 inferred_links["node5"].add("node2") 

313 inferred_links["node12"].add("node127") 

314 result = get_entity_neighbors(graph, inferred_links, [], "node2") 

315 assert result == ["node1", "node3", "node5"] 

316 

317 def test_inferred_mixed_contrary(self, graph) -> None: 

318 inferred_links = defaultdict(set) 

319 inferred_links["node5"].add("node2") 

320 result = get_entity_neighbors(graph, inferred_links, [], "node5") 

321 assert result == ["node2", "node4"] 

322 

323 def test_inferred_mixed_multiple(self, graph) -> None: 

324 inferred_links = defaultdict(set) 

325 inferred_links["node5"].add("node2") 

326 inferred_links["node7"].add("node2") 

327 inferred_links["node12"].add("node127") 

328 result = get_entity_neighbors(graph, inferred_links, [], "node2") 

329 assert result == ["node1", "node3", "node5", "node7"] 

330 

331 

332class TestEntityGraph: 

333 # Test cases 

334 @pytest.fixture() 

335 def sample_graph(self): 

336 G = nx.Graph() 

337 G.add_nodes_from( 

338 [ 

339 ("ENTITY==entity_1"), 

340 ("ENTITY==entity_2"), 

341 ("ENTITY==entity_3"), 

342 ("ENTITY==entity_5"), 

343 ("attr==att_1"), 

344 ("attr==att_2"), 

345 ("attr==att_3"), 

346 ("attr==att_4"), 

347 ] 

348 ) 

349 

350 G.add_edges_from( 

351 [ 

352 ("ENTITY==entity_1", "attr==att_1"), 

353 ("ENTITY==entity_5", "ENTITY==entity_2"), 

354 ("ENTITY==entity_1", "attr==att_2"), 

355 ("ENTITY==entity_2", "attr==att_2"), 

356 ("ENTITY==entity_2", "attr==att_3"), 

357 ("ENTITY==entity_3", "attr==att_4"), 

358 ] 

359 ) 

360 return G 

361 

362 def test_project_entity_graph_empty_graph(self): 

363 G = nx.Graph() 

364 trimmed_nodeset = set() 

365 inferred_links = {} 

366 supporting_attribute_types = [] 

367 P = project_entity_graph( 

368 G, trimmed_nodeset, inferred_links, supporting_attribute_types 

369 ) 

370 assert len(P.nodes) == 0 

371 

372 def test_project_entity_graph_basic(self, sample_graph): 

373 trimmed_nodeset = set() 

374 inferred_links = {} 

375 supporting_attribute_types = [] 

376 

377 P = project_entity_graph( 

378 sample_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

379 ) 

380 expected = [ 

381 ("ENTITY==entity_1", "ENTITY==entity_2"), 

382 ("ENTITY==entity_2", "ENTITY==entity_5"), 

383 ] 

384 assert list(P.edges()) == expected 

385 

386 def test_project_entity_graph_trimmed(self, sample_graph): 

387 trimmed_nodeset = set() 

388 trimmed_nodeset.add("ENTITY==entity_5") 

389 inferred_links = {} 

390 supporting_attribute_types = [] 

391 

392 P = project_entity_graph( 

393 sample_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

394 ) 

395 expected = [ 

396 ("ENTITY==entity_1", "ENTITY==entity_2"), 

397 ("ENTITY==entity_2", "ENTITY==entity_5"), 

398 ] 

399 assert list(P.edges()) == expected 

400 

401 def test_project_entity_graph_inferred(self, sample_graph): 

402 trimmed_nodeset = set() 

403 inferred_links = defaultdict(set) 

404 inferred_links["ENTITY==entity_1"].add("ENTITY==entity_3") 

405 supporting_attribute_types = [] 

406 

407 P = project_entity_graph( 

408 sample_graph, trimmed_nodeset, inferred_links, supporting_attribute_types 

409 ) 

410 expected = [ 

411 ("ENTITY==entity_1", "ENTITY==entity_3"), 

412 ("ENTITY==entity_1", "ENTITY==entity_2"), 

413 ("ENTITY==entity_2", "ENTITY==entity_5"), 

414 ] 

415 assert list(P.edges()) == expected 

416 

417 

418class TestSubgraph: 

419 def test_basic_functionality(self): 

420 graph = nx.Graph() 

421 graph.add_edges_from([(1, 2), (2, 3), (3, 4)]) 

422 nodes = [1, 2, 3] 

423 community_nodes, entity_to_community = get_subgraph(graph, nodes) 

424 assert len(community_nodes) == 1 

425 assert set(community_nodes[0]) == set(nodes) 

426 assert entity_to_community == {1: 0, 2: 0, 3: 0} 

427 

428 def test_disconnected_components(self): 

429 graph = nx.Graph() 

430 graph.add_edges_from([(1, 2), (3, 4)]) 

431 nodes = [1, 2, 3, 4] 

432 community_nodes, entity_to_community = get_subgraph(graph, nodes) 

433 assert len(community_nodes) == 2 

434 assert set(community_nodes[0]) == {1, 2} or set(community_nodes[0]) == {3, 4} 

435 assert set(community_nodes[1]) == {1, 2} or set(community_nodes[1]) == {3, 4} 

436 assert len(entity_to_community) == 4 

437 

438 def test_max_network_entities(self): 

439 graph = nx.Graph() 

440 graph.add_edges_from( 

441 [ 

442 (1, 2), 

443 (3, 4), 

444 (7, 9), 

445 (8, 90), 

446 (54, 66), 

447 (66, 44), 

448 (66, 1), 

449 (66, 2), 

450 ] 

451 ) 

452 nodes = [1, 2, 3, 4, 66, 54, 89] 

453 max_network_entities = 2 

454 community_nodes, _ = get_subgraph( 

455 graph, nodes, max_network_entities=max_network_entities 

456 ) 

457 

458 for community in community_nodes: 

459 assert len(community) <= max_network_entities 

460 

461 def test_max_network_entities_size_high(self): 

462 graph = nx.Graph() 

463 graph.add_edges_from( 

464 [ 

465 (1, 2), 

466 (3, 4), 

467 (7, 9), 

468 (8, 90), 

469 (54, 66), 

470 (66, 44), 

471 (66, 1), 

472 (66, 2), 

473 ] 

474 ) 

475 nodes = [1, 2, 3, 4, 66, 54, 89] 

476 max_network_entities = 10 

477 community_nodes, _ = get_subgraph( 

478 graph, nodes, max_network_entities=max_network_entities 

479 ) 

480 

481 for community in community_nodes: 

482 assert len(community) <= max_network_entities 

483 

484 def test_graph_with_weights(self): 

485 graph = nx.Graph() 

486 graph.add_edge(1, 2, weight=1.0) 

487 graph.add_edge(2, 3, weight=2.0) 

488 graph.add_edge(3, 4, weight=3.0) 

489 nodes = [1, 2, 3, 4] 

490 community_nodes, entity_to_community = get_subgraph(graph, nodes) 

491 assert len(community_nodes) > 0 

492 assert len(entity_to_community) == 4 

493 

494 def test_empty_node_list(self): 

495 graph = nx.Graph() 

496 graph.add_edges_from([(1, 2), (2, 3)]) 

497 nodes = [] 

498 community_nodes, entity_to_community = get_subgraph(graph, nodes) 

499 assert community_nodes == [] 

500 assert entity_to_community == {} 

501 

502 def test_empty_graph(self): 

503 graph = nx.Graph() 

504 nodes = [1, 2, 3] 

505 community_nodes, entity_to_community = get_subgraph(graph, nodes) 

506 assert community_nodes == [] 

507 assert entity_to_community == {} 

508 

509 

510class TestNodes: 

511 def test_nodes(self): 

512 G = nx.Graph() 

513 nx.add_path(G, ["node_A", "node_B", "node_C", "node_E"]) 

514 nx.add_path(G, ["node_V", "node_X"]) 

515 nx.add_path(G, ["node_D", "node_P"]) 

516 nx.add_path(G, ["node_D", "node_C"]) 

517 

518 result = get_community_nodes(G, 10) 

519 

520 assert result == ( 

521 [ 

522 { 

523 "node_A", 

524 "node_B", 

525 "node_C", 

526 "node_D", 

527 "node_E", 

528 "node_P", 

529 }, 

530 { 

531 "node_V", 

532 "node_X", 

533 }, 

534 ], 

535 { 

536 "node_A": 0, 

537 "node_B": 0, 

538 "node_C": 0, 

539 "node_D": 0, 

540 "node_E": 0, 

541 "node_P": 0, 

542 "node_V": 1, 

543 "node_X": 1, 

544 }, 

545 ) 

546 

547 def test_max_size(self): 

548 G = nx.Graph() 

549 nx.add_path(G, ["node_A", "node_B", "node_C", "node_E"]) 

550 nx.add_path(G, ["node_V", "node_X"]) 

551 nx.add_path(G, ["node_D", "node_P"]) 

552 nx.add_path(G, ["node_D", "node_C"]) 

553 

554 result = get_community_nodes(G, 2) 

555 

556 expected_communities = [ 

557 { 

558 "node_B", 

559 "node_A", 

560 }, 

561 { 

562 "node_E", 

563 "node_C", 

564 }, 

565 { 

566 "node_D", 

567 "node_P", 

568 }, 

569 { 

570 "node_V", 

571 "node_X", 

572 }, 

573 ] 

574 

575 expected_entity_to_community = { 

576 "node_A": 0, 

577 "node_B": 0, 

578 "node_C": 1, 

579 "node_D": 2, 

580 "node_E": 1, 

581 "node_P": 2, 

582 "node_V": 3, 

583 "node_X": 3, 

584 } 

585 

586 result_communities = [set(community) for community in result[0]] 

587 

588 assert len(result_communities) == len(expected_communities) 

589 for community in expected_communities: 

590 assert community in result_communities 

591 

592 assert result[1] == expected_entity_to_community 

593 

594 

595class TestIntegratedFlags: 

596 @pytest.fixture() 

597 def qualified_entities(self) -> list[str]: 

598 return ["ENTITY==1", "ENTITY==2", "ENTITY==3"] 

599 

600 @pytest.fixture() 

601 def integrated_flags(self, qualified_entities) -> pl.DataFrame: 

602 return pl.DataFrame( 

603 { 

604 "qualified_entity": qualified_entities, 

605 "count": [1, 0, 3], 

606 } 

607 ) 

608 

609 def test_empty_integrated_flags(self) -> None: 

610 integrated_flags = pl.DataFrame() 

611 entities = [] 

612 result = get_integrated_flags(integrated_flags, entities) 

613 assert result == (0, 0, 0, 0, 0) 

614 

615 def test_no_entities(self, integrated_flags): 

616 entities = [] 

617 result = get_integrated_flags(integrated_flags, entities) 

618 assert result == (0, 0, 0, 0, 0) 

619 

620 def test_base_entities(self, integrated_flags, qualified_entities): 

621 ( 

622 community_flags, 

623 flagged, 

624 flagged_per_unflagged, 

625 flags_per_entity, 

626 total_entities, 

627 ) = get_integrated_flags(integrated_flags, qualified_entities) 

628 

629 assert flags_per_entity == 1.33 

630 assert total_entities == 3 

631 assert flagged_per_unflagged == 2 

632 assert flagged == 2 

633 assert community_flags == 4 

634 

635 def test_inferred_links(self, integrated_flags, qualified_entities) -> None: 

636 qualified_entities.append("ENTITY==5") 

637 inferred_links = defaultdict(set) 

638 inferred_links["ENTITY==1"].add("ENTITY==5") 

639 

640 integrated_flags = integrated_flags.vstack( 

641 pl.DataFrame({"qualified_entity": ["ENTITY==5"], "count": [0]}) 

642 ) 

643 

644 ( 

645 community_flags, 

646 flagged, 

647 flagged_per_unflagged, 

648 flags_per_entity, 

649 total_entities, 

650 ) = get_integrated_flags(integrated_flags, qualified_entities, inferred_links) 

651 

652 assert flags_per_entity == 1.0 

653 assert total_entities == 4 

654 assert flagged_per_unflagged == 1.0 

655 assert flagged == 2.0 

656 assert community_flags == 4 

657 

658 def test_inferred_links_flagged(self, integrated_flags, qualified_entities) -> None: 

659 qualified_entities.append("ENTITY==5") 

660 inferred_links = defaultdict(set) 

661 inferred_links["ENTITY==1"].add("ENTITY==5") 

662 

663 integrated_flags = integrated_flags.vstack( 

664 pl.DataFrame({"qualified_entity": ["ENTITY==5"], "count": [1]}) 

665 ) 

666 

667 ( 

668 community_flags, 

669 flagged, 

670 flagged_per_unflagged, 

671 flags_per_entity, 

672 total_entities, 

673 ) = get_integrated_flags(integrated_flags, qualified_entities, inferred_links) 

674 

675 assert flags_per_entity == 1.25 

676 assert total_entities == 4 

677 assert flagged_per_unflagged == 3.0 

678 assert flagged == 3.0 

679 assert community_flags == 5 

680 

681 

682class TestBuildEntityRecords: 

683 @pytest.fixture() 

684 def community_nodes(self): 

685 return [ 

686 ["ENTITY==1", "ENTITY==2", "ENTITY==3"], 

687 ["ENTITY==4", "ENTITY==5", "ENTITY==6"], 

688 ] 

689 

690 @pytest.fixture() 

691 def integrated_flags(self): 

692 return pl.DataFrame( 

693 { 

694 "qualified_entity": ["ENTITY==1", "ENTITY==2", "ENTITY==3"], 

695 "count": [1, 0, 3], 

696 } 

697 ) 

698 

699 @pytest.fixture() 

700 def community_nodes_multiple(self) -> list[list[str]]: 

701 return [ 

702 ["ENTITY==1", "ENTITY==2", "ENTITY==3"], 

703 ["ENTITY==4", "ENTITY==5", "ENTITY==6"], 

704 ["ENTITY==8", "ENTITY==9", "ENTITY==11"], 

705 ] 

706 

707 @pytest.fixture() 

708 def integrated_flags_multiple(self) -> pl.DataFrame: 

709 return pl.DataFrame( 

710 { 

711 "qualified_entity": [ 

712 "ENTITY==1", 

713 "ENTITY==2", 

714 "ENTITY==3", 

715 "ENTITY==11", 

716 "ENTITY==9", 

717 ], 

718 "count": [1, 0, 3, 2, 5], 

719 } 

720 ) 

721 

722 def test_final_integrated(self, community_nodes, integrated_flags): 

723 result = build_entity_records(community_nodes, integrated_flags) 

724 

725 expected = [ 

726 ("1", 1, 0, 3, 4, 2, 1.33, 2.0), 

727 ("2", 0, 0, 3, 4, 2, 1.33, 2.0), 

728 ("3", 3, 0, 3, 4, 2, 1.33, 2.0), 

729 ("4", 0, 1, 3, 0, 0, 0.0, 0.0), 

730 ("5", 0, 1, 3, 0, 0, 0.0, 0.0), 

731 ("6", 0, 1, 3, 0, 0, 0.0, 0.0), 

732 ] 

733 

734 assert result == expected 

735 

736 def test_final_count_inferred(self, community_nodes, integrated_flags) -> None: 

737 inferred_links = defaultdict(set) 

738 inferred_links["ENTITY==1"].add("ENTITY==5") 

739 result = build_entity_records(community_nodes, integrated_flags, inferred_links) 

740 

741 expected = [ 

742 ("1", 1, 0, 4, 4, 2, 1.0, 1.0), 

743 ("2", 0, 0, 4, 4, 2, 1.0, 1.0), 

744 ("3", 3, 0, 4, 4, 2, 1.0, 1.0), 

745 ("4", 0, 1, 4, 1, 1, 0.25, 0.33), 

746 ("5", 0, 1, 4, 1, 1, 0.25, 0.33), 

747 ("6", 0, 1, 4, 1, 1, 0.25, 0.33), 

748 ] 

749 assert result == expected 

750 

751 def test_final_count_inferred_existant( 

752 self, community_nodes, integrated_flags 

753 ) -> None: 

754 inferred_links = defaultdict(set) 

755 inferred_links["ENTITY==1"].add("ENTITY==2") 

756 result = build_entity_records(community_nodes, integrated_flags, inferred_links) 

757 

758 expected = [ 

759 ("1", 1, 0, 3, 4, 2, 1.33, 2.0), 

760 ("2", 0, 0, 3, 4, 2, 1.33, 2.0), 

761 ("3", 3, 0, 3, 4, 2, 1.33, 2.0), 

762 ("4", 0, 1, 3, 0, 0, 0.0, 0.0), 

763 ("5", 0, 1, 3, 0, 0, 0.0, 0.0), 

764 ("6", 0, 1, 3, 0, 0, 0.0, 0.0), 

765 ] 

766 

767 assert result == expected 

768 

769 def test_final_count_inferred_with_flags( 

770 self, community_nodes, integrated_flags 

771 ) -> None: 

772 inferred_links = defaultdict(set) 

773 inferred_links["ENTITY==3"].add("ENTITY==6") 

774 result = build_entity_records(community_nodes, integrated_flags, inferred_links) 

775 

776 expected = [ 

777 ("1", 1, 0, 4, 4, 2, 1.0, 1.0), 

778 ("2", 0, 0, 4, 4, 2, 1.0, 1.0), 

779 ("3", 3, 0, 4, 4, 2, 1.0, 1.0), 

780 ("4", 0, 1, 4, 3, 1, 0.75, 0.33), 

781 ("5", 0, 1, 4, 3, 1, 0.75, 0.33), 

782 ("6", 0, 1, 4, 3, 1, 0.75, 0.33), 

783 ] 

784 

785 assert result == expected 

786 

787 def test_final_count_inferred_both_with_flags( 

788 self, community_nodes, integrated_flags 

789 ) -> None: 

790 inferred_links = defaultdict(set) 

791 inferred_links["ENTITY==3"].add("ENTITY==6") 

792 integrated_flags = integrated_flags.vstack( 

793 pl.DataFrame({"qualified_entity": ["ENTITY==6"], "count": [1]}) 

794 ) 

795 result = build_entity_records(community_nodes, integrated_flags, inferred_links) 

796 

797 expected = [ 

798 ("1", 1, 0, 4, 5, 3, 1.25, 3.0), 

799 ("2", 0, 0, 4, 5, 3, 1.25, 3.0), 

800 ("3", 3, 0, 4, 5, 3, 1.25, 3.0), 

801 ("4", 0, 1, 4, 4, 2, 1.0, 1.0), 

802 ("5", 0, 1, 4, 4, 2, 1.0, 1.0), 

803 ("6", 1, 1, 4, 4, 2, 1.0, 1.0), 

804 ] 

805 

806 assert result == expected 

807 

808 def test_final_count_inferred_multiple( 

809 self, community_nodes_multiple, integrated_flags_multiple 

810 ) -> None: 

811 inferred_links = defaultdict(set) 

812 inferred_links["ENTITY==3"].add("ENTITY==6") 

813 inferred_links["ENTITY==3"].add("ENTITY==11") 

814 result = build_entity_records( 

815 community_nodes_multiple, integrated_flags_multiple, inferred_links 

816 ) 

817 

818 expected = [ 

819 ("1", 1, 0, 5, 6, 3, 1.2, 1.5), 

820 ("2", 0, 0, 5, 6, 3, 1.2, 1.5), 

821 ("3", 3, 0, 5, 6, 3, 1.2, 1.5), 

822 ("4", 0, 1, 4, 3, 1, 0.75, 0.33), 

823 ("5", 0, 1, 4, 3, 1, 0.75, 0.33), 

824 ("6", 0, 1, 4, 3, 1, 0.75, 0.33), 

825 ("8", 0, 2, 4, 10, 3, 2.5, 3.0), 

826 ("9", 5, 2, 4, 10, 3, 2.5, 3.0), 

827 ("11", 2, 2, 4, 10, 3, 2.5, 3.0), 

828 ] 

829 

830 assert result == expected 

831 

832 def test_final_not_integrated(self, community_nodes): 

833 integrated_flags = None 

834 result = build_entity_records(community_nodes, integrated_flags) 

835 

836 expected = [ 

837 ("1", 0, 0, 3, 0, 0, 0, 0), 

838 ("2", 0, 0, 3, 0, 0, 0, 0), 

839 ("3", 0, 0, 3, 0, 0, 0, 0), 

840 ("4", 0, 1, 3, 0, 0, 0, 0), 

841 ("5", 0, 1, 3, 0, 0, 0, 0), 

842 ("6", 0, 1, 3, 0, 0, 0, 0), 

843 ] 

844 

845 assert result == expected