Coverage for intelligence_toolkit/detect_entity_networks/prepare_model.py: 93%
96 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5import re
6from collections import defaultdict
7from typing import Any
9import networkx as nx
10import polars as pl
12from intelligence_toolkit.detect_entity_networks.classes import FlagAggregatorType
13from intelligence_toolkit.detect_entity_networks.config import ENTITY_LABEL
14from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
15from intelligence_toolkit.helpers.texts import clean_text_for_csv
18def clean_text(text: str | int) -> str:
19 # remove punctuation but retain characters and digits in any language
20 # compress whitespace to single space
21 cleaned_text = clean_text_for_csv(text).strip()
22 return re.sub(r"\s+", " ", cleaned_text)
25def format_data_columns(
26 values_df: pl.DataFrame, columns_to_link: list[str], entity_id_column: str | int
27) -> pl.DataFrame:
28 values_df = values_df.with_columns(
29 [
30 pl.col(entity_id_column)
31 .map_elements(clean_text, return_dtype=pl.Utf8)
32 .alias(entity_id_column)
33 ]
34 )
35 for value_col in columns_to_link:
36 values_df = values_df.with_columns(
37 [
38 pl.col(value_col)
39 .map_elements(clean_text, return_dtype=pl.Utf8)
40 .alias(value_col)
41 ]
42 )
43 return values_df
46def generate_attribute_links(
47 data_df: pl.DataFrame,
48 entity_id_column: str,
49 columns_to_link: list[str],
50 existing_links: list | None = None,
51) -> list:
52 """
53 Generate attribute links for the given entity and columns.
55 Args:
56 data_df (pl.DataFrame): The DataFrame containing the data.
57 entity_id_column (str): The name of the column containing entity IDs.
58 columns_to_link (list[str]): A list of column names to link as attributes.
59 existing_links (list, optional): Existing attribute links. Defaults to None.
61 Returns:
62 list: A list of attribute links.
63 """
64 attribute_links = existing_links or []
66 for value_col in columns_to_link:
67 data_df = data_df.with_columns([pl.lit(value_col).alias("attribute_col")])
69 attribute_links.append(
70 data_df.select([entity_id_column, "attribute_col", value_col])
71 .to_numpy()
72 .tolist()
73 )
75 return attribute_links
78def build_main_graph(
79 attribute_links: list[Any] | None = None,
80) -> nx.Graph:
81 graph = nx.Graph()
82 if attribute_links is None:
83 return graph
85 value_to_atts = defaultdict(set)
86 for link_list in attribute_links:
87 for link in link_list:
88 n1 = f"{ENTITY_LABEL}{ATTRIBUTE_VALUE_SEPARATOR}{link[0]}"
89 n2 = f"{link[1]}{ATTRIBUTE_VALUE_SEPARATOR}{link[2]}"
90 edge = (n1, n2) if n1 < n2 else (n2, n1)
91 graph.add_edge(edge[0], edge[1], type=link[1])
92 graph.add_node(n1, type=ENTITY_LABEL)
93 graph.add_node(n2, type=link[1])
94 value_to_atts[link[2]].add(n2)
96 for atts in value_to_atts.values():
97 att_list = list(atts)
98 for i, att1 in enumerate(att_list):
99 for att2 in att_list[i + 1 :]:
100 edge = (att1, att2) if att1 < att2 else (att2, att1)
101 graph.add_edge(edge[0], edge[1], type="equality")
102 return graph
105def build_flag_links(
106 df_flag: pl.DataFrame,
107 entity_col: str,
108 flag_agg: FlagAggregatorType,
109 flag_columns: list[str],
110 existing_flag_links: list | None = None,
111) -> list[Any]:
112 flag_links = existing_flag_links or []
114 if entity_col not in df_flag.columns:
115 msg = f"Column {entity_col} not found in the DataFrame."
116 raise ValueError(msg)
118 for value_col in flag_columns:
119 if value_col not in df_flag.columns:
120 msg = f"Column {value_col} not found in the DataFrame."
121 raise ValueError(msg)
122 gdf = df_flag.with_columns([pl.col(value_col).cast(pl.Int32).alias(value_col)])
123 gdf = gdf.group_by(entity_col).agg([pl.sum(col) for col in flag_columns])
124 vals = (
125 gdf[
126 [
127 entity_col,
128 value_col,
129 ]
130 ]
131 .to_numpy()
132 .tolist()
133 )
134 if flag_agg == FlagAggregatorType.Instance:
135 gdf = gdf.with_columns([pl.lit(1).alias("count_col")])
136 flag_links.extend([[val[0], value_col, val[1], 1] for val in vals])
137 elif flag_agg == FlagAggregatorType.Count:
138 flag_links.extend([[val[0], value_col, value_col, val[1]] for val in vals])
140 return flag_links
143def transform_entity(entity):
144 return f"{ENTITY_LABEL}{ATTRIBUTE_VALUE_SEPARATOR}{entity}"
147def build_flags(
148 network_flag_links: list | None = None,
149) -> tuple:
150 if network_flag_links is None:
151 return pl.DataFrame(), 0, 0
153 flags = pl.DataFrame(
154 {
155 "entity": [item[0] for item in network_flag_links],
156 "type": [item[1] for item in network_flag_links],
157 "flag": [item[2] for item in network_flag_links],
158 "count": [item[3] for item in network_flag_links],
159 }
160 )
161 flags = flags.group_by(["entity", "type", "flag"]).agg(pl.sum("count"))
163 flags = flags.with_columns(
164 [flags["entity"].map_elements(transform_entity).alias("qualified_entity")]
165 )
166 overall_df = flags.group_by("qualified_entity").agg(pl.sum("count"))
167 max_entity_flags = overall_df["count"].max()
168 mean_flagged_flags = round(
169 overall_df.filter(pl.col("count") > 0)["count"].mean(), 2
170 )
172 return flags, max_entity_flags, mean_flagged_flags
175def build_groups(
176 value_cols: list[str],
177 df_groups: pl.DataFrame,
178 entity_col: str,
179 existing_group_links: list | None = None,
180) -> list[Any]:
181 group_links = existing_group_links or []
183 if df_groups.is_empty():
184 return group_links
186 for value_col in value_cols:
187 if value_col not in df_groups.columns:
188 msg = f"Column {value_col} not found in the DataFrame."
189 raise ValueError(msg)
191 df_groups = df_groups.with_columns([pl.lit(value_col).alias("attribute_col")])
192 if entity_col not in df_groups.columns:
193 msg = f"Column {entity_col} not found in the DataFrame."
194 raise ValueError(msg)
196 link_list = (
197 df_groups.select([entity_col, "attribute_col", value_col])
198 .to_numpy()
199 .tolist()
200 )
201 group_links.append(link_list)
203 return group_links
206def build_model_with_attributes(
207 input_dataframe: pl.DataFrame, entity_id_column: str, columns_to_link: list[str]
208) -> nx.Graph:
209 data_df = format_data_columns(input_dataframe, columns_to_link, entity_id_column)
210 attribute_links = generate_attribute_links(
211 data_df, entity_id_column, columns_to_link
212 )
214 return build_main_graph(attribute_links)
217def get_flags(
218 flags_dataframe, entity_col, flag_agg, value_cols
219) -> tuple[pl.DataFrame, int, int]:
220 flag_links = build_flag_links(
221 flags_dataframe,
222 entity_col,
223 flag_agg,
224 value_cols,
225 )
226 return build_flags(flag_links)