Coverage for MPP/utils.py: 93%
206 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-22 11:26 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-22 11:26 +0200
1"""
2utils.py
3========
5Utilities for MPP.
6"""
8import numpy as np
9from numpy.typing import NDArray
10import mdtraj as md
11from tqdm import tqdm
14def translate_trajectory(
15 trajectory: NDArray[np.int_], map: NDArray[np.int_]
16) -> NDArray[np.int_]:
17 """
18 Transform trajectory to other state names.
20 trajectory (NDArray[np.int_]): original state trajectory
21 map (NDArray[np.int_]): index is original state, value at that position is
22 new value
24 returns translated trajectory
25 """
26 macrostates = np.unique(map)
27 if map.max() < 2**8: 27 ↛ 29line 27 didn't jump to line 29 because the condition on line 27 was always true
28 macrostate_trajectory_type = np.uint8
29 elif map.max() < 2**16:
30 macrostate_trajectory_type = np.uint16
31 else:
32 macrostate_trajectory_type = np.uint32
34 macrostate_trajectory = np.zeros(trajectory.shape, dtype=macrostate_trajectory_type)
35 for macrostate in macrostates:
36 macrostate_trajectory[np.isin(trajectory, np.where(map == macrostate)[0])] = (
37 macrostate
38 )
39 return macrostate_trajectory
42def macrostate_tmat(tmat, macrostate_assignment, pop):
43 """
44 transform a transition matrix from microstates to macrostates
45 """
46 n_macrostates = macrostate_assignment.shape[0]
47 m_tmat = np.zeros((n_macrostates, n_macrostates), dtype=tmat.dtype.type)
48 for i, ms in enumerate(macrostate_assignment):
49 for j, other_ms in enumerate(macrostate_assignment):
50 m_tmat[i, j] = (tmat[ms][:, other_ms] * np.expand_dims(pop[ms], -1)).sum()
51 return m_tmat / m_tmat.sum(axis=0)
54def get_grid_format(n):
55 sqrt = np.sqrt(n)
56 y = int(sqrt)
57 x = y
58 if x < sqrt:
59 x += 2
60 if (x - 1) * y >= n:
61 x -= 1
62 return x, y
65def gmrq(tmat):
66 # Generalized matrix Rayleigh quotient
67 q = np.zeros(len(tmat))
68 for i, t in enumerate(tmat):
69 val, vec = np.linalg.eig(t)
70 q[i] = val[:3].sum()
71 return q
74def Z_to_linkage(Z):
75 linkage = Z[:, :3].copy()
76 for i, row in enumerate(linkage):
77 mask = np.where(linkage[:, :2] == i + Z.shape[0] + 1)
78 linkage[:, :2][mask] = row[1]
79 linkage[:, :2] += 1
80 return linkage
83def linkage_to_Z(linkage, pop):
84 linkage = np.array(linkage)
85 n_states = linkage.shape[0] + 1
86 Z = np.zeros((linkage.shape[0], 4))
87 Z[:, :3] = linkage[:, :3]
88 Z[:, :2] -= 1
90 full_pop = np.zeros(2 * n_states - 1, dtype=pop.dtype.type)
91 full_pop[:n_states] = pop
92 for i in range(len(linkage[:-1])):
93 new_state = n_states + i
94 old_state = Z[i, 1]
95 full_pop[new_state] = full_pop[[Z[i, 0].astype(int), int(old_state)]].sum()
96 Z[i + 1 :, :2][np.where(Z[i + 1 :, :2] == old_state)] = new_state
97 full_pop[-1] = full_pop[Z[i + 1, :2].astype(int)].sum()
98 Z[:, 3] = full_pop[n_states:]
99 return Z, full_pop
102def merge_states(tmat, states, new_state, full_pop, reset_states=True):
103 full_pop[new_state] = full_pop[states].sum()
105 tmat[new_state] = (tmat[states] * full_pop[states, np.newaxis]).sum(
106 axis=0
107 ) / full_pop[new_state]
109 tmat[:, new_state] = tmat[:, states].sum(axis=1)
110 if reset_states:
111 tmat[:, states] = 0
112 tmat[states, :] = 0
113 return tmat, full_pop
116def calc_full_tmat(tmat, pop, Z):
117 """Calculate full tmat for a give Z matrix"""
118 # Ensure that Z is 3D
119 if Z.ndim == 2: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 Z = Z.reshape((1, *Z.shape))
122 # Initialize full_tmat and full_pop
123 n_states = tmat.shape[0]
124 full_dim = 2 * n_states - 1
125 n_runs = Z.shape[0]
126 full_tmat = np.empty((n_runs, full_dim, full_dim))
127 full_pop = np.empty((n_runs, full_dim), dtype=np.uint32)
129 full_tmat[:, :n_states, :n_states] = tmat
130 full_pop[:, :n_states] = pop
132 for run, z in enumerate(Z):
133 for i, (origin, target) in enumerate(z[:, :2].astype(int)):
134 full_tmat[run], full_pop[run] = merge_states(
135 full_tmat[run],
136 [origin, target],
137 n_states + i,
138 full_pop[run],
139 reset_states=False,
140 )
141 return full_tmat, full_pop
144def Z_to_mask(Z):
145 """
146 Calculate the mask for each lumping step.
147 Z (Nx4): Z matrix
148 """
149 n1 = Z.shape[0]
150 n = n1 + 1
151 m = np.zeros((n1, 2 * n - 1), dtype=bool)
152 m[0, :n] = True
153 for k, (i, j) in enumerate(Z[:-1, :2].astype(int)):
154 m[k + 1] = m[k]
155 m[k + 1, [i, j]] = False
156 m[k + 1, k + n] = True
157 return m
160def get_macrostate_assignment_from_tree(tree):
161 macrostate_order = [l.assigned_macrostate.name for l in tree.leaves]
162 macrostates = {l.assigned_macrostate for l in tree.leaves}
163 q_ma = np.array([(m.name, m.feature) for m in macrostates])
164 ma_order = np.argsort(q_ma[:, 1])[::-1]
165 # Dict to translate from n+i numbering to actual macrostate numbers.
166 full2real = {f: r for r, f in enumerate(q_ma[ma_order, 0])}
167 macrostate_assignment = np.full((len(macrostates), len(macrostate_order)), False)
168 macrostate_assignment[
169 [full2real[m] for m in macrostate_order], np.arange(len(macrostate_order))
170 ] = True
171 reorder_microstates = np.zeros(len(macrostate_order), dtype=int)
172 reorder_microstates[[l.name for l in tree.leaves]] = np.arange(
173 len(macrostate_order)
174 )
175 return macrostate_assignment[:, reorder_microstates]
178def similarity(ref, sto):
179 """Return similarity of two clusterings"""
180 # Similarity matrix
181 S = np.zeros((3, ref.n_macrostates[0], sto.n_runs))
183 for n_i in range(sto.n_runs):
184 ref_ma = ref.macrostate_assignment[0].astype(bool)
185 sto_ma = sto.macrostate_assignment[n_i].astype(bool)
186 for i in range(ref.n_macrostates[0]):
187 for j in range(sto.n_macrostates[n_i]):
188 intersect = (
189 np.logical_and(ref_ma[i], sto_ma[j])
190 * ref.full_pop[0, : ref.n_states]
191 ).sum()
192 union = (
193 np.logical_or(ref_ma[i], sto_ma[j])
194 * ref.full_pop[0, : ref.n_states]
195 ).sum()
196 # union
197 S[0, i, n_i] = max(S[0, i, n_i], intersect / union)
198 # reference
199 S[1, i, n_i] = max(
200 S[1, i, n_i],
201 intersect / (ref_ma[i] * ref.full_pop[0, : ref.n_states]).sum(),
202 )
203 # lumping
204 S[2, i, n_i] = max(
205 S[2, i, n_i],
206 intersect / (sto_ma[j] * ref.full_pop[0, : ref.n_states]).sum(),
207 )
208 return S
211def shannon_entropy(p):
212 p = p / sum(p)
213 return -(p * np.log(p)).sum() / np.log(p.shape[0])
216def weighting_function(dq):
217 if dq.shape[0] == 1:
218 return np.exp(-dq)
219 # sigma = np.sqrt(np.var(dq))
220 sigma2 = np.var(dq)
221 return np.exp(-(dq**2) / (2 * sigma2))
224### RMSD #####################################################################
227def load_trajectory(
228 topfile, trajectoryfile, atom_selection="all", frames=None, stride=None
229):
230 print("Loading trajectory...")
231 top = md.load_topology(topfile)
232 if frames is None: 232 ↛ 240line 232 didn't jump to line 240 because the condition on line 232 was always true
233 return md.load_xtc(
234 trajectoryfile,
235 top=top,
236 atom_indices=top.select(atom_selection),
237 stride=stride,
238 )
239 else:
240 return md.join(
241 [
242 md.load_xtc(
243 trajectoryfile,
244 top=top,
245 atom_indices=top.select(atom_selection),
246 frame=frame,
247 )
248 for frame in frames
249 ]
250 )
253def load_mean_frames(topfile, trajectoryfile, mean_frames, dt=0.1):
254 top = md.load_topology(topfile)
255 idxs = [int(frame.time[0]) / dt for frame in mean_frames]
256 trajectory = md.join(
257 [md.load_xtc(trajectoryfile, top=top, frame=frame) for frame in idxs]
258 )
259 return trajectory
262def find_mean_frame(trajectory):
263 mean_rmsd = np.array([estimate_rmsd(frame, trajectory) for frame in trajectory])
264 mean_frame = trajectory[np.argmin(mean_rmsd)]
265 return mean_frame
268def estimate_rmsd(frame, trajectory):
269 rmsd = md.rmsd(
270 trajectory,
271 frame,
272 )
273 return np.mean(rmsd)
276def align_trajectory_to_reference(trajectory, reference):
277 """
278 Aligns each frame in the trajectory array to the reference frame using the Kabsch algorithm.
280 Parameters:
281 - trajectory: numpy array of shape (N, 35, 3) where N is the number of frames.
282 - reference: numpy array of shape (1, 35, 3) representing the reference points.
284 Returns:
285 - aligned_trajectory: numpy array of shape (N, 35, 3) where each frame is aligned to the reference.
286 """
288 # Extract the reference frame (since reference is of shape (1, 35, 3), we need to squeeze it to (35, 3))
289 reference_frame = reference.squeeze()
291 # Compute the centroid (mean) of the reference points
292 reference_centroid = np.mean(reference_frame, axis=0)
294 # Center the reference points by subtracting the centroid
295 centered_reference = reference_frame - reference_centroid
297 # Initialize the array to store the aligned trajectory
298 aligned_trajectory = np.zeros_like(trajectory)
300 # Iterate through each frame in the trajectory
301 for i in range(trajectory.shape[0]):
302 # Extract the current frame
303 current_frame = trajectory[i]
305 # Compute the centroid of the current frame
306 frame_centroid = np.mean(current_frame, axis=0)
308 # Center the current frame by subtracting the centroid
309 centered_frame = current_frame - frame_centroid
311 # Compute the covariance matrix
312 H = np.dot(centered_frame.T, centered_reference)
314 # Compute the Singular Value Decomposition (SVD)
315 U, S, Vt = np.linalg.svd(H)
317 # Compute the optimal rotation matrix
318 R = np.dot(Vt.T, U.T)
320 # Handle special reflection case where the determinant of R is -1
321 if np.linalg.det(R) < 0: 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true
322 Vt[2, :] *= -1
323 R = np.dot(Vt.T, U.T)
325 # Apply the rotation to the centered frame
326 rotated_frame = np.dot(centered_frame, R)
328 # Re-add the reference centroid to align the trajectory in the reference coordinate system
329 aligned_trajectory[i] = rotated_frame + reference_centroid
331 return aligned_trajectory
334def calc_var(ref, trajectory):
335 """Calculate RMSD"""
336 aligned_trajectory = align_trajectory_to_reference(trajectory, ref)
337 d = ((aligned_trajectory - ref) ** 2).sum(axis=2)
338 return d.mean(axis=0)
341def opt_num_batches(n):
342 return int(np.cbrt(n**2 / 2))
345def calc_rmsd(lumping, quiet=False):
346 t = load_trajectory(
347 lumping.topology_file,
348 lumping.xtc_trajectory_file,
349 atom_selection="name CA",
350 stride=lumping.xtc_stride,
351 )
352 mean_frames = []
353 rmsd = np.empty([lumping.n_macrostates[lumping.n_i], t.n_atoms])
354 for j in range(lumping.n_macrostates[lumping.n_i]):
355 if not quiet: 355 ↛ 356line 355 didn't jump to line 356 because the condition on line 355 was never true
356 print(f"Process macrostate {j}")
357 m = lumping.macrostate_trajectory[lumping.n_i] == j
358 tm = t[m]
359 m_frames = []
360 n_batches = opt_num_batches(lumping.macrostate_population[lumping.n_i][j])
361 for i in tqdm(range(n_batches)) if not quiet else range(n_batches):
362 m_frames.append(find_mean_frame(tm[i::n_batches]))
363 mean_frames.append(find_mean_frame(md.join(m_frames)))
364 rmsd[j] = calc_var(mean_frames[j].xyz, tm.xyz)
365 return rmsd, mean_frames
368def find_state_lengths(arr):
369 # Lists to store unique states and their consecutive counts
370 unique_states = []
371 lengths = []
373 # Initialize the first state and its count
374 current_state = arr[0]
375 count = 1
377 # Iterate over the array from the second element onward
378 for value in arr[1:]:
379 if value == current_state:
380 # Increment count if the state is the same
381 count += 1
382 else:
383 # Append the state and its count when a new state is encountered
384 unique_states.append(current_state)
385 lengths.append(count)
386 # Update the current state and reset count
387 current_state = value
388 count = 1
390 # Append the last state and its count
391 unique_states.append(current_state)
392 lengths.append(count)
394 return np.array(unique_states), np.array(lengths)
397def get_multi_state_trajectory(trajectories: np.ndarray, limits: np.ndarray):
398 """Load trajectory containing several concatenated trajectories"""
399 if limits is None:
400 return trajectories
401 trajectory_collection = []
402 current_position = 0
403 for limit in limits:
404 trajectory_collection.append(
405 trajectories[current_position : int(current_position + limit)]
406 )
407 current_position += limit
408 return trajectory_collection