Coverage for intelligence_toolkit/detect_entity_networks/prepare_model.py: 93%

96 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5import re 

6from collections import defaultdict 

7from typing import Any 

8 

9import networkx as nx 

10import polars as pl 

11 

12from intelligence_toolkit.detect_entity_networks.classes import FlagAggregatorType 

13from intelligence_toolkit.detect_entity_networks.config import ENTITY_LABEL 

14from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR 

15from intelligence_toolkit.helpers.texts import clean_text_for_csv 

16 

17 

18def clean_text(text: str | int) -> str: 

19 # remove punctuation but retain characters and digits in any language 

20 # compress whitespace to single space 

21 cleaned_text = clean_text_for_csv(text).strip() 

22 return re.sub(r"\s+", " ", cleaned_text) 

23 

24 

25def format_data_columns( 

26 values_df: pl.DataFrame, columns_to_link: list[str], entity_id_column: str | int 

27) -> pl.DataFrame: 

28 values_df = values_df.with_columns( 

29 [ 

30 pl.col(entity_id_column) 

31 .map_elements(clean_text, return_dtype=pl.Utf8) 

32 .alias(entity_id_column) 

33 ] 

34 ) 

35 for value_col in columns_to_link: 

36 values_df = values_df.with_columns( 

37 [ 

38 pl.col(value_col) 

39 .map_elements(clean_text, return_dtype=pl.Utf8) 

40 .alias(value_col) 

41 ] 

42 ) 

43 return values_df 

44 

45 

46def generate_attribute_links( 

47 data_df: pl.DataFrame, 

48 entity_id_column: str, 

49 columns_to_link: list[str], 

50 existing_links: list | None = None, 

51) -> list: 

52 """ 

53 Generate attribute links for the given entity and columns. 

54 

55 Args: 

56 data_df (pl.DataFrame): The DataFrame containing the data. 

57 entity_id_column (str): The name of the column containing entity IDs. 

58 columns_to_link (list[str]): A list of column names to link as attributes. 

59 existing_links (list, optional): Existing attribute links. Defaults to None. 

60 

61 Returns: 

62 list: A list of attribute links. 

63 """ 

64 attribute_links = existing_links or [] 

65 

66 for value_col in columns_to_link: 

67 data_df = data_df.with_columns([pl.lit(value_col).alias("attribute_col")]) 

68 

69 attribute_links.append( 

70 data_df.select([entity_id_column, "attribute_col", value_col]) 

71 .to_numpy() 

72 .tolist() 

73 ) 

74 

75 return attribute_links 

76 

77 

78def build_main_graph( 

79 attribute_links: list[Any] | None = None, 

80) -> nx.Graph: 

81 graph = nx.Graph() 

82 if attribute_links is None: 

83 return graph 

84 

85 value_to_atts = defaultdict(set) 

86 for link_list in attribute_links: 

87 for link in link_list: 

88 n1 = f"{ENTITY_LABEL}{ATTRIBUTE_VALUE_SEPARATOR}{link[0]}" 

89 n2 = f"{link[1]}{ATTRIBUTE_VALUE_SEPARATOR}{link[2]}" 

90 edge = (n1, n2) if n1 < n2 else (n2, n1) 

91 graph.add_edge(edge[0], edge[1], type=link[1]) 

92 graph.add_node(n1, type=ENTITY_LABEL) 

93 graph.add_node(n2, type=link[1]) 

94 value_to_atts[link[2]].add(n2) 

95 

96 for atts in value_to_atts.values(): 

97 att_list = list(atts) 

98 for i, att1 in enumerate(att_list): 

99 for att2 in att_list[i + 1 :]: 

100 edge = (att1, att2) if att1 < att2 else (att2, att1) 

101 graph.add_edge(edge[0], edge[1], type="equality") 

102 return graph 

103 

104 

105def build_flag_links( 

106 df_flag: pl.DataFrame, 

107 entity_col: str, 

108 flag_agg: FlagAggregatorType, 

109 flag_columns: list[str], 

110 existing_flag_links: list | None = None, 

111) -> list[Any]: 

112 flag_links = existing_flag_links or [] 

113 

114 if entity_col not in df_flag.columns: 

