Coverage for intelligence_toolkit/detect_entity_networks/identify_networks.py: 88%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from collections import defaultdict 

5from typing import Any 

6 

7import networkx as nx 

8import polars as pl 

9from graspologic.partition import hierarchical_leiden 

10 

11from intelligence_toolkit.detect_entity_networks.config import ( 

12 DEFAULT_MAX_ATTRIBUTE_DEGREE, 

13 ENTITY_LABEL, 

14) 

15from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR 

16 

17 

18# ruff: noqa 

19def trim_nodeset( 

20 graph: nx.Graph, 

21 max_attribute_degree: int, 

22 additional_trimmed_attributes: list[str] | None = None, 

23) -> tuple[set, set[Any | str]]: 

24 if additional_trimmed_attributes is None: 

25 additional_trimmed_attributes = [] 

26 

27 trimmed_degrees = set() 

28 for node, degree in graph.degree(): 

29 if not node.startswith(ENTITY_LABEL) and degree > max_attribute_degree: 

30 trimmed_degrees.add((node, degree)) 

31 

32 trimmed_nodes = {t[0] for t in trimmed_degrees}.union(additional_trimmed_attributes) 

33 return trimmed_degrees, trimmed_nodes 

34 

35 

36def get_entity_neighbors(overall_graph, inferred_links, trimmed_nodeset, node) -> list: 

37 if not len(overall_graph.nodes()): 

38 return [] 

39 

40 if node not in overall_graph.nodes(): 

41 raise ValueError(f"Node {node} not in graph") 

42 

43 neighbors = set(overall_graph.neighbors(node)) 

44 if inferred_links: 

45 if node in inferred_links: 

46 neighbors = neighbors.union(inferred_links[node]) 

47 

48 for inferred_link in inferred_links: 

49 if node in inferred_links[inferred_link]: 

50 neighbors = neighbors.union([inferred_link]) 

51 

52 neighbors = neighbors.difference(trimmed_nodeset) 

53 

54 return [neighbor for neighbor in sorted(neighbors) if neighbor != node] 

55 

56 

57def neighbor_is_valid( 

58 neighbor_node, supporting_attribute_types, trimmed_nodeset 

59) -> bool: 

60 if not neighbor_node: 

61 return False 

62 node_name = neighbor_node.split(ATTRIBUTE_VALUE_SEPARATOR)[0] 

63 is_not_supported = node_name not in supporting_attribute_types 

64 is_trimmed = neighbor_node in trimmed_nodeset 

65 return is_not_supported and not is_trimmed 

66 

67 

68def project_entity_graph( 

69 overall_graph: nx.Graph, 

70 trimmed_nodeset: set, 

71 inferred_links: dict[set] | None = None, 

72 supporting_attribute_types: list[str] | None = None, 

73) -> nx.Graph: 

74 P = nx.Graph() 

75 entity_nodes = [ 

76 node for node in overall_graph.nodes() if node.startswith(ENTITY_LABEL) 

77 ] 

78 

79 if supporting_attribute_types is None: 

80 supporting_attribute_types = [] 

81 

82 if inferred_links is None: 

83 inferred_links = [] 

84 

85 for node in entity_nodes: 

86 neighbors = get_entity_neighbors( 

87 overall_graph, inferred_links, trimmed_nodeset, node 

88 ) 

89 

90 for ent_neighbor in neighbors: 

91 if ent_neighbor.startswith(ENTITY_LABEL): 

92 P.add_edge(node, ent_neighbor) 

93 elif neighbor_is_valid( 

94 ent_neighbor, supporting_attribute_types, trimmed_nodeset 

95 ): 

96 att_neighbors = set(overall_graph.neighbors(ent_neighbor)) 

97 if ent_neighbor in inferred_links: 

98 att_neighbors = att_neighbors.union(inferred_links[ent_neighbor]) 

99 for att_neighbor in att_neighbors: 

100 if neighbor_is_valid( 

101 att_neighbor, supporting_attribute_types, trimmed_nodeset 

102 ): 

103 if att_neighbor.startswith(ENTITY_LABEL): 

104 if node != att_neighbor: 

105 P.add_edge(node, att_neighbor) 

106 else: # fuzzy att link 

107 fuzzy_att_neighbors = set( 

108 overall_graph.neighbors(att_neighbor) 

109 ) 

110 if att_neighbor in inferred_links: 

111 fuzzy_att_neighbors = fuzzy_att_neighbors.union( 

112 inferred_links[att_neighbor] 

113 ) 

114 for fuzzy_att_neighbor in fuzzy_att_neighbors: 

115 if neighbor_is_valid( 

116 fuzzy_att_neighbor, 

117 supporting_attribute_types, 

118 trimmed_nodeset, 

119 ): 

120 if fuzzy_att_neighbor.startswith(ENTITY_LABEL): 

121 if node != fuzzy_att_neighbor: 

122 P.add_edge( 

123 node, 

124 fuzzy_att_neighbor, 

125 ) 

126 return P 

127 

128 

129def get_subgraph( 

130 entity_graph: nx.Graph, 

131 nodes: list[str | int], 

132 random_seed: int = 42, 

133 max_network_entities: int = 20, 

134) -> tuple[list, dict]: 

135 entity_to_community = {} 

136 community_nodes = [] 

137 

138 if not nodes or not entity_graph.nodes(): 

139 return community_nodes, entity_to_community 

140 

141 S = nx.subgraph(entity_graph, nodes) 

142 

143 node_to_network = hierarchical_leiden( 

144 S, 

145 resolution=1.0, 

146 random_seed=random_seed, 

147 max_cluster_size=max_network_entities, 

148 ).final_level_hierarchical_clustering() 

149 

150 network_to_nodes = defaultdict(set) 

