Coverage for MPP/run.py: 86%

161 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-22 12:22 +0200

1#!/usr/bin/env python 

2 

3import os 

4import yaml 

5from pathlib import Path 

6import argparse 

7 

8import numpy as np 

9import MPP 

10 

11 

12class Data: 

13 def __init__(self, yaml_file): 

14 with open(yaml_file, "r") as f: 

15 self.d = yaml.safe_load(f) 

16 

17 self.source = self.d["source"] 

18 

19 self.microstate_trajectory = np.loadtxt( 

20 os.path.join(self.source, self.d["microstate trajectory"]), dtype=np.uint16 

21 ) 

22 self.multi_state_trajectory_raw = np.loadtxt( 

23 os.path.join( 

24 self.source, 

25 self.d["multi feature trajectory"], 

26 ) 

27 ) 

28 self.limits = ( 

29 None 

30 if self.d["limits"] is None 

31 else np.loadtxt( 

32 os.path.join(self.source, self.d["limits"]), 

33 dtype=np.int_, 

34 ) 

35 ) 

36 self.multi_feature_trajectory = self.multi_state_trajectory_raw < 0.45 

37 self.feature_trajectory = self.multi_feature_trajectory.mean(axis=1) 

38 self.cluster = os.path.join(self.source, self.d["cluster file"]) 

39 

40 if "topology file" in self.d: 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true

41 self.top = os.path.join(self.source, self.d["topology file"]) 

42 else: 

43 self.top = None 

44 if "xtc file" in self.d: 44 ↛ 47line 44 didn't jump to line 47 because the condition on line 44 was always true

45 self.xtc = os.path.join(self.source, self.d["xtc file"]) 

46 else: 

47 self.xtc = None 

48 if "helices" in self.d: 

49 self.helices = np.loadtxt( 

50 os.path.join(self.source, self.d["helices"]), dtype=int 

51 ) 

52 else: 

53 self.helices = None 

54 

55 self.frame_length = self.d["frame length"] 

56 self.lagtime = self.d["lagtime"] 

57 self.pop_thr = self.d["pop_thr"] 

58 self.q_min = self.d["q_min"] 

59 

60 self.lumping_dir = None 

61 self.kernel = None 

62 self.feature_kernel = None 

63 self.mpp = None 

64 

65 self.n_random_frames = 20 

66 self.use_ref = True 

67 

68 def prepare_mpp(self, dij, gij): 

69 if "stochastic" in self.d: 

70 kernel = MPP.kernel.LumpingKernel( 

71 method=self.d["stochastic"]["method"], 

72 param=self.d["stochastic"]["param"], 

73 similarity=dij, 

74 ) 

75 else: 

76 kernel = MPP.kernel.LumpingKernel( 

77 similarity=dij, 

78 ) 

79 

80 if gij == "none": 

81 feature_kernel = None 

82 elif gij == "JS": 82 ↛ 88line 82 didn't jump to line 88 because the condition on line 82 was always true

83 feature_kernel = MPP.kernel.FeatureKernel( 

84 self.multi_feature_trajectory, 

85 self.microstate_trajectory, 

86 ) 

87 else: 

88 raise ValueError("feature kernel must be None, q or JS.") 

89 

90 if dij == "T" and gij == "none" and "stochastic" not in self.d: 

91 self.use_ref = False 

92 

93 self.kernel = kernel 

94 self.feature_kernel = feature_kernel 

95 

96 def setup_mpp(self, dij, gij): 

97 if dij != "gpcca": 

98 self.prepare_mpp(dij, gij) 

99 self.mpp = MPP.Lumping( 

100 self.microstate_trajectory, 

101 self.lagtime, 

102 self.multi_state_trajectory_raw, 

103 contact_threshold=0.45, 

104 pop_thr=self.pop_thr, 

105 q_min=self.q_min, 

106 limits=self.limits, 

107 quiet=True, 

108 ) 

109 if os.path.exists(self.top): 109 ↛ 111line 109 didn't jump to line 111 because the condition on line 109 was always true

110 self.mpp.topology_file = self.top 

111 if os.path.exists(self.xtc): 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was always true

