Coverage for intelligence_toolkit/detect_entity_networks/explore_networks.py: 76%

153 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5 

6import colorsys 

7from typing import Any 

8 

9import networkx as nx 

10import polars as pl 

11 

12from intelligence_toolkit.detect_entity_networks.config import ( 

13 ENTITY_LABEL, 

14 LIST_SEPARATOR, 

15) 

16from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR 

17 

18 

19def _integrate_flags(graph: nx.Graph, df_integrated_flags: pl.DataFrame) -> nx.Graph: 

20 if not graph.nodes() or df_integrated_flags.is_empty(): 

21 return nx.Graph() 

22 

23 df_integrated_flags = df_integrated_flags.filter(pl.col("count") > 0) 

24 flagged_nodes = ( 

25 df_integrated_flags.select("qualified_entity").unique().to_series().to_list() 

26 ) 

27 

28 flagged_nodes = [node for node in flagged_nodes if node in graph.nodes()] 

29 for node in flagged_nodes: 

30 graph.nodes[node]["flags"] = df_integrated_flags.filter( 

31 pl.col("qualified_entity") == node 

32 )["count"].sum() 

33 return graph 

34 

35 

36def _build_fuzzy_neighbors( 

37 graph: nx.Graph, 

38 network_graph: nx.Graph, 

39 att_neighbor: str | int, 

40 trimmed_nodeset: set, 

41 inferred_links: dict, 

42) -> nx.Graph: 

43 if not len(graph.nodes()): 

44 return nx.Graph() 

45 

46 if att_neighbor not in graph.nodes(): 

47 msg = f"Node {att_neighbor} not in graph" 

48 raise ValueError(msg) 

49 

50 fuzzy_att_neighbors = set(graph.neighbors(att_neighbor)) 

51 if att_neighbor in inferred_links: 

52 fuzzy_att_neighbors = fuzzy_att_neighbors.union(inferred_links[att_neighbor]) 

53 

54 fuzzy_att_neighbors_not_trimmed = [ 

55 fuzzy_att_neighbor 

56 for fuzzy_att_neighbor in fuzzy_att_neighbors 

57 if fuzzy_att_neighbor not in trimmed_nodeset 

58 and not fuzzy_att_neighbor.startswith(ENTITY_LABEL) 

59 ] 

60 for fuzzy_att_neighbor in fuzzy_att_neighbors_not_trimmed: 

61 network_graph.add_node( 

62 fuzzy_att_neighbor, 

63 type=fuzzy_att_neighbor.split(ATTRIBUTE_VALUE_SEPARATOR)[0], 

64 flags=0, 

65 ) 

66 network_graph.add_edge(att_neighbor, fuzzy_att_neighbor) 

67 return network_graph 

68 

69 

70def build_network_from_entities( 

71 graph, 

72 entity_to_community, 

73 integrated_flags: pl.DataFrame | None = None, 

74 trimmed_attributes: list[tuple[str, int]] | None = None, 

75 inferred_links: Any | None = None, 

76 selected_nodes: list[str] | None = None, 

77) -> nx.Graph: 

78 network_graph = nx.Graph() 

79 nodes = selected_nodes 

80 # additional_trimmed_nodeset = set(sv.network_additional_trimmed_attributes.value) 

81 # trimmed_nodeset = trimmed_attributes["Attribute"].unique().tolist() 

82 trimmed_nodeset = {t[0] for t in trimmed_attributes} 

83 

84 if inferred_links is None: 

85 inferred_links = {} 

86 

87 if integrated_flags is None: 

88 integrated_flags = pl.DataFrame() 

89 

90 # trimmed_nodeset.extend(additional_trimmed_nodeset) 

91 for node in nodes: 

92 n_c = str(entity_to_community[node]) if node in entity_to_community else "" 

93 network_graph.add_node(node, type=ENTITY_LABEL, network=n_c, flags=0) 

94 ent_neighbors = set(graph.neighbors(node)) 

95 if node in inferred_links: 

96 ent_neighbors = ent_neighbors.union(inferred_links[node]) 

97 for i in inferred_links: 

98 if node in inferred_links[i]: 

