Coverage for intelligence_toolkit/detect_entity_networks/explore_networks.py: 76%
153 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
6import colorsys
7from typing import Any
9import networkx as nx
10import polars as pl
12from intelligence_toolkit.detect_entity_networks.config import (
13 ENTITY_LABEL,
14 LIST_SEPARATOR,
15)
16from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
19def _integrate_flags(graph: nx.Graph, df_integrated_flags: pl.DataFrame) -> nx.Graph:
20 if not graph.nodes() or df_integrated_flags.is_empty():
21 return nx.Graph()
23 df_integrated_flags = df_integrated_flags.filter(pl.col("count") > 0)
24 flagged_nodes = (
25 df_integrated_flags.select("qualified_entity").unique().to_series().to_list()
26 )
28 flagged_nodes = [node for node in flagged_nodes if node in graph.nodes()]
29 for node in flagged_nodes:
30 graph.nodes[node]["flags"] = df_integrated_flags.filter(
31 pl.col("qualified_entity") == node
32 )["count"].sum()
33 return graph
36def _build_fuzzy_neighbors(
37 graph: nx.Graph,
38 network_graph: nx.Graph,
39 att_neighbor: str | int,
40 trimmed_nodeset: set,
41 inferred_links: dict,
42) -> nx.Graph:
43 if not len(graph.nodes()):
44 return nx.Graph()
46 if att_neighbor not in graph.nodes():
47 msg = f"Node {att_neighbor} not in graph"
48 raise ValueError(msg)
50 fuzzy_att_neighbors = set(graph.neighbors(att_neighbor))
51 if att_neighbor in inferred_links:
52 fuzzy_att_neighbors = fuzzy_att_neighbors.union(inferred_links[att_neighbor])
54 fuzzy_att_neighbors_not_trimmed = [
55 fuzzy_att_neighbor
56 for fuzzy_att_neighbor in fuzzy_att_neighbors
57 if fuzzy_att_neighbor not in trimmed_nodeset
58 and not fuzzy_att_neighbor.startswith(ENTITY_LABEL)
59 ]
60 for fuzzy_att_neighbor in fuzzy_att_neighbors_not_trimmed:
61 network_graph.add_node(
62 fuzzy_att_neighbor,
63 type=fuzzy_att_neighbor.split(ATTRIBUTE_VALUE_SEPARATOR)[0],
64 flags=0,
65 )
66 network_graph.add_edge(att_neighbor, fuzzy_att_neighbor)
67 return network_graph
70def build_network_from_entities(
71 graph,
72 entity_to_community,
73 integrated_flags: pl.DataFrame | None = None,
74 trimmed_attributes: list[tuple[str, int]] | None = None,
75 inferred_links: Any | None = None,
76 selected_nodes: list[str] | None = None,
77) -> nx.Graph:
78 network_graph = nx.Graph()
79 nodes = selected_nodes
80 # additional_trimmed_nodeset = set(sv.network_additional_trimmed_attributes.value)
81 # trimmed_nodeset = trimmed_attributes["Attribute"].unique().tolist()
82 trimmed_nodeset = {t[0] for t in trimmed_attributes}
84 if inferred_links is None:
85 inferred_links = {}
87 if integrated_flags is None:
88 integrated_flags = pl.DataFrame()
90 # trimmed_nodeset.extend(additional_trimmed_nodeset)
91 for node in nodes:
92 n_c = str(entity_to_community[node]) if node in entity_to_community else ""
93 network_graph.add_node(node, type=ENTITY_LABEL, network=n_c, flags=0)
94 ent_neighbors = set(graph.neighbors(node))
95 if node in inferred_links:
96 ent_neighbors = ent_neighbors.union(inferred_links[node])
97 for i in inferred_links:
98 if node in inferred_links[i]:
99 ent_neighbors = ent_neighbors.union([node])
101 ent_neighbors_not_trimmed = [
102 ent_neighbor
103 for ent_neighbor in ent_neighbors
104 if ent_neighbor not in trimmed_nodeset
105 ]
107 for ent_neighbor in ent_neighbors_not_trimmed:
108 if ent_neighbor.startswith(ENTITY_LABEL):
109 if node != ent_neighbor:
110 en_c = entity_to_community.get(ent_neighbor, "")
111 network_graph.add_node(
112 ent_neighbor, type=ENTITY_LABEL, network=en_c
113 )
114 network_graph.add_edge(node, ent_neighbor)
115 else:
116 network_graph.add_node(
117 ent_neighbor,
118 type=ent_neighbor.split(ATTRIBUTE_VALUE_SEPARATOR)[0],
119 flags=0,
120 )
121 network_graph.add_edge(node, ent_neighbor)
122 att_neighbors = set(graph.neighbors(ent_neighbor))
123 if ent_neighbor in inferred_links:
124 att_neighbors = att_neighbors.union(inferred_links[ent_neighbor])
125 att_neighbors_not_trimmed = [
126 att_neighbor
127 for att_neighbor in att_neighbors
128 if att_neighbor not in trimmed_nodeset
129 and not att_neighbor.startswith(ENTITY_LABEL)
130 ]
131 for att_neighbor in att_neighbors_not_trimmed:
132 network_graph.add_node(
133 att_neighbor,
134 type=att_neighbor.split(ATTRIBUTE_VALUE_SEPARATOR)[0],
135 flags=0,
136 )
137 network_graph = _build_fuzzy_neighbors(
138 graph,
139 network_graph,
140 att_neighbor,
141 trimmed_nodeset,
142 inferred_links,
143 )
145 if len(integrated_flags) > 0:
146 network_graph = _integrate_flags(network_graph, integrated_flags)
147 return network_graph
150def _merge_condition(x: str, y: str) -> bool:
151 """
152 Merge condition function for merging nodes in the graph.
153 """
154 x_parts = set(x.split(LIST_SEPARATOR))
155 y_parts = set(y.split(LIST_SEPARATOR))
156 return any(
157 x_part.split(ATTRIBUTE_VALUE_SEPARATOR)[i]
158 == y_part.split(ATTRIBUTE_VALUE_SEPARATOR)[i]
159 for i in range(2)
160 for x_part in x_parts
161 for y_part in y_parts
162 )
165def _merge_node_list(graph: nx.Graph, merge_list: list[str]) -> nx.Graph:
166 graph = graph.copy()
167 merged_node = LIST_SEPARATOR.join(sorted(merge_list))
168 merged_type = LIST_SEPARATOR.join(
169 sorted([graph.nodes[n]["type"] for n in merge_list])
170 )
171 merged_risk = max(
172 graph.nodes[n]["flags"] if "flags" in graph.nodes[n] else 0 for n in merge_list
173 )
174 graph.add_node(merged_node, type=merged_type, flags=merged_risk)
175 for n in merge_list:
176 for nn in graph.neighbors(n):
177 if nn not in merge_list:
178 graph.add_edge(merged_node, nn)
179 graph.remove_node(n)
180 return graph
183def _merge_nodes(graph: nx.Graph, should_merge=_merge_condition) -> nx.Graph:
184 nodes = list(graph.nodes()) # may change during iteration
185 for node in nodes:
186 if node not in graph.nodes():
187 continue
188 neighbours = list(graph.neighbors(node))
189 merge_list = [node]
190 for n in neighbours:
191 if n not in graph.nodes():
192 continue
193 if should_merge(node, n):
194 merge_list.append(n)
195 if len(merge_list) > 1:
196 graph = _merge_node_list(graph, merge_list)
198 return graph
201def simplify_entities_graph(entities_graph: nx.Graph) -> nx.Graph:
202 # remove single degree attributes
203 entities_graph = entities_graph.copy()
204 for node in list(entities_graph.nodes()):
205 if entities_graph.degree(node) < 2 and not node.startswith(ENTITY_LABEL):
206 entities_graph.remove_node(node)
208 entities_graph = _merge_nodes(entities_graph)
210 # remove single degree attributes
211 for node in list(entities_graph.nodes()):
212 if entities_graph.degree(node) < 2 and not node.startswith(ENTITY_LABEL):
213 entities_graph.remove_node(node)
215 return entities_graph
218def hsl_to_hex(hue: int, saturation: int, lightness: int) -> str:
219 rgb = colorsys.hls_to_rgb(hue / 360, lightness / 100, saturation / 100)
220 return "#{:02x}{:02x}{:02x}".format(*tuple(int(c * 255) for c in rgb))
223def get_type_color(node_type: str, is_flagged: bool, attribute_types: list[Any]) -> str:
224 if is_flagged:
225 hue = 0
226 saturation = 70
227 lightness = 80
228 else:
229 start = 230
230 reserve = 35
231 prop = attribute_types.index(node_type) / len(attribute_types)
232 inc = prop * (360 - 2 * reserve)
233 # avoid reds
234 hue = (start + inc) % 360
235 if hue < reserve:
236 hue += 2 * reserve
237 if hue > 360 - reserve:
238 hue = (hue + 2 * reserve) % 360
239 saturation = 70
240 lightness = 80
241 return str(hsl_to_hex(hue, saturation, lightness))
244def get_entity_graph(
245 network_entities_graph: nx.Graph, selected: str, attribute_types: list[str]
246) -> tuple[list, list]:
247 """
248 Implements the entity graph visualization after network selection
249 """
250 node_names = set()
251 nodes = []
252 edges = []
254 if not network_entities_graph.edges():
255 return nodes, edges
257 links_df = pl.DataFrame(
258 list(network_entities_graph.edges()), schema=["source", "target"]
259 )
261 links_df.with_columns(
262 [
263 pl.col("target")
264 .map_elements(lambda x: x.split(ATTRIBUTE_VALUE_SEPARATOR)[0])
265 .alias("attribute")
266 ]
267 )
269 all_nodes = set(links_df["source"]).union(set(links_df["target"]))
270 for node in all_nodes:
271 node_names.add(node)
273 size = 20 if node == selected else 12 if node.startswith(ENTITY_LABEL) else 8
274 vadjust = -size - 10
276 parts = [p.split(ATTRIBUTE_VALUE_SEPARATOR) for p in node.split(LIST_SEPARATOR)]
277 atts = [p[0] for p in parts]
278 atts = list(dict.fromkeys(atts))
280 flags = network_entities_graph.nodes[node].get("flags", 0)
281 color = get_type_color(
282 atts[0],
283 flags > 0,
284 attribute_types,
285 )
287 vals = [p[1] for p in parts if len(p) > 1]
288 vals = list(dict.fromkeys(vals))
289 label = "\n".join(vals) + "\n(" + LIST_SEPARATOR.join(atts) + ")"
291 nodes.append(
292 {
293 "title": node + f"\nFlags: {flags}",
294 "id": node,
295 "label": label,
296 "size": size,
297 "color": color,
298 "font": {"vadjust": vadjust, "size": 5},
299 }
300 )
301 for row in list(network_entities_graph.edges()):
302 source = row[0]
303 target = row[1]
304 edges.append(
305 {
306 "source": source,
307 "target": target,
308 "color": "mediumgray",
309 "size": 1,
310 }
311 )
312 return nodes, edges