Coverage for src / molecular_simulations / analysis / cov_ppi.py: 49%

237 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-13 01:26 -0600

1from collections import defaultdict 

2import json 

3import matplotlib.pyplot as plt 

4import MDAnalysis as mda 

5from MDAnalysis.analysis.distances import distance_array 

6from MDAnalysis.lib.util import convert_aa_code 

7import numpy as np 

8from pathlib import Path 

9import polars as pl 

10import seaborn as sns 

11from typing import Callable, Union 

12 

13PathLike = Union[Path, str] 

14Results = dict[str, dict[str, float]] 

15TaskTree = tuple[list[Callable], list[str]] 

16 

17class PPInteractions: 

18 """Code herein adapted from:  

19 https://www.biorxiv.org/content/10.1101/2025.03.24.644990v1.full.pdf 

20 Takes an input topology file and trajectory file, and highlights relevant 

21 interactions between two selections. To this end we first compute the  

22 covariance matrix between the two selections, filter out all interactions 

23 which occur too far apart (11Å for positive covariance, 13Å for negative 

24 covariance), and examines each based on a variety of distance and angle 

25 cutoffs defined in the literature. 

26 

27 Arguments: 

28 top (PathLike): Path to topology file. 

29 traj (PathLike): Path to trajectory file. 

30 out (PathLike): Path to outputs. 

31 sel1 (str): Defaults to 'chainID A'. MDAnalysis selection string for the 

32 first selection. 

33 sel2 (str): Defaults to 'chainID B'. MDAnalysis selection string for the 

34 second selection. 

35 cov_cutoff (tuple[float]): Defaults to (11., 13.). Tuple of the distance 

36 cutoffs to use for positive and negative covariance respectively. 

37 sb_cutoff (float): Defaults to 6.0Å. Distance cutoff for salt bridges. 

38 hbond_cutoff (float): Defaults to 3.5Å. Distance cutoff for hydrogen bonds. 

39 hbond_angle (float): Defaults to 30.0 degrees. Angle cutoff for hydrogen bonds. 

40 hydrophobic_cutoff (float): Defaults to 8.0Å. Distance cutoff for hydrophobic 

41 interactions. 

42 plot (bool): Defaults to True. Whether or not to plot results. Saves plots 

43 at the output directory. 

44 """ 

45 def __init__(self, 

46 top: PathLike, 

47 traj: PathLike, 

48 out: PathLike, 

49 sel1: str='chainID A', 

50 sel2: str='chainID B', 

51 cov_cutoff: tuple[float]=(11., 13.), 

52 sb_cutoff: float=6., 

53 hbond_cutoff: float=3.5, 

54 hbond_angle: float=30., 

55 hydrophobic_cutoff: float=8., 

56 plot: bool=True): 

57 self.u = mda.Universe(top, traj) 

58 self.n_frames = len(self.u.trajectory) 

59 self.out = out 

60 self.sel1 = sel1 

61 self.sel2 = sel2 

62 self.cov_cutoff = cov_cutoff 

63 self.sb = sb_cutoff 

64 self.hb_d = hbond_cutoff 

65 self.hb_a = hbond_angle * 180 / np.pi 

66 self.hydr = hydrophobic_cutoff 

67 self.plot = plot 

68 

69 def run(self) -> None: 

70 """Main function that runs the workflow. Obtains a covariance matrix, 

71 screens for close interactions, evaluates each pairwise interaction 

72 for each amino acid and report the contact probability of each. 

73 

74 Returns: 

75 None 

76 """ 

77 cov = self.get_covariance() 

78 positive, negative = self.interpret_covariance(cov) 

79 

80 results = {'positive': {}, 'negative': {}} 

81 for res1, res2 in positive: 

82 data = self.compute_interactions(res1, res2) 

83 results['positive'].update(data) 

84 

85 for res1, res2 in negative: 

86 data = self.compute_interactions(res1, res2) 

87 results['negative'].update(data) 

88 

89 self.save(results) 

90 

91 if self.plot: 

92 self.plot_results(results) 

93 

94 def compute_interactions(self, 

95 res1: int, 

96 res2: int) -> Results: 

