Coverage for intelligence_toolkit/detect_entity_networks/api.py: 0%

144 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5from collections import defaultdict 

6from typing import Any 

7 

8import networkx as nx 

9import polars as pl 

10 

11from intelligence_toolkit.AI import utils 

12from intelligence_toolkit.AI.client import OpenAIClient 

13from intelligence_toolkit.detect_entity_networks import prompts 

14from intelligence_toolkit.detect_entity_networks.classes import ( 

15 FlagAggregatorType, 

16 SummaryData, 

17) 

18from intelligence_toolkit.detect_entity_networks.config import ( 

19 DEFAULT_MAX_ATTRIBUTE_DEGREE, 

20 ENTITY_LABEL, 

21) 

22from intelligence_toolkit.detect_entity_networks.explore_networks import ( 

23 build_network_from_entities, 

24 get_entity_graph, 

25 simplify_entities_graph, 

26) 

27from intelligence_toolkit.detect_entity_networks.exposure_report import ( 

28 build_exposure_report, 

29) 

30from intelligence_toolkit.detect_entity_networks.identify_networks import ( 

31 build_entity_records, 

32 build_networks, 

33 trim_nodeset, 

34) 

35from intelligence_toolkit.detect_entity_networks.index_and_infer import ( 

36 index_nodes, 

37 infer_nodes, 

38) 

39from intelligence_toolkit.detect_entity_networks.prepare_model import ( 

40 build_flag_links, 

41 build_flags, 

42 build_groups, 

43 build_main_graph, 

44 format_data_columns, 

45 generate_attribute_links, 

46) 

47from intelligence_toolkit.helpers.classes import IntelligenceWorkflow 

48from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR 

49from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

50 

51 

52class DetectEntityNetworks(IntelligenceWorkflow): 

53 def __init__(self, **kwargs) -> None: 

54 super().__init__(**kwargs) 

55 self.attribute_links = [] 

56 self.attributes_list = [] 

57 self.flag_links = [] 

58 self.group_links = [] 

59 self.additional_trimmed_attributes = [] 

60 self.graph = nx.Graph() 

61 self.integrated_flags = pl.DataFrame() 

62 self.node_types = set() 

63 self.inferred_links = {} 

64 self.exposure_report = "" 

65 

66 def format_links_added( 

67 self, values_df: pl.DataFrame, entity_id: int | str, columns: list[str] 

68 ) -> pl.DataFrame: 

69 return format_data_columns(values_df, columns, entity_id) 

70 

71 def get_entity_types(self) -> list[str]: 

72 return sorted( 

73 [ 

74 ENTITY_LABEL, 

75 *list(self.node_types), 

76 ] 

77 ) 

78 

79 def get_attributes(self) -> pl.DataFrame: 

80 return pl.DataFrame(self.attributes_list, columns=["Attribute"]) 

81 

82 def remove_attributes(self, selected_rows: pl.DataFrame) -> list[str]: 

83 self.additional_trimmed_attributes = selected_rows["Attribute"].tolist() 

84 

85 def add_attribute_links( 

86 self, data_df: pl.DataFrame, entity_id_column: str, columns_to_link: list[str] 

87 ) -> list: 

88 for column in columns_to_link: 

89 data_df = data_df.filter(pl.col(column).is_not_null()) 

90 

91 data_df_formatted = self.format_links_added( 

92 data_df, entity_id_column, columns_to_link 

93 ) 

94 links = generate_attribute_links( 

95 data_df_formatted, entity_id_column, columns_to_link 

96 ) 

97 self.attribute_links.extend(links) 

98 for attribute_link in links: 

99 self.node_types.add(attribute_link[0][1]) 

100 self.graph = build_main_graph(self.attribute_links) 

101 

102 return self.attribute_links 

103 

104 def add_flag_links( 

105 self, 

106 data_df: pl.DataFrame, 

107 entity_id_column: str, 

108 flag_columns: list[str], 

109 flag_format: FlagAggregatorType, 

110 ) -> list: 

111 data_df = self.format_links_added(data_df, entity_id_column, flag_columns) 

112 

113 links = build_flag_links( 

114 data_df, 

115 entity_id_column, 

116 flag_format, 

117 flag_columns, 

118 ) 

119 self.flag_links.extend(links) 

120 

121 ( 

122 self.integrated_flags, 

123 self.max_entity_flags, 

124 self.mean_flagged_flags, 

125 ) = build_flags(self.flag_links) 

126 return self.flag_links 

127 

128 def add_group_links( 

129 self, 

130 data_df: pl.DataFrame, 

131 entity_id_column: str, 

132 group_cols: list[str], 

133 ) -> list: 

