Coverage for MPT/run.py: 62%
164 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:20 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 11:20 +0200
1#!/usr/bin/env python
3import os
4import yaml
5from pathlib import Path
6import argparse
8import numpy as np
9import MPT
12class Data:
13 def __init__(self, yaml_file):
14 with open(yaml_file, "r") as f:
15 self.d = yaml.safe_load(f)
17 self.source = self.d["source"]
19 self.microtraj = np.loadtxt(
20 os.path.join(self.source, self.d["microstate trajectory"]), dtype=np.uint16
21 )
22 self.mtraj_raw = np.loadtxt(
23 os.path.join(
24 self.source,
25 self.d["multi feature trajectory"],
26 )
27 )
28 self.limits = (
29 None
30 if self.d["limits"] is None
31 else np.loadtxt(
32 os.path.join(self.source, self.d["limits"]),
33 dtype=np.int_,
34 )
35 )
36 self.mfeature_traj = self.mtraj_raw < 0.45
37 self.feature_traj = self.mfeature_traj.mean(axis=1)
38 self.cluster = os.path.join(self.source, self.d["cluster file"])
40 if "topology file" in self.d: 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true
41 self.top = os.path.join(self.source, self.d["topology file"])
42 else:
43 self.top = None
44 if "xtc file" in self.d: 44 ↛ 47line 44 didn't jump to line 47 because the condition on line 44 was always true
45 self.xtc = os.path.join(self.source, self.d["xtc file"])
46 else:
47 self.xtc = None
48 if "helices" in self.d:
49 self.helices = np.loadtxt(
50 os.path.join(self.source, self.d["helices"]), dtype=int
51 )
52 else:
53 self.helices = None
55 self.frame_length = self.d["frame length"]
56 self.tlag = self.d["tlag"]
57 self.pop_min = self.d["pop_min"]
58 self.q_min = self.d["q_min"]
60 self.lumping_dir = None
61 self.kernel = None
62 self.feature_kernel = None
63 self.mpp = None
65 self.n_random_frames = 20
66 self.use_ref = True
68 def prepare_mpp(self, dij, gij):
69 if "stochastic" in self.d:
70 kernel = MPT.kernel.MPTKernel(
71 method=self.d["stochastic"]["method"],
72 param=self.d["stochastic"]["param"],
73 similarity=dij,
74 )
75 else:
76 kernel = MPT.kernel.MPTKernel(
77 similarity=dij,
78 )
80 if gij == "none":
81 feature_kernel = None
82 elif gij == "q": 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true
83 feature_kernel = MPT.kernel.FeatureKernel(
84 self.feature_traj,
85 self.microtraj,
86 )
87 elif gij == "JS": 87 ↛ 93line 87 didn't jump to line 93 because the condition on line 87 was always true
88 feature_kernel = MPT.kernel.MultiFeatureKernel(
89 self.mfeature_traj,
90 self.microtraj,
91 )
92 else:
93 raise ValueError("feature kernel must be None, q or JS.")
95 if dij == "T" and gij == "none" and "stochastic" not in self.d:
96 self.use_ref = False
98 self.kernel = kernel
99 self.feature_kernel = feature_kernel
101 def setup_mpp(self, dij, gij):
102 if dij != "gpcca":
103 self.prepare_mpp(dij, gij)
104 self.mpp = MPT.MPT(
105 self.microtraj,
106 self.tlag,
107 self.mtraj_raw,
108 contact_threshold=0.45,
109 macrostate_thresholds=(self.pop_min, self.q_min),
110 limits=self.limits,
111 quiet=True,
112 )
113 if os.path.exists(self.top): 113 ↛ 115line 113 didn't jump to line 115 because the condition on line 113 was always true
114 self.mpp.topology_file = self.top
115 if os.path.exists(self.xtc): 115 ↛ 117line 115 didn't jump to line 117 because the condition on line 115 was always true
116 self.mpp.xtc_trajectory_file = self.xtc
117 self.mpp.xtc_stride = self.d.get("xtc stride", None)
118 self.mpp.frame_length = self.frame_length
120 def perform_mpp(self, out, overwrite=False):
121 """out: Z.npy"""
122 if os.path.exists(out) and not overwrite:
123 print("Loading existing Z")
124 self.mpp.from_Z(out)
125 else:
126 Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True)
127 self.mpp.mpt(
128 self.kernel,
129 feature_kernel=self.feature_kernel,
130 n=self.d["stochastic"]["n"] if "stochastic" in self.d else 1,
131 )
132 self.mpp.save_Z(out)
134 def perform_gpcca(self, n_macrostates, out=None, overwrite=False):
135 """n_macrostates: int or 'ref' for n_macrostates from reference (T)"""
136 if out is not None and os.path.exists(out) and not overwrite: 136 ↛ anywhereline 136 didn't jump anywhere: it always raised an exception.
137 print("Loading existing Z")
138 self.mpp.from_Z(out)
139 else:
140 if n_macrostates == "ref": 140 ↛ 142line 140 didn't jump to line 142 because the condition on line 140 was always true
141 n_macrostates = self.mpp.reference.n_macrostates[0]
142 self.mpp.gpcca(n_macrostates)
143 if out is not None:
144 self.mpp.save_Z(out)
146 def get_rmsd(self, out, overwrite=False):
147 """out: rmsd.npy"""
148 if not out.endswith(".npy"): 148 ↛ anywhereline 148 didn't jump anywhere: it always raised an exception.
149 out += ".npy"
150 if os.path.exists(out) and not overwrite: 150 ↛ anywhereline 150 didn't jump anywhere: it always raised an exception.
151 self.mpp.load_rmsd(out)
152 else:
153 self.mpp.save_rmsd(out)
156def plot(data, out, kind="dendrogram", scale=1):
157 """
158 kind: dendrogram, timescales, sankey, contacts, macrotraj, ck_test, rmsd
159 """
160 if kind == "dendrogram": 160 ↛ anywhereline 160 didn't jump anywhere: it always raised an exception.
161 print("Plotting dendrogram")
162 data.mpp.plot(out, scale=scale, offset=0.0)
163 elif kind == "timescales":
164 if "n timescales" in data.d: 164 ↛ 166line 164 didn't jump to line 166 because the condition on line 164 was always true
165 data.mpp.calc_timescales(data.d["n timescales"])
166 data.mpp.plot_implied_timescales(out, scale=scale, use_ref=data.use_ref)
167 elif kind == "sankey": 167 ↛ 169line 167 didn't jump to line 169 because the condition on line 167 was always true
168 data.mpp.plot_sankey(out, scale=scale)
169 elif kind == "contacts":
170 data.mpp.plot_contact_rep(data.cluster, out, scale=scale)
171 elif kind == "macrotraj": 171 ↛ anywhereline 171 didn't jump anywhere: it always raised an exception.
172 # traj_length = data.microtraj.shape[0]
173 # n_macrostates = data.mpp.n_macrostates[0]
174 # row_length = 1 / int(np.round(np.sqrt(traj_length) / (np.sqrt(n_macrostates) * 30)))
175 row_length = 1 / 6
176 if data.limits is not None:
177 row_length = 1 / len(data.limits)
178 data.mpp.plot_macrotraj(out, row_length=row_length)
179 elif kind == "ck_test": 179 ↛ 181line 179 didn't jump to line 181 because the condition on line 179 was always true
180 data.mpp.plot_ck_test(out)
181 elif kind == "rmsd": 181 ↛ anywhereline 181 didn't jump anywhere: it always raised an exception.
182 # data.get_rmsd(os.path.splitext(out)[0] + ".npy")
183 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy"))
184 data.mpp.plot_rmsd(out, helices=data.helices)
185 elif kind == "delta_rmsd": 185 ↛ anywhereline 185 didn't jump anywhere: it always raised an exception.
186 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy"))
187 data.mpp.plot_delta_rmsd(out, helices=data.helices)
188 elif kind == "state_network": 188 ↛ anywhereline 188 didn't jump anywhere: it always raised an exception.
189 print("Plotting state network")
190 data.mpp.plot_state_network(out)
191 elif kind == "macro_feature": 191 ↛ anywhereline 191 didn't jump anywhere: it always raised an exception.
192 data.mpp.plot_macro_feature(out)
193 elif kind == "stochastic_state_similarity": 193 ↛ 195line 193 didn't jump to line 195 because the condition on line 193 was always true
194 data.mpp.plot_stochastic_state_similarity(out)
195 elif kind == "relative_implied_timescales":
196 data.mpp.plot_relative_implied_timescales(out)
197 elif kind == "transition_matrix":
198 data.mpp.plot_transition_matrix(out)
199 elif kind == "transition_time": 199 ↛ 202line 199 didn't jump to line 202 because the condition on line 199 was always true
200 data.mpp.plot_transition_time(out)
201 else:
202 raise ValueError(f"Unknown plot kind: {kind}")
205def draw_random_frames(mpt, data):
206 if mpt.Z is None:
207 mpt.from_Z(os.path.join(data.lumping_dir, "Z.npy"))
208 Path(os.path.join(data.lumping_dir + "random_frames/")).mkdir(
209 parents=True, exist_ok=True
210 )
211 mpt.topology_file = data.top
212 mpt.xtc_trajectory_file = data.xtc
213 mpt.draw_random_frames(
214 os.path.join(data.lumping_dir + "random_frames/"), n=data.n_random_frames
215 )
216 return mpt
219def write_random_frames_indices(mpt, out, n):
220 Path(os.path.join(out)).mkdir(parents=True, exist_ok=True)
221 mpt.draw_random_frames_indices(out, n)
224def parse_args():
225 parser = argparse.ArgumentParser(
226 prog="Perform MPP on MD simulation data",
227 description=(
228 "This program allows for the analysis of MD data utilizing the "
229 "most probable path algorithm. It allows for easy plotting of "
230 "different quality measures."
231 ),
232 )
233 parser.add_argument(
234 "data_specification",
235 help=(
236 "yaml file containing specification of files and parameters of "
237 "the simulation"
238 ),
239 type=argparse.FileType("r", encoding="latin-1"),
240 )
241 parser.add_argument("d", help=("dij to be used."))
242 parser.add_argument("g", help=("gij to be used."))
243 parser.add_argument(
244 "-o",
245 "--out",
246 help=("Override output directory set by config file"),
247 )
248 parser.add_argument(
249 "-Z",
250 help="Perform MPP and write the Z matrix.",
251 )
252 parser.add_argument(
253 "--rmsd",
254 help="Generate and write RMSD to file.",
255 )
256 parser.add_argument(
257 "--xtc-stride",
258 help="Read every nth frame.",
259 )
260 parser.add_argument(
261 "-r",
262 "--draw-random",
263 help="Draw N random frames for each macrostate",
264 metavar="N",
265 type=int,
266 )
267 parser.add_argument(
268 "-p",
269 "--plot",
270 help="Generate listed plots. Possible arguments include dendrogram, contacts, sankey, rmsd, macrotraj, timescales and more. (not yet implemented)",
271 )
272 parser.add_argument(
273 "--get-least-moving-residues",
274 help="Write least moving residues for each macrostate to a file.",
275 )
276 return parser.parse_args()
279def main():
280 print("MPT.run running")
281 args = parse_args()
283 # Parse input files
284 data = Data(args.data_specification.name)
285 data.setup_mpp(args.d, args.g)
286 if args.d == "gpcca":
287 data.perform_gpcca(args.g, args.Z)
288 else:
289 data.perform_mpp(args.Z)
291 if args.rmsd: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 data.get_rmsd(args.rmsd, overwrite=True)
294 # for p in args.plot:
295 if args.plot: 295 ↛ anywhereline 295 didn't jump anywhere: it always raised an exception.
296 plot(data, args.out, kind=args.plot)
298 if args.draw_random:
299 write_random_frames_indices(data.mpp, args.out, args.draw_random)
301 if args.get_least_moving_residues:
302 data.mpp.write_least_moving_residues(args.get_least_moving_residues, args.out)
305if __name__ == "__main__": 305 ↛ anywhereline 305 didn't jump anywhere: it always raised an exception.
306 main()