97 """Ingests two resIDs, generates MDAnalysis AtomGroups for each, identifies 

98 relevant non-bonded interactions (HBonds, saltbridge, hydrophobic) and 

99 computes each. Returns a dict containing the proportion of simulation time 

100 that each interaction is engaged. 

101 

102 Arguments: 

103 res1 (int): ResID for a residue in sel1. 

104 res2 (int): ResID for a residue in sel2. 

105 

106 Returns: 

107 (Results): A nested dictionary containing the results of each interaction 

108 type. 

109 """ 

110 grp1 = self.u.select_atoms(f'{self.sel1} and resid {res1}') 

111 grp2 = self.u.select_atoms(f'{self.sel2} and resid {res2}') 

112 r1 = convert_aa_code(grp1.resnames[0]) 

113 r2 = convert_aa_code(grp2.resnames[0]) 

114 name = f'A_{r1}{res1}-B_{r2}{res2}' 

115 

116 data = {name: {label: 0. for label in ['hydrophobic', 'hbond', 'saltbridge']}} 

117 function_calls, labels = self.identify_interaction_type( 

118 grp1.resnames[0], 

119 grp2.resnames[0] 

120 ) 

121 

122 for call, label in zip(function_calls, labels): 

123 data[name][label] = call(grp1, grp2) 

124 

125 return data 

126 

127 def get_covariance(self) -> np.ndarray: 

128 """ 

129 Loop over all C-alpha atoms and compute the positional 

130 covariance using the functional form: 

131 C = <(R1 - <R1>)(R2 - <R2>)T> 

132 where each element corresponds to the ensemble average movement 

133 C_ij = <deltaR_i * deltaR_j> 

134 with the magnitude being the strength of correlation and the sign 

135 corresponding to positive and negative correlation respectively. 

136 

137 Returns: 

138 (np.ndarray): Covariance matrix. 

139 """ 

140 p1_ca = self.u.select_atoms('chainID A and name CA') 

141 N = p1_ca.n_residues 

142 

143 p2_ca = self.u.select_atoms('chainID B and name CA') 

144 M = p2_ca.n_residues 

145 

146 self.res_map(p1_ca, p2_ca) 

147 

148 R1_avg = np.zeros((N, 3)) 

149 R2_avg = np.zeros((M, 3)) 

150 

151 for ts in self.u.trajectory: 

152 R1_avg += p1_ca.positions 

153 R2_avg += p2_ca.positions 

154 

155 R1_avg /= self.n_frames 

156 R2_avg /= self.n_frames 

157 

158 C = np.zeros((N, M)) 

159 

160 for ts in self.u.trajectory: 

161 R1 = p1_ca.positions 

162 R2 = p2_ca.positions 

163 

164 dR1 = R1 - R1_avg 

165 dR2 = R2 - R2_avg 

166 

167 for i in range(N): 

168 for j in range(M): 

169 C[i, j] += np.dot(dR1[i], dR2[j]) 

170 

171 C /= self.n_frames 

172 

173 for i in range(N): 

174 for j in range(M): 

175 dist = np.linalg.norm(R1_avg[i] - R2_avg[j]) 

176 if C[i, j] > 0: 

177 if dist > self.cov_cutoff[0]: 

178 C[i, j] = 0. 

179 elif dist > self.cov_cutoff[1]: 

180 C[i, j] = 0. 

181 

182 return C 

183 

184 def res_map(self, 

185 ag1: mda.AtomGroup, 

186 ag2: mda.AtomGroup) -> None: 

187 """Map covariance matrix indices to AtomGroup resIDs so that we are 

188 examining the correct pairs of residues. 

189 

190 Arguments: 

191 ag1 (mda.AtomGroup): AtomGroup of the first selection. 

192 ag2 (mda.AtomGroup): AtomGroup of the second selection. 

193 

194 Returns: 

195 None 

196 """ 

197 mapping = {'ag1': {}, 'ag2': {}} 

198 for i, resid in enumerate(ag1.resids): 

199 mapping['ag1'][i] = resid 

200 

201 for i, resid in enumerate(ag2.resids): 

202 mapping['ag2'][i] = resid 

203 

204 self.mapping = mapping 

205 

206 def interpret_covariance(self, 

207 cov_mat: np.ndarray) -> tuple[tuple[int, int]]: 

208 """Identify pairs of residues with positive or negative correlations. 

209 Returns a tuple comprised of pairs for each. 

210 

211 Arguments: 

212 cov_mat (np.ndarray): Covariance matrix. 

213 

214 Returns: 

215 (tuple[tuple[int, int]]): Tuple of positively and negatively correlated 

216 pairs of residues coming from each selection. 

217 """ 

