Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_explore_networks.py: 100%
288 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5import networkx as nx
6import polars as pl
7import pytest
8from networkx import Graph
10from intelligence_toolkit.detect_entity_networks.explore_networks import (
11 _build_fuzzy_neighbors,
12 _integrate_flags,
13 _merge_condition,
14 _merge_node_list,
15 _merge_nodes,
16 get_entity_graph,
17 get_type_color,
18 hsl_to_hex,
19 simplify_entities_graph,
20)
23@pytest.fixture()
24def graph() -> Graph:
25 G = nx.Graph()
26 G.add_node("A", type="TypeA", flags=1)
28 G.add_node("B", type="TypeB", flags=0)
29 G.add_node("C", type="TypeC", flags=1)
30 G.add_node("D", type="TypeC", flags=1)
31 G.add_node("E", type="TypeC", flags=0)
32 G.add_node("F", type="TypeC", flags=1)
33 G.add_edge("A", "B")
34 G.add_edge("B", "F")
35 G.add_edge("B", "C")
36 G.add_edge("E", "F")
37 G.add_edge("D", "F")
38 G.add_edge("A", "F")
39 return G
42@pytest.fixture()
43def simple_graph() -> Graph:
44 G = nx.Graph()
45 G.add_node("ENTITY==1")
46 G.add_node("ENTITY==2")
47 G.add_node("ENTITY==3")
48 G.add_node("ENTITY==4")
49 G.add_node("ENTITY==5")
50 G.add_node("Attr==Type1")
51 G.add_node("Attr==Type2")
52 G.add_node("Attr==Type35")
53 G.add_node("AttributeABCD==Type35")
54 G.add_node("AttributeABCD==Type37")
55 G.add_node("AttributeABCD==Type38")
56 G.add_node("AttributeABCD==Type47")
58 G.add_edge("ENTITY==1", "ENTITY==2")
59 G.add_edge("Attr==Type1", "ENTITY==1")
60 G.add_edge("AttributeABCD==Type37", "Attr==Type1")
61 G.add_edge("AttributeABCD==Type47", "Attr==Type2")
62 G.add_edge("ENTITY==3", "Attr==Type1")
63 G.add_edge("ENTITY==3", "AttributeABCD==Type35")
64 G.add_edge("AttributeABCD==Type35", "AttributeABCD==Type38")
65 G.add_edge("ENTITY==3", "ENTITY==5")
66 return G
69class TestFuzzyNeighbors:
70 @pytest.fixture()
71 def existing_network_graph(self):
72 G = nx.Graph()
73 G.add_node("ENTITY==47")
74 G.add_node("Attr==Type108")
75 G.add_node("Attr==Type222")
77 G.add_edge("Attr==Type108", "ENTITY==47")
78 G.add_edge("Attr==Type108", "Attr==Type222")
80 return G
82 def test_empty_graph(self):
83 result = _build_fuzzy_neighbors(
84 nx.Graph(), nx.Graph(), "ENTITY==1", set(), dict
85 )
86 assert result.nodes == nx.Graph().nodes
88 def test_node_inexistent(self, simple_graph):
89 with pytest.raises(ValueError, match="Node ENTITY==78 not in graph"):
90 _build_fuzzy_neighbors(simple_graph, nx.Graph(), "ENTITY==78", set(), dict)
92 def test_fuzzy_neighbor_basic(self, simple_graph):
93 network_graph = nx.Graph()
94 att_neighbor = "Attr==Type1"
95 trimmed_nodeset = set()
96 inferred_links = {}
97 result = _build_fuzzy_neighbors(
98 simple_graph, network_graph, att_neighbor, trimmed_nodeset, inferred_links
99 )
101 assert list(result.nodes()) == ["AttributeABCD==Type37", "Attr==Type1"]
102 assert list(result.edges()) == [("AttributeABCD==Type37", "Attr==Type1")]
104 def test_fuzzy_neighbor_inferred(self, simple_graph):
105 network_graph = nx.Graph()
106 att_neighbor = "Attr==Type1"
107 trimmed_nodeset = set()
108 inf = set()
109 inf.add("AttributeABCD==Type47")
110 inferred_links = {"Attr==Type1": inf}
111 result = _build_fuzzy_neighbors(
112 simple_graph, network_graph, att_neighbor, trimmed_nodeset, inferred_links
113 )
115 assert len(result.edges()) == 2
116 assert ("AttributeABCD==Type47", "Attr==Type1") in result.edges()
117 assert ("AttributeABCD==Type37", "Attr==Type1") in result.edges()
119 def test_fuzzy_neighbor_trimmed(self, simple_graph):
120 network_graph = nx.Graph()
121 att_neighbor = "AttributeABCD==Type35"
122 trimmed_nodeset = set()
123 trimmed_nodeset.add("AttributeABCD==Type38")
124 inferred_links = {}
125 result = _build_fuzzy_neighbors(
126 simple_graph, network_graph, att_neighbor, trimmed_nodeset, inferred_links
127 )
129 assert len(result.nodes()) == 0
130 assert len(result.edges()) == 0
132 def test_fuzzy_neighbor_graph_existent(self, simple_graph, existing_network_graph):
133 att_neighbor = "Attr==Type1"
134 trimmed_nodeset = set()
135 inferred_links = {}
136 result = _build_fuzzy_neighbors(
137 simple_graph,
138 existing_network_graph,
139 att_neighbor,
140 trimmed_nodeset,
141 inferred_links,
142 )
144 assert len(result.edges()) == 3
145 assert len(result.nodes()) == 5
146 assert ("AttributeABCD==Type37", "Attr==Type1") in result.edges()
147 assert ("Attr==Type108", "Attr==Type222") in result.edges()
150class TestIntegrateFlags:
151 @pytest.fixture()
152 def graph_flags(self):
153 G = nx.Graph()
154 G.add_node("A")
156 G.add_node("B")
157 G.add_node("C")
158 G.add_node("D")
159 G.add_node("E")
160 G.add_node("F")
161 return G
163 def test_empty_graph(self):
164 result = _integrate_flags(nx.Graph(), pl.DataFrame())
166 assert len(result.nodes()) == 0
168 def test_empty_flags(self, graph_flags):
169 result = _integrate_flags(graph_flags, pl.DataFrame())
171 assert len(result.nodes()) == 0
173 def test_integration(self, graph_flags):
174 flags = pl.DataFrame(
175 {
176 "qualified_entity": ["A", "C", "D", "F"],
177 "count": [1, 2, 3, 0],
178 }
179 )
181 result = _integrate_flags(graph_flags, flags)
183 assert result.nodes["A"]["flags"] == 1
184 assert result.nodes["C"]["flags"] == 2
185 assert result.nodes["D"]["flags"] == 3
186 assert "flags" not in result.nodes["B"]
187 assert "flags" not in result.nodes["E"]
188 assert "flags" not in result.nodes["F"]
190 def test_sum(self, graph_flags):
191 flags = pl.DataFrame(
192 {
193 "qualified_entity": ["A", "C", "D", "F", "A"],
194 "count": [1, 2, 3, 0, 5],
195 }
196 )
198 result = _integrate_flags(graph_flags, flags)
200 assert result.nodes["A"]["flags"] == 6
201 assert result.nodes["C"]["flags"] == 2
202 assert result.nodes["D"]["flags"] == 3
203 assert "flags" not in result.nodes["B"]
204 assert "flags" not in result.nodes["E"]
205 assert "flags" not in result.nodes["F"]
207 def test_node_not_in_graph(self, graph_flags):
208 flags = pl.DataFrame(
209 {
210 "qualified_entity": ["A", "C", "D", "F", "Z"],
211 "count": [1, 2, 3, 0, 3],
212 }
213 )
215 result = _integrate_flags(graph_flags, flags)
217 assert result.nodes["A"]["flags"] == 1
218 assert result.nodes["C"]["flags"] == 2
219 assert result.nodes["D"]["flags"] == 3
220 assert "flags" not in result.nodes["B"]
221 assert "flags" not in result.nodes["E"]
222 assert "flags" not in result.nodes["F"]
223 assert "Z" not in result.nodes()
224 assert "flags" not in result.nodes["B"]
225 assert "flags" not in result.nodes["E"]
226 assert "flags" not in result.nodes["F"]
227 assert "Z" not in result.nodes()
230class TestMergeCondition:
231 def test_merge(self) -> None:
232 x = "A==1;B==2;C==3"
233 y = "D==4;E==1;F==5"
234 assert _merge_condition(x, y) is True
236 def test_merge_common_attr(self) -> None:
237 x = "A==1;B==2;C==3"
238 y = "D==4;E==5;F==2"
239 assert _merge_condition(x, y) is True
241 def test_merge_not_common(self) -> None:
242 x = "A==1;B==2;C==3"
243 y = "D==4;E==5;F==6"
244 assert _merge_condition(x, y) is False
246 def test_merge_multiple_common(self) -> None:
247 x = "A==1;B==2;C==3"
248 y = "D==4;E==1;F==2"
249 assert _merge_condition(x, y) is True
252class TestMergeNodeList:
253 def test_merge_node_list(self, graph):
254 merge_list = ["A", "C"]
255 merged_graph = _merge_node_list(graph, merge_list)
257 assert merged_graph.has_node("A;C")
258 assert merged_graph.has_node("B")
259 assert merged_graph.nodes["A;C"]["type"] == "TypeA;TypeC"
260 assert merged_graph.nodes["A;C"]["flags"] == 1
261 assert merged_graph.has_edge("B", "A;C")
262 assert not merged_graph.has_node("A")
263 assert not merged_graph.has_node("C")
264 assert not merged_graph.has_edge("A", "B")
265 assert not merged_graph.has_edge("B", "C")
266 assert not merged_graph.has_edge("A", "B")
267 assert not merged_graph.has_edge("B", "C")
268 assert not merged_graph.has_edge("B", "C")
271class TestMergeNodes:
272 def test_merge_all_nodes(self, graph):
273 merged_graph = _merge_nodes(graph, lambda _x, _y: True)
274 assert len(merged_graph.nodes()) < len(graph.nodes())
276 def test_merge_no_nodes(self, graph):
277 merged_graph = _merge_nodes(graph, lambda _x, _y: False)
278 assert len(merged_graph.nodes()) == len(graph.nodes())
280 def test_empty_graph(self):
281 G = nx.Graph()
282 merged_graph = _merge_nodes(G, lambda _x, _y: True)
283 assert len(merged_graph.nodes()) == 0
286class TestSimplifyEntitiesGraph:
287 def test_simplify_condition_false(self, mocker, graph):
288 mocker.patch(
289 "intelligence_toolkit.detect_entity_networks.explore_networks._merge_nodes"
290 ).return_value = graph
292 aba = simplify_entities_graph(graph)
293 assert len(aba.nodes()) == 3
295 def test_simplify_condition_true(self, mocker, graph):
296 G = nx.Graph()
297 mocker.patch(
298 "intelligence_toolkit.detect_entity_networks.explore_networks._merge_nodes"
299 ).return_value = G
301 aba = simplify_entities_graph(graph)
302 assert len(aba.nodes()) == 0
303 assert len(aba.nodes()) == 0
306class TestHslToHex:
307 def test_colors(self):
308 assert hsl_to_hex(0, 0, 0) == "#000000"
309 assert hsl_to_hex(0, 0, 100) == "#ffffff"
310 assert hsl_to_hex(0, 100, 50) == "#ff0000"
311 assert hsl_to_hex(120, 100, 50) == "#00ff00"
312 assert hsl_to_hex(240, 100, 50) == "#0000ff"
315class TestGetTypeColor:
316 @pytest.fixture()
317 def attribute_types(self) -> list[str]:
318 return ["Type1", "Type2", "Type3"]
320 def test_not_flagged(self, attribute_types) -> None:
321 node_type = "Type1"
322 is_flagged = False
323 result = get_type_color(node_type, is_flagged, attribute_types)
325 assert result == "#a8b4ef"
327 def test_flagged(self, attribute_types) -> None:
328 node_type = "Type1"
329 is_flagged = True
330 result = get_type_color(node_type, is_flagged, attribute_types)
332 assert result == "#efa8a8"
334 def test_type_2_3_types(self, attribute_types) -> None:
335 node_type = "Type2"
336 is_flagged = False
337 result = get_type_color(node_type, is_flagged, attribute_types)
339 assert result == "#efd3a8"
341 def test_type_2_2_types(self, attribute_types) -> None:
342 node_type = "Type2"
343 is_flagged = False
344 attribute_types = attribute_types[:-1]
345 result = get_type_color(node_type, is_flagged, attribute_types)
347 assert result == "#d1efa8"
350class TestGetEntityGraph:
351 @pytest.fixture()
352 def simple_graph(self) -> Graph:
353 G = nx.Graph()
354 G.add_node("ENTITY==1")
355 G.add_node("ENTITY==2")
356 G.add_node("ENTITY==3")
357 G.add_node("Attr==Type1")
358 G.add_node("Attr==Type2")
359 G.add_node("AttributeABCD==Type35")
360 G.add_node("AttributeABCD==Type37")
362 G.add_edge("ENTITY==1", "ENTITY==2")
363 G.add_edge("Attr==Type1", "ENTITY==1")
364 G.add_edge("AttributeABCD==Type37", "Attr==Type1")
365 G.add_edge("ENTITY==3", "Attr==Type1")
366 G.add_edge("ENTITY==3", "AttributeABCD==Type35")
367 return G
369 def test_empty(self) -> None:
370 G = nx.Graph()
371 selected = ""
372 attribute_types = []
374 nodes, edges = get_entity_graph(G, selected, attribute_types)
375 assert len(nodes) == 0
376 assert len(edges) == 0
378 def test_none_selected_nodes(self, simple_graph) -> None:
379 attribute_types = ["ENTITY", "AttributeABCD", "Attr"]
380 selected = ""
381 nodes, _ = get_entity_graph(simple_graph, selected, attribute_types)
383 expected_nodes = [
384 {
385 "title": "ENTITY==1\nFlags: 0",
386 "id": "ENTITY==1",
387 "label": "1\n(ENTITY)",
388 "size": 12,
389 "color": "#a8b4ef",
390 "font": {"vadjust": -22, "size": 5},
391 },
392 {
393 "title": "ENTITY==3\nFlags: 0",
394 "id": "ENTITY==3",
395 "label": "3\n(ENTITY)",
396 "size": 12,
397 "color": "#a8b4ef",
398 "font": {"vadjust": -22, "size": 5},
399 },
400 {
401 "title": "AttributeABCD==Type37\nFlags: 0",
402 "id": "AttributeABCD==Type37",
403 "label": "Type37\n(AttributeABCD)",
404 "size": 8,
405 "color": "#efd3a8",
406 "font": {"vadjust": -18, "size": 5},
407 },
408 {
409 "title": "ENTITY==2\nFlags: 0",
410 "id": "ENTITY==2",
411 "label": "2\n(ENTITY)",
412 "size": 12,
413 "color": "#a8b4ef",
414 "font": {"vadjust": -22, "size": 5},
415 },
416 {
417 "title": "Attr==Type1\nFlags: 0",
418 "id": "Attr==Type1",
419 "label": "Type1\n(Attr)",
420 "size": 8,
421 "color": "#ebefa8",
422 "font": {"vadjust": -18, "size": 5},
423 },
424 {
425 "title": "AttributeABCD==Type35\nFlags: 0",
426 "id": "AttributeABCD==Type35",
427 "label": "Type35\n(AttributeABCD)",
428 "size": 8,
429 "color": "#efd3a8",
430 "font": {"vadjust": -18, "size": 5},
431 },
432 ]
434 expected_nodes = sorted(expected_nodes, key=lambda x: x["id"])
435 nodes = sorted(nodes, key=lambda x: x["id"])
436 assert expected_nodes == nodes
438 def test_none_selected_edges(self, simple_graph):
439 attribute_types = ["ENTITY", "AttributeABCD", "Attr"]
440 selected = ""
441 _, edges = get_entity_graph(simple_graph, selected, attribute_types)
443 expected_edges = [
444 {
445 "source": "Attr==Type1",
446 "target": "AttributeABCD==Type37",
447 "color": "mediumgray",
448 "size": 1,
449 },
450 {
451 "source": "ENTITY==1",
452 "target": "Attr==Type1",
453 "color": "mediumgray",
454 "size": 1,
455 },
456 {
457 "source": "ENTITY==1",
458 "target": "ENTITY==2",
459 "color": "mediumgray",
460 "size": 1,
461 },
462 {
463 "source": "ENTITY==3",
464 "target": "Attr==Type1",
465 "color": "mediumgray",
466 "size": 1,
467 },
468 {
469 "source": "ENTITY==3",
470 "target": "AttributeABCD==Type35",
471 "color": "mediumgray",
472 "size": 1,
473 },
474 ]
476 edges = sorted(edges, key=lambda x: (x["source"], x["target"]))
477 assert edges == expected_edges
479 def test_selected_nodes(self, simple_graph):
480 attribute_types = ["ENTITY", "AttributeABCD", "Attr"]
481 selected = "ENTITY==3"
482 nodes, _ = get_entity_graph(simple_graph, selected, attribute_types)
484 expected_nodes = [
485 {
486 "title": "ENTITY==1\nFlags: 0",
487 "id": "ENTITY==1",
488 "label": "1\n(ENTITY)",
489 "size": 12,
490 "color": "#a8b4ef",
491 "font": {"vadjust": -22, "size": 5},
492 },
493 {
494 "title": "ENTITY==3\nFlags: 0",
495 "id": "ENTITY==3",
496 "label": "3\n(ENTITY)",
497 "size": 20,
498 "color": "#a8b4ef",
499 "font": {"vadjust": -30, "size": 5},
500 },
501 {
502 "title": "AttributeABCD==Type37\nFlags: 0",
503 "id": "AttributeABCD==Type37",
504 "label": "Type37\n(AttributeABCD)",
505 "size": 8,
506 "color": "#efd3a8",
507 "font": {"vadjust": -18, "size": 5},
508 },
509 {
510 "title": "ENTITY==2\nFlags: 0",
511 "id": "ENTITY==2",
512 "label": "2\n(ENTITY)",
513 "size": 12,
514 "color": "#a8b4ef",
515 "font": {"vadjust": -22, "size": 5},
516 },
517 {
518 "title": "Attr==Type1\nFlags: 0",
519 "id": "Attr==Type1",
520 "label": "Type1\n(Attr)",
521 "size": 8,
522 "color": "#ebefa8",
523 "font": {"vadjust": -18, "size": 5},
524 },
525 {
526 "title": "AttributeABCD==Type35\nFlags: 0",
527 "id": "AttributeABCD==Type35",
528 "label": "Type35\n(AttributeABCD)",
529 "size": 8,
530 "color": "#efd3a8",
531 "font": {"vadjust": -18, "size": 5},
532 },
533 ]
535 nodes = sorted(nodes, key=lambda x: x["id"])
536 expected_nodes = sorted(expected_nodes, key=lambda x: x["id"])
537 assert nodes == expected_nodes
539 def test_selected_edges(self, simple_graph):
540 attribute_types = ["ENTITY", "AttributeABCD", "Attr"]
541 selected = "ENTITY==3"
542 _, edges = get_entity_graph(simple_graph, selected, attribute_types)
544 expected_edges = [
545 {
546 "source": "Attr==Type1",
547 "target": "AttributeABCD==Type37",
548 "color": "mediumgray",
549 "size": 1,
550 },
551 {
552 "source": "ENTITY==1",
553 "target": "Attr==Type1",
554 "color": "mediumgray",
555 "size": 1,
556 },
557 {
558 "source": "ENTITY==1",
559 "target": "ENTITY==2",
560 "color": "mediumgray",
561 "size": 1,
562 },
563 {
564 "source": "ENTITY==3",
565 "target": "Attr==Type1",
566 "color": "mediumgray",
567 "size": 1,
568 },
569 {
570 "source": "ENTITY==3",
571 "target": "AttributeABCD==Type35",
572 "color": "mediumgray",
573 "size": 1,
574 },
575 ]
577 edges = sorted(edges, key=lambda x: (x["source"], x["target"]))
578 assert edges == expected_edges