Coverage for intelligence_toolkit/tests/unit/graph/test_graph_fusion_encoder_embedding.py: 100%

123 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from collections import namedtuple 

5from unittest.mock import MagicMock 

6 

7import networkx as nx 

8import numpy as np 

9import pytest 

10 

11from intelligence_toolkit.graph.graph_fusion_encoder_embedding import ( 

12 _cosine_distance, 

13 _generate_embeddings_for_period, 

14 _get_edge_list, 

15 create_concept_to_community_hierarchy, 

16 generate_graph_fusion_encoder_embedding, 

17 is_converging_pair, 

18) 

19 

20 

21@pytest.fixture 

22def simple_graph(): 

23 """Create a simple NetworkX graph.""" 

24 G = nx.Graph() 

25 G.add_weighted_edges_from([ 

26 ("A", "B", 1.0), 

27 ("B", "C", 1.0), 

28 ("C", "D", 1.0), 

29 ("D", "A", 1.0) 

30 ]) 

31 return G 

32 

33 

34@pytest.fixture 

35def node_list(): 

36 return ["A", "B", "C", "D"] 

37 

38 

39@pytest.fixture 

40def node_to_label(): 

41 """Create simple hierarchical labels for nodes.""" 

42 return { 

43 "A": {0: 0, 1: 0}, 

44 "B": {0: 0, 1: 1}, 

45 "C": {0: 1, 1: 2}, 

46 "D": {0: 1, 1: 3} 

47 } 

48 

49 

50def test_get_edge_list(simple_graph, node_list): 

51 edge_list = _get_edge_list(simple_graph, node_list) 

52 

53 assert len(edge_list) == 4 

54 assert all(len(edge) == 3 for edge in edge_list) # [source, target, weight] 

55 

56 # Check that indices are valid 

57 for edge in edge_list: 

58 assert 0 <= edge[0] < len(node_list) 

59 assert 0 <= edge[1] < len(node_list) 

60 assert edge[2] > 0 # Weight should be positive 

61 

62 

63def test_get_edge_list_filters_missing_nodes(simple_graph): 

64 # Only include subset of nodes 

65 partial_node_list = ["A", "B"] 

66 edge_list = _get_edge_list(simple_graph, partial_node_list) 

67 

68 # Should only include edges between A and B 

69 assert len(edge_list) == 1 

70 

71 

72def test_cosine_distance(): 

73 x = np.array([1, 0, 0]) 

74 y = np.array([0, 1, 0]) 

75 

76 # Orthogonal vectors should have distance 1 

77 dist = _cosine_distance(x, y) 

78 assert dist == 1.0 

79 

80 

81def test_cosine_distance_identical_vectors(): 

82 x = np.array([1, 2, 3]) 

83 y = np.array([1, 2, 3]) 

84 

85 # Identical vectors should have distance 0 

86 dist = _cosine_distance(x, y) 

87 assert np.isclose(dist, 0.0) 

88 

89 

90def test_cosine_distance_opposite_vectors(): 

91 x = np.array([1, 0, 0]) 

92 y = np.array([-1, 0, 0]) 

93 

94 # Opposite vectors should have distance 2 

95 dist = _cosine_distance(x, y) 

96 assert np.isclose(dist, 2.0) 

97 

98 

99def test_cosine_distance_zero_vector(): 

100 x = np.array([1, 2, 3]) 

101 y = np.array([0, 0, 0]) 

102 

103 # Division by zero should return infinity 

104 dist = _cosine_distance(x, y) 

105 assert np.isinf(dist) 

106 

107 

108def test_generate_embeddings_for_period(simple_graph, node_list, node_to_label): 

109 embeddings = _generate_embeddings_for_period( 

110 simple_graph, 

111 node_list, 

112 node_to_label, 

113 correlation=True, 

114 diaga=True, 

115 laplacian=False, 

116 max_level=1 

117 ) 

118 

119 # Should return embeddings for all nodes 

120 assert embeddings.shape[0] == len(node_list) 

121 

122 # Should have embeddings for all levels concatenated 

123 assert embeddings.shape[1] > 0 

124 

125 

126def test_generate_embeddings_for_period_multiple_levels(simple_graph, node_list): 

127 # Create labels with more hierarchy levels 

128 node_to_label_deep = { 

129 "A": {0: 0, 1: 0, 2: 0}, 

130 "B": {0: 0, 1: 1, 2: 1}, 

131 "C": {0: 1, 1: 2, 2: 2}, 

132 "D": {0: 1, 1: 3, 2: 3} 

133 } 

134 