218 pos_corr = np.where(cov_mat > 0.) 

219 neg_corr = np.where(cov_mat < 0.) 

220 

221 seen = set() 

222 positive = list() 

223 for i in range(len(pos_corr[0])): 

224 res1 = self.mapping['ag1'][pos_corr[0][i]] 

225 res2 = self.mapping['ag2'][pos_corr[1][i]] 

226 if (res1, res2) not in seen: 

227 positive.append((res1, res2)) 

228 seen.add((res1, res2)) 

229 seen.add((res2, res1)) 

230 

231 negative = list() 

232 for i in range(len(neg_corr[0])): 

233 res1 = self.mapping['ag1'][neg_corr[0][i]] 

234 res2 = self.mapping['ag2'][neg_corr[1][i]] 

235 if (res1, res2) not in seen: 

236 negative.append((res1, res2)) 

237 seen.add((res1, res2)) 

238 seen.add((res2, res1)) 

239 

240 return positive, negative 

241 

242 def identify_interaction_type(self, 

243 res1: str, 

244 res2: str) -> TaskTree: 

245 """Identifies what analyses to compute for a given pair of protein 

246 residues (i.e. hydrophobic interactions, hydrogen bonds, saltbridges). 

247 

248 Arguments: 

249 res1 (str): 3-letter code resname for a residue from selection 1. 

250 res2 (str): 3-letter code resname for a residue from selection 2. 

251 

252 Returns: 

253 (TaskTree): Tuple containing list of function calls and list of labels. 

254 """ 

255 int_types = { 

256 'TYR': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

257 'HIS': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

258 'HID': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

259 'HIE': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

260 'SER': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

261 'THR': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

262 'ASN': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

263 'GLN': {'funcs': [self.analyze_hbond], 'label': ['hbond']}, 

264 'ASP': {'funcs': [self.analyze_hbond, self.analyze_saltbridge], 

265 'label': ['hbond', 'saltbridge']}, 

266 'GLU': {'funcs': [self.analyze_hbond, self.analyze_saltbridge], 

267 'label': ['hbond', 'saltbridge']}, 

268 'LYS': {'funcs': [self.analyze_hbond, self.analyze_saltbridge], 

269 'label': ['hbond', 'saltbridge']}, 

270 'ARG': {'funcs': [self.analyze_hbond, self.analyze_saltbridge], 

271 'label': ['hbond', 'saltbridge']}, 

272 'HIP': {'funcs': [self.analyze_hbond, self.analyze_saltbridge], 

273 'label': ['hbond', 'saltbridge']}, 

274 } 

275 

276 funcs = defaultdict(lambda: [[], []]) 

277 for res, calls in int_types.items(): 

278 funcs[res] = [calls['funcs'], calls['label']] 

279 

280 functions = [self.analyze_hydrophobic] 

281 labels = ['hydrophobic'] 

282 for func, lab in zip(*funcs[res1]): 

283 if func in funcs[res2][0]: 

284 functions.append(func) 

285 labels.append(lab) 

286 

287 return functions, labels 

288 

289 def analyze_saltbridge(self, 

290 res1: mda.AtomGroup, 

291 res2: mda.AtomGroup) -> float: 

292 """Uses a simple distance cutoff to highlight the occupancy of  

293 saltbridge between two residues. Returns the fraction of 

294 simulation time spent engaged in saltbridge. 

295 

296 Arguments: 

297 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1. 

298 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2. 

299 

300 Returns: 

301 (float): Proportion of simulation time spent in salt bridge. 

302 """ 

303 pos = ['LYS', 'ARG'] 

304 neg = ['ASP', 'GLU'] 

305 name1 = res1.resnames[0] 

306 name2 = res2.resnames[0] 

307 if name1 not in pos + neg: 

308 return 0. 

309 elif name2 not in pos + neg: 

310 return 0. 

311 elif name1 in pos and name2 in pos: 

312 return 0. 

313 elif name1 in neg and name2 in neg: 

314 return 0. 

315 

316 atom_names = ['NZ', 'NH1', 'NH2', 'OD1', 'OD2', 'OE1', 'OE2'] 

317 

318 grp1 = self.u.select_atoms('resname DUMMY') 

319 for atom in res1.atoms: 

320 if atom.name in atom_names: 