115 msg = f"Column {entity_col} not found in the DataFrame." 

116 raise ValueError(msg) 

117 

118 for value_col in flag_columns: 

119 if value_col not in df_flag.columns: 

120 msg = f"Column {value_col} not found in the DataFrame." 

121 raise ValueError(msg) 

122 gdf = df_flag.with_columns([pl.col(value_col).cast(pl.Int32).alias(value_col)]) 

123 gdf = gdf.group_by(entity_col).agg([pl.sum(col) for col in flag_columns]) 

124 vals = ( 

125 gdf[ 

126 [ 

127 entity_col, 

128 value_col, 

129 ] 

130 ] 

131 .to_numpy() 

132 .tolist() 

133 ) 

134 if flag_agg == FlagAggregatorType.Instance: 

135 gdf = gdf.with_columns([pl.lit(1).alias("count_col")]) 

136 flag_links.extend([[val[0], value_col, val[1], 1] for val in vals]) 

137 elif flag_agg == FlagAggregatorType.Count: 

138 flag_links.extend([[val[0], value_col, value_col, val[1]] for val in vals]) 

139 

140 return flag_links 

141 

142 

143def transform_entity(entity): 

144 return f"{ENTITY_LABEL}{ATTRIBUTE_VALUE_SEPARATOR}{entity}" 

145 

146 

147def build_flags( 

148 network_flag_links: list | None = None, 

149) -> tuple: 

150 if network_flag_links is None: 

151 return pl.DataFrame(), 0, 0 

152 

153 flags = pl.DataFrame( 

154 { 

155 "entity": [item[0] for item in network_flag_links], 

156 "type": [item[1] for item in network_flag_links], 

157 "flag": [item[2] for item in network_flag_links], 

158 "count": [item[3] for item in network_flag_links], 

159 } 

160 ) 

161 flags = flags.group_by(["entity", "type", "flag"]).agg(pl.sum("count")) 

162 

163 flags = flags.with_columns( 

164 [flags["entity"].map_elements(transform_entity).alias("qualified_entity")] 

165 ) 

166 overall_df = flags.group_by("qualified_entity").agg(pl.sum("count")) 

167 max_entity_flags = overall_df["count"].max() 

168 mean_flagged_flags = round( 

169 overall_df.filter(pl.col("count") > 0)["count"].mean(), 2 

170 ) 

171 

172 return flags, max_entity_flags, mean_flagged_flags 

173 

174 

175def build_groups( 

176 value_cols: list[str], 

177 df_groups: pl.DataFrame, 

178 entity_col: str, 

179 existing_group_links: list | None = None, 

180) -> list[Any]: 

181 group_links = existing_group_links or [] 

182 

183 if df_groups.is_empty(): 

184 return group_links 

185 

186 for value_col in value_cols: 

187 if value_col not in df_groups.columns: 

188 msg = f"Column {value_col} not found in the DataFrame." 

189 raise ValueError(msg) 

190 

191 df_groups = df_groups.with_columns([pl.lit(value_col).alias("attribute_col")]) 

192 if entity_col not in df_groups.columns: 

193 msg = f"Column {entity_col} not found in the DataFrame." 

194 raise ValueError(msg) 

195 

196 link_list = ( 

197 df_groups.select([entity_col, "attribute_col", value_col]) 

198 .to_numpy() 

199 .tolist() 

200 ) 

201 group_links.append(link_list) 

202 

203 return group_links 

204 

205 

206def build_model_with_attributes( 

207 input_dataframe: pl.DataFrame, entity_id_column: str, columns_to_link: list[str] 

208) -> nx.Graph: 

209 data_df = format_data_columns(input_dataframe, columns_to_link, entity_id_column) 

210 attribute_links = generate_attribute_links( 

211 data_df, entity_id_column, columns_to_link 

212 ) 

213 

214 return build_main_graph(attribute_links) 

215 

216 

217def get_flags( 

218 flags_dataframe, entity_col, flag_agg, value_cols 

219) -> tuple[pl.DataFrame, int, int]: 

220 flag_links = build_flag_links( 

221 flags_dataframe, 

222 entity_col, 

223 flag_agg, 

224 value_cols, 

225 ) 

226 return build_flags(flag_links)