Coverage for tests/test_run.py: 96%

97 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-22 12:22 +0200

1import unittest 

2 

3 

4import sys 

5from io import StringIO 

6from contextlib import redirect_stdout, redirect_stderr 

7import tempfile 

8from pathlib import Path 

9import numpy as np 

10import yaml 

11import MPP.run as run_module 

12 

13 

14# TODO: 

15# - MultiFeatureKernel.full_feature_from_Z 

16 

17 

18# DATASETS = ["HP35", "PDZ3", "aSyn"] 

19# DATASETS = ["HP35", "PDZ3"] 

20DATASETS = ["HP35"] 

21MAPPING_FILE = Path(__file__).parent / "data" / "lumpings.yaml" 

22 

23 

24def _run_main_with_args(args_list): 

25 """Helper to run run.main() with patched sys.argv and capture output.""" 

26 saved_argv = sys.argv 

27 sys.argv = ["run.py"] + args_list 

28 stdout, stderr = StringIO(), StringIO() 

29 try: 

30 with redirect_stdout(stdout), redirect_stderr(stderr): 

31 run_module.main() 

32 return 0, stdout.getvalue(), stderr.getvalue() 

33 except SystemExit as e: 

34 return e.code, stdout.getvalue(), stderr.getvalue() 

35 finally: 

36 sys.argv = saved_argv 

37 

38 

39class TestRunScript(unittest.TestCase): 

40 # @classmethod 

41 # def setUpClass(cls): 

42 # with open(MAPPING_FILE, "r") as f: 

43 # cls.param_map = yaml.safe_load(f) 

44 

45 def setUp(self): 

46 self.base_data_dir = Path(__file__).parent / "data" 

47 with open(MAPPING_FILE, "r") as f: 

48 self.param_map = yaml.safe_load(f) 

49 

50 def _run_command(self, config_path, d, g, output_file, r=None, o=None): 

51 """Helper to invoke the run script.""" 

52 args = [ 

53 str(config_path), 

54 d, 

55 g, 

56 "-Z", 

57 str(output_file), 

58 ] 

59 if r is not None: 

60 args.append("-r") 

61 args.append(str(r)) 

62 if o is not None: 

63 args.append("-o") 

64 args.append(str(o)) 

65 return _run_main_with_args(args) 

66 

67 def _get_key(self, d, g): 

68 """Returns the key like 'kl', 't_js', etc. from the mapping.""" 

69 for key, val in self.param_map.items(): 69 ↛ 72line 69 didn't jump to line 72 because the loop on line 69 didn't complete

70 if val["kernel similarity"] == d and val["feature kernel"] == g: 

71 return key 

72 raise ValueError(f"No key found for d={d}, g={g}") 

73 

74 def run_and_validate_output(self, dataset, d, g, stochastic=False): 

75 config_file = ( 

76 self.base_data_dir 

77 / dataset 

78 / f"config{'_stochastic' if stochastic else ''}.yaml" 

79 ) 

80 key = self._get_key(d, g) 

81 with self.subTest(dataset=dataset, d=d, g=g): 

82 with tempfile.TemporaryDirectory() as tmpdir: 

83 z_output = Path(tmpdir) / "Z.npy" 

84 

85 # Run first time: should compute and save 

86 exit_code, stdout, stderr = self._run_command( 

87 config_file, d, g, z_output 

88 ) 

89 self.assertEqual(exit_code, 0, f"Failed for {stderr} {d}-{g}") 

90 self.assertTrue( 

91 z_output.exists(), 

92 f"Z.npy not created for {dataset} {d}-{g}", 

93 ) 

94 

95 # Compare with expected 

96 expected_path = ( 

97 self.base_data_dir 

98 / dataset 

99 / "expected_output" 

100 / key 

101 / f"Z{'_stochastic' if stochastic else ''}.npy" 

102 ) 

103 self.assertTrue( 

104 expected_path.exists(), 

105 f"Expected file missing: {expected_path}", 

106 ) 