321 grp1 += atom 

322 

323 grp2 = self.u.select_atoms('resname DUMMY') 

324 for atom in res2.atoms: 

325 if atom.name in atom_names: 

326 grp2 += atom 

327 

328 n_frames = 0 

329 for ts in self.u.trajectory: 

330 dist = np.linalg.norm(grp1.positions - grp2.positions) 

331 if dist < self.sb: 

332 n_frames += 1 

333 

334 return n_frames / self.n_frames 

335 

336 def analyze_hbond(self, 

337 res1: mda.AtomGroup, 

338 res2: mda.AtomGroup) -> float: 

339 """Identifies all potential donor/acceptor atoms between two 

340 residues. Culls this list based on distance array across simulation 

341 and then evaluates each pair over the trajectory utilizing a 

342 distance and angle cutoff. 

343 

344 Arguments: 

345 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1. 

346 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2. 

347 

348 Returns: 

349 (float): Proportion of simulation time spent in hydrogen bond. 

350 """ 

351 donors, acceptors = self.survey_donors_acceptors(res1, res2) 

352 

353 n_frames = 0 

354 for ts in self.u.trajectory: 

355 n_frames += self.evaluate_hbond(donors, acceptors) 

356 

357 return n_frames / self.n_frames 

358 

359 def analyze_hydrophobic(self, 

360 res1: mda.AtomGroup, 

361 res2: mda.AtomGroup) -> float: 

362 """Uses a simple distance cutoff to highlight the occupancy of  

363 hydrophobic interaction between two residues. Returns the fraction of 

364 simulation time spent engaged in interaction. 

365 

366 Arguments: 

367 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1. 

368 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2. 

369 

370 Returns: 

371 (float): Proportion of simulation time spent in interaction. 

372 """ 

373 h1 = self.u.select_atoms('resname DUMMY') 

374 h2 = self.u.select_atoms('resname DUMMY') 

375 

376 for atom in res1.atoms: 

377 if 'C' in atom.type: 

378 h1 += atom 

379 

380 for atom in res2.atoms: 

381 if 'C' in atom.type: 

382 h2 += atom 

383 

384 n_frames = 0 

385 for ts in self.u.trajectory: 

386 da = distance_array(h1, h2) 

387 if np.min(da) < self.hydr: 

388 n_frames += 1 

389 

390 return n_frames / self.n_frames 

391 

392 def survey_donors_acceptors(self, 

393 res1: mda.AtomGroup, 

394 res2: mda.AtomGroup) -> tuple[mda.AtomGroup]: 

395 """First pass distance threshhold to identify potential Hydrogen bonds. 

396 Should be followed by querying HBond angles but this serves to reduce 

397 our search space and time complexity. Only returns donors/acceptors which 

398 are within the distance cutoff in at least a single frame. 

399 

400 Arguments: 

401 res1 (mda.AtomGroup): AtomGroup for a residue from selection 1. 

402 res2 (mda.AtomGroup): AtomGroup for a residue from selection 2. 

403 

404 Returns: 

405 (tuple[mda.AtomGroup]): Tuple of AtomGroups for residues which pass 

406 crude distance cutoff for hydrogen bond donors/acceptors. 

407 """ 

408 donors = self.u.select_atoms('resname DUMMY') 

409 acceptors = self.u.select_atoms('resname DUMMY') 

410 

411 for atom in res1.atoms: 

412 if any([a in atom.type for a in ['O', 'N']]): 

413 if any(['H' in bond for bond in atom.bonded_atoms.types]): 

414 donors += atom 

415 acceptors += atom 

416 

417 for atom in res2.atoms: 

418 if any([a in atom.type for a in ['O', 'N']]): 

419 if any(['H' in bond for bond in atom.bonded_atoms.types]): 

420 donors += atom 

421 acceptors += atom 

422 

423 distances = distance_array(donors, acceptors) 

424 contacts = np.where(distances < self.hb_d) 

425 don_contacts = np.unique(contacts[0]) 

426 acc_contacts = np.unique(contacts[1]) 

427 

428 return donors[don_contacts], acceptors[acc_contacts] 

429 

430 def evaluate_hbond(self, 

431 donor: mda.AtomGroup, 

432 acceptor: mda.AtomGroup) -> int: 

