Coverage for MPT/run.py: 62%

164 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-11 11:20 +0200

1#!/usr/bin/env python 

2 

3import os 

4import yaml 

5from pathlib import Path 

6import argparse 

7 

8import numpy as np 

9import MPT 

10 

11 

12class Data: 

13 def __init__(self, yaml_file): 

14 with open(yaml_file, "r") as f: 

15 self.d = yaml.safe_load(f) 

16 

17 self.source = self.d["source"] 

18 

19 self.microtraj = np.loadtxt( 

20 os.path.join(self.source, self.d["microstate trajectory"]), dtype=np.uint16 

21 ) 

22 self.mtraj_raw = np.loadtxt( 

23 os.path.join( 

24 self.source, 

25 self.d["multi feature trajectory"], 

26 ) 

27 ) 

28 self.limits = ( 

29 None 

30 if self.d["limits"] is None 

31 else np.loadtxt( 

32 os.path.join(self.source, self.d["limits"]), 

33 dtype=np.int_, 

34 ) 

35 ) 

36 self.mfeature_traj = self.mtraj_raw < 0.45 

37 self.feature_traj = self.mfeature_traj.mean(axis=1) 

38 self.cluster = os.path.join(self.source, self.d["cluster file"]) 

39 

40 if "topology file" in self.d: 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true

41 self.top = os.path.join(self.source, self.d["topology file"]) 

42 else: 

43 self.top = None 

44 if "xtc file" in self.d: 44 ↛ 47line 44 didn't jump to line 47 because the condition on line 44 was always true

45 self.xtc = os.path.join(self.source, self.d["xtc file"]) 

46 else: 

47 self.xtc = None 

48 if "helices" in self.d: 

49 self.helices = np.loadtxt( 

50 os.path.join(self.source, self.d["helices"]), dtype=int 

51 ) 

52 else: 

53 self.helices = None 

54 

55 self.frame_length = self.d["frame length"] 

56 self.tlag = self.d["tlag"] 

57 self.pop_min = self.d["pop_min"] 

58 self.q_min = self.d["q_min"] 

59 

60 self.lumping_dir = None 

61 self.kernel = None 

62 self.feature_kernel = None 

63 self.mpp = None 

64 

65 self.n_random_frames = 20 

66 self.use_ref = True 

67 

68 def prepare_mpp(self, dij, gij): 

69 if "stochastic" in self.d: 

70 kernel = MPT.kernel.MPTKernel( 

71 method=self.d["stochastic"]["method"], 

72 param=self.d["stochastic"]["param"], 

73 similarity=dij, 

74 ) 

75 else: 

76 kernel = MPT.kernel.MPTKernel( 

77 similarity=dij, 

78 ) 

79 

80 if gij == "none": 

81 feature_kernel = None 

82 elif gij == "q": 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 feature_kernel = MPT.kernel.FeatureKernel( 

84 self.feature_traj, 

85 self.microtraj, 

86 ) 

87 elif gij == "JS": 87 ↛ 93line 87 didn't jump to line 93 because the condition on line 87 was always true

88 feature_kernel = MPT.kernel.MultiFeatureKernel( 

89 self.mfeature_traj, 

90 self.microtraj, 

91 ) 

92 else: 

93 raise ValueError("feature kernel must be None, q or JS.") 

94 

95 if dij == "T" and gij == "none" and "stochastic" not in self.d: 

96 self.use_ref = False 

97 

98 self.kernel = kernel 

99 self.feature_kernel = feature_kernel 

100 

101 def setup_mpp(self, dij, gij): 

102 if dij != "gpcca": 

103 self.prepare_mpp(dij, gij) 

104 self.mpp = MPT.MPT( 

105 self.microtraj, 

106 self.tlag, 

107 self.mtraj_raw, 

108 contact_threshold=0.45, 

109 macrostate_thresholds=(self.pop_min, self.q_min), 

110 limits=self.limits, 

111 quiet=True, 

112 ) 

113 if os.path.exists(self.top): 113 ↛ 115line 113 didn't jump to line 115 because the condition on line 113 was always true

114 self.mpp.topology_file = self.top 

115 if os.path.exists(self.xtc): 115 ↛ 117line 115 didn't jump to line 117 because the condition on line 115 was always true

116 self.mpp.xtc_trajectory_file = self.xtc 

117 self.mpp.xtc_stride = self.d.get("xtc stride", None) 

118 self.mpp.frame_length = self.frame_length 