112 self.mpp.xtc_trajectory_file = self.xtc 

113 self.mpp.xtc_stride = self.d.get("xtc stride", None) 

114 self.mpp.frame_length = self.frame_length 

115 

116 def perform_mpp(self, out, overwrite=False): 

117 """out: Z.npy""" 

118 if os.path.exists(out) and not overwrite: 

119 print("Loading existing Z") 

120 self.mpp.load_Z(out) 

121 else: 

122 Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) 

123 self.mpp.run_mpp( 

124 self.kernel, 

125 feature_kernel=self.feature_kernel, 

126 n=self.d["stochastic"]["n"] if "stochastic" in self.d else 1, 

127 ) 

128 self.mpp.save_Z(out) 

129 

130 def perform_gpcca(self, n_macrostates="ref", out=None, overwrite=False): 

131 """n_macrostates: int or 'ref' for n_macrostates from reference (T)""" 

132 if out is not None and os.path.exists(out) and not overwrite: 

133 print("Loading existing Z") 

134 self.mpp.load_Z(out) 

135 else: 

136 if n_macrostates == "ref": 136 ↛ 139line 136 didn't jump to line 139 because the condition on line 136 was always true

137 n_macrostates = self.mpp.reference.n_macrostates[0] 

138 print(f"n_macrostates: {n_macrostates}") 

139 self.mpp.gpcca(n_macrostates) 

140 if out is not None: 140 ↛ exitline 140 didn't return from function 'perform_gpcca' because the condition on line 140 was always true

141 self.mpp.save_Z(out) 

142 

143 def get_rmsd(self, out, overwrite=False): 

144 """out: rmsd.npy""" 

145 if not out.endswith(".npy"): 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 out += ".npy" 

147 if os.path.exists(out) and not overwrite: 147 ↛ 150line 147 didn't jump to line 150 because the condition on line 147 was always true

148 self.mpp.load_rmsd(out) 

149 else: 

150 self.mpp.save_rmsd(out) 

151 

152 

153def plot(data, out, kind="dendrogram", scale=1): 

154 """ 

155 kind: dendrogram, timescales, sankey, contacts, macrostate_trajectory, ck_test, rmsd 

156 """ 

157 if kind == "dendrogram": 

158 print("Plotting dendrogram") 

159 data.mpp.plot.dendrogram(out, scale=scale, offset=0.0) 

160 elif kind == "timescales": 

161 if "n timescales" in data.d: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 data.mpp.calc_timescales(data.d["n timescales"]) 

163 data.mpp.plot.implied_timescales(out, scale=scale, use_ref=data.use_ref) 

164 elif kind == "sankey": 

165 data.mpp.plot.sankey(out, scale=scale) 

166 elif kind == "contacts": 

167 data.mpp.plot.contact_rep(data.cluster, out, scale=scale) 

168 elif kind == "macrotraj": 

169 # trajectory_length = data.microstate_trajectory.shape[0] 

170 # n_macrostates = data.mpp.n_macrostates[0] 

171 # row_length = 1 / int(np.round(np.sqrt(trajectory_length) / (np.sqrt(n_macrostates) * 30))) 

172 row_length = 1 / 6 

173 if data.limits is not None: 

174 row_length = 1 / len(data.limits) 

175 data.mpp.plot.macrostate_trajectory(out, row_length=row_length) 

176 elif kind == "ck_test": 

177 data.mpp.plot.ck_test(out) 

178 elif kind == "rmsd": 

179 # data.get_rmsd(os.path.splitext(out)[0] + ".npy") 

180 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy")) 

181 data.mpp.plot.rmsd(out, helices=data.helices) 

182 elif kind == "delta_rmsd": 

183 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy")) 

184 data.mpp.plot.delta_rmsd(out, helices=data.helices) 

185 elif kind == "state_network": 

186 print("Plotting state network") 

187 data.mpp.plot.state_network(out) 

188 elif kind == "macro_feature": 

189 data.mpp.plot.macro_feature(out) 

190 elif kind == "stochastic_state_similarity": 

191 data.mpp.plot.stochastic_state_similarity(out) 

