Coverage for MPP/run.py: 86%
161 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-22 12:22 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-22 12:22 +0200
1#!/usr/bin/env python
3import os
4import yaml
5from pathlib import Path
6import argparse
8import numpy as np
9import MPP
12class Data:
13 def __init__(self, yaml_file):
14 with open(yaml_file, "r") as f:
15 self.d = yaml.safe_load(f)
17 self.source = self.d["source"]
19 self.microstate_trajectory = np.loadtxt(
20 os.path.join(self.source, self.d["microstate trajectory"]), dtype=np.uint16
21 )
22 self.multi_state_trajectory_raw = np.loadtxt(
23 os.path.join(
24 self.source,
25 self.d["multi feature trajectory"],
26 )
27 )
28 self.limits = (
29 None
30 if self.d["limits"] is None
31 else np.loadtxt(
32 os.path.join(self.source, self.d["limits"]),
33 dtype=np.int_,
34 )
35 )
36 self.multi_feature_trajectory = self.multi_state_trajectory_raw < 0.45
37 self.feature_trajectory = self.multi_feature_trajectory.mean(axis=1)
38 self.cluster = os.path.join(self.source, self.d["cluster file"])
40 if "topology file" in self.d: 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true
41 self.top = os.path.join(self.source, self.d["topology file"])
42 else:
43 self.top = None
44 if "xtc file" in self.d: 44 ↛ 47line 44 didn't jump to line 47 because the condition on line 44 was always true
45 self.xtc = os.path.join(self.source, self.d["xtc file"])
46 else:
47 self.xtc = None
48 if "helices" in self.d:
49 self.helices = np.loadtxt(
50 os.path.join(self.source, self.d["helices"]), dtype=int
51 )
52 else:
53 self.helices = None
55 self.frame_length = self.d["frame length"]
56 self.lagtime = self.d["lagtime"]
57 self.pop_thr = self.d["pop_thr"]
58 self.q_min = self.d["q_min"]
60 self.lumping_dir = None
61 self.kernel = None
62 self.feature_kernel = None
63 self.mpp = None
65 self.n_random_frames = 20
66 self.use_ref = True
68 def prepare_mpp(self, dij, gij):
69 if "stochastic" in self.d:
70 kernel = MPP.kernel.LumpingKernel(
71 method=self.d["stochastic"]["method"],
72 param=self.d["stochastic"]["param"],
73 similarity=dij,
74 )
75 else:
76 kernel = MPP.kernel.LumpingKernel(
77 similarity=dij,
78 )
80 if gij == "none":
81 feature_kernel = None
82 elif gij == "JS": 82 ↛ 88line 82 didn't jump to line 88 because the condition on line 82 was always true
83 feature_kernel = MPP.kernel.FeatureKernel(
84 self.multi_feature_trajectory,
85 self.microstate_trajectory,
86 )
87 else:
88 raise ValueError("feature kernel must be None, q or JS.")
90 if dij == "T" and gij == "none" and "stochastic" not in self.d:
91 self.use_ref = False
93 self.kernel = kernel
94 self.feature_kernel = feature_kernel
96 def setup_mpp(self, dij, gij):
97 if dij != "gpcca":
98 self.prepare_mpp(dij, gij)
99 self.mpp = MPP.Lumping(
100 self.microstate_trajectory,
101 self.lagtime,
102 self.multi_state_trajectory_raw,
103 contact_threshold=0.45,
104 pop_thr=self.pop_thr,
105 q_min=self.q_min,
106 limits=self.limits,
107 quiet=True,
108 )
109 if os.path.exists(self.top): 109 ↛ 111line 109 didn't jump to line 111 because the condition on line 109 was always true
110 self.mpp.topology_file = self.top
111 if os.path.exists(self.xtc): 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was always true
112 self.mpp.xtc_trajectory_file = self.xtc
113 self.mpp.xtc_stride = self.d.get("xtc stride", None)
114 self.mpp.frame_length = self.frame_length
116 def perform_mpp(self, out, overwrite=False):
117 """out: Z.npy"""
118 if os.path.exists(out) and not overwrite:
119 print("Loading existing Z")
120 self.mpp.load_Z(out)
121 else:
122 Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True)
123 self.mpp.run_mpp(
124 self.kernel,
125 feature_kernel=self.feature_kernel,
126 n=self.d["stochastic"]["n"] if "stochastic" in self.d else 1,
127 )
128 self.mpp.save_Z(out)
130 def perform_gpcca(self, n_macrostates="ref", out=None, overwrite=False):
131 """n_macrostates: int or 'ref' for n_macrostates from reference (T)"""
132 if out is not None and os.path.exists(out) and not overwrite:
133 print("Loading existing Z")
134 self.mpp.load_Z(out)
135 else:
136 if n_macrostates == "ref": 136 ↛ 139line 136 didn't jump to line 139 because the condition on line 136 was always true
137 n_macrostates = self.mpp.reference.n_macrostates[0]
138 print(f"n_macrostates: {n_macrostates}")
139 self.mpp.gpcca(n_macrostates)
140 if out is not None: 140 ↛ exitline 140 didn't return from function 'perform_gpcca' because the condition on line 140 was always true
141 self.mpp.save_Z(out)
143 def get_rmsd(self, out, overwrite=False):
144 """out: rmsd.npy"""
145 if not out.endswith(".npy"): 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 out += ".npy"
147 if os.path.exists(out) and not overwrite: 147 ↛ 150line 147 didn't jump to line 150 because the condition on line 147 was always true
148 self.mpp.load_rmsd(out)
149 else:
150 self.mpp.save_rmsd(out)
153def plot(data, out, kind="dendrogram", scale=1):
154 """
155 kind: dendrogram, timescales, sankey, contacts, macrostate_trajectory, ck_test, rmsd
156 """
157 if kind == "dendrogram":
158 print("Plotting dendrogram")
159 data.mpp.plot.dendrogram(out, scale=scale, offset=0.0)
160 elif kind == "timescales":
161 if "n timescales" in data.d: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 data.mpp.calc_timescales(data.d["n timescales"])
163 data.mpp.plot.implied_timescales(out, scale=scale, use_ref=data.use_ref)
164 elif kind == "sankey":
165 data.mpp.plot.sankey(out, scale=scale)
166 elif kind == "contacts":
167 data.mpp.plot.contact_rep(data.cluster, out, scale=scale)
168 elif kind == "macrotraj":
169 # trajectory_length = data.microstate_trajectory.shape[0]
170 # n_macrostates = data.mpp.n_macrostates[0]
171 # row_length = 1 / int(np.round(np.sqrt(trajectory_length) / (np.sqrt(n_macrostates) * 30)))
172 row_length = 1 / 6
173 if data.limits is not None:
174 row_length = 1 / len(data.limits)
175 data.mpp.plot.macrostate_trajectory(out, row_length=row_length)
176 elif kind == "ck_test":
177 data.mpp.plot.ck_test(out)
178 elif kind == "rmsd":
179 # data.get_rmsd(os.path.splitext(out)[0] + ".npy")
180 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy"))
181 data.mpp.plot.rmsd(out, helices=data.helices)
182 elif kind == "delta_rmsd":
183 data.get_rmsd(os.path.join(os.path.dirname(out), "rmsd.npy"))
184 data.mpp.plot.delta_rmsd(out, helices=data.helices)
185 elif kind == "state_network":
186 print("Plotting state network")
187 data.mpp.plot.state_network(out)
188 elif kind == "macro_feature":
189 data.mpp.plot.macro_feature(out)
190 elif kind == "stochastic_state_similarity":
191 data.mpp.plot.stochastic_state_similarity(out)
192 elif kind == "relative_implied_timescales":
193 data.mpp.plot.relative_implied_timescales(out)
194 elif kind == "transition_matrix":
195 data.mpp.plot.transition_matrix(out)
196 elif kind == "transition_time": 196 ↛ 199line 196 didn't jump to line 199 because the condition on line 196 was always true
197 data.mpp.plot.transition_time(out)
198 else:
199 raise ValueError(f"Unknown plot kind: {kind}")
202def draw_random_frames(mpp, data):
203 if mpp.Z is None:
204 mpp.load_Z(os.path.join(data.lumping_dir, "Z.npy"))
205 Path(os.path.join(data.lumping_dir + "random_frames/")).mkdir(
206 parents=True, exist_ok=True
207 )
208 mpp.topology_file = data.top
209 mpp.xtc_trajectory_file = data.xtc
210 mpp.draw_random_frames(
211 # os.path.join(data.lumping_dir + "random_frames/"), n=data.n_random_frames
212 Path(data.lumping_dir) / "random_frames/",
213 n=data.n_random_frames,
214 )
215 return mpp
218def write_random_frames_indices(mpp, out, n):
219 # Path(os.path.join(out)).mkdir(parents=True, exist_ok=True)
220 mpp.draw_random_frames_indices(Path(out), n)
223def parse_args():
224 parser = argparse.ArgumentParser(
225 prog="Perform MPP on MD simulation data",
226 description=(
227 "This program allows for the analysis of MD data utilizing the "
228 "most probable path algorithm. It allows for easy plotting of "
229 "different quality measures."
230 ),
231 )
232 parser.add_argument(
233 "data_specification",
234 help=(
235 "yaml file containing specification of files and parameters of "
236 "the simulation"
237 ),
238 type=argparse.FileType("r", encoding="latin-1"),
239 )
240 parser.add_argument("d", help=("dij to be used."))
241 parser.add_argument("g", help=("gij to be used."))
242 parser.add_argument(
243 "-o",
244 "--out",
245 help=("Override output directory set by config file"),
246 )
247 parser.add_argument(
248 "-Z",
249 help="Perform MPP and write the Z matrix.",
250 )
251 parser.add_argument(
252 "--rmsd",
253 help="Generate and write RMSD to file.",
254 )
255 parser.add_argument(
256 "--xtc-stride",
257 help="Read every nth frame.",
258 )
259 parser.add_argument(
260 "-r",
261 "--draw-random",
262 help="Draw N random frames for each macrostate",
263 metavar="N",
264 type=int,
265 )
266 parser.add_argument(
267 "-p",
268 "--plot",
269 help="Generate listed plots. Possible arguments include dendrogram, contacts, sankey, rmsd, macrostate_trajectory, timescales and more. (not yet implemented)",
270 )
271 parser.add_argument(
272 "--get-least-moving-residues",
273 help="Write least moving residues for each macrostate to a file.",
274 )
275 return parser.parse_args()
278def main():
279 args = parse_args()
281 # Parse input files
282 data = Data(args.data_specification.name)
283 data.setup_mpp(args.d, args.g)
284 if args.d == "gpcca":
285 data.perform_gpcca(args.g, args.Z)
286 else:
287 data.perform_mpp(args.Z)
289 if args.rmsd: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true
290 data.get_rmsd(args.rmsd, overwrite=True)
292 # for p in args.plot:
293 if args.plot:
294 plot(data, args.out, kind=args.plot)
296 if args.draw_random:
297 write_random_frames_indices(data.mpp, args.out, args.draw_random)
299 if args.get_least_moving_residues: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true
300 data.mpp.write_least_moving_residues(args.get_least_moving_residues, args.out)
303if __name__ == "__main__": 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 main()