Coverage for intelligence_toolkit/tests/unit/graph/test_graph_fusion_encoder_embedding.py: 100%
123 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from collections import namedtuple
5from unittest.mock import MagicMock
7import networkx as nx
8import numpy as np
9import pytest
11from intelligence_toolkit.graph.graph_fusion_encoder_embedding import (
12 _cosine_distance,
13 _generate_embeddings_for_period,
14 _get_edge_list,
15 create_concept_to_community_hierarchy,
16 generate_graph_fusion_encoder_embedding,
17 is_converging_pair,
18)
21@pytest.fixture
22def simple_graph():
23 """Create a simple NetworkX graph."""
24 G = nx.Graph()
25 G.add_weighted_edges_from([
26 ("A", "B", 1.0),
27 ("B", "C", 1.0),
28 ("C", "D", 1.0),
29 ("D", "A", 1.0)
30 ])
31 return G
34@pytest.fixture
35def node_list():
36 return ["A", "B", "C", "D"]
39@pytest.fixture
40def node_to_label():
41 """Create simple hierarchical labels for nodes."""
42 return {
43 "A": {0: 0, 1: 0},
44 "B": {0: 0, 1: 1},
45 "C": {0: 1, 1: 2},
46 "D": {0: 1, 1: 3}
47 }
50def test_get_edge_list(simple_graph, node_list):
51 edge_list = _get_edge_list(simple_graph, node_list)
53 assert len(edge_list) == 4
54 assert all(len(edge) == 3 for edge in edge_list) # [source, target, weight]
56 # Check that indices are valid
57 for edge in edge_list:
58 assert 0 <= edge[0] < len(node_list)
59 assert 0 <= edge[1] < len(node_list)
60 assert edge[2] > 0 # Weight should be positive
63def test_get_edge_list_filters_missing_nodes(simple_graph):
64 # Only include subset of nodes
65 partial_node_list = ["A", "B"]
66 edge_list = _get_edge_list(simple_graph, partial_node_list)
68 # Should only include edges between A and B
69 assert len(edge_list) == 1
72def test_cosine_distance():
73 x = np.array([1, 0, 0])
74 y = np.array([0, 1, 0])
76 # Orthogonal vectors should have distance 1
77 dist = _cosine_distance(x, y)
78 assert dist == 1.0
81def test_cosine_distance_identical_vectors():
82 x = np.array([1, 2, 3])
83 y = np.array([1, 2, 3])
85 # Identical vectors should have distance 0
86 dist = _cosine_distance(x, y)
87 assert np.isclose(dist, 0.0)
90def test_cosine_distance_opposite_vectors():
91 x = np.array([1, 0, 0])
92 y = np.array([-1, 0, 0])
94 # Opposite vectors should have distance 2
95 dist = _cosine_distance(x, y)
96 assert np.isclose(dist, 2.0)
99def test_cosine_distance_zero_vector():
100 x = np.array([1, 2, 3])
101 y = np.array([0, 0, 0])
103 # Division by zero should return infinity
104 dist = _cosine_distance(x, y)
105 assert np.isinf(dist)
108def test_generate_embeddings_for_period(simple_graph, node_list, node_to_label):
109 embeddings = _generate_embeddings_for_period(
110 simple_graph,
111 node_list,
112 node_to_label,
113 correlation=True,
114 diaga=True,
115 laplacian=False,
116 max_level=1
117 )
119 # Should return embeddings for all nodes
120 assert embeddings.shape[0] == len(node_list)
122 # Should have embeddings for all levels concatenated
123 assert embeddings.shape[1] > 0
126def test_generate_embeddings_for_period_multiple_levels(simple_graph, node_list):
127 # Create labels with more hierarchy levels
128 node_to_label_deep = {
129 "A": {0: 0, 1: 0, 2: 0},
130 "B": {0: 0, 1: 1, 2: 1},
131 "C": {0: 1, 1: 2, 2: 2},
132 "D": {0: 1, 1: 3, 2: 3}
133 }
135 embeddings = _generate_embeddings_for_period(
136 simple_graph,
137 node_list,
138 node_to_label_deep,
139 correlation=True,
140 diaga=True,
141 laplacian=False,
142 max_level=2
143 )
145 assert embeddings.shape[0] == len(node_list)
148def test_generate_graph_fusion_encoder_embedding_single_period(simple_graph, node_to_label):
149 period_to_graph = {"2020": simple_graph}
151 node_to_period_to_pos, node_to_period_to_shift = generate_graph_fusion_encoder_embedding(
152 period_to_graph,
153 node_to_label,
154 correlation=True,
155 diaga=True,
156 laplacian=False,
157 max_level=1,
158 callbacks=[]
159 )
161 # Check that all nodes have embeddings
162 assert len(node_to_period_to_pos) == 4
164 # Each node should have period "2020" and "ALL" (centroid)
165 for node in node_to_label.keys():
166 assert "2020" in node_to_period_to_pos[node]
167 assert "ALL" in node_to_period_to_pos[node]
170def test_generate_graph_fusion_encoder_embedding_multiple_periods(simple_graph, node_to_label):
171 period_to_graph = {
172 "2020": simple_graph,
173 "2021": simple_graph # Using same graph for simplicity
174 }
176 node_to_period_to_pos, node_to_period_to_shift = generate_graph_fusion_encoder_embedding(
177 period_to_graph,
178 node_to_label,
179 correlation=True,
180 diaga=True,
181 laplacian=False,
182 max_level=1,
183 callbacks=[]
184 )
186 # Each node should have embeddings for both periods
187 for node in node_to_label.keys():
188 assert "2020" in node_to_period_to_pos[node]
189 assert "2021" in node_to_period_to_pos[node]
190 assert "ALL" in node_to_period_to_pos[node]
191 assert "<2021" in node_to_period_to_pos[node] # Prior centroid
194def test_generate_graph_fusion_encoder_embedding_with_callbacks(simple_graph, node_to_label):
195 period_to_graph = {"2020": simple_graph}
197 callback = MagicMock()
198 callback.on_batch_change = MagicMock()
200 generate_graph_fusion_encoder_embedding(
201 period_to_graph,
202 node_to_label,
203 correlation=True,
204 diaga=True,
205 laplacian=False,
206 max_level=1,
207 callbacks=[callback]
208 )
210 # Callback should have been called
211 assert callback.on_batch_change.called
214def test_is_converging_pair_with_positions():
215 # Create mock position data
216 node_to_period_to_pos = {
217 "A": {
218 "2020": (np.array([1, 0, 0]), None),
219 "ALL": np.array([0.5, 0.5, 0])
220 },
221 "B": {
222 "2020": (np.array([0, 1, 0]), None),
223 "ALL": np.array([0.5, 0.5, 0])
224 }
225 }
227 # This test checks the structure exists
228 assert "A" in node_to_period_to_pos
229 assert "B" in node_to_period_to_pos
232def test_is_converging_pair_missing_node():
233 node_to_period_to_pos = {
234 "A": {
235 "2020": (np.array([1, 0, 0]), None),
236 "ALL": np.array([0.5, 0.5, 0])
237 }
238 }
240 result = is_converging_pair("2020", "A", "Z", node_to_period_to_pos, all_time=True)
242 # Should return False when a node is missing
243 assert result is False
246def test_is_converging_pair_converging():
247 # Create positions where nodes are converging (closer in period than in centroid)
248 node_to_period_to_pos = {
249 "A": {
250 "2020": (None, np.array([0.0, 0.0, 1.0])), # Close to each other in this period
251 "ALL": np.array([1.0, 0.0, 0.0]) # Far apart in centroid
252 },
253 "B": {
254 "2020": (None, np.array([0.0, 0.0, 0.99])), # Very close to A in this period
255 "ALL": np.array([0.0, 1.0, 0.0]) # Far from A in centroid
256 }
257 }
259 result = is_converging_pair("2020", "A", "B", node_to_period_to_pos, all_time=True)
261 # Period distance should be much smaller than centroid distance, so nodes are converging
262 assert result == True
265def test_is_converging_pair_diverging():
266 # Create positions where nodes are diverging (farther in period than in centroid)
267 node_to_period_to_pos = {
268 "A": {
269 "2020": (None, np.array([1.0, 0.0, 0.0])), # Far apart in this period
270 "ALL": np.array([0.5, 0.5, 0.0]) # Close in centroid
271 },
272 "B": {
273 "2020": (None, np.array([0.0, 1.0, 0.0])), # Far from A in this period
274 "ALL": np.array([0.5, 0.4, 0.0]) # Close to A in centroid
275 }
276 }
278 result = is_converging_pair("2020", "A", "B", node_to_period_to_pos, all_time=True)
280 # Period distance >= centroid distance, so not converging
281 assert result == False
284def test_is_converging_pair_with_prior_centroid():
285 # Test with all_time=False (uses prior centroid)
286 node_to_period_to_pos = {
287 "A": {
288 "2020": (None, np.array([0.0, 0.0, 1.0])),
289 "<2020": np.array([1.0, 0.0, 0.0]), # Prior centroid
290 "ALL": np.array([0.5, 0.0, 0.5])
291 },
292 "B": {
293 "2020": (None, np.array([0.0, 0.0, 0.99])),
294 "<2020": np.array([0.0, 1.0, 0.0]), # Prior centroid
295 "ALL": np.array([0.0, 0.5, 0.5])
296 }
297 }
299 result = is_converging_pair("2020", "A", "B", node_to_period_to_pos, all_time=False)
301 # Should use prior centroid instead of ALL and return a boolean
302 assert result == True # Nodes are close in period, far in prior centroid
305def test_create_concept_to_community_hierarchy():
306 # Create mock hierarchical communities
307 HierarchicalCommunity = namedtuple("HierarchicalCommunity", ["node", "level", "cluster"])
309 hierarchical_communities = [
310 HierarchicalCommunity(node="A", level="0", cluster=0),
311 HierarchicalCommunity(node="A", level="1", cluster=0),
312 HierarchicalCommunity(node="B", level="0", cluster=0),
313 HierarchicalCommunity(node="B", level="1", cluster=1),
314 HierarchicalCommunity(node="C", level="0", cluster=1),
315 HierarchicalCommunity(node="C", level="1", cluster=2),
316 ]
318 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy(
319 hierarchical_communities
320 )
322 assert max_level == 1
323 assert "A" in concept_to_community
324 assert "B" in concept_to_community
325 assert "C" in concept_to_community
327 # Check that levels are properly filled
328 assert 0 in concept_to_community["A"]
329 assert 1 in concept_to_community["A"]
332def test_create_concept_to_community_hierarchy_fills_missing_levels():
333 HierarchicalCommunity = namedtuple("HierarchicalCommunity", ["node", "level", "cluster"])
335 # Node A only has level 0, should propagate to level 1
336 hierarchical_communities = [
337 HierarchicalCommunity(node="A", level="0", cluster=5),
338 HierarchicalCommunity(node="B", level="0", cluster=3),
339 HierarchicalCommunity(node="B", level="1", cluster=7),
340 ]
342 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy(
343 hierarchical_communities
344 )
346 # Node A should have level 1 filled with the level 0 value
347 assert concept_to_community["A"][0] == 5
348 assert concept_to_community["A"][1] == 5 # Propagated from level 0
351def test_create_concept_to_community_hierarchy_max_clusters():
352 HierarchicalCommunity = namedtuple("HierarchicalCommunity", ["node", "level", "cluster"])
354 hierarchical_communities = [
355 HierarchicalCommunity(node="A", level="0", cluster=0),
356 HierarchicalCommunity(node="B", level="0", cluster=5),
357 HierarchicalCommunity(node="C", level="0", cluster=10),
358 ]
360 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy(
361 hierarchical_communities
362 )
364 # Max cluster at level 0 should be 10
365 assert max_cluster_per_level[0] == 10
368def test_create_concept_to_community_hierarchy_empty():
369 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy([])
371 assert max_level == -1
372 assert len(concept_to_community) == 0