119 

120 def perform_mpp(self, out, overwrite=False): 

121 """out: Z.npy""" 

122 if os.path.exists(out) and not overwrite: 

123 print("Loading existing Z") 

124 self.mpp.from_Z(out) 

125 else: 

126 Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) 

127 self.mpp.mpt( 

128 self.kernel, 

129 feature_kernel=self.feature_kernel, 

130 n=self.d["stochastic"]["n"] if "stochastic" in self.d else 1, 

131 ) 

132 self.mpp.save_Z(out) 

133 

134 def perform_gpcca(self, n_macrostates, out=None, overwrite=False): 

135 """n_macrostates: int or 'ref' for n_macrostates from reference (T)""" 

136 if out is not None and os.path.exists(out) and not overwrite: 136 ↛ anywhereline 136 didn't jump anywhere: it always raised an exception.

137 print("Loading existing Z") 

138 self.mpp.from_Z(out) 

139 else: 

140 if n_macrostates == "ref": 140 ↛ 142line 140 didn't jump to line 142 because the condition on line 140 was always true

141 n_macrostates = self.mpp.reference.n_macrostates[0] 

142 self.mpp.gpcca(n_macrostates) 

143 if out is not None: 

144 self.mpp.save_Z(out) 

145 

146 def get_rmsd(self, out, overwrite=False): 

147 """out: rmsd.npy""" 

148 if not out.endswith(".npy"): 148 ↛ anywhereline 148 didn't jump anywhere: it always raised an exception.

149 out += ".npy" 

150 if os.path.exists(out) and not overwrite: 150 ↛ anywhereline 150 didn't jump anywhere: it always raised an exception.

151 self.mpp.load_rmsd(out) 

152 else: 

153 self.mpp.save_rmsd(out) 

154 

155 

156def plot(data, out, kind="dendrogram", scale=1): 

157 """ 

158 kind: dendrogram, timescales, sankey, contacts, macrotraj, ck_test, rmsd 

159 """ 

160 if kind == "dendrogram": 160 ↛ anywhereline 160 didn't jump anywhere: it always raised an exception.

161 print("Plotting dendrogram") 

162 data.mpp.plot(out, scale=scale, offset=0.0) 

163 elif kind == "timescales": 

164 if "n timescales" in data.d: 164 ↛ 166line 164 didn't jump to line 166 because the condition on line 164 was always true

165 data.mpp.calc_timescales(data.d["n timescales"]) 

166 data.mpp.plot_implied_timescales(out, scale=scale, use_ref=data.use_ref) 

167 elif kind == "sankey": 167 ↛ 169line 167 didn't jump to line 169 because the condition on line 167 was always true

168 data.mpp.plot_sankey(out, scale=scale) 

169 elif kind == "contacts": 

170 data.mpp.plot_contact_rep(data.cluster, out, scale=scale) 

171 elif kind == "macrotraj": 171 ↛ anywhereline 171 didn't jump anywhere: it always raised an exception.

172 # traj_length = data.microtraj.shape[0] 

173 # n_macrostates = data.mpp.n_macrostates[0] 

174 # row_length = 1 / int(np.round(np.sqrt(traj_length) / (np.sqrt(n_macrostates) * 30))) 

175 row_length = 1 / 6 

176 if data.limits is not None: 

177 row_length = 1 / len(data.limits) 

178 data.mpp.plot_macrotraj(out, row_length=row_length) 

179 elif kind == "ck_test": 179 ↛ 181line 179 didn't jump to line 181 because the condition on line 179 was always true

180 data.mpp.plot_ck_test(out) 

181 elif kind == "rmsd": 181 ↛ anywhereline 181 didn't jump anywhere: it always raised an exception.

182 # data.get_rmsd(os.path.splitext(out)[0] + ".npy") 

183 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy")) 

184 data.mpp.plot_rmsd(out, helices=data.helices) 

185 elif kind == "delta_rmsd": 185 ↛ anywhereline 185 didn't jump anywhere: it always raised an exception.

186 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy")) 

187 data.mpp.plot_delta_rmsd(out, helices=data.helices) 

188 elif kind == "state_network": 188 ↛ anywhereline 188 didn't jump anywhere: it always raised an exception.

189 print("Plotting state network") 

190 data.mpp.plot_state_network(out) 

191 elif kind == "macro_feature": 191 ↛ anywhereline 191 didn't jump anywhere: it always raised an exception.

