Coverage for intelligence_toolkit/detect_entity_networks/api.py: 0%
144 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5from collections import defaultdict
6from typing import Any
8import networkx as nx
9import polars as pl
11from intelligence_toolkit.AI import utils
12from intelligence_toolkit.AI.client import OpenAIClient
13from intelligence_toolkit.detect_entity_networks import prompts
14from intelligence_toolkit.detect_entity_networks.classes import (
15 FlagAggregatorType,
16 SummaryData,
17)
18from intelligence_toolkit.detect_entity_networks.config import (
19 DEFAULT_MAX_ATTRIBUTE_DEGREE,
20 ENTITY_LABEL,
21)
22from intelligence_toolkit.detect_entity_networks.explore_networks import (
23 build_network_from_entities,
24 get_entity_graph,
25 simplify_entities_graph,
26)
27from intelligence_toolkit.detect_entity_networks.exposure_report import (
28 build_exposure_report,
29)
30from intelligence_toolkit.detect_entity_networks.identify_networks import (
31 build_entity_records,
32 build_networks,
33 trim_nodeset,
34)
35from intelligence_toolkit.detect_entity_networks.index_and_infer import (
36 index_nodes,
37 infer_nodes,
38)
39from intelligence_toolkit.detect_entity_networks.prepare_model import (
40 build_flag_links,
41 build_flags,
42 build_groups,
43 build_main_graph,
44 format_data_columns,
45 generate_attribute_links,
46)
47from intelligence_toolkit.helpers.classes import IntelligenceWorkflow
48from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
49from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
52class DetectEntityNetworks(IntelligenceWorkflow):
53 def __init__(self, **kwargs) -> None:
54 super().__init__(**kwargs)
55 self.attribute_links = []
56 self.attributes_list = []
57 self.flag_links = []
58 self.group_links = []
59 self.additional_trimmed_attributes = []
60 self.graph = nx.Graph()
61 self.integrated_flags = pl.DataFrame()
62 self.node_types = set()
63 self.inferred_links = {}
64 self.exposure_report = ""
66 def format_links_added(
67 self, values_df: pl.DataFrame, entity_id: int | str, columns: list[str]
68 ) -> pl.DataFrame:
69 return format_data_columns(values_df, columns, entity_id)
71 def get_entity_types(self) -> list[str]:
72 return sorted(
73 [
74 ENTITY_LABEL,
75 *list(self.node_types),
76 ]
77 )
79 def get_attributes(self) -> pl.DataFrame:
80 return pl.DataFrame(self.attributes_list, columns=["Attribute"])
82 def remove_attributes(self, selected_rows: pl.DataFrame) -> list[str]:
83 self.additional_trimmed_attributes = selected_rows["Attribute"].tolist()
85 def add_attribute_links(
86 self, data_df: pl.DataFrame, entity_id_column: str, columns_to_link: list[str]
87 ) -> list:
88 for column in columns_to_link:
89 data_df = data_df.filter(pl.col(column).is_not_null())
91 data_df_formatted = self.format_links_added(
92 data_df, entity_id_column, columns_to_link
93 )
94 links = generate_attribute_links(
95 data_df_formatted, entity_id_column, columns_to_link
96 )
97 self.attribute_links.extend(links)
98 for attribute_link in links:
99 self.node_types.add(attribute_link[0][1])
100 self.graph = build_main_graph(self.attribute_links)
102 return self.attribute_links
104 def add_flag_links(
105 self,
106 data_df: pl.DataFrame,
107 entity_id_column: str,
108 flag_columns: list[str],
109 flag_format: FlagAggregatorType,
110 ) -> list:
111 data_df = self.format_links_added(data_df, entity_id_column, flag_columns)
113 links = build_flag_links(
114 data_df,
115 entity_id_column,
116 flag_format,
117 flag_columns,
118 )
119 self.flag_links.extend(links)
121 (
122 self.integrated_flags,
123 self.max_entity_flags,
124 self.mean_flagged_flags,
125 ) = build_flags(self.flag_links)
126 return self.flag_links
128 def add_group_links(
129 self,
130 data_df: pl.DataFrame,
131 entity_id_column: str,
132 group_cols: list[str],
133 ) -> list:
134 data_df = self.format_links_added(data_df, entity_id_column, group_cols)
136 links = build_groups(
137 group_cols,
138 data_df,
139 entity_id_column,
140 )
141 self.group_links.extend(links)
143 return self.group_links
145 def get_model_summary_data(self) -> SummaryData:
146 num_entities = 0
147 num_attributes = 0
148 num_flags = 0
149 groups = set()
150 for link_list in self.group_links:
151 for link in link_list:
152 groups.add(f"{link[1]}{ATTRIBUTE_VALUE_SEPARATOR}{link[2]}")
154 if len(self.graph.nodes) > 0:
155 all_nodes = self.graph.nodes()
156 entity_nodes = [node for node in all_nodes if node.startswith(ENTITY_LABEL)]
157 self.attributes_list = [
158 node for node in all_nodes if not node.startswith(ENTITY_LABEL)
159 ]
160 num_entities = len(entity_nodes)
161 num_attributes = len(all_nodes) - num_entities
163 if len(self.integrated_flags) > 0:
164 num_flags = self.integrated_flags["count"].sum()
165 return SummaryData(
166 entities=num_entities,
167 attributes=num_attributes,
168 flags=num_flags,
169 groups=len(groups),
170 links=len(self.graph.edges()),
171 )
173 def get_model_summary_value(self):
174 summary = self.get_model_summary_data()
175 return f"Number of entities: {summary.entities}, Number of attributes: {summary.attributes}, Number of flags: {summary.flags}, Number of groups: {summary.groups}, Number of links: {summary.links}"
177 async def index_nodes(self, node_types: list[str]) -> None:
178 (
179 self.embedded_texts,
180 self.nearest_text_distances,
181 self.nearest_text_indices,
182 ) = await index_nodes(
183 node_types,
184 self.graph,
185 )
187 def infer_nodes(
188 self,
189 similarity_threshold: float,
190 progress_callbacks: list[ProgressBatchCallback] | None = None,
191 ) -> defaultdict[Any, set]:
192 self.inferred_links = infer_nodes(
193 similarity_threshold,
194 self.embedded_texts,
195 self.nearest_text_indices,
196 self.nearest_text_distances,
197 progress_callbacks,
198 )
199 return self.inferred_links
201 def clear_inferred_links(self) -> None:
202 self.inferred_links = {}
204 def clear_data_model(self) -> None:
205 self.attribute_links = []
206 self.flag_links = []
207 self.group_links = []
208 self.graph = nx.Graph()
209 self.integrated_flags = pl.DataFrame()
210 self.node_types = set()
211 self.inferred_links = {}
213 def inferred_nodes_df(self) -> pl.DataFrame:
214 link_list = [
215 (text, n)
216 for text, near in self.inferred_links.items()
217 for n in near
218 if text < n
219 ]
220 inferred_df = pl.DataFrame(link_list, schema=["text", "similar"])
221 return inferred_df.with_columns(
222 [
223 pl.col("text").str.replace(
224 ENTITY_LABEL + ATTRIBUTE_VALUE_SEPARATOR, ""
225 ),
226 pl.col("similar").str.replace(
227 ENTITY_LABEL + ATTRIBUTE_VALUE_SEPARATOR, ""
228 ),
229 ]
230 ).sort(["text", "similar"])
232 def identify(
233 self,
234 max_network_entities: int | None = 20,
235 max_attribute_degree: int | None = DEFAULT_MAX_ATTRIBUTE_DEGREE,
236 supporting_attribute_types: list[str] | None = None,
237 ) -> None:
238 (trimmed_degrees, trimmed_nodes) = trim_nodeset(
239 self.graph, max_attribute_degree, self.additional_trimmed_attributes
240 )
242 self.trimmed_attributes = pl.DataFrame(
243 list(trimmed_degrees),
244 schema=["Attribute", "Linked Entities"],
245 ).sort("Linked Entities")
247 (
248 self.community_nodes,
249 self.entity_to_community_ix,
250 ) = build_networks(
251 self.graph,
252 trimmed_nodes,
253 self.inferred_links,
254 supporting_attribute_types,
255 max_network_entities,
256 )
258 self.entity_records = build_entity_records(
259 self.community_nodes,
260 self.integrated_flags,
261 self.inferred_links,
262 )
263 return self.entity_records
265 def get_community_sizes(self) -> list[int]:
266 return [len(comm) for comm in self.community_nodes if len(comm) > 1]
268 def get_records_summary(self) -> str:
269 if len(self.community_nodes) > 0:
270 comm_sizes = self.get_community_sizes()
271 max_comm_size = max(comm_sizes)
273 return f"Networks identified: {len(self.community_nodes)} ({len(comm_sizes)} with multiple entities, maximum {max_comm_size})"
274 return ""
276 def get_entity_df(self) -> pl.DataFrame:
277 entity_df = pl.DataFrame(
278 self.entity_records,
279 schema=[
280 "entity_id",
281 "entity_flags",
282 "network_id",
283 "network_entities",
284 "network_flags",
285 "flagged",
286 "flags/entity",
287 "flagged/unflagged",
288 ],
289 )
290 return entity_df.sort(by=["flagged/unflagged"], descending=True)
292 def get_grouped_df(self) -> pl.DataFrame:
293 show_df = self.get_entity_df()
295 for group_links in self.group_links:
296 selected_df = pl.DataFrame(
297 group_links, schema=["entity_id", "group", "value"]
298 )
300 selected_df = selected_df.filter(pl.col("value").is_not_null())
302 selected_df = selected_df.pivot(
303 values="value",
304 index="entity_id",
305 columns="group",
306 aggregate_function="first",
307 )
309 show_df = show_df.join(selected_df, on="entity_id", how="left")
310 return show_df
312 def get_exposure_report(
313 self, selected_entity: str, selected_network: int
314 ) -> pl.DataFrame:
315 c_nodes = self.community_nodes[selected_network]
317 self.exposure_report = build_exposure_report(
318 self.integrated_flags,
319 selected_entity,
320 c_nodes,
321 self.get_entities_graph(selected_network),
322 self.inferred_links,
323 )
324 return self.exposure_report
326 def get_entities_graph(self, selected_network: int) -> nx.Graph:
327 c_nodes = self.community_nodes[selected_network]
328 return build_network_from_entities(
329 self.graph,
330 self.entity_to_community_ix,
331 self.integrated_flags,
332 self.trimmed_attributes,
333 self.inferred_links,
334 c_nodes,
335 )
337 def get_single_entity_graph(
338 self, entities_graph: nx.Graph, selected_entity: str
339 ) -> tuple[list, list]:
340 entity_name = f"{ENTITY_LABEL}{ATTRIBUTE_VALUE_SEPARATOR}{selected_entity}"
341 return get_entity_graph(entities_graph, entity_name, self.get_entity_types())
343 def get_merged_graph_df(self, selected_network: int) -> pl.DataFrame:
344 self.simplified_graph = simplify_entities_graph(
345 self.get_entities_graph(selected_network),
346 )
347 nodes = pl.DataFrame(
348 [
349 (n, d["type"], d["flags"])
350 for n, d in self.simplified_graph.nodes(data=True)
351 ],
352 schema=["node", "type", "flags"],
353 )
354 links = pl.DataFrame(
355 list(self.simplified_graph.edges()),
356 schema=["source", "target"],
357 )
358 links = links.with_columns(
359 pl.col("target")
360 .apply(lambda x: x.split(ATTRIBUTE_VALUE_SEPARATOR)[0])
361 .alias("attribute")
362 )
364 return nodes, links
366 def generate_report(
367 self,
368 selected_network,
369 selected_entity: str | None = "",
370 ai_instructions: str | None = prompts.user_prompt,
371 ):
372 nodes_merged, links_merged = self.get_merged_graph_df(selected_network)
373 variables = {
374 "entity_id": selected_entity,
375 "network_id": selected_network,
376 "max_flags": self.max_entity_flags,
377 "mean_flags": self.mean_flagged_flags,
378 "exposure": self.exposure_report,
379 "network_nodes": nodes_merged.write_csv(),
380 "network_edges": links_merged.write_csv(),
381 }
382 messages = utils.generate_messages(
383 ai_instructions,
384 prompts.list_prompts["report_prompt"],
385 variables,
386 prompts.list_prompts["safety_prompt"],
387 )
388 return OpenAIClient(self.ai_configuration).generate_chat(messages, stream=False)