Coverage for src / molecular_simulations / analysis / cov_ppi.py: 49%
237 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-13 01:26 -0600
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-13 01:26 -0600
1from collections import defaultdict
2import json
3import matplotlib.pyplot as plt
4import MDAnalysis as mda
5from MDAnalysis.analysis.distances import distance_array
6from MDAnalysis.lib.util import convert_aa_code
7import numpy as np
8from pathlib import Path
9import polars as pl
10import seaborn as sns
11from typing import Callable, Union
13PathLike = Union[Path, str]
14Results = dict[str, dict[str, float]]
15TaskTree = tuple[list[Callable], list[str]]
17class PPInteractions:
18 """Code herein adapted from:
19 https://www.biorxiv.org/content/10.1101/2025.03.24.644990v1.full.pdf
20 Takes an input topology file and trajectory file, and highlights relevant
21 interactions between two selections. To this end we first compute the
22 covariance matrix between the two selections, filter out all interactions
23 which occur too far apart (11Å for positive covariance, 13Å for negative
24 covariance), and examines each based on a variety of distance and angle
25 cutoffs defined in the literature.
27 Arguments:
28 top (PathLike): Path to topology file.
29 traj (PathLike): Path to trajectory file.
30 out (PathLike): Path to outputs.
31 sel1 (str): Defaults to 'chainID A'. MDAnalysis selection string for the
32 first selection.
33 sel2 (str): Defaults to 'chainID B'. MDAnalysis selection string for the
34 second selection.
35 cov_cutoff (tuple[float]): Defaults to (11., 13.). Tuple of the distance
36 cutoffs to use for positive and negative covariance respectively.
37 sb_cutoff (float): Defaults to 6.0Å. Distance cutoff for salt bridges.
38 hbond_cutoff (float): Defaults to 3.5Å. Distance cutoff for hydrogen bonds.
39 hbond_angle (float): Defaults to 30.0 degrees. Angle cutoff for hydrogen bonds.
40 hydrophobic_cutoff (float): Defaults to 8.0Å. Distance cutoff for hydrophobic
41 interactions.
42 plot (bool): Defaults to True. Whether or not to plot results. Saves plots
43 at the output directory.
44 """
45 def __init__(self,
46 top: PathLike,
47 traj: PathLike,
48 out: PathLike,
49 sel1: str='chainID A',
50 sel2: str='chainID B',
51 cov_cutoff: tuple[float]=(11., 13.),
52 sb_cutoff: float=6.,
53 hbond_cutoff: float=3.5,
54 hbond_angle: float=30.,
55 hydrophobic_cutoff: float=8.,
56 plot: bool=True):
57 self.u = mda.Universe(top, traj)
58 self.n_frames = len(self.u.trajectory)
59 self.out = out
60 self.sel1 = sel1
61 self.sel2 = sel2
62 self.cov_cutoff = cov_cutoff
63 self.sb = sb_cutoff
64 self.hb_d = hbond_cutoff
65 self.hb_a = hbond_angle * 180 / np.pi
66 self.hydr = hydrophobic_cutoff
67 self.plot = plot
69 def run(self) -> None:
70 """Main function that runs the workflow. Obtains a covariance matrix,
71 screens for close interactions, evaluates each pairwise interaction
72 for each amino acid and report the contact probability of each.
74 Returns:
75 None
76 """
77 cov = self.get_covariance()
78 positive, negative = self.interpret_covariance(cov)
80 results = {'positive': {}, 'negative': {}}
81 for res1, res2 in positive:
82 data = self.compute_interactions(res1, res2)
83 results['positive'].update(data)
85 for res1, res2 in negative:
86 data = self.compute_interactions(res1, res2)
87 results['negative'].update(data)
89 self.save(results)
91 if self.plot:
92 self.plot_results(results)
94 def compute_interactions(self,
95 res1: int,
96 res2: int) -> Results:
97 """Ingests two resIDs, generates MDAnalysis AtomGroups for each, identifies
98 relevant non-bonded interactions (HBonds, saltbridge, hydrophobic) and
99 computes each. Returns a dict containing the proportion of simulation time
100 that each interaction is engaged.
102 Arguments:
103 res1 (int): ResID for a residue in sel1.
104 res2 (int): ResID for a residue in sel2.
106 Returns:
107 (Results): A nested dictionary containing the results of each interaction
108 type.
109 """
110 grp1 = self.u.select_atoms(f'{self.sel1} and resid {res1}')
111 grp2 = self.u.select_atoms(f'{self.sel2} and resid {res2}')
112 r1 = convert_aa_code(grp1.resnames[0])
113 r2 = convert_aa_code(grp2.resnames[0])
114 name = f'A_{r1}{res1}-B_{r2}{res2}'
116 data = {name: {label: 0. for label in ['hydrophobic', 'hbond', 'saltbridge']}}
117 function_calls, labels = self.identify_interaction_type(
118 grp1.resnames[0],
119 grp2.resnames[0]
120 )
122 for call, label in zip(function_calls, labels):
123 data[name][label] = call(grp1, grp2)
125 return data
127 def get_covariance(self) -> np.ndarray:
128 """
129 Loop over all C-alpha atoms and compute the positional
130 covariance using the functional form:
131 C = <(R1 - <R1>)(R2 - <R2>)T>
132 where each element corresponds to the ensemble average movement
133 C_ij = <deltaR_i * deltaR_j>
134 with the magnitude being the strength of correlation and the sign
135 corresponding to positive and negative correlation respectively.
137 Returns:
138 (np.ndarray): Covariance matrix.
139 """
140 p1_ca = self.u.select_atoms('chainID A and name CA')
141 N = p1_ca.n_residues
143 p2_ca = self.u.select_atoms('chainID B and name CA')
144 M = p2_ca.n_residues
146 self.res_map(p1_ca, p2_ca)
148 R1_avg = np.zeros((N, 3))
149 R2_avg = np.zeros((M, 3))
151 for ts in self.u.trajectory:
152 R1_avg += p1_ca.positions
153 R2_avg += p2_ca.positions
155 R1_avg /= self.n_frames
156 R2_avg /= self.n_frames
158 C = np.zeros((N, M))
160 for ts in self.u.trajectory:
161 R1 = p1_ca.positions
162 R2 = p2_ca.positions
164 dR1 = R1 - R1_avg
165 dR2 = R2 - R2_avg
167 for i in range(N):
168 for j in range(M):
169 C[i, j] += np.dot(dR1[i], dR2[j])
171 C /= self.n_frames
173 for i in range(N):
174 for j in range(M):
175 dist = np.linalg.norm(R1_avg[i] - R2_avg[j])
176 if C[i, j] > 0:
177 if dist > self.cov_cutoff[0]:
178 C[i, j] = 0.
179 elif dist > self.cov_cutoff[1]:
180 C[i, j] = 0.
182 return C
184 def res_map(self,
185 ag1: mda.AtomGroup,
186 ag2: mda.AtomGroup) -> None:
187 """Map covariance matrix indices to AtomGroup resIDs so that we are
188 examining the correct pairs of residues.
190 Arguments:
191 ag1 (mda.AtomGroup): AtomGroup of the first selection.
192 ag2 (mda.AtomGroup): AtomGroup of the second selection.
194 Returns:
195 None
196 """
197 mapping = {'ag1': {}, 'ag2': {}}
198 for i, resid in enumerate(ag1.resids):
199 mapping['ag1'][i] = resid
201 for i, resid in enumerate(ag2.resids):
202 mapping['ag2'][i] = resid
204 self.mapping = mapping
206 def interpret_covariance(self,
207 cov_mat: np.ndarray) -> tuple[tuple[int, int]]:
208 """Identify pairs of residues with positive or negative correlations.
209 Returns a tuple comprised of pairs for each.
211 Arguments:
212 cov_mat (np.ndarray): Covariance matrix.
214 Returns:
215 (tuple[tuple[int, int]]): Tuple of positively and negatively correlated
216 pairs of residues coming from each selection.
217 """
218 pos_corr = np.where(cov_mat > 0.)
219 neg_corr = np.where(cov_mat < 0.)
221 seen = set()
222 positive = list()
223 for i in range(len(pos_corr[0])):
224 res1 = self.mapping['ag1'][pos_corr[0][i]]
225 res2 = self.mapping['ag2'][pos_corr[1][i]]
226 if (res1, res2) not in seen:
227 positive.append((res1, res2))
228 seen.add((res1, res2))
229 seen.add((res2, res1))
231 negative = list()
232 for i in range(len(neg_corr[0])):
233 res1 = self.mapping['ag1'][neg_corr[0][i]]
234 res2 = self.mapping['ag2'][neg_corr[1][i]]
235 if (res1, res2) not in seen:
236 negative.append((res1, res2))
237 seen.add((res1, res2))
238 seen.add((res2, res1))
240 return positive, negative
242 def identify_interaction_type(self,
243 res1: str,
244 res2: str) -> TaskTree:
245 """Identifies what analyses to compute for a given pair of protein
246 residues (i.e. hydrophobic interactions, hydrogen bonds, saltbridges).
248 Arguments:
249 res1 (str): 3-letter code resname for a residue from selection 1.
250 res2 (str): 3-letter code resname for a residue from selection 2.
252 Returns:
253 (TaskTree): Tuple containing list of function calls and list of labels.
254 """
255 int_types = {
256 'TYR': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
257 'HIS': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
258 'HID': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
259 'HIE': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
260 'SER': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
261 'THR': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
262 'ASN': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
263 'GLN': {'funcs': [self.analyze_hbond], 'label': ['hbond']},
264 'ASP': {'funcs': [self.analyze_hbond, self.analyze_saltbridge],
265 'label': ['hbond', 'saltbridge']},
266 'GLU': {'funcs': [self.analyze_hbond, self.analyze_saltbridge],
267 'label': ['hbond', 'saltbridge']},
268 'LYS': {'funcs': [self.analyze_hbond, self.analyze_saltbridge],
269 'label': ['hbond', 'saltbridge']},
270 'ARG': {'funcs': [self.analyze_hbond, self.analyze_saltbridge],
271 'label': ['hbond', 'saltbridge']},
272 'HIP': {'funcs': [self.analyze_hbond, self.analyze_saltbridge],
273 'label': ['hbond', 'saltbridge']},
274 }
276 funcs = defaultdict(lambda: [[], []])
277 for res, calls in int_types.items():
278 funcs[res] = [calls['funcs'], calls['label']]
280 functions = [self.analyze_hydrophobic]
281 labels = ['hydrophobic']
282 for func, lab in zip(*funcs[res1]):
283 if func in funcs[res2][0]:
284 functions.append(func)
285 labels.append(lab)
287 return functions, labels
289 def analyze_saltbridge(self,
290 res1: mda.AtomGroup,
291 res2: mda.AtomGroup) -> float:
292 """Uses a simple distance cutoff to highlight the occupancy of
293 saltbridge between two residues. Returns the fraction of
294 simulation time spent engaged in saltbridge.
296 Arguments:
297 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1.
298 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2.
300 Returns:
301 (float): Proportion of simulation time spent in salt bridge.
302 """
303 pos = ['LYS', 'ARG']
304 neg = ['ASP', 'GLU']
305 name1 = res1.resnames[0]
306 name2 = res2.resnames[0]
307 if name1 not in pos + neg:
308 return 0.
309 elif name2 not in pos + neg:
310 return 0.
311 elif name1 in pos and name2 in pos:
312 return 0.
313 elif name1 in neg and name2 in neg:
314 return 0.
316 atom_names = ['NZ', 'NH1', 'NH2', 'OD1', 'OD2', 'OE1', 'OE2']
318 grp1 = self.u.select_atoms('resname DUMMY')
319 for atom in res1.atoms:
320 if atom.name in atom_names:
321 grp1 += atom
323 grp2 = self.u.select_atoms('resname DUMMY')
324 for atom in res2.atoms:
325 if atom.name in atom_names:
326 grp2 += atom
328 n_frames = 0
329 for ts in self.u.trajectory:
330 dist = np.linalg.norm(grp1.positions - grp2.positions)
331 if dist < self.sb:
332 n_frames += 1
334 return n_frames / self.n_frames
336 def analyze_hbond(self,
337 res1: mda.AtomGroup,
338 res2: mda.AtomGroup) -> float:
339 """Identifies all potential donor/acceptor atoms between two
340 residues. Culls this list based on distance array across simulation
341 and then evaluates each pair over the trajectory utilizing a
342 distance and angle cutoff.
344 Arguments:
345 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1.
346 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2.
348 Returns:
349 (float): Proportion of simulation time spent in hydrogen bond.
350 """
351 donors, acceptors = self.survey_donors_acceptors(res1, res2)
353 n_frames = 0
354 for ts in self.u.trajectory:
355 n_frames += self.evaluate_hbond(donors, acceptors)
357 return n_frames / self.n_frames
359 def analyze_hydrophobic(self,
360 res1: mda.AtomGroup,
361 res2: mda.AtomGroup) -> float:
362 """Uses a simple distance cutoff to highlight the occupancy of
363 hydrophobic interaction between two residues. Returns the fraction of
364 simulation time spent engaged in interaction.
366 Arguments:
367 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1.
368 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2.
370 Returns:
371 (float): Proportion of simulation time spent in interaction.
372 """
373 h1 = self.u.select_atoms('resname DUMMY')
374 h2 = self.u.select_atoms('resname DUMMY')
376 for atom in res1.atoms:
377 if 'C' in atom.type:
378 h1 += atom
380 for atom in res2.atoms:
381 if 'C' in atom.type:
382 h2 += atom
384 n_frames = 0
385 for ts in self.u.trajectory:
386 da = distance_array(h1, h2)
387 if np.min(da) < self.hydr:
388 n_frames += 1
390 return n_frames / self.n_frames
392 def survey_donors_acceptors(self,
393 res1: mda.AtomGroup,
394 res2: mda.AtomGroup) -> tuple[mda.AtomGroup]:
395 """First pass distance threshhold to identify potential Hydrogen bonds.
396 Should be followed by querying HBond angles but this serves to reduce
397 our search space and time complexity. Only returns donors/acceptors which
398 are within the distance cutoff in at least a single frame.
400 Arguments:
401 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1.
402 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2.
404 Returns:
405 (tuple[mda.AtomGroup]): Tuple of AtomGroups for residues which pass
406 crude distance cutoff for hydrogen bond donors/acceptors.
407 """
408 donors = self.u.select_atoms('resname DUMMY')
409 acceptors = self.u.select_atoms('resname DUMMY')
411 for atom in res1.atoms:
412 if any([a in atom.type for a in ['O', 'N']]):
413 if any(['H' in bond for bond in atom.bonded_atoms.types]):
414 donors += atom
415 acceptors += atom
417 for atom in res2.atoms:
418 if any([a in atom.type for a in ['O', 'N']]):
419 if any(['H' in bond for bond in atom.bonded_atoms.types]):
420 donors += atom
421 acceptors += atom
423 distances = distance_array(donors, acceptors)
424 contacts = np.where(distances < self.hb_d)
425 don_contacts = np.unique(contacts[0])
426 acc_contacts = np.unique(contacts[1])
428 return donors[don_contacts], acceptors[acc_contacts]
430 def evaluate_hbond(self,
431 donor: mda.AtomGroup,
432 acceptor: mda.AtomGroup) -> int:
433 """Evaluates whether there is a defined hydrogen bond between any
434 donor and acceptor atoms in a given frame. Must pass a distance
435 cutoff as well as an angle cutoff. Returns early when a legal
436 HBond is detected.
438 Arguments:
439 donor (mda.AtomGroup): AtomGroup of HBond donor.
440 acceptor (mda.AtomGroup): AtomGroup of HBond acceptor.
442 Returns:
443 (int): 1 if legal hbond found, else 0
444 """
445 for d in donor.atoms:
446 pos1 = d.position
447 hpos = [atom.position for atom in d.bonded_atoms if 'H' in atom.type]
448 for a in acceptor.atoms:
449 pos3 = a.position
451 if np.linalg.norm(pos3 - pos1) <= self.hb_d:
452 for pos2 in hpos:
453 v1 = pos2 - pos1
454 v2 = pos3 - pos2
456 v1 /= np.linalg.norm(v1)
457 v2 /= np.linalg.norm(v2)
459 if np.arccos(np.dot(v1, v2)) <= self.hb_a:
460 return 1
462 return 0
464 def save(self,
465 results: Results) -> None:
466 """Save results as a json file.
468 Arguments:
469 results (Results): Dictionary of results to be saved.
471 Returns:
472 None
473 """
474 with open(self.out, 'w') as fout:
475 json.dump(results, fout, indent=4)
477 def plot_results(self,
478 results: Results) -> None:
479 """Plot results.
481 Arguments:
482 results (Results): Dictionary of results to be plotted.
484 Returns:
485 None
486 """
487 df = self.parse_results(results)
489 plot = Path('plots')
490 plot.mkdir(exist_ok=True)
491 for cov_type in ['positive', 'negative']:
492 for int_type in ['Hydrophobic', 'Hydrogen Bond', 'Salt Bridge']:
493 data = df.filter(
494 (pl.col('Covariance') == cov_type) & (pl.col(int_type) > 0.)
495 )
497 if not data.empty:
498 name = f'{cov_type.capitalize()}_Covariance_'
499 name += f'{"_".join(int_type.split(" "))}.png'
501 self.make_plot(
502 data,
503 int_type,
504 plot / name
505 )
507 def parse_results(self,
508 results: Results) -> pl.DataFrame:
509 """Prepares results for plotting. Removes any entries which are
510 all 0. and returns as a pandas DataFrame for easier plotting.
512 Arguments:
513 results (Results): Dictionary of results to be prepped.
515 Returns:
516 (pl.DataFrame): Polars dataframe of results.
517 """
518 data_rows = []
519 for cov_type, pair_dict in results.items():
520 for pair, data in pair_dict.items():
521 if any(val > 0. for val in data.values()):
522 row = {
523 'Residue Pair': pair,
524 'Hydrophobic': data['hydrophobic'],
525 'Hydrogen Bond': data['hbond'],
526 'Salt Bridge': data['saltbridge'],
527 'Covariance': cov_type,
528 }
530 data_rows.append(row)
532 return pl.DataFrame(data_rows)
534 def make_plot(self,
535 data: pl.DataFrame,
536 column: str,
537 name: PathLike,
538 fs: int=15) -> None:
539 """Generates a seaborn barplot from a dataframe for a specified column.
541 Arguments:
542 data (pl.DataFrame): Polars dataframe of data.
543 column (str): Label for desired column.
544 name (PathLike): Path to file to save plot to.
545 fs (int): Defaults to 15. Size of font for plot.
547 Returns:
548 None
549 """
550 fig, ax = plt.subplots(1, 1, figsize=(6, 5))
552 sns.barplot(data=data, x='Residue Pair', y=column, ax=ax)
554 ax.set_xlabel('Residue Pair', fontsize=fs)
555 ax.set_ylabel('Probability', fontsize=fs)
556 ax.set_title(column, fontsize=fs+2)
557 ax.tick_params(labelsize=fs)
558 ax.tick_params(axis='x', rotation=45)
560 plt.tight_layout()
561 plt.savefig(str(name), dpi=300)