Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_identify_networks.py: 100%
400 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from collections import defaultdict
6import networkx as nx
7import polars as pl
8import pytest
10from intelligence_toolkit.detect_entity_networks.identify_networks import (
11 build_entity_records,
12 get_community_nodes,
13 get_entity_neighbors,
14 get_integrated_flags,
15 get_subgraph,
16 neighbor_is_valid,
17 project_entity_graph,
18 trim_nodeset,
19)
22class TestTrimNodeset:
23 @pytest.fixture()
24 def overall_graph(self):
25 G = nx.Graph()
26 # Adding more nodes and edges to the graph
27 for i in range(1, 10):
28 G.add_node(
29 f"Entity{i}", type=f"Type{chr(65 + (i % 3))}"
30 ) # Types will be TypeA, TypeB, TypeC
31 for i in range(1, 10, 2):
32 G.add_edge(f"Entity{i}", f"Entity{i + 1}")
33 G.add_edge(f"Entity{i}", f"Entity{i + 2}")
34 return G
36 def test_trim_nodeset_additional_empty(self, overall_graph):
37 max_attribute_degree = 1
38 additional_trimmed_attributes = set()
39 (trimmed_degrees, trimmed_nodes) = trim_nodeset(
40 overall_graph, max_attribute_degree, additional_trimmed_attributes
41 )
43 trimmed_nodes_expected = {
44 "Entity1",
45 "Entity3",
46 "Entity5",
47 "Entity7",
48 "Entity9",
49 }
50 trimmed_degrees_expected = {
51 ("Entity1", 2),
52 ("Entity3", 3),
53 ("Entity5", 3),
54 ("Entity7", 3),
55 ("Entity9", 3),
56 }
58 assert trimmed_nodes == trimmed_nodes_expected
59 assert trimmed_degrees == trimmed_degrees_expected
61 def test_trim_nodeset_additional(self, overall_graph):
62 max_attribute_degree = 1
63 additional_trimmed_attributes = {"Entity2"}
64 (trimmed_degrees, trimmed_nodes) = trim_nodeset(
65 overall_graph,
66 max_attribute_degree,
67 additional_trimmed_attributes,
68 )
70 trimmed_nodes_expected = {
71 "Entity1",
72 "Entity2",
73 "Entity3",
74 "Entity5",
75 "Entity7",
76 "Entity9",
77 }
79 trimmed_degrees_expected = {
80 ("Entity1", 2),
81 ("Entity3", 3),
82 ("Entity5", 3),
83 ("Entity7", 3),
84 ("Entity9", 3),
85 }
87 assert trimmed_nodes == trimmed_nodes_expected
88 assert trimmed_degrees == trimmed_degrees_expected
91class TestProjectEntityGraph:
92 @pytest.fixture()
93 def simple_graph(self):
94 G = nx.Graph()
95 G.add_node("ENTITY==1")
96 G.add_node("ENTITY==2")
97 G.add_node("ENTITY==3")
98 G.add_node("ENTITY==4")
99 G.add_node("ENTITY==5")
100 G.add_node("Attr==Type1")
101 G.add_node("Attr==Type2")
102 G.add_node("Attr==Type35")
103 G.add_node("AttributeABCD==Type35")
105 G.add_edge("ENTITY==1", "ENTITY==2")
106 G.add_edge("Attr==Type1", "ENTITY==1")
107 G.add_edge("ENTITY==3", "Attr==Type1")
108 G.add_edge("ENTITY==3", "AttributeABCD==Type35")
109 G.add_edge("ENTITY==3", "ENTITY==5")
110 return G
112 @pytest.fixture()
113 def trimmed_nodeset(self):
114 return {"ENTITY==5"}
116 @pytest.fixture()
117 def inferred_links(self):
118 expected_links = defaultdict(set)
119 # Inferred links vira um link entre eles
120 expected_links["ENTITY==1"].add("ENTITY==5")
121 return expected_links
123 @pytest.fixture()
124 def inferred_links_empty(self):
125 return defaultdict(set)
127 @pytest.fixture()
128 def supporting_attribute_types(self):
129 # Liga um node a um edge pelo atributo, se os dois estao ligados ao mesmo atributo
130 return ["Attr"] # If here, aparece em edges e nodes! Se nao, nao
132 def test_empty_graph(
133 self, trimmed_nodeset, inferred_links, supporting_attribute_types
134 ):
135 empty_graph = nx.Graph()
136 projected = project_entity_graph(
137 empty_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
138 )
139 assert len(projected.nodes()) == 0
140 assert len(projected.edges()) == 0
142 def test_edges_no_inferred(
143 self,
144 simple_graph,
145 trimmed_nodeset,
146 inferred_links_empty,
147 supporting_attribute_types,
148 ):
149 projected = project_entity_graph(
150 simple_graph,
151 trimmed_nodeset,
152 inferred_links_empty,
153 supporting_attribute_types,
154 )
155 assert ("ENTITY==1", "ENTITY==2") in projected.edges()
156 assert ("ENTITY==5", "ENTITY==3") in projected.edges()
158 def test_edges_no_trimmed(
159 self,
160 simple_graph,
161 inferred_links_empty,
162 supporting_attribute_types,
163 ):
164 projected = project_entity_graph(
165 simple_graph,
166 [],
167 inferred_links_empty,
168 supporting_attribute_types,
169 )
170 assert ("ENTITY==1", "ENTITY==2") in projected.edges()
171 assert ("ENTITY==3", "ENTITY==5") in projected.edges()
173 def test_edges_inferred_trimmed(
174 self, simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
175 ):
176 projected = project_entity_graph(
177 simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
178 )
179 assert ("ENTITY==1", "ENTITY==2") in projected.edges()
180 assert ("ENTITY==1", "ENTITY==5") in projected.edges()
181 assert ("ENTITY==5", "ENTITY==3") in projected.edges()
183 def test_edges_inferred_no_trimmed(
184 self, simple_graph, inferred_links, supporting_attribute_types
185 ):
186 projected = project_entity_graph(
187 simple_graph, [], inferred_links, supporting_attribute_types
188 )
189 assert ("ENTITY==1", "ENTITY==2") in projected.edges()
190 assert ("ENTITY==1", "ENTITY==5") in projected.edges()
191 assert ("ENTITY==5", "ENTITY==3") in projected.edges()
193 def test_nodes_no_inferred(
194 self,
195 simple_graph,
196 trimmed_nodeset,
197 inferred_links_empty,
198 supporting_attribute_types,
199 ):
200 projected = project_entity_graph(
201 simple_graph,
202 trimmed_nodeset,
203 inferred_links_empty,
204 supporting_attribute_types,
205 )
206 assert ("ENTITY==1") in projected.nodes()
207 assert ("ENTITY==2") in projected.nodes()
208 assert ("ENTITY==3") in projected.nodes()
209 assert ("ENTITY==4") not in projected.nodes()
211 def test_nodes_inferred(
212 self, simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
213 ):
214 projected = project_entity_graph(
215 simple_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
216 )
217 assert ("ENTITY==1") in projected.nodes()
218 assert ("ENTITY==2") in projected.nodes()
219 assert ("ENTITY==3") in projected.nodes()
220 assert ("ENTITY==5") in projected.nodes()
221 assert ("ENTITY==4") not in projected.nodes()
224class TestValidNeighbor:
225 @pytest.fixture()
226 def supporting_attribute_types(self):
227 return ["Node1"]
229 @pytest.fixture()
230 def trimmed_nodeset(self):
231 return ["Node2==1"]
233 @pytest.fixture()
234 def node1(self):
235 return "Node1==2"
237 @pytest.fixture()
238 def node2(self):
239 return "Node2==1"
241 def test_empty(self):
242 result = neighbor_is_valid("", [], [])
243 assert result is False
245 def test_is_supported(self, node1, supporting_attribute_types):
246 result = neighbor_is_valid(node1, supporting_attribute_types, [])
247 assert result is False
249 def test_is_not_supported(self, node2, supporting_attribute_types):
250 result = neighbor_is_valid(node2, supporting_attribute_types, [])
251 assert result is True
253 def test_is_trimmed(self, node2, supporting_attribute_types, trimmed_nodeset):
254 result = neighbor_is_valid(node2, supporting_attribute_types, trimmed_nodeset)
255 assert result is False
258class TestGetEntityNeighbors:
259 @pytest.fixture()
260 def graph(self):
261 G = nx.Graph()
262 G.add_node("node0")
263 G.add_node("node1")
264 G.add_node("node2")
265 G.add_node("node3")
266 G.add_node("node4")
267 G.add_node("node5")
268 G.add_node("node6")
269 G.add_node("node7")
270 G.add_edge("node0", "node1")
271 G.add_edge("node1", "node2")
272 G.add_edge("node1", "node1")
273 G.add_edge("node1", "node3")
274 G.add_edge("node2", "node3")
275 G.add_edge("node3", "node4")
276 G.add_edge("node4", "node5")
277 return G
279 def test_empty_graph(self):
280 result = get_entity_neighbors(nx.Graph(), [], [], "")
281 assert result == []
283 def test_node_not_int_graph(self, graph):
284 with pytest.raises(
285 ValueError,
286 match="Node node77 not in graph",
287 ):
288 get_entity_neighbors(graph, [], [], "node77")
290 def test_no_inferred(self, graph):
291 result = get_entity_neighbors(graph, [], [], "node5")
292 assert result == ["node4"]
294 def test_inferred(self, graph):
295 inferred_links = defaultdict(set)
296 inferred_links["node5"].add("node2")
297 inferred_links["node12"].add("node127")
298 result = get_entity_neighbors(graph, inferred_links, [], "node5")
299 assert result == ["node2", "node4"]
301 def test_trimmed(self, graph):
302 trimmed = ["node2"]
303 result = get_entity_neighbors(graph, [], trimmed, "node1")
304 assert result == ["node0", "node3"]
306 def test_node_equals(self, graph):
307 result = get_entity_neighbors(graph, [], [], "node1")
308 assert result == ["node0", "node2", "node3"]
310 def test_inferred_mixed(self, graph) -> None:
311 inferred_links = defaultdict(set)
312 inferred_links["node5"].add("node2")
313 inferred_links["node12"].add("node127")
314 result = get_entity_neighbors(graph, inferred_links, [], "node2")
315 assert result == ["node1", "node3", "node5"]
317 def test_inferred_mixed_contrary(self, graph) -> None:
318 inferred_links = defaultdict(set)
319 inferred_links["node5"].add("node2")
320 result = get_entity_neighbors(graph, inferred_links, [], "node5")
321 assert result == ["node2", "node4"]
323 def test_inferred_mixed_multiple(self, graph) -> None:
324 inferred_links = defaultdict(set)
325 inferred_links["node5"].add("node2")
326 inferred_links["node7"].add("node2")
327 inferred_links["node12"].add("node127")
328 result = get_entity_neighbors(graph, inferred_links, [], "node2")
329 assert result == ["node1", "node3", "node5", "node7"]
332class TestEntityGraph:
333 # Test cases
334 @pytest.fixture()
335 def sample_graph(self):
336 G = nx.Graph()
337 G.add_nodes_from(
338 [
339 ("ENTITY==entity_1"),
340 ("ENTITY==entity_2"),
341 ("ENTITY==entity_3"),
342 ("ENTITY==entity_5"),
343 ("attr==att_1"),
344 ("attr==att_2"),
345 ("attr==att_3"),
346 ("attr==att_4"),
347 ]
348 )
350 G.add_edges_from(
351 [
352 ("ENTITY==entity_1", "attr==att_1"),
353 ("ENTITY==entity_5", "ENTITY==entity_2"),
354 ("ENTITY==entity_1", "attr==att_2"),
355 ("ENTITY==entity_2", "attr==att_2"),
356 ("ENTITY==entity_2", "attr==att_3"),
357 ("ENTITY==entity_3", "attr==att_4"),
358 ]
359 )
360 return G
362 def test_project_entity_graph_empty_graph(self):
363 G = nx.Graph()
364 trimmed_nodeset = set()
365 inferred_links = {}
366 supporting_attribute_types = []
367 P = project_entity_graph(
368 G, trimmed_nodeset, inferred_links, supporting_attribute_types
369 )
370 assert len(P.nodes) == 0
372 def test_project_entity_graph_basic(self, sample_graph):
373 trimmed_nodeset = set()
374 inferred_links = {}
375 supporting_attribute_types = []
377 P = project_entity_graph(
378 sample_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
379 )
380 expected = [
381 ("ENTITY==entity_1", "ENTITY==entity_2"),
382 ("ENTITY==entity_2", "ENTITY==entity_5"),
383 ]
384 assert list(P.edges()) == expected
386 def test_project_entity_graph_trimmed(self, sample_graph):
387 trimmed_nodeset = set()
388 trimmed_nodeset.add("ENTITY==entity_5")
389 inferred_links = {}
390 supporting_attribute_types = []
392 P = project_entity_graph(
393 sample_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
394 )
395 expected = [
396 ("ENTITY==entity_1", "ENTITY==entity_2"),
397 ("ENTITY==entity_2", "ENTITY==entity_5"),
398 ]
399 assert list(P.edges()) == expected
401 def test_project_entity_graph_inferred(self, sample_graph):
402 trimmed_nodeset = set()
403 inferred_links = defaultdict(set)
404 inferred_links["ENTITY==entity_1"].add("ENTITY==entity_3")
405 supporting_attribute_types = []
407 P = project_entity_graph(
408 sample_graph, trimmed_nodeset, inferred_links, supporting_attribute_types
409 )
410 expected = [
411 ("ENTITY==entity_1", "ENTITY==entity_3"),
412 ("ENTITY==entity_1", "ENTITY==entity_2"),
413 ("ENTITY==entity_2", "ENTITY==entity_5"),
414 ]
415 assert list(P.edges()) == expected
418class TestSubgraph:
419 def test_basic_functionality(self):
420 graph = nx.Graph()
421 graph.add_edges_from([(1, 2), (2, 3), (3, 4)])
422 nodes = [1, 2, 3]
423 community_nodes, entity_to_community = get_subgraph(graph, nodes)
424 assert len(community_nodes) == 1
425 assert set(community_nodes[0]) == set(nodes)
426 assert entity_to_community == {1: 0, 2: 0, 3: 0}
428 def test_disconnected_components(self):
429 graph = nx.Graph()
430 graph.add_edges_from([(1, 2), (3, 4)])
431 nodes = [1, 2, 3, 4]
432 community_nodes, entity_to_community = get_subgraph(graph, nodes)
433 assert len(community_nodes) == 2
434 assert set(community_nodes[0]) == {1, 2} or set(community_nodes[0]) == {3, 4}
435 assert set(community_nodes[1]) == {1, 2} or set(community_nodes[1]) == {3, 4}
436 assert len(entity_to_community) == 4
438 def test_max_network_entities(self):
439 graph = nx.Graph()
440 graph.add_edges_from(
441 [
442 (1, 2),
443 (3, 4),
444 (7, 9),
445 (8, 90),
446 (54, 66),
447 (66, 44),
448 (66, 1),
449 (66, 2),
450 ]
451 )
452 nodes = [1, 2, 3, 4, 66, 54, 89]
453 max_network_entities = 2
454 community_nodes, _ = get_subgraph(
455 graph, nodes, max_network_entities=max_network_entities
456 )
458 for community in community_nodes:
459 assert len(community) <= max_network_entities
461 def test_max_network_entities_size_high(self):
462 graph = nx.Graph()
463 graph.add_edges_from(
464 [
465 (1, 2),
466 (3, 4),
467 (7, 9),
468 (8, 90),
469 (54, 66),
470 (66, 44),
471 (66, 1),
472 (66, 2),
473 ]
474 )
475 nodes = [1, 2, 3, 4, 66, 54, 89]
476 max_network_entities = 10
477 community_nodes, _ = get_subgraph(
478 graph, nodes, max_network_entities=max_network_entities
479 )
481 for community in community_nodes:
482 assert len(community) <= max_network_entities
484 def test_graph_with_weights(self):
485 graph = nx.Graph()
486 graph.add_edge(1, 2, weight=1.0)
487 graph.add_edge(2, 3, weight=2.0)
488 graph.add_edge(3, 4, weight=3.0)
489 nodes = [1, 2, 3, 4]
490 community_nodes, entity_to_community = get_subgraph(graph, nodes)
491 assert len(community_nodes) > 0
492 assert len(entity_to_community) == 4
494 def test_empty_node_list(self):
495 graph = nx.Graph()
496 graph.add_edges_from([(1, 2), (2, 3)])
497 nodes = []
498 community_nodes, entity_to_community = get_subgraph(graph, nodes)
499 assert community_nodes == []
500 assert entity_to_community == {}
502 def test_empty_graph(self):
503 graph = nx.Graph()
504 nodes = [1, 2, 3]
505 community_nodes, entity_to_community = get_subgraph(graph, nodes)
506 assert community_nodes == []
507 assert entity_to_community == {}
510class TestNodes:
511 def test_nodes(self):
512 G = nx.Graph()
513 nx.add_path(G, ["node_A", "node_B", "node_C", "node_E"])
514 nx.add_path(G, ["node_V", "node_X"])
515 nx.add_path(G, ["node_D", "node_P"])
516 nx.add_path(G, ["node_D", "node_C"])
518 result = get_community_nodes(G, 10)
520 assert result == (
521 [
522 {
523 "node_A",
524 "node_B",
525 "node_C",
526 "node_D",
527 "node_E",
528 "node_P",
529 },
530 {
531 "node_V",
532 "node_X",
533 },
534 ],
535 {
536 "node_A": 0,
537 "node_B": 0,
538 "node_C": 0,
539 "node_D": 0,
540 "node_E": 0,
541 "node_P": 0,
542 "node_V": 1,
543 "node_X": 1,
544 },
545 )
547 def test_max_size(self):
548 G = nx.Graph()
549 nx.add_path(G, ["node_A", "node_B", "node_C", "node_E"])
550 nx.add_path(G, ["node_V", "node_X"])
551 nx.add_path(G, ["node_D", "node_P"])
552 nx.add_path(G, ["node_D", "node_C"])
554 result = get_community_nodes(G, 2)
556 expected_communities = [
557 {
558 "node_B",
559 "node_A",
560 },
561 {
562 "node_E",
563 "node_C",
564 },
565 {
566 "node_D",
567 "node_P",
568 },
569 {
570 "node_V",
571 "node_X",
572 },
573 ]
575 expected_entity_to_community = {
576 "node_A": 0,
577 "node_B": 0,
578 "node_C": 1,
579 "node_D": 2,
580 "node_E": 1,
581 "node_P": 2,
582 "node_V": 3,
583 "node_X": 3,
584 }
586 result_communities = [set(community) for community in result[0]]
588 assert len(result_communities) == len(expected_communities)
589 for community in expected_communities:
590 assert community in result_communities
592 assert result[1] == expected_entity_to_community
595class TestIntegratedFlags:
596 @pytest.fixture()
597 def qualified_entities(self) -> list[str]:
598 return ["ENTITY==1", "ENTITY==2", "ENTITY==3"]
600 @pytest.fixture()
601 def integrated_flags(self, qualified_entities) -> pl.DataFrame:
602 return pl.DataFrame(
603 {
604 "qualified_entity": qualified_entities,
605 "count": [1, 0, 3],
606 }
607 )
609 def test_empty_integrated_flags(self) -> None:
610 integrated_flags = pl.DataFrame()
611 entities = []
612 result = get_integrated_flags(integrated_flags, entities)
613 assert result == (0, 0, 0, 0, 0)
615 def test_no_entities(self, integrated_flags):
616 entities = []
617 result = get_integrated_flags(integrated_flags, entities)
618 assert result == (0, 0, 0, 0, 0)
620 def test_base_entities(self, integrated_flags, qualified_entities):
621 (
622 community_flags,
623 flagged,
624 flagged_per_unflagged,
625 flags_per_entity,
626 total_entities,
627 ) = get_integrated_flags(integrated_flags, qualified_entities)
629 assert flags_per_entity == 1.33
630 assert total_entities == 3
631 assert flagged_per_unflagged == 2
632 assert flagged == 2
633 assert community_flags == 4
635 def test_inferred_links(self, integrated_flags, qualified_entities) -> None:
636 qualified_entities.append("ENTITY==5")
637 inferred_links = defaultdict(set)
638 inferred_links["ENTITY==1"].add("ENTITY==5")
640 integrated_flags = integrated_flags.vstack(
641 pl.DataFrame({"qualified_entity": ["ENTITY==5"], "count": [0]})
642 )
644 (
645 community_flags,
646 flagged,
647 flagged_per_unflagged,
648 flags_per_entity,
649 total_entities,
650 ) = get_integrated_flags(integrated_flags, qualified_entities, inferred_links)
652 assert flags_per_entity == 1.0
653 assert total_entities == 4
654 assert flagged_per_unflagged == 1.0
655 assert flagged == 2.0
656 assert community_flags == 4
658 def test_inferred_links_flagged(self, integrated_flags, qualified_entities) -> None:
659 qualified_entities.append("ENTITY==5")
660 inferred_links = defaultdict(set)
661 inferred_links["ENTITY==1"].add("ENTITY==5")
663 integrated_flags = integrated_flags.vstack(
664 pl.DataFrame({"qualified_entity": ["ENTITY==5"], "count": [1]})
665 )
667 (
668 community_flags,
669 flagged,
670 flagged_per_unflagged,
671 flags_per_entity,
672 total_entities,
673 ) = get_integrated_flags(integrated_flags, qualified_entities, inferred_links)
675 assert flags_per_entity == 1.25
676 assert total_entities == 4
677 assert flagged_per_unflagged == 3.0
678 assert flagged == 3.0
679 assert community_flags == 5
682class TestBuildEntityRecords:
683 @pytest.fixture()
684 def community_nodes(self):
685 return [
686 ["ENTITY==1", "ENTITY==2", "ENTITY==3"],
687 ["ENTITY==4", "ENTITY==5", "ENTITY==6"],
688 ]
690 @pytest.fixture()
691 def integrated_flags(self):
692 return pl.DataFrame(
693 {
694 "qualified_entity": ["ENTITY==1", "ENTITY==2", "ENTITY==3"],
695 "count": [1, 0, 3],
696 }
697 )
699 @pytest.fixture()
700 def community_nodes_multiple(self) -> list[list[str]]:
701 return [
702 ["ENTITY==1", "ENTITY==2", "ENTITY==3"],
703 ["ENTITY==4", "ENTITY==5", "ENTITY==6"],
704 ["ENTITY==8", "ENTITY==9", "ENTITY==11"],
705 ]
707 @pytest.fixture()
708 def integrated_flags_multiple(self) -> pl.DataFrame:
709 return pl.DataFrame(
710 {
711 "qualified_entity": [
712 "ENTITY==1",
713 "ENTITY==2",
714 "ENTITY==3",
715 "ENTITY==11",
716 "ENTITY==9",
717 ],
718 "count": [1, 0, 3, 2, 5],
719 }
720 )
722 def test_final_integrated(self, community_nodes, integrated_flags):
723 result = build_entity_records(community_nodes, integrated_flags)
725 expected = [
726 ("1", 1, 0, 3, 4, 2, 1.33, 2.0),
727 ("2", 0, 0, 3, 4, 2, 1.33, 2.0),
728 ("3", 3, 0, 3, 4, 2, 1.33, 2.0),
729 ("4", 0, 1, 3, 0, 0, 0.0, 0.0),
730 ("5", 0, 1, 3, 0, 0, 0.0, 0.0),
731 ("6", 0, 1, 3, 0, 0, 0.0, 0.0),
732 ]
734 assert result == expected
736 def test_final_count_inferred(self, community_nodes, integrated_flags) -> None:
737 inferred_links = defaultdict(set)
738 inferred_links["ENTITY==1"].add("ENTITY==5")
739 result = build_entity_records(community_nodes, integrated_flags, inferred_links)
741 expected = [
742 ("1", 1, 0, 4, 4, 2, 1.0, 1.0),
743 ("2", 0, 0, 4, 4, 2, 1.0, 1.0),
744 ("3", 3, 0, 4, 4, 2, 1.0, 1.0),
745 ("4", 0, 1, 4, 1, 1, 0.25, 0.33),
746 ("5", 0, 1, 4, 1, 1, 0.25, 0.33),
747 ("6", 0, 1, 4, 1, 1, 0.25, 0.33),
748 ]
749 assert result == expected
751 def test_final_count_inferred_existant(
752 self, community_nodes, integrated_flags
753 ) -> None:
754 inferred_links = defaultdict(set)
755 inferred_links["ENTITY==1"].add("ENTITY==2")
756 result = build_entity_records(community_nodes, integrated_flags, inferred_links)
758 expected = [
759 ("1", 1, 0, 3, 4, 2, 1.33, 2.0),
760 ("2", 0, 0, 3, 4, 2, 1.33, 2.0),
761 ("3", 3, 0, 3, 4, 2, 1.33, 2.0),
762 ("4", 0, 1, 3, 0, 0, 0.0, 0.0),
763 ("5", 0, 1, 3, 0, 0, 0.0, 0.0),
764 ("6", 0, 1, 3, 0, 0, 0.0, 0.0),
765 ]
767 assert result == expected
769 def test_final_count_inferred_with_flags(
770 self, community_nodes, integrated_flags
771 ) -> None:
772 inferred_links = defaultdict(set)
773 inferred_links["ENTITY==3"].add("ENTITY==6")
774 result = build_entity_records(community_nodes, integrated_flags, inferred_links)
776 expected = [
777 ("1", 1, 0, 4, 4, 2, 1.0, 1.0),
778 ("2", 0, 0, 4, 4, 2, 1.0, 1.0),
779 ("3", 3, 0, 4, 4, 2, 1.0, 1.0),
780 ("4", 0, 1, 4, 3, 1, 0.75, 0.33),
781 ("5", 0, 1, 4, 3, 1, 0.75, 0.33),
782 ("6", 0, 1, 4, 3, 1, 0.75, 0.33),
783 ]
785 assert result == expected
787 def test_final_count_inferred_both_with_flags(
788 self, community_nodes, integrated_flags
789 ) -> None:
790 inferred_links = defaultdict(set)
791 inferred_links["ENTITY==3"].add("ENTITY==6")
792 integrated_flags = integrated_flags.vstack(
793 pl.DataFrame({"qualified_entity": ["ENTITY==6"], "count": [1]})
794 )
795 result = build_entity_records(community_nodes, integrated_flags, inferred_links)
797 expected = [
798 ("1", 1, 0, 4, 5, 3, 1.25, 3.0),
799 ("2", 0, 0, 4, 5, 3, 1.25, 3.0),
800 ("3", 3, 0, 4, 5, 3, 1.25, 3.0),
801 ("4", 0, 1, 4, 4, 2, 1.0, 1.0),
802 ("5", 0, 1, 4, 4, 2, 1.0, 1.0),
803 ("6", 1, 1, 4, 4, 2, 1.0, 1.0),
804 ]
806 assert result == expected
808 def test_final_count_inferred_multiple(
809 self, community_nodes_multiple, integrated_flags_multiple
810 ) -> None:
811 inferred_links = defaultdict(set)
812 inferred_links["ENTITY==3"].add("ENTITY==6")
813 inferred_links["ENTITY==3"].add("ENTITY==11")
814 result = build_entity_records(
815 community_nodes_multiple, integrated_flags_multiple, inferred_links
816 )
818 expected = [
819 ("1", 1, 0, 5, 6, 3, 1.2, 1.5),
820 ("2", 0, 0, 5, 6, 3, 1.2, 1.5),
821 ("3", 3, 0, 5, 6, 3, 1.2, 1.5),
822 ("4", 0, 1, 4, 3, 1, 0.75, 0.33),
823 ("5", 0, 1, 4, 3, 1, 0.75, 0.33),
824 ("6", 0, 1, 4, 3, 1, 0.75, 0.33),
825 ("8", 0, 2, 4, 10, 3, 2.5, 3.0),
826 ("9", 5, 2, 4, 10, 3, 2.5, 3.0),
827 ("11", 2, 2, 4, 10, 3, 2.5, 3.0),
828 ]
830 assert result == expected
832 def test_final_not_integrated(self, community_nodes):
833 integrated_flags = None
834 result = build_entity_records(community_nodes, integrated_flags)
836 expected = [
837 ("1", 0, 0, 3, 0, 0, 0, 0),
838 ("2", 0, 0, 3, 0, 0, 0, 0),
839 ("3", 0, 0, 3, 0, 0, 0, 0),
840 ("4", 0, 1, 3, 0, 0, 0, 0),
841 ("5", 0, 1, 3, 0, 0, 0, 0),
842 ("6", 0, 1, 3, 0, 0, 0, 0),
843 ]
845 assert result == expected