135 embeddings = _generate_embeddings_for_period( 

136 simple_graph, 

137 node_list, 

138 node_to_label_deep, 

139 correlation=True, 

140 diaga=True, 

141 laplacian=False, 

142 max_level=2 

143 ) 

144 

145 assert embeddings.shape[0] == len(node_list) 

146 

147 

148def test_generate_graph_fusion_encoder_embedding_single_period(simple_graph, node_to_label): 

149 period_to_graph = {"2020": simple_graph} 

150 

151 node_to_period_to_pos, node_to_period_to_shift = generate_graph_fusion_encoder_embedding( 

152 period_to_graph, 

153 node_to_label, 

154 correlation=True, 

155 diaga=True, 

156 laplacian=False, 

157 max_level=1, 

158 callbacks=[] 

159 ) 

160 

161 # Check that all nodes have embeddings 

162 assert len(node_to_period_to_pos) == 4 

163 

164 # Each node should have period "2020" and "ALL" (centroid) 

165 for node in node_to_label.keys(): 

166 assert "2020" in node_to_period_to_pos[node] 

167 assert "ALL" in node_to_period_to_pos[node] 

168 

169 

170def test_generate_graph_fusion_encoder_embedding_multiple_periods(simple_graph, node_to_label): 

171 period_to_graph = { 

172 "2020": simple_graph, 

173 "2021": simple_graph # Using same graph for simplicity 

174 } 

175 

176 node_to_period_to_pos, node_to_period_to_shift = generate_graph_fusion_encoder_embedding( 

177 period_to_graph, 

178 node_to_label, 

179 correlation=True, 

180 diaga=True, 

181 laplacian=False, 

182 max_level=1, 

183 callbacks=[] 

184 ) 

185 

186 # Each node should have embeddings for both periods 

187 for node in node_to_label.keys(): 

188 assert "2020" in node_to_period_to_pos[node] 

189 assert "2021" in node_to_period_to_pos[node] 

190 assert "ALL" in node_to_period_to_pos[node] 

191 assert "<2021" in node_to_period_to_pos[node] # Prior centroid 

192 

193 

194def test_generate_graph_fusion_encoder_embedding_with_callbacks(simple_graph, node_to_label): 

195 period_to_graph = {"2020": simple_graph} 

196 

197 callback = MagicMock() 

198 callback.on_batch_change = MagicMock() 

199 

200 generate_graph_fusion_encoder_embedding( 

201 period_to_graph, 

202 node_to_label, 

203 correlation=True, 

204 diaga=True, 

205 laplacian=False, 

206 max_level=1, 

207 callbacks=[callback] 

208 ) 

209 

210 # Callback should have been called 

211 assert callback.on_batch_change.called 

212 

213 

214def test_is_converging_pair_with_positions(): 

215 # Create mock position data 

216 node_to_period_to_pos = { 

217 "A": { 

218 "2020": (np.array([1, 0, 0]), None), 

219 "ALL": np.array([0.5, 0.5, 0]) 

220 }, 

221 "B": { 

222 "2020": (np.array([0, 1, 0]), None), 

223 "ALL": np.array([0.5, 0.5, 0]) 

224 } 

225 } 

226 

227 # This test checks the structure exists 

228 assert "A" in node_to_period_to_pos 

229 assert "B" in node_to_period_to_pos 

230 

231 

232def test_is_converging_pair_missing_node(): 

233 node_to_period_to_pos = { 

234 "A": { 

235 "2020": (np.array([1, 0, 0]), None), 

236 "ALL": np.array([0.5, 0.5, 0]) 

237 } 

238 } 

239 

240 result = is_converging_pair("2020", "A", "Z", node_to_period_to_pos, all_time=True) 

241 

242 # Should return False when a node is missing 

243 assert result is False 

244 

245 

246def test_is_converging_pair_converging(): 

247 # Create positions where nodes are converging (closer in period than in centroid) 

248 node_to_period_to_pos = { 

249 "A": { 

250 "2020": (None, np.array([0.0, 0.0, 1.0])), # Close to each other in this period 

251 "ALL": np.array([1.0, 0.0, 0.0]) # Far apart in centroid 

252 }, 

253 "B": { 

254 "2020": (None, np.array([0.0, 0.0, 0.99])), # Very close to A in this period 

255 "ALL": np.array([0.0, 1.0, 0.0]) # Far from A in centroid 

256 } 

257 } 

258 

259 result = is_converging_pair("2020", "A", "B", node_to_period_to_pos, all_time=True) 

260 

261 # Period distance should be much smaller than centroid distance, so nodes are converging 

262 assert result == True 

263 

264 