134 data_df = self.format_links_added(data_df, entity_id_column, group_cols) 

135 

136 links = build_groups( 

137 group_cols, 

138 data_df, 

139 entity_id_column, 

140 ) 

141 self.group_links.extend(links) 

142 

143 return self.group_links 

144 

145 def get_model_summary_data(self) -> SummaryData: 

146 num_entities = 0 

147 num_attributes = 0 

148 num_flags = 0 

149 groups = set() 

150 for link_list in self.group_links: 

151 for link in link_list: 

152 groups.add(f"{link[1]}{ATTRIBUTE_VALUE_SEPARATOR}{link[2]}") 

153 

154 if len(self.graph.nodes) > 0: 

155 all_nodes = self.graph.nodes() 

156 entity_nodes = [node for node in all_nodes if node.startswith(ENTITY_LABEL)] 

157 self.attributes_list = [ 

158 node for node in all_nodes if not node.startswith(ENTITY_LABEL) 

159 ] 

160 num_entities = len(entity_nodes) 

161 num_attributes = len(all_nodes) - num_entities 

162 

163 if len(self.integrated_flags) > 0: 

164 num_flags = self.integrated_flags["count"].sum() 

165 return SummaryData( 

166 entities=num_entities, 

167 attributes=num_attributes, 

168 flags=num_flags, 

169 groups=len(groups), 

170 links=len(self.graph.edges()), 

171 ) 

172 

173 def get_model_summary_value(self): 

174 summary = self.get_model_summary_data() 

175 return f"Number of entities: {summary.entities}, Number of attributes: {summary.attributes}, Number of flags: {summary.flags}, Number of groups: {summary.groups}, Number of links: {summary.links}" 

176 

177 async def index_nodes(self, node_types: list[str]) -> None: 

178 ( 

179 self.embedded_texts, 

180 self.nearest_text_distances, 

181 self.nearest_text_indices, 

182 ) = await index_nodes( 

183 node_types, 

184 self.graph, 

185 ) 

186 

187 def infer_nodes( 

188 self, 

189 similarity_threshold: float, 

190 progress_callbacks: list[ProgressBatchCallback] | None = None, 

191 ) -> defaultdict[Any, set]: 

192 self.inferred_links = infer_nodes( 

193 similarity_threshold, 

194 self.embedded_texts, 

195 self.nearest_text_indices, 

196 self.nearest_text_distances, 

197 progress_callbacks, 

198 ) 

199 return self.inferred_links 

200 

201 def clear_inferred_links(self) -> None: 

202 self.inferred_links = {} 

203 

204 def clear_data_model(self) -> None: 

205 self.attribute_links = [] 

206 self.flag_links = [] 

207 self.group_links = [] 

208 self.graph = nx.Graph() 

209 self.integrated_flags = pl.DataFrame() 

210 self.node_types = set() 

211 self.inferred_links = {} 

212 

213 def inferred_nodes_df(self) -> pl.DataFrame: 

214 link_list = [ 

215 (text, n) 

216 for text, near in self.inferred_links.items() 

217 for n in near 

218 if text < n 

219 ] 

220 inferred_df = pl.DataFrame(link_list, schema=["text", "similar"]) 

221 return inferred_df.with_columns( 

222 [ 

223 pl.col("text").str.replace( 

224 ENTITY_LABEL + ATTRIBUTE_VALUE_SEPARATOR, "" 

225 ), 

226 pl.col("similar").str.replace( 

227 ENTITY_LABEL + ATTRIBUTE_VALUE_SEPARATOR, "" 

228 ), 

229 ] 

230 ).sort(["text", "similar"]) 

231 

232 def identify( 

233 self, 

234 max_network_entities: int | None = 20, 

235 max_attribute_degree: int | None = DEFAULT_MAX_ATTRIBUTE_DEGREE, 

236 supporting_attribute_types: list[str] | None = None, 

237 ) -> None: 

238 (trimmed_degrees, trimmed_nodes) = trim_nodeset( 

239 self.graph, max_attribute_degree, self.additional_trimmed_attributes 

240 ) 

241 

242 self.trimmed_attributes = pl.DataFrame( 

243 list(trimmed_degrees), 

244 schema=["Attribute", "Linked Entities"], 

245 ).sort("Linked Entities") 

246 

247 ( 

248 self.community_nodes, 

249 self.entity_to_community_ix, 

250 ) = build_networks( 

251 self.graph, 

252 trimmed_nodes, 

253 self.inferred_links, 

254 supporting_attribute_types, 

255 max_network_entities, 

256 ) 

257 

258 self.entity_records = build_entity_records( 

259 self.community_nodes, 

260 self.integrated_flags, 

261 self.inferred_links, 

262 ) 