192 data.mpp.plot_macro_feature(out) 

193 elif kind == "stochastic_state_similarity": 193 ↛ 195line 193 didn't jump to line 195 because the condition on line 193 was always true

194 data.mpp.plot_stochastic_state_similarity(out) 

195 elif kind == "relative_implied_timescales": 

196 data.mpp.plot_relative_implied_timescales(out) 

197 elif kind == "transition_matrix": 

198 data.mpp.plot_transition_matrix(out) 

199 elif kind == "transition_time": 199 ↛ 202line 199 didn't jump to line 202 because the condition on line 199 was always true

200 data.mpp.plot_transition_time(out) 

201 else: 

202 raise ValueError(f"Unknown plot kind: {kind}") 

203 

204 

205def draw_random_frames(mpt, data): 

206 if mpt.Z is None: 

207 mpt.from_Z(os.path.join(data.lumping_dir, "Z.npy")) 

208 Path(os.path.join(data.lumping_dir + "random_frames/")).mkdir( 

209 parents=True, exist_ok=True 

210 ) 

211 mpt.topology_file = data.top 

212 mpt.xtc_trajectory_file = data.xtc 

213 mpt.draw_random_frames( 

214 os.path.join(data.lumping_dir + "random_frames/"), n=data.n_random_frames 

215 ) 

216 return mpt 

217 

218 

219def write_random_frames_indices(mpt, out, n): 

220 Path(os.path.join(out)).mkdir(parents=True, exist_ok=True) 

221 mpt.draw_random_frames_indices(out, n) 

222 

223 

224def parse_args(): 

225 parser = argparse.ArgumentParser( 

226 prog="Perform MPP on MD simulation data", 

227 description=( 

228 "This program allows for the analysis of MD data utilizing the " 

229 "most probable path algorithm. It allows for easy plotting of " 

230 "different quality measures." 

231 ), 

232 ) 

233 parser.add_argument( 

234 "data_specification", 

235 help=( 

236 "yaml file containing specification of files and parameters of " 

237 "the simulation" 

238 ), 

239 type=argparse.FileType("r", encoding="latin-1"), 

240 ) 

241 parser.add_argument("d", help=("dij to be used.")) 

242 parser.add_argument("g", help=("gij to be used.")) 

243 parser.add_argument( 

244 "-o", 

245 "--out", 

246 help=("Override output directory set by config file"), 

247 ) 

248 parser.add_argument( 

249 "-Z", 

250 help="Perform MPP and write the Z matrix.", 

251 ) 

252 parser.add_argument( 

253 "--rmsd", 

254 help="Generate and write RMSD to file.", 

255 ) 

256 parser.add_argument( 

257 "--xtc-stride", 

258 help="Read every nth frame.", 

259 ) 

260 parser.add_argument( 

261 "-r", 

262 "--draw-random", 

263 help="Draw N random frames for each macrostate", 

264 metavar="N", 

265 type=int, 

266 ) 

267 parser.add_argument( 

268 "-p", 

269 "--plot", 

270 help="Generate listed plots. Possible arguments include dendrogram, contacts, sankey, rmsd, macrotraj, timescales and more. (not yet implemented)", 

271 ) 

272 parser.add_argument( 

273 "--get-least-moving-residues", 

274 help="Write least moving residues for each macrostate to a file.", 

275 ) 

276 return parser.parse_args() 

277 

278 

279def main(): 

280 print("MPT.run running") 

281 args = parse_args() 

282 

283 # Parse input files 

284 data = Data(args.data_specification.name) 

285 data.setup_mpp(args.d, args.g) 

286 if args.d == "gpcca": 

287 data.perform_gpcca(args.g, args.Z) 

288 else: 

289 data.perform_mpp(args.Z) 

290 

291 if args.rmsd: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 data.get_rmsd(args.rmsd, overwrite=True) 

293 

294 # for p in args.plot: 

295 if args.plot: 295 ↛ anywhereline 295 didn't jump anywhere: it always raised an exception.

296 plot(data, args.out, kind=args.plot) 

297 

298 if args.draw_random: 

299 write_random_frames_indices(data.mpp, args.out, args.draw_random) 

300 

301 if args.get_least_moving_residues: 

302 data.mpp.write_least_moving_residues(args.get_least_moving_residues, args.out) 

303 

304 

305if __name__ == "__main__": 305 ↛ anywhereline 305 didn't jump anywhere: it always raised an exception.

306 main()