265def test_is_converging_pair_diverging(): 

266 # Create positions where nodes are diverging (farther in period than in centroid) 

267 node_to_period_to_pos = { 

268 "A": { 

269 "2020": (None, np.array([1.0, 0.0, 0.0])), # Far apart in this period 

270 "ALL": np.array([0.5, 0.5, 0.0]) # Close in centroid 

271 }, 

272 "B": { 

273 "2020": (None, np.array([0.0, 1.0, 0.0])), # Far from A in this period 

274 "ALL": np.array([0.5, 0.4, 0.0]) # Close to A in centroid 

275 } 

276 } 

277 

278 result = is_converging_pair("2020", "A", "B", node_to_period_to_pos, all_time=True) 

279 

280 # Period distance >= centroid distance, so not converging 

281 assert result == False 

282 

283 

284def test_is_converging_pair_with_prior_centroid(): 

285 # Test with all_time=False (uses prior centroid) 

286 node_to_period_to_pos = { 

287 "A": { 

288 "2020": (None, np.array([0.0, 0.0, 1.0])), 

289 "<2020": np.array([1.0, 0.0, 0.0]), # Prior centroid 

290 "ALL": np.array([0.5, 0.0, 0.5]) 

291 }, 

292 "B": { 

293 "2020": (None, np.array([0.0, 0.0, 0.99])), 

294 "<2020": np.array([0.0, 1.0, 0.0]), # Prior centroid 

295 "ALL": np.array([0.0, 0.5, 0.5]) 

296 } 

297 } 

298 

299 result = is_converging_pair("2020", "A", "B", node_to_period_to_pos, all_time=False) 

300 

301 # Should use prior centroid instead of ALL and return a boolean 

302 assert result == True # Nodes are close in period, far in prior centroid 

303 

304 

305def test_create_concept_to_community_hierarchy(): 

306 # Create mock hierarchical communities 

307 HierarchicalCommunity = namedtuple("HierarchicalCommunity", ["node", "level", "cluster"]) 

308 

309 hierarchical_communities = [ 

310 HierarchicalCommunity(node="A", level="0", cluster=0), 

311 HierarchicalCommunity(node="A", level="1", cluster=0), 

312 HierarchicalCommunity(node="B", level="0", cluster=0), 

313 HierarchicalCommunity(node="B", level="1", cluster=1), 

314 HierarchicalCommunity(node="C", level="0", cluster=1), 

315 HierarchicalCommunity(node="C", level="1", cluster=2), 

316 ] 

317 

318 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy( 

319 hierarchical_communities 

320 ) 

321 

322 assert max_level == 1 

323 assert "A" in concept_to_community 

324 assert "B" in concept_to_community 

325 assert "C" in concept_to_community 

326 

327 # Check that levels are properly filled 

328 assert 0 in concept_to_community["A"] 

329 assert 1 in concept_to_community["A"] 

330 

331 

332def test_create_concept_to_community_hierarchy_fills_missing_levels(): 

333 HierarchicalCommunity = namedtuple("HierarchicalCommunity", ["node", "level", "cluster"]) 

334 

335 # Node A only has level 0, should propagate to level 1 

336 hierarchical_communities = [ 

337 HierarchicalCommunity(node="A", level="0", cluster=5), 

338 HierarchicalCommunity(node="B", level="0", cluster=3), 

339 HierarchicalCommunity(node="B", level="1", cluster=7), 

340 ] 

341 

342 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy( 

343 hierarchical_communities 

344 ) 

345 

346 # Node A should have level 1 filled with the level 0 value 

347 assert concept_to_community["A"][0] == 5 

348 assert concept_to_community["A"][1] == 5 # Propagated from level 0 

349 

350 

351def test_create_concept_to_community_hierarchy_max_clusters(): 

352 HierarchicalCommunity = namedtuple("HierarchicalCommunity", ["node", "level", "cluster"]) 

353 

354 hierarchical_communities = [ 

355 HierarchicalCommunity(node="A", level="0", cluster=0), 

356 HierarchicalCommunity(node="B", level="0", cluster=5), 

357 HierarchicalCommunity(node="C", level="0", cluster=10), 

358 ] 

359 

360 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy( 

361 hierarchical_communities 

362 ) 

363 

364 # Max cluster at level 0 should be 10 

365 assert max_cluster_per_level[0] == 10 

366 

367 

368def test_create_concept_to_community_hierarchy_empty(): 

369 concept_to_community, max_cluster_per_level, max_level = create_concept_to_community_hierarchy([]) 

370 

371 assert max_level == -1 

372 assert len(concept_to_community) == 0