151 for node, network in node_to_network.items(): 

152 network_to_nodes[network].add(node) 

153 

154 networks = [list(nodes) for nodes in network_to_nodes.values()] 

155 for network in networks: 

156 community_nodes.append(network) 

157 for node in network: 

158 entity_to_community[node] = len(community_nodes) - 1 

159 

160 return community_nodes, entity_to_community 

161 

162 

163def get_community_nodes( 

164 entity_graph: nx.Graph, 

165 max_network_entities: int | None = 20, 

166) -> tuple[list, dict]: 

167 # get set of connected nodes list 

168 sorted_components = sorted( 

169 nx.components.connected_components(entity_graph), 

170 key=lambda x: len(x), 

171 reverse=True, 

172 ) 

173 

174 components_sequence = range(len(sorted_components)) 

175 component_to_nodes = dict(zip(components_sequence, sorted_components, strict=False)) 

176 

177 entity_to_community_ix = {} 

178 community_nodes = [] 

179 for sequence in components_sequence: 

180 nodes = component_to_nodes[sequence] 

181 if len(nodes) > max_network_entities: 

182 community_nodes_sequence, entity_to_community = get_subgraph( 

183 entity_graph, nodes, max_network_entities 

184 ) 

185 community_nodes.extend(community_nodes_sequence) 

186 entity_to_community_ix.update(entity_to_community) 

187 else: 

188 community_nodes.append(nodes) 

189 for node in nodes: 

190 entity_to_community_ix[node] = len(community_nodes) - 1 

191 return community_nodes, entity_to_community_ix 

192 

193 

194def build_networks( 

195 main_graph: nx.Graph, 

196 trimmed_nodes: set, 

197 inferred_links: set | None = None, 

198 supporting_attribute_types: list[str] | None = None, 

199 max_network_entities: int | None = 20, 

200) -> tuple[list, dict]: 

201 P = project_entity_graph( 

202 main_graph, trimmed_nodes, inferred_links, supporting_attribute_types 

203 ) 

204 

205 ( 

206 community_nodes, 

207 entity_to_community, 

208 ) = get_community_nodes( 

209 P, 

210 max_network_entities, 

211 ) 

212 

213 return community_nodes, entity_to_community 

214 

215 

216def get_integrated_flags( 

217 integrated_flags: pl.DataFrame, 

218 entities: list[str], 

219 inferred_links: dict[set] | None = None, 

220) -> tuple[Any, int, float, int, int]: 

221 total_entities = len(entities) 

222 if integrated_flags.is_empty(): 

223 return 0, 0, 0, 0, total_entities 

224 entities_processed = entities.copy() 

225 

226 flags_df = integrated_flags.filter(pl.col("qualified_entity").is_in(entities)) 

227 community_flags = flags_df.get_column("count").sum() 

228 

229 flagged = len(flags_df.filter(pl.col("count") > 0)["qualified_entity"].unique()) 

230 

231 for n in entities: # entities from a network 

232 if inferred_links: 

233 if n in inferred_links: 

234 if n not in entities_processed: 

235 flags = integrated_flags.filter(pl.col("qualified_entity") == n)[ 

236 "count" 

237 ].sum() 

238 community_flags += flags 

239 total_entities += 1 

240 else: 

241 for l in inferred_links[n]: 

242 if l not in entities_processed: 

243 flags = integrated_flags.filter( 

244 pl.col("qualified_entity") == l 

245 )["count"].sum() 

246 community_flags += flags 

247 total_entities += 1 

248 entities_processed.append(l) 

249 if flags > 0: 

250 flagged += 1 

251 

252 for i in inferred_links: 

253 if n in inferred_links[i]: 

254 if i not in entities_processed: 

255 flags = integrated_flags.filter( 

256 pl.col("qualified_entity") == i 

257 )["count"].sum() 

258 community_flags += flags 

259 total_entities += 1 

260 entities_processed.append(i) 

261 if flags > 0: 

262 flagged += 1 

263 

264 unflagged = total_entities - flagged 

265 flagged_per_unflagged = flagged / unflagged if unflagged > 0 else 0 

266 

267 flagged_per_unflagged = round(flagged_per_unflagged, 2) 

268 

269 flags_per_entity = round( 

270 community_flags / total_entities if total_entities > 0 else 0, 2 

271 ) 

272 

273 return ( 

274 community_flags, 

275 flagged, 

276 flagged_per_unflagged, 

277 flags_per_entity, 

278 total_entities, 

279 ) 

280 

281 

282def build_entity_records( 

283 community_nodes: list[str], 

284 integrated_flags: pl.DataFrame | None = None, 

285 inferred_links: defaultdict[set] | None = None, 

286) -> list[tuple[str, int, int, int, Any, int, float, float]]: 

287 if integrated_flags is None: 

288 integrated_flags = pl.DataFrame() 

289 

290 entity_records = [] 

291 for ix, entities in enumerate(community_nodes): 

292 ( 

293 community_flags, 

294 flagged, 

295 flagged_per_unflagged, 

296 flags_per_entity, 

297 total_entities, 

298 ) = get_integrated_flags(integrated_flags, entities, inferred_links) 

299 

300 for n in entities: # entities from a network 

301 flags = 0 

302 if not integrated_flags.is_empty(): 

303 flags = integrated_flags.filter(pl.col("qualified_entity") == n)[ 

304 "count" 

305 ].sum() 

306 

307 entity_records.append( 

308 ( 

309 n.split(ATTRIBUTE_VALUE_SEPARATOR)[1], 

310 flags, 

311 ix, 

312 total_entities, 

313 community_flags, 

314 flagged, 

315 flags_per_entity, 

316 flagged_per_unflagged, 

317 ) 

318 ) 

319 

320 return entity_records