107 

108 output_data = np.load(z_output) 

109 expected_data = np.load(expected_path) 

110 

111 if not stochastic: 

112 np.testing.assert_allclose( 

113 output_data, 

114 expected_data, 

115 rtol=1e-5, 

116 err_msg=f"Mismatch in Z for {dataset} {d}-{g}", 

117 ) 

118 

119 # Second run: should load existing file (tests the from_Z logic indirectly) 

120 exit_code2, stdout2, stderr2 = self._run_command( 

121 config_file, d, g, z_output 

122 ) 

123 self.assertEqual( 

124 exit_code2, 

125 0, 

126 f"Reload failed for {stderr2} {d}-{g}", 

127 ) 

128 

129 # Verify "Loading existing Z" is printed (indicating from_Z was called) 

130 self.assertIn( 

131 "Loading existing Z", 

132 stdout2, 

133 f"Z not loaded from file for {stderr2} {d}-{g}", 

134 ) 

135 

136 def test_HP35_t_ref(self): 

137 self.run_and_validate_output("HP35", "T", "none") 

138 

139 def test_HP35_t_stoch(self): 

140 self.run_and_validate_output("HP35", "T", "none", stochastic=True) 

141 

142 def test_HP35_kl(self): 

143 self.run_and_validate_output("HP35", "KL", "none") 

144 

145 def test_HP35_t_js(self): 

146 self.run_and_validate_output("HP35", "T", "JS") 

147 

148 def test_HP35_js(self): 

149 self.run_and_validate_output("HP35", "none", "JS") 

150 

151 def test_HP35_gpcca(self): 

152 self.run_and_validate_output("HP35", "gpcca", "ref") 

153 

154 def test_PDZ3_kl(self): 

155 self.run_and_validate_output("PDZ3", "KL", "none") 

156 

157 def test_aSyn_t(self): 

158 self.run_and_validate_output("aSyn", "T", "none") 

159 

160 def test_aSyn_kl_js(self): 

161 self.run_and_validate_output("aSyn", "KL", "JS") 

162 

163 def test_aSyn_t_stoch(self): 

164 self.run_and_validate_output("aSyn", "T", "none", stochastic=True) 

165 

166 def assert_same_file_count(self, expected_dir, actual_dir, pattern="*"): 

167 expected_files = list(Path(expected_dir).glob(pattern)) 

168 actual_files = list(Path(actual_dir).glob(pattern)) 

169 

170 self.assertEqual( 

171 len(actual_files), 

172 len(expected_files), 

173 f"Mismatch in file count: expected {len(expected_files)} but got {len(actual_files)}", 

174 ) 

175 

176 def _run_random_frames_indices(self, dataset, d, g, r=20): 

177 key = self._get_key(d, g) 

178 z_file = self.base_data_dir / dataset / "expected_output" / key / "Z.npy" 

179 config_file = self.base_data_dir / dataset / "config.yaml" 

180 with self.subTest(dataset=dataset, d=d, g=g, r=r): 

181 with tempfile.TemporaryDirectory() as tmpdir: 

182 output_dir = Path(tmpdir) 

183 

184 exit_code, stdout, stderr = self._run_command( 

185 config_file, d, g, z_file, r=r, o=output_dir 

186 ) 

187 self.assertEqual(exit_code, 0, f"Failed for {stderr} {d}-{g}") 

188 

189 # Compare with expected 

190 expected_path = ( 

191 self.base_data_dir 

192 / dataset 

193 / "expected_output" 

194 / key 

195 / "random_frames" 

196 ) 

197 self.assertTrue( 

198 expected_path.exists(), 

199 f"Expected directory missing: {expected_path}", 

200 ) 

201 

202 self.assert_same_file_count(expected_path, output_dir, pattern="*.ndx") 

203 

204 def test_random_frames_indices_aSyn_t_ref(self): 

205 self._run_random_frames_indices("aSyn", "T", "none")