99 ent_neighbors = ent_neighbors.union([node]) 

100 

101 ent_neighbors_not_trimmed = [ 

102 ent_neighbor 

103 for ent_neighbor in ent_neighbors 

104 if ent_neighbor not in trimmed_nodeset 

105 ] 

106 

107 for ent_neighbor in ent_neighbors_not_trimmed: 

108 if ent_neighbor.startswith(ENTITY_LABEL): 

109 if node != ent_neighbor: 

110 en_c = entity_to_community.get(ent_neighbor, "") 

111 network_graph.add_node( 

112 ent_neighbor, type=ENTITY_LABEL, network=en_c 

113 ) 

114 network_graph.add_edge(node, ent_neighbor) 

115 else: 

116 network_graph.add_node( 

117 ent_neighbor, 

118 type=ent_neighbor.split(ATTRIBUTE_VALUE_SEPARATOR)[0], 

119 flags=0, 

120 ) 

121 network_graph.add_edge(node, ent_neighbor) 

122 att_neighbors = set(graph.neighbors(ent_neighbor)) 

123 if ent_neighbor in inferred_links: 

124 att_neighbors = att_neighbors.union(inferred_links[ent_neighbor]) 

125 att_neighbors_not_trimmed = [ 

126 att_neighbor 

127 for att_neighbor in att_neighbors 

128 if att_neighbor not in trimmed_nodeset 

129 and not att_neighbor.startswith(ENTITY_LABEL) 

130 ] 

131 for att_neighbor in att_neighbors_not_trimmed: 

132 network_graph.add_node( 

133 att_neighbor, 

134 type=att_neighbor.split(ATTRIBUTE_VALUE_SEPARATOR)[0], 

135 flags=0, 

136 ) 

137 network_graph = _build_fuzzy_neighbors( 

138 graph, 

139 network_graph, 

140 att_neighbor, 

141 trimmed_nodeset, 

142 inferred_links, 

143 ) 

144 

145 if len(integrated_flags) > 0: 

146 network_graph = _integrate_flags(network_graph, integrated_flags) 

147 return network_graph 

148 

149 

150def _merge_condition(x: str, y: str) -> bool: 

151 """ 

152 Merge condition function for merging nodes in the graph. 

153 """ 

154 x_parts = set(x.split(LIST_SEPARATOR)) 

155 y_parts = set(y.split(LIST_SEPARATOR)) 

156 return any( 

157 x_part.split(ATTRIBUTE_VALUE_SEPARATOR)[i] 

158 == y_part.split(ATTRIBUTE_VALUE_SEPARATOR)[i] 

159 for i in range(2) 

160 for x_part in x_parts 

161 for y_part in y_parts 

162 ) 

163 

164 

165def _merge_node_list(graph: nx.Graph, merge_list: list[str]) -> nx.Graph: 

166 graph = graph.copy() 

167 merged_node = LIST_SEPARATOR.join(sorted(merge_list)) 

168 merged_type = LIST_SEPARATOR.join( 

169 sorted([graph.nodes[n]["type"] for n in merge_list]) 

170 ) 

171 merged_risk = max( 

172 graph.nodes[n]["flags"] if "flags" in graph.nodes[n] else 0 for n in merge_list 

173 ) 

174 graph.add_node(merged_node, type=merged_type, flags=merged_risk) 

175 for n in merge_list: 

176 for nn in graph.neighbors(n): 

177 if nn not in merge_list: 

178 graph.add_edge(merged_node, nn) 

179 graph.remove_node(n) 

180 return graph 

181 

182 

183def _merge_nodes(graph: nx.Graph, should_merge=_merge_condition) -> nx.Graph: 

184 nodes = list(graph.nodes()) # may change during iteration 

185 for node in nodes: 

186 if node not in graph.nodes(): 

187 continue 

188 neighbours = list(graph.neighbors(node)) 

189 merge_list = [node] 

190 for n in neighbours: 

191 if n not in graph.nodes(): 

192 continue 

193 if should_merge(node, n): 

194 merge_list.append(n) 

195 if len(merge_list) > 1: 

196 graph = _merge_node_list(graph, merge_list) 

197 

198 return graph 

