Coverage for intelligence_toolkit/detect_entity_networks/identify_networks.py: 88%
152 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from collections import defaultdict
5from typing import Any
7import networkx as nx
8import polars as pl
9from graspologic.partition import hierarchical_leiden
11from intelligence_toolkit.detect_entity_networks.config import (
12 DEFAULT_MAX_ATTRIBUTE_DEGREE,
13 ENTITY_LABEL,
14)
15from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
18# ruff: noqa
19def trim_nodeset(
20 graph: nx.Graph,
21 max_attribute_degree: int,
22 additional_trimmed_attributes: list[str] | None = None,
23) -> tuple[set, set[Any | str]]:
24 if additional_trimmed_attributes is None:
25 additional_trimmed_attributes = []
27 trimmed_degrees = set()
28 for node, degree in graph.degree():
29 if not node.startswith(ENTITY_LABEL) and degree > max_attribute_degree:
30 trimmed_degrees.add((node, degree))
32 trimmed_nodes = {t[0] for t in trimmed_degrees}.union(additional_trimmed_attributes)
33 return trimmed_degrees, trimmed_nodes
36def get_entity_neighbors(overall_graph, inferred_links, trimmed_nodeset, node) -> list:
37 if not len(overall_graph.nodes()):
38 return []
40 if node not in overall_graph.nodes():
41 raise ValueError(f"Node {node} not in graph")
43 neighbors = set(overall_graph.neighbors(node))
44 if inferred_links:
45 if node in inferred_links:
46 neighbors = neighbors.union(inferred_links[node])
48 for inferred_link in inferred_links:
49 if node in inferred_links[inferred_link]:
50 neighbors = neighbors.union([inferred_link])
52 neighbors = neighbors.difference(trimmed_nodeset)
54 return [neighbor for neighbor in sorted(neighbors) if neighbor != node]
57def neighbor_is_valid(
58 neighbor_node, supporting_attribute_types, trimmed_nodeset
59) -> bool:
60 if not neighbor_node:
61 return False
62 node_name = neighbor_node.split(ATTRIBUTE_VALUE_SEPARATOR)[0]
63 is_not_supported = node_name not in supporting_attribute_types
64 is_trimmed = neighbor_node in trimmed_nodeset
65 return is_not_supported and not is_trimmed
68def project_entity_graph(
69 overall_graph: nx.Graph,
70 trimmed_nodeset: set,
71 inferred_links: dict[set] | None = None,
72 supporting_attribute_types: list[str] | None = None,
73) -> nx.Graph:
74 P = nx.Graph()
75 entity_nodes = [
76 node for node in overall_graph.nodes() if node.startswith(ENTITY_LABEL)
77 ]
79 if supporting_attribute_types is None:
80 supporting_attribute_types = []
82 if inferred_links is None:
83 inferred_links = []
85 for node in entity_nodes:
86 neighbors = get_entity_neighbors(
87 overall_graph, inferred_links, trimmed_nodeset, node
88 )
90 for ent_neighbor in neighbors:
91 if ent_neighbor.startswith(ENTITY_LABEL):
92 P.add_edge(node, ent_neighbor)
93 elif neighbor_is_valid(
94 ent_neighbor, supporting_attribute_types, trimmed_nodeset
95 ):
96 att_neighbors = set(overall_graph.neighbors(ent_neighbor))
97 if ent_neighbor in inferred_links:
98 att_neighbors = att_neighbors.union(inferred_links[ent_neighbor])
99 for att_neighbor in att_neighbors:
100 if neighbor_is_valid(
101 att_neighbor, supporting_attribute_types, trimmed_nodeset
102 ):
103 if att_neighbor.startswith(ENTITY_LABEL):
104 if node != att_neighbor:
105 P.add_edge(node, att_neighbor)
106 else: # fuzzy att link
107 fuzzy_att_neighbors = set(
108 overall_graph.neighbors(att_neighbor)
109 )
110 if att_neighbor in inferred_links:
111 fuzzy_att_neighbors = fuzzy_att_neighbors.union(
112 inferred_links[att_neighbor]
113 )
114 for fuzzy_att_neighbor in fuzzy_att_neighbors:
115 if neighbor_is_valid(
116 fuzzy_att_neighbor,
117 supporting_attribute_types,
118 trimmed_nodeset,
119 ):
120 if fuzzy_att_neighbor.startswith(ENTITY_LABEL):
121 if node != fuzzy_att_neighbor:
122 P.add_edge(
123 node,
124 fuzzy_att_neighbor,
125 )
126 return P
129def get_subgraph(
130 entity_graph: nx.Graph,
131 nodes: list[str | int],
132 random_seed: int = 42,
133 max_network_entities: int = 20,
134) -> tuple[list, dict]:
135 entity_to_community = {}
136 community_nodes = []
138 if not nodes or not entity_graph.nodes():
139 return community_nodes, entity_to_community
141 S = nx.subgraph(entity_graph, nodes)
143 node_to_network = hierarchical_leiden(
144 S,
145 resolution=1.0,
146 random_seed=random_seed,
147 max_cluster_size=max_network_entities,
148 ).final_level_hierarchical_clustering()
150 network_to_nodes = defaultdict(set)
151 for node, network in node_to_network.items():
152 network_to_nodes[network].add(node)
154 networks = [list(nodes) for nodes in network_to_nodes.values()]
155 for network in networks:
156 community_nodes.append(network)
157 for node in network:
158 entity_to_community[node] = len(community_nodes) - 1
160 return community_nodes, entity_to_community
163def get_community_nodes(
164 entity_graph: nx.Graph,
165 max_network_entities: int | None = 20,
166) -> tuple[list, dict]:
167 # get set of connected nodes list
168 sorted_components = sorted(
169 nx.components.connected_components(entity_graph),
170 key=lambda x: len(x),
171 reverse=True,
172 )
174 components_sequence = range(len(sorted_components))
175 component_to_nodes = dict(zip(components_sequence, sorted_components, strict=False))
177 entity_to_community_ix = {}
178 community_nodes = []
179 for sequence in components_sequence:
180 nodes = component_to_nodes[sequence]
181 if len(nodes) > max_network_entities:
182 community_nodes_sequence, entity_to_community = get_subgraph(
183 entity_graph, nodes, max_network_entities
184 )
185 community_nodes.extend(community_nodes_sequence)
186 entity_to_community_ix.update(entity_to_community)
187 else:
188 community_nodes.append(nodes)
189 for node in nodes:
190 entity_to_community_ix[node] = len(community_nodes) - 1
191 return community_nodes, entity_to_community_ix
194def build_networks(
195 main_graph: nx.Graph,
196 trimmed_nodes: set,
197 inferred_links: set | None = None,
198 supporting_attribute_types: list[str] | None = None,
199 max_network_entities: int | None = 20,
200) -> tuple[list, dict]:
201 P = project_entity_graph(
202 main_graph, trimmed_nodes, inferred_links, supporting_attribute_types
203 )
205 (
206 community_nodes,
207 entity_to_community,
208 ) = get_community_nodes(
209 P,
210 max_network_entities,
211 )
213 return community_nodes, entity_to_community
216def get_integrated_flags(
217 integrated_flags: pl.DataFrame,
218 entities: list[str],
219 inferred_links: dict[set] | None = None,
220) -> tuple[Any, int, float, int, int]:
221 total_entities = len(entities)
222 if integrated_flags.is_empty():
223 return 0, 0, 0, 0, total_entities
224 entities_processed = entities.copy()
226 flags_df = integrated_flags.filter(pl.col("qualified_entity").is_in(entities))
227 community_flags = flags_df.get_column("count").sum()
229 flagged = len(flags_df.filter(pl.col("count") > 0)["qualified_entity"].unique())
231 for n in entities: # entities from a network
232 if inferred_links:
233 if n in inferred_links:
234 if n not in entities_processed:
235 flags = integrated_flags.filter(pl.col("qualified_entity") == n)[
236 "count"
237 ].sum()
238 community_flags += flags
239 total_entities += 1
240 else:
241 for l in inferred_links[n]:
242 if l not in entities_processed:
243 flags = integrated_flags.filter(
244 pl.col("qualified_entity") == l
245 )["count"].sum()
246 community_flags += flags
247 total_entities += 1
248 entities_processed.append(l)
249 if flags > 0:
250 flagged += 1
252 for i in inferred_links:
253 if n in inferred_links[i]:
254 if i not in entities_processed:
255 flags = integrated_flags.filter(
256 pl.col("qualified_entity") == i
257 )["count"].sum()
258 community_flags += flags
259 total_entities += 1
260 entities_processed.append(i)
261 if flags > 0:
262 flagged += 1
264 unflagged = total_entities - flagged
265 flagged_per_unflagged = flagged / unflagged if unflagged > 0 else 0
267 flagged_per_unflagged = round(flagged_per_unflagged, 2)
269 flags_per_entity = round(
270 community_flags / total_entities if total_entities > 0 else 0, 2
271 )
273 return (
274 community_flags,
275 flagged,
276 flagged_per_unflagged,
277 flags_per_entity,
278 total_entities,
279 )
282def build_entity_records(
283 community_nodes: list[str],
284 integrated_flags: pl.DataFrame | None = None,
285 inferred_links: defaultdict[set] | None = None,
286) -> list[tuple[str, int, int, int, Any, int, float, float]]:
287 if integrated_flags is None:
288 integrated_flags = pl.DataFrame()
290 entity_records = []
291 for ix, entities in enumerate(community_nodes):
292 (
293 community_flags,
294 flagged,
295 flagged_per_unflagged,
296 flags_per_entity,
297 total_entities,
298 ) = get_integrated_flags(integrated_flags, entities, inferred_links)
300 for n in entities: # entities from a network
301 flags = 0
302 if not integrated_flags.is_empty():
303 flags = integrated_flags.filter(pl.col("qualified_entity") == n)[
304 "count"
305 ].sum()
307 entity_records.append(
308 (
309 n.split(ATTRIBUTE_VALUE_SEPARATOR)[1],
310 flags,
311 ix,
312 total_entities,
313 community_flags,
314 flagged,
315 flags_per_entity,
316 flagged_per_unflagged,
317 )
318 )
320 return entity_records