192 elif kind == "relative_implied_timescales": 

193 data.mpp.plot.relative_implied_timescales(out) 

194 elif kind == "transition_matrix": 

195 data.mpp.plot.transition_matrix(out) 

196 elif kind == "transition_time": 196 ↛ 199line 196 didn't jump to line 199 because the condition on line 196 was always true

197 data.mpp.plot.transition_time(out) 

198 else: 

199 raise ValueError(f"Unknown plot kind: {kind}") 

200 

201 

202def draw_random_frames(mpp, data): 

203 if mpp.Z is None: 

204 mpp.load_Z(os.path.join(data.lumping_dir, "Z.npy")) 

205 Path(os.path.join(data.lumping_dir + "random_frames/")).mkdir( 

206 parents=True, exist_ok=True 

207 ) 

208 mpp.topology_file = data.top 

209 mpp.xtc_trajectory_file = data.xtc 

210 mpp.draw_random_frames( 

211 # os.path.join(data.lumping_dir + "random_frames/"), n=data.n_random_frames 

212 Path(data.lumping_dir) / "random_frames/", 

213 n=data.n_random_frames, 

214 ) 

215 return mpp 

216 

217 

218def write_random_frames_indices(mpp, out, n): 

219 # Path(os.path.join(out)).mkdir(parents=True, exist_ok=True) 

220 mpp.draw_random_frames_indices(Path(out), n) 

221 

222 

223def parse_args(): 

224 parser = argparse.ArgumentParser( 

225 prog="Perform MPP on MD simulation data", 

226 description=( 

227 "This program allows for the analysis of MD data utilizing the " 

228 "most probable path algorithm. It allows for easy plotting of " 

229 "different quality measures." 

230 ), 

231 ) 

232 parser.add_argument( 

233 "data_specification", 

234 help=( 

235 "yaml file containing specification of files and parameters of " 

236 "the simulation" 

237 ), 

238 type=argparse.FileType("r", encoding="latin-1"), 

239 ) 

240 parser.add_argument("d", help=("dij to be used.")) 

241 parser.add_argument("g", help=("gij to be used.")) 

242 parser.add_argument( 

243 "-o", 

244 "--out", 

245 help=("Override output directory set by config file"), 

246 ) 

247 parser.add_argument( 

248 "-Z", 

249 help="Perform MPP and write the Z matrix.", 

250 ) 

251 parser.add_argument( 

252 "--rmsd", 

253 help="Generate and write RMSD to file.", 

254 ) 

255 parser.add_argument( 

256 "--xtc-stride", 

257 help="Read every nth frame.", 

258 ) 

259 parser.add_argument( 

260 "-r", 

261 "--draw-random", 

262 help="Draw N random frames for each macrostate", 

263 metavar="N", 

264 type=int, 

265 ) 

266 parser.add_argument( 

267 "-p", 

268 "--plot", 

269 help="Generate listed plots. Possible arguments include dendrogram, contacts, sankey, rmsd, macrostate_trajectory, timescales and more. (not yet implemented)", 

270 ) 

271 parser.add_argument( 

272 "--get-least-moving-residues", 

273 help="Write least moving residues for each macrostate to a file.", 

274 ) 

275 return parser.parse_args() 

276 

277 

278def main(): 

279 args = parse_args() 

280 

281 # Parse input files 

282 data = Data(args.data_specification.name) 

283 data.setup_mpp(args.d, args.g) 

284 if args.d == "gpcca": 

285 data.perform_gpcca(args.g, args.Z) 

286 else: 

287 data.perform_mpp(args.Z) 

288 

289 if args.rmsd: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true

290 data.get_rmsd(args.rmsd, overwrite=True) 

291 

292 # for p in args.plot: 

293 if args.plot: 

294 plot(data, args.out, kind=args.plot) 

295 

296 if args.draw_random: 

297 write_random_frames_indices(data.mpp, args.out, args.draw_random) 

298 

299 if args.get_least_moving_residues: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 data.mpp.write_least_moving_residues(args.get_least_moving_residues, args.out) 

301 

302 

303if __name__ == "__main__": 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 main()