263 return self.entity_records 

264 

265 def get_community_sizes(self) -> list[int]: 

266 return [len(comm) for comm in self.community_nodes if len(comm) > 1] 

267 

268 def get_records_summary(self) -> str: 

269 if len(self.community_nodes) > 0: 

270 comm_sizes = self.get_community_sizes() 

271 max_comm_size = max(comm_sizes) 

272 

273 return f"Networks identified: {len(self.community_nodes)} ({len(comm_sizes)} with multiple entities, maximum {max_comm_size})" 

274 return "" 

275 

276 def get_entity_df(self) -> pl.DataFrame: 

277 entity_df = pl.DataFrame( 

278 self.entity_records, 

279 schema=[ 

280 "entity_id", 

281 "entity_flags", 

282 "network_id", 

283 "network_entities", 

284 "network_flags", 

285 "flagged", 

286 "flags/entity", 

287 "flagged/unflagged", 

288 ], 

289 ) 

290 return entity_df.sort(by=["flagged/unflagged"], descending=True) 

291 

292 def get_grouped_df(self) -> pl.DataFrame: 

293 show_df = self.get_entity_df() 

294 

295 for group_links in self.group_links: 

296 selected_df = pl.DataFrame( 

297 group_links, schema=["entity_id", "group", "value"] 

298 ) 

299 

300 selected_df = selected_df.filter(pl.col("value").is_not_null()) 

301 

302 selected_df = selected_df.pivot( 

303 values="value", 

304 index="entity_id", 

305 columns="group", 

306 aggregate_function="first", 

307 ) 

308 

309 show_df = show_df.join(selected_df, on="entity_id", how="left") 

310 return show_df 

311 

312 def get_exposure_report( 

313 self, selected_entity: str, selected_network: int 

314 ) -> pl.DataFrame: 

315 c_nodes = self.community_nodes[selected_network] 

316 

317 self.exposure_report = build_exposure_report( 

318 self.integrated_flags, 

319 selected_entity, 

320 c_nodes, 

321 self.get_entities_graph(selected_network), 

322 self.inferred_links, 

323 ) 

324 return self.exposure_report 

325 

326 def get_entities_graph(self, selected_network: int) -> nx.Graph: 

327 c_nodes = self.community_nodes[selected_network] 

328 return build_network_from_entities( 

329 self.graph, 

330 self.entity_to_community_ix, 

331 self.integrated_flags, 

332 self.trimmed_attributes, 

333 self.inferred_links, 

334 c_nodes, 

335 ) 

336 

337 def get_single_entity_graph( 

338 self, entities_graph: nx.Graph, selected_entity: str 

339 ) -> tuple[list, list]: 

340 entity_name = f"{ENTITY_LABEL}{ATTRIBUTE_VALUE_SEPARATOR}{selected_entity}" 

341 return get_entity_graph(entities_graph, entity_name, self.get_entity_types()) 

342 

343 def get_merged_graph_df(self, selected_network: int) -> pl.DataFrame: 

344 self.simplified_graph = simplify_entities_graph( 

345 self.get_entities_graph(selected_network), 

346 ) 

347 nodes = pl.DataFrame( 

348 [ 

349 (n, d["type"], d["flags"]) 

350 for n, d in self.simplified_graph.nodes(data=True) 

351 ], 

352 schema=["node", "type", "flags"], 

353 ) 

354 links = pl.DataFrame( 

355 list(self.simplified_graph.edges()), 

356 schema=["source", "target"], 

357 ) 

358 links = links.with_columns( 

359 pl.col("target") 

360 .apply(lambda x: x.split(ATTRIBUTE_VALUE_SEPARATOR)[0]) 

361 .alias("attribute") 

362 ) 

363 

364 return nodes, links 

365 

366 def generate_report( 

367 self, 

368 selected_network, 

369 selected_entity: str | None = "", 

370 ai_instructions: str | None = prompts.user_prompt, 

371 ): 

372 nodes_merged, links_merged = self.get_merged_graph_df(selected_network) 

373 variables = { 

374 "entity_id": selected_entity, 

375 "network_id": selected_network, 

376 "max_flags": self.max_entity_flags, 

377 "mean_flags": self.mean_flagged_flags, 

378 "exposure": self.exposure_report, 

379 "network_nodes": nodes_merged.write_csv(), 

380 "network_edges": links_merged.write_csv(), 

381 } 

382 messages = utils.generate_messages( 

383 ai_instructions, 

384 prompts.list_prompts["report_prompt"], 

385 variables, 

386 prompts.list_prompts["safety_prompt"], 

387 ) 

388 return OpenAIClient(self.ai_configuration).generate_chat(messages, stream=False)