433 """Evaluates whether there is a defined hydrogen bond between any 

434 donor and acceptor atoms in a given frame. Must pass a distance 

435 cutoff as well as an angle cutoff. Returns early when a legal 

436 HBond is detected. 

437 

438 Arguments: 

439 donor (mda.AtomGroup): AtomGroup of HBond donor. 

440 acceptor (mda.AtomGroup): AtomGroup of HBond acceptor. 

441 

442 Returns: 

443 (int): 1 if legal hbond found, else 0 

444 """ 

445 for d in donor.atoms: 

446 pos1 = d.position 

447 hpos = [atom.position for atom in d.bonded_atoms if 'H' in atom.type] 

448 for a in acceptor.atoms: 

449 pos3 = a.position 

450 

451 if np.linalg.norm(pos3 - pos1) <= self.hb_d: 

452 for pos2 in hpos: 

453 v1 = pos2 - pos1 

454 v2 = pos3 - pos2 

455 

456 v1 /= np.linalg.norm(v1) 

457 v2 /= np.linalg.norm(v2) 

458 

459 if np.arccos(np.dot(v1, v2)) <= self.hb_a: 

460 return 1 

461 

462 return 0 

463 

464 def save(self, 

465 results: Results) -> None: 

466 """Save results as a json file. 

467 

468 Arguments: 

469 results (Results): Dictionary of results to be saved. 

470 

471 Returns: 

472 None 

473 """ 

474 with open(self.out, 'w') as fout: 

475 json.dump(results, fout, indent=4) 

476 

477 def plot_results(self, 

478 results: Results) -> None: 

479 """Plot results. 

480 

481 Arguments: 

482 results (Results): Dictionary of results to be plotted. 

483 

484 Returns: 

485 None 

486 """ 

487 df = self.parse_results(results) 

488 

489 plot = Path('plots') 

490 plot.mkdir(exist_ok=True) 

491 for cov_type in ['positive', 'negative']: 

492 for int_type in ['Hydrophobic', 'Hydrogen Bond', 'Salt Bridge']: 

493 data = df.filter( 

494 (pl.col('Covariance') == cov_type) & (pl.col(int_type) > 0.) 

495 ) 

496 

497 if not data.empty: 

498 name = f'{cov_type.capitalize()}_Covariance_' 

499 name += f'{"_".join(int_type.split(" "))}.png' 

500 

501 self.make_plot( 

502 data, 

503 int_type, 

504 plot / name 

505 ) 

506 

507 def parse_results(self, 

508 results: Results) -> pl.DataFrame: 

509 """Prepares results for plotting. Removes any entries which are 

510 all 0. and returns as a pandas DataFrame for easier plotting. 

511  

512 Arguments: 

513 results (Results): Dictionary of results to be prepped. 

514 

515 Returns: 

516 (pl.DataFrame): Polars dataframe of results. 

517 """ 

518 data_rows = [] 

519 for cov_type, pair_dict in results.items(): 

520 for pair, data in pair_dict.items(): 

521 if any(val > 0. for val in data.values()): 

522 row = { 

523 'Residue Pair': pair, 

524 'Hydrophobic': data['hydrophobic'], 

525 'Hydrogen Bond': data['hbond'], 

526 'Salt Bridge': data['saltbridge'], 

527 'Covariance': cov_type, 

528 } 

529 

530 data_rows.append(row) 

531 

532 return pl.DataFrame(data_rows) 

533 

534 def make_plot(self, 

535 data: pl.DataFrame, 

536 column: str, 

537 name: PathLike, 

538 fs: int=15) -> None: 

539 """Generates a seaborn barplot from a dataframe for a specified column. 

540 

541 Arguments: 

542 data (pl.DataFrame): Polars dataframe of data. 

543 column (str): Label for desired column. 

544 name (PathLike): Path to file to save plot to. 

545 fs (int): Defaults to 15. Size of font for plot. 

546 

547 Returns: 

548 None 

549 """ 

550 fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 

551 

552 sns.barplot(data=data, x='Residue Pair', y=column, ax=ax) 

553 

554 ax.set_xlabel('Residue Pair', fontsize=fs) 

555 ax.set_ylabel('Probability', fontsize=fs) 

556 ax.set_title(column, fontsize=fs+2) 

557 ax.tick_params(labelsize=fs) 

558 ax.tick_params(axis='x', rotation=45) 

559 

560 plt.tight_layout() 

561 plt.savefig(str(name), dpi=300)