199 

200 

201def simplify_entities_graph(entities_graph: nx.Graph) -> nx.Graph: 

202 # remove single degree attributes 

203 entities_graph = entities_graph.copy() 

204 for node in list(entities_graph.nodes()): 

205 if entities_graph.degree(node) < 2 and not node.startswith(ENTITY_LABEL): 

206 entities_graph.remove_node(node) 

207 

208 entities_graph = _merge_nodes(entities_graph) 

209 

210 # remove single degree attributes 

211 for node in list(entities_graph.nodes()): 

212 if entities_graph.degree(node) < 2 and not node.startswith(ENTITY_LABEL): 

213 entities_graph.remove_node(node) 

214 

215 return entities_graph 

216 

217 

218def hsl_to_hex(hue: int, saturation: int, lightness: int) -> str: 

219 rgb = colorsys.hls_to_rgb(hue / 360, lightness / 100, saturation / 100) 

220 return "#{:02x}{:02x}{:02x}".format(*tuple(int(c * 255) for c in rgb)) 

221 

222 

223def get_type_color(node_type: str, is_flagged: bool, attribute_types: list[Any]) -> str: 

224 if is_flagged: 

225 hue = 0 

226 saturation = 70 

227 lightness = 80 

228 else: 

229 start = 230 

230 reserve = 35 

231 prop = attribute_types.index(node_type) / len(attribute_types) 

232 inc = prop * (360 - 2 * reserve) 

233 # avoid reds 

234 hue = (start + inc) % 360 

235 if hue < reserve: 

236 hue += 2 * reserve 

237 if hue > 360 - reserve: 

238 hue = (hue + 2 * reserve) % 360 

239 saturation = 70 

240 lightness = 80 

241 return str(hsl_to_hex(hue, saturation, lightness)) 

242 

243 

244def get_entity_graph( 

245 network_entities_graph: nx.Graph, selected: str, attribute_types: list[str] 

246) -> tuple[list, list]: 

247 """ 

248 Implements the entity graph visualization after network selection 

249 """ 

250 node_names = set() 

251 nodes = [] 

252 edges = [] 

253 

254 if not network_entities_graph.edges(): 

255 return nodes, edges 

256 

257 links_df = pl.DataFrame( 

258 list(network_entities_graph.edges()), schema=["source", "target"] 

259 ) 

260 

261 links_df.with_columns( 

262 [ 

263 pl.col("target") 

264 .map_elements(lambda x: x.split(ATTRIBUTE_VALUE_SEPARATOR)[0]) 

265 .alias("attribute") 

266 ] 

267 ) 

268 

269 all_nodes = set(links_df["source"]).union(set(links_df["target"])) 

270 for node in all_nodes: 

271 node_names.add(node) 

272 

273 size = 20 if node == selected else 12 if node.startswith(ENTITY_LABEL) else 8 

274 vadjust = -size - 10 

275 

276 parts = [p.split(ATTRIBUTE_VALUE_SEPARATOR) for p in node.split(LIST_SEPARATOR)] 

277 atts = [p[0] for p in parts] 

278 atts = list(dict.fromkeys(atts)) 

279 

280 flags = network_entities_graph.nodes[node].get("flags", 0) 

281 color = get_type_color( 

282 atts[0], 

283 flags > 0, 

284 attribute_types, 

285 ) 

286 

287 vals = [p[1] for p in parts if len(p) > 1] 

288 vals = list(dict.fromkeys(vals)) 

289 label = "\n".join(vals) + "\n(" + LIST_SEPARATOR.join(atts) + ")" 

290 

291 nodes.append( 

292 { 

293 "title": node + f"\nFlags: {flags}", 

294 "id": node, 

295 "label": label, 

296 "size": size, 

297 "color": color, 

298 "font": {"vadjust": vadjust, "size": 5}, 

299 } 

300 ) 

301 for row in list(network_entities_graph.edges()): 

302 source = row[0] 

303 target = row[1] 

304 edges.append( 

305 { 

306 "source": source, 

307 "target": target, 

308 "color": "mediumgray", 

309 "size": 1, 

310 } 

311 ) 

312 return nodes, edges