Coverage for intelligence_toolkit/graph/graph_fusion_encoder_embedding.py: 100%
104 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4import numpy as np
5import networkx as nx
6from collections import defaultdict
7from sklearn.decomposition import TruncatedSVD
8from sklearn.preprocessing import normalize
9from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
10from intelligence_toolkit.graph.graph_encoder_embed import GraphEncoderEmbed
13def _get_edge_list(graph, node_list):
14 """Generate a list of edges with weights for existing nodes."""
15 node_to_ix = {node: i for i, node in enumerate(node_list)}
16 return [
17 [node_to_ix[s], node_to_ix[t], w]
18 for s, t, w in graph.edges(data="weight")
19 if s in node_list and t in node_list
20 ]
23def _generate_embeddings_for_period(
24 graph, node_list, node_to_label, correlation, diaga, laplacian, max_level
25):
26 """Generate embeddings for a single period."""
27 edge_list = _get_edge_list(graph, node_list)
28 num_nodes = len(node_list)
30 ############TODO#########################
31 node_to_ix = {node: i for i, node in enumerate(node_list)}
33 # Note that this function relies upon the incoming node_to_label dictionary to be FULL and complete. Every node MUST be defined for EVERY level in the hierarchy.
34 # When a node doesn't technically exist, make sure that it is populated with the leaf level parent for the logic below to work.
36 level_embeddings = {}
37 # For each level
38 for level in range(0, max_level + 1):
39 labels = np.array(
40 [
41 node_to_label[node][level] if node in node_to_label else -1
42 for node in node_list
43 ]
44 ).reshape(
45 (
46 num_nodes,
47 1,
48 )
49 )
50 Z, _ = GraphEncoderEmbed().run(
51 edge_list,
52 labels,
53 num_nodes,
54 EdgeList=True,
55 Laplacian=laplacian,
56 DiagA=diaga,
57 Correlation=correlation,
58 )
59 level_embeddings[level] = Z
61 # Now create a joint embedding across all levels
62 # get the length of any vector at the root level - this is the minimal number of dimensions to PCA to
63 # and resize all vectors to be of this same length.
64 # TODO: Experiment with different balances of each layer (intuition is we need to weight higher levels in the hierarchy with more weight from observations)
65 embedding_length = level_embeddings[0].shape[1]
67 normalized_vectors = {}
69 for level in range(0, max_level + 1):
70 # First check to see if we should PCA the whole thing first to a standard dimensionality
71 # Obviously for root level 0 - nothing needs to be done
72 if level_embeddings[level].shape[1] == embedding_length:
73 # at the root level, just copy the vectors over
74 # TODO: normalize
76 normalized_vectors[level] = normalize(level_embeddings[level].toarray())
78 else:
79 # ideally we actually run a PCA, but that doesn't scale - so we instead use a TSVD
80 tsvd = TruncatedSVD(n_components=embedding_length)
82 tsvd.fit(level_embeddings[level].toarray())
83 # TODO: normalize vectors
85 normalized_vectors[level] = normalize(
86 tsvd.transform(level_embeddings[level].toarray())
87 )
89 concat_vectors = {}
90 # Next, ONLY copy over the nodes that actually exist at a given level
91 for node in node_list:
92 for level in range(0, max_level + 1):
93 if level not in concat_vectors:
94 concat_vectors[level] = {}
95 # Check to see if the node actually existed natively at this level of the hierarchy - otherwise we can take alternative logic - like zeroing out this part of the vector.
96 if level == 0:
97 # First, all nodes exist at 0
98 concat_vectors[level][node] = normalized_vectors[level][
99 node_to_ix[node]
100 ]
101 else:
102 # Deeper the level 0, we have to check if the node actually exists at this level
103 if node_to_label[node][level - 1] != node_to_label[node][level]:
104 # the node existed at this depth of the hierarchy
105 # TODO: Check to make sure we're indexing
106 concat_vectors[level][node] = normalized_vectors[level][
107 node_to_ix[node]
108 ]
109 else:
110 # if the node has the SAME cluster ID as its parent, then we know it doesn't actually at this level in the hierarchy
111 # So this zeros out the vector if it didn't exist at level of the hierarchy.... we can zero it out OR we can use the cluster membership from the tiers above and then keep the embedding
112 concat_vectors[level][node] = [0] * embedding_length
114 nodevecs = []
115 # next - concat all the vectors together for all layers of the hierarchy
116 for node in node_list:
117 nodevec = []
118 for level in range(0, max_level + 1):
119 nodevec = np.append(nodevec, concat_vectors[level][node])
120 nodevecs.append(np.array(nodevec))
121 nodearry = np.vstack(nodevecs)
123 return nodearry
126def _cosine_distance(x, y):
127 den = np.linalg.norm(x) * np.linalg.norm(y)
128 dist = 1 - (np.dot(x, y) / den) if den > 0 else np.inf
129 return dist
132def generate_graph_fusion_encoder_embedding(
133 period_to_graph,
134 node_to_label,
135 correlation,
136 diaga,
137 laplacian,
138 max_level,
139 callbacks=[],
140):
141 """Generate embeddings for all periods and calculate centroids.
142 All-time centroids are encoded as 'ALL' and prior centroids are encoded as '<'+period.
143 """
144 node_list = sorted(node_to_label.keys())
145 node_to_period_to_pos = defaultdict(lambda: defaultdict(np.array))
146 node_to_period_to_shift = defaultdict(lambda: defaultdict(np.array))
147 for period, graph in period_to_graph.items():
148 period_embedding = _generate_embeddings_for_period(
149 graph, node_list, node_to_label, correlation, diaga, laplacian, max_level
150 )
151 for node_id in range(len(period_embedding)):
152 node_to_period_to_pos[node_list[node_id]][period] = period_embedding[
153 node_id
154 ]
156 for ix, (node, period_to_pos) in enumerate(node_to_period_to_pos.items()):
157 for callback in callbacks:
158 callback.on_batch_change(ix + 1, len(node_to_period_to_pos.keys()))
159 all_positions = [pos for period, pos in period_to_pos.items()]
160 centroid = np.mean(all_positions, axis=0)
161 node_to_period_to_pos[node]["ALL"] = centroid
162 sorted_periods = sorted(period_to_pos.keys())
163 prior_positions = []
164 for period in sorted_periods:
165 if len(prior_positions) > 0:
166 # Encodes prior centroid position and shift
167 prior_centroid = np.mean(prior_positions, axis=0)
168 node_to_period_to_pos[node]["<" + period] = prior_centroid
169 node_to_period_to_shift[node]["<" + period] = _cosine_distance(
170 period_to_pos[period], prior_centroid
171 )
172 prior_positions.append(period_to_pos[period])
173 node_to_period_to_shift[node][period] = _cosine_distance(
174 period_to_pos[period], centroid
175 )
177 return node_to_period_to_pos, node_to_period_to_shift
180def is_converging_pair(period, n1, n2, node_to_period_to_pos, all_time=True):
181 if n1 not in node_to_period_to_pos or n2 not in node_to_period_to_pos:
182 return False
183 c1 = (
184 node_to_period_to_pos[n1]["ALL"]
185 if all_time
186 else node_to_period_to_pos[n1]["<" + period]
187 )
188 c2 = (
189 node_to_period_to_pos[n2]["ALL"]
190 if all_time
191 else node_to_period_to_pos[n2]["<" + period]
192 )
193 centroid_dist = _cosine_distance(c1, c2)
194 p1 = node_to_period_to_pos[n1][period][1]
195 p2 = node_to_period_to_pos[n2][period][1]
196 period_dist = _cosine_distance(p1, p2)
197 return period_dist < centroid_dist
200def create_concept_to_community_hierarchy(hierarchical_communities):
201 concept_to_community_hierarchy = {}
202 max_level = -1
204 for rw in hierarchical_communities:
205 if rw.node not in concept_to_community_hierarchy:
206 concept_to_community_hierarchy[rw.node] = {}
207 concept_to_community_hierarchy[rw.node][int(rw.level)] = rw.cluster
208 if int(rw.level) > max_level:
209 max_level = int(rw.level)
211 # Clean up hierarchy to propagate leaf nodes down to all lower levels in the hierarchy
212 max_cluster_per_level = {}
213 for level_cluster in range(0, max_level + 1):
214 max_cluster_per_level[level_cluster] = -1
215 for ky in concept_to_community_hierarchy:
216 lastobservedcluster = -1
217 # if len(concept_to_community_hierarchy[ky]) < (max_level + 1):
218 for level in range(0, max_level + 1):
219 # Always start with 0 and go deeper - that way we'll always have a fallback since everything has membership for zero at the least
220 if level in concept_to_community_hierarchy[ky]:
221 lastobservedcluster = concept_to_community_hierarchy[ky][level]
222 else:
223 # These are the tiers that are missing clusters
224 concept_to_community_hierarchy[ky][level] = lastobservedcluster
225 if lastobservedcluster > max_cluster_per_level[level]:
226 max_cluster_per_level[level] = lastobservedcluster
227 return concept_to_community_hierarchy, max_cluster_per_level, max_level