Coverage for src / molecular_simulations / simulate / mmpbsa.py: 13%
543 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 10:07 -0600
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 10:07 -0600
1from concurrent.futures import as_completed, ThreadPoolExecutor, wait, ALL_COMPLETED
2from dataclasses import dataclass
3import json
4import logging
5import os
6import pandas as pd
7from pathlib import Path
8import polars as pl
9import re
10import subprocess
11import time
12from typing import Literal, Optional, Union
14# This is simply to enable higher level parallelism by parsl/academy
15# Numpy by default allows all threads to be used and in agentic settings
16# we have seen oversubscription of threads and some calculations fail to
17# to write out. These settings must be set BEFORE importing numpy
18os.environ.setdefault('OPENBLAS_NUM_THREADS', '1')
19os.environ.setdefault('MKL_NUM_THREADS', '1')
20os.environ.setdefault('OMP_NUM_THREADS', '1')
21import numpy as np
23PathLike = Union[Path, str]
25logger = logging.getLogger(__name__)
27def _run_energy_calculation(args: tuple[str],
28 max_retries: int=3) -> tuple[Path, bool, str]:
29 """Worker function for parallel energy calculations.
30 Must be module level for ThreadPoolExecutor pickling.
32 Args:
33 args (tuple): Tuple of (mmpbsa_binary, mdin_path, prmtop, pdb, traj_chunk, output_path, cwd)
34 max_retries (int): Number of retry attempts for failed calculations
36 Returns:
37 (Path): Path to the output file.
38 """
39 mmpbsa_binary, mdin, prm, pdb, trj, out, cwd = args
40 cmd = f'{mmpbsa_binary} -O -i {mdin} -p {prm} -c {pdb} -y {trj} -o {out}'
42 expected_output = Path(cwd) / out
44 for attempt in range(max_retries):
45 try:
46 result = subprocess.run(cmd, shell=True, cwd=str(cwd),
47 capture_output=True, text=True)
49 if expected_output.exists() and expected_output.stat().st_size > 0:
50 with open(expected_output, 'r') as f:
51 content = f.read()
52 if ' BOND' in content:
53 return (out, True, '')
54 else:
55 error = 'Output file exists but contains no energy data'
56 else:
57 error = 'Output file missing or empty after subprocess complete'
59 if result.returncode != 0:
60 error = f'Return code: {result.returncode}: {result.stderr or result.stdout}'
62 except subprocess.TimeoutExpired:
63 error = 'Calculation timed out'
64 except Exception as e:
65 error = f'Exception: {e}'
67 if attempt < max_retries - 1:
68 logger.warning(f'Energy calculation {out} failed (attempt {attempt + 1}/{max_retries})')
69 logger.warning(f'Error: {error}')
70 time.sleep(2 ** attempt)
72 return (out, False, error)
74def _run_sasa_calculation(args: tuple[str],
75 max_retries: int=3) -> tuple[Path, bool, str]:
76 """Worker function for parallel SASA calculations.
77 Must be module level for ThreadPoolExecutor pickling."""
78 cpptraj_binary, sasa_script, cwd = args
80 # Parse expected output from script
81 script_path = Path(sasa_script)
82 with open(script_path, 'r') as f:
83 script_content = f.read()
85 match = re.search(r'molsurf\s+.*?\s+out\s+(\S+)', script_content)
86 if match:
87 expected_output = Path(cwd) / match.group(1)
88 else:
89 expected_output = None
91 for attempt in range(max_retries):
92 try:
93 result = subprocess.run(f'{cpptraj_binary} -i {sasa_script}', shell=True, cwd=str(cwd),
94 capture_output=True, text=True)
95 if expected_output and expected_output.exists():
96 if expected_output.stat().st_size > 0:
97 with open(expected_output, 'r') as f:
98 lines = f.readlines()
99 data_lines = [l for l in lines if l.strip() and not l.strip().startswith('#')]
100 if len(data_lines) > 0:
101 return (sasa_script, True, '')
102 else:
103 error = 'Output file has no data lines'
104 else:
105 error = 'Expected output file is empty'
106 else:
107 error = f'Expected output file not found: {expected_output}'
109 if result.returncode != 0:
110 error = f'Return code {result.returncode}: {result.stderr or result.stdout}'
112 except subprocess.TimeoutExpired:
113 error = 'Calculation timed out'
114 except Exception as e:
115 error = f'Exception: {e}'
117 if attempt < max_retries - 1:
118 logger.warning(f'SASA calculation {sasa_script} failed (attempt {attempt+1}/{max_retries})')
119 time.sleep(2 ** attempt)
121 return (script_path, False, error)
123@dataclass
124class MMPBSA_settings:
125 top: PathLike
126 dcd: PathLike
127 selections: list[str]
128 first_frame: int = 0
129 last_frame: int = -1
130 stride: int = 1
131 n_cpus: int = 1
132 out: str = 'mmpbsa'
133 solvent_probe: float = 1.4
134 offset: int = 0
135 gb_surften: float=0.0072
136 gb_surfoff: float=0.
138class MMPBSA(MMPBSA_settings):
139 """
140 This is an experiment in patience. What follows is a reconstruction of the various
141 pieces of code that run MM-P(G)BSA from AMBER but written in a more digestible manner
142 with actual documentation. Herein we have un-CLI'd what should never have been a
143 CLI and piped together the correct pieces of the ambertools ecosystem to perform
144 MM-P(G)BSA and that alone. Your trajectory is required to be concatenated into a single
145 continuous trajectory - or you can run this serially over each by instancing this class
146 for each trajectory you have. In this way we have also disentangled the requirement to
147 parallelize by use of MPI, allowing the user to choose their own parallelization/scaling
148 scheme.
150 Arguments:
151 top (PathLike): Input topology for a solvated system. Should match the input trajectory.
152 dcd (PathLike): Input trajectory. Can be DCD format or MDCRD already.
153 selections (list[str]): A list of residue ID selections for the receptor and ligand
154 in that order. Should be formatted for cpptraj (e.g. `:1-10`).
155 first_frame (int): Defaults to 0. The first frame of the input trajectory to begin
156 the calculations on.
157 last_frame (int): Defaults to -1. Optional final frame to cut trajectory at. If -1,
158 acts as a flag to run the whole trajectory.
159 stride (int): Defaults to 1. The number of frames to stride the trajectory by.
160 n_cpus (int): Number of parallel processes
161 out (str): The prefix name or path for output files.
162 solvent_probe (float): Defaults to 1.4Å. The probe radius to use for SA calculations.
163 offset (int): Defaults to 0Å. I don't know what this does.
164 gb_surften (float): Defaults to 0.0072.
165 gb_surfoff (float): Defaults to 0.0.
166 parallel_mode (str): 'frame' for frame-level parallelization (recommended),
167 'system' for system-level parallelization,
168 'hybrid' for both (most aggressive)
169 """
170 def __init__(self,
171 top: PathLike,
172 dcd: PathLike,
173 selections: list[str],
174 first_frame: int=0,
175 last_frame: int=-1,
176 stride: int=1,
177 n_cpus: int=1,
178 out: str='mmpbsa',
179 solvent_probe: float=1.4,
180 offset: int=0,
181 gb_surften: float=0.0072,
182 gb_surfoff: float=0.,
183 amberhome: Optional[str]=None,
184 parallel_mode: Literal['frame', 'serial'] = 'frame',
185 **kwargs):
186 super().__init__(top=top,
187 dcd=dcd,
188 selections=selections,
189 first_frame=first_frame,
190 last_frame=last_frame,
191 stride=stride,
192 n_cpus=n_cpus,
193 out=out,
194 solvent_probe=solvent_probe,
195 offset=offset,
196 gb_surften=gb_surften,
197 gb_surfoff=gb_surfoff)
198 self.parallel_mode = parallel_mode
199 self.top = Path(self.top).resolve()
200 self.traj = Path(self.dcd).resolve()
201 self.path = self.top.parent
202 if out == 'mmpbsa':
203 self.path = self.path / 'mmpbsa'
204 else:
205 self.path = Path(out).resolve()
207 self.path.mkdir(exist_ok=True, parents=True)
209 self.cpptraj = 'cpptraj'
210 self.mmpbsa_py_energy = 'mmpbsa_py_energy'
211 if amberhome is None: # we are overriding AMBERHOME or using another env's install
212 if 'AMBERHOME' in os.environ:
213 amberhome = os.environ['AMBERHOME']
214 else:
215 raise ValueError('AMBERHOME not set in env vars!')
217 self.cpptraj = Path(amberhome) / 'bin' / self.cpptraj
218 self.mmpbsa_py_energy = Path(amberhome) / 'bin' / self.mmpbsa_py_energy
220 self.fh = FileHandler(
221 top=self.top,
222 traj=self.traj,
223 path=self.path,
224 sels=self.selections,
225 first=self.first_frame,
226 last=self.last_frame,
227 stride=self.stride,
228 cpptraj_binary=self.cpptraj,
229 n_chunks=self.n_cpus
230 )
232 self.analyzer = OutputAnalyzer(
233 path=self.path,
234 surface_tension=self.gb_surften,
235 sasa_offset=self.gb_surfoff
236 )
238 for key, value in kwargs.items():
239 setattr(self, key, value)
242 def run(self) -> None:
243 """
244 Main logic of MM-PBSA with parallelization.
246 Depending on parallel_mode:
247 - 'frame': Splits trajectory into chunks, processes in parallel
248 """
249 logger.debug(f'Preparing MM-PBSA calculation with {self.n_cpus} CPUs (mode: {self.parallel_mode})')
250 gb_mdin, pb_mdin = self.write_mdins()
252 if self.parallel_mode == 'frame':
253 self._run_frame_parallel(gb_mdin, pb_mdin)
254 else:
255 # Fallback to serial
256 self._run_serial(gb_mdin, pb_mdin)
258 logger.debug('Collating results.')
259 self.analyzer.parse_outputs()
261 self.free_energy = self.analyzer.free_energy
263 def _run_serial(self, gb_mdin: Path, pb_mdin: Path) -> None:
264 """Original serial implementation."""
265 for (prefix, top, traj, pdb) in self.fh.files:
266 logger.debug(f'Computing energy terms for {prefix.name}.')
267 self.calculate_sasa(prefix, top, traj)
268 self.calculate_energy(prefix, top, traj, pdb, gb_mdin, 'gb')
269 self.calculate_energy(prefix, top, traj, pdb, pb_mdin, 'pb')
271 def _run_frame_parallel(self, gb_mdin: Path, pb_mdin: Path) -> None:
272 """
273 Frame-level parallelization: split trajectory into chunks, process in parallel.
274 This provides the best speedup for long trajectories.
275 """
276 # Collect all calculation tasks
277 energy_tasks = []
278 sasa_tasks = []
280 for (prefix, top, traj_chunks, pdb) in self.fh.files_chunked:
281 system_name = prefix.name
282 logger.debug(f'Preparing parallel energy calculations for {system_name}.')
284 # SASA calculations for each chunk
285 for i, traj_chunk in enumerate(traj_chunks):
286 sasa_script = self._write_sasa_script(prefix, top, traj_chunk, chunk_idx=i)
287 sasa_tasks.append((str(self.cpptraj), str(sasa_script), str(self.path)))
289 # Energy calculations for each chunk
290 for i, traj_chunk in enumerate(traj_chunks):
291 # GB calculation
292 out_gb = f'{system_name}_chunk{i}_gb.mdout'
293 energy_tasks.append((
294 str(self.mmpbsa_py_energy), str(gb_mdin), str(top),
295 str(pdb), str(traj_chunk), out_gb, str(self.path)
296 ))
297 # PB calculation
298 out_pb = f'{system_name}_chunk{i}_pb.mdout'
299 energy_tasks.append((
300 str(self.mmpbsa_py_energy), str(pb_mdin), str(top),
301 str(pdb), str(traj_chunk), out_pb, str(self.path)
302 ))
304 # Run SASA calculations in parallel
305 logger.debug(f'Running {len(sasa_tasks)} SASA calculations in parallel.')
306 sasa_failures = []
307 with ThreadPoolExecutor(max_workers=self.n_cpus) as executor:
308 futures = []
309 for task in sasa_tasks:
310 futures.append(executor.submit(_run_sasa_calculation, task))
312 logger.debug(f'Submitted {len(futures)} SASA futures, waiting for completion...')
314 # Wait for ALL to complete before proceeding
315 done, _ = wait(futures, return_when=ALL_COMPLETED)
317 for future in done:
318 script, success, error = future.result()
319 if not success:
320 sasa_failures.append((script, error))
321 logger.error(f'SASA calculation failed: {script}: {error[:300]}')
323 if sasa_failures:
324 failed_scripts = [f[0] for f in sasa_failures]
325 raise RuntimeError(f'{len(sasa_failures)} SASA calculations failed: {failed_scripts}')
326 logger.debug('All SASA calculations completed successfully')
328 # Combine SASA results
329 self._combine_sasa_chunks()
331 # Run Energy calculations in parallel
332 logger.debug(f'Running {len(energy_tasks)} energy calculations in parallel.')
333 energy_failures = []
334 with ThreadPoolExecutor(max_workers=self.n_cpus) as executor:
335 futures = []
336 for task in energy_tasks:
337 futures.append(executor.submit(_run_energy_calculation, task))
339 logger.debug(f'Submitted {len(futures)} Energy futures, waiting for completion...')
341 # Wait for ALL to complete before proceeding
342 done, _ = wait(futures, return_when=ALL_COMPLETED)
344 for future in done:
345 script, success, error = future.result()
346 if not success:
347 energy_failures.append((script, error))
348 logger.error(f'Energy calculation failed: {script}: {error[:300]}')
350 if energy_failures:
351 failed_scripts = [f[0] for f in energy_failures]
352 raise RuntimeError(f'{len(energy_failures)} Energy calculations failed: {failed_scripts}')
353 logger.debug('All Energy calculations completed successfully')
355 # Combine Energy results
356 self._combine_energy_chunks()
358 self._verify_combined_outputs()
360 def _write_sasa_script(self, prefix: Path, prm: Path, trj: Path,
361 chunk_idx: int = 0) -> Path:
362 """Write a SASA calculation script for a trajectory chunk."""
363 sasa = self.path / f'sasa_{prefix.name}_chunk{chunk_idx}.in'
364 out_file = f'{prefix.name}_chunk{chunk_idx}_surf.dat'
365 sasa_in = [
366 f'parm {prm}',
367 f'trajin {trj}',
368 f'molsurf :* out {out_file} probe {self.solvent_probe} offset {self.offset}',
369 'run',
370 'quit'
371 ]
372 self.fh.write_file(sasa_in, sasa)
373 return sasa
375 def _combine_sasa_chunks(self) -> None:
376 """Combine SASA results from all chunks into single files (in correct frame order)."""
377 def extract_chunk_idx(filepath: Path) -> int:
378 """Extract chunk index from filename for proper numerical sorting."""
379 match = re.search(r'_chunk(\d+)_', filepath.name)
380 return int(match.group(1)) if match else 0
382 for system in ['complex', 'receptor', 'ligand']:
383 combined_data = []
384 chunk_files = list(self.path.glob(f'{system}_chunk*_surf.dat'))
385 # Sort numerically by chunk index, not lexicographically
386 chunk_files.sort(key=extract_chunk_idx)
388 for chunk_file in chunk_files:
389 with open(chunk_file) as f:
390 lines = f.readlines()
391 if combined_data:
392 # Skip header for subsequent chunks
393 combined_data.extend(lines[1:])
394 else:
395 combined_data.extend(lines)
397 # Write combined file
398 output = self.path / f'{system}_surf.dat'
399 with open(output, 'w') as f:
400 f.writelines(combined_data)
402 # Clean up chunk files
403 for chunk_file in chunk_files:
404 chunk_file.unlink()
406 def _combine_energy_chunks(self) -> None:
407 """Combine energy results from all chunks into single files (in correct frame order)."""
408 def extract_chunk_idx(filepath: Path) -> int:
409 """Extract chunk index from filename for proper numerical sorting."""
410 match = re.search(r'_chunk(\d+)_', filepath.name)
411 return int(match.group(1)) if match else 0
413 for system in ['complex', 'receptor', 'ligand']:
414 for level in ['gb', 'pb']:
415 combined_data = []
416 chunk_files = list(self.path.glob(f'{system}_chunk*_{level}.mdout'))
417 # Sort numerically by chunk index, not lexicographically
418 chunk_files.sort(key=extract_chunk_idx)
420 for chunk_file in chunk_files:
421 with open(chunk_file) as f:
422 content = f.read()
423 combined_data.append(content)
425 # Write combined file (mdout format allows concatenation of frame data)
426 output = self.path / f'{system}_{level}.mdout'
427 with open(output, 'w') as f:
428 f.write('\n'.join(combined_data))
430 # Clean up chunk files
431 for chunk_file in chunk_files:
432 chunk_file.unlink()
434 def calculate_sasa(self,
435 pre: str,
436 prm: PathLike,
437 trj: PathLike) -> None:
438 """
439 Runs the molsurf command in cpptraj to compute the SASA of a given system.
441 Arguments:
442 pre (str): Prefix for output SASA file.
443 prm (PathLike): Path to prmtop file.
444 trj (PathLike): Path to CRD trajectory file.
446 Returns:
447 None
448 """
449 sasa = self.fh.path / 'sasa.in'
450 sasa_in = [
451 f'parm {prm}',
452 f'trajin {trj}',
453 f'molsurf :* out {pre}_surf.dat probe {self.solvent_probe} offset {self.offset}',
454 'run',
455 'quit'
456 ]
458 self.fh.write_file(sasa_in, sasa)
460 subprocess.run(f'{self.cpptraj} -i {sasa}', shell=True, cwd=str(self.path),
461 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
462 sasa.unlink()
464 def calculate_energy(self,
465 pre: str,
466 prm: PathLike,
467 trj: PathLike,
468 pdb: PathLike,
469 mdin: PathLike,
470 suf: str) -> None:
471 """
472 Runs mmpbsa_py_energy, an undocumented binary file which somehow mysteriously
473 computes the energy of a system. This software is not only undocumented but is
474 a binary which we cannot inspect ourselves.
476 Arguments:
477 pre (str): Prefix for output file.
478 prm (PathLike): Path to prmtop file.
479 trj (PathLike): Path to CRD trajectory file.
480 pdb (PathLike): Path to PDB file.
481 mdin (PathLike): Configuration file for the program.
482 suf (str): Suffix for output file.
484 Returns:
485 None
486 """
487 cmd = f'{self.mmpbsa_py_energy} -O -i {mdin} -p {prm} -c {pdb} -y {trj} -o {pre}_{suf}.mdout'
488 subprocess.run(cmd, shell=True, cwd=str(self.path),
489 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
491 def write_mdins(self) -> tuple[Path, Path]:
492 """
493 Writes out the configuration files that are to be fed to mmpbsa_py_energy.
494 These are also undocumented and I took the parameters from the location
495 in which they are hardcoded in ambertools.
497 Returns:
498 (tuple[Path, Path]): Tuple of paths to the GB mdin and the PB mdin.
499 """
500 gb = self.fh.path / 'gb_mdin'
501 gb_mdin = [
502 'GB',
503 'igb = 2',
504 'extdiel = 78.3',
505 'saltcon = 0.10',
506 f'surften = {self.gb_surften}',
507 'rgbmax = 25.0'
508 ]
510 self.fh.write_file(gb_mdin, gb)
512 pb = self.fh.path / 'pb_mdin'
513 pb_mdin = [
514 'PB',
515 'inp = 2',
516 'smoothopt = 1',
517 'radiopt = 0',
518 'npbopt = 0',
519 'solvopt = 1',
520 'maxitn = 1000',
521 'nfocus = 2',
522 'bcopt = 5',
523 'eneopt = 2',
524 'fscale = 8',
525 'epsin = 1.0',
526 'epsout = 80.0',
527 'istrng = 0.10',
528 'dprob = 1.4',
529 'iprob = 2.0',
530 'accept = 0.001',
531 'fillratio = 4.0',
532 'space = 0.5',
533 'cutnb = 0',
534 'sprob = 0.557',
535 'cavity_surften = 0.0378',
536 'cavity_offset = -0.5692'
537 ]
539 self.fh.write_file(pb_mdin, pb)
541 return gb, pb
543 def _verify_combined_outputs(self) -> None:
544 """Verify all required output files exist and have valid content.
545 Raises RuntimeError if any files are missing or invalid"""
546 missing_files = []
547 empty_files = []
548 invalid_files = []
550 for system in ['complex', 'receptor', 'ligand']:
551 # Check SASA file
552 sasa_file = self.path / f'{system}_surf.dat'
553 if not sasa_file.exists():
554 missing_files.append(str(sasa_file))
555 elif sasa_file.stat().st_size == 0:
556 empty_files.append(str(sasa_file))
557 else:
558 # Verify it has data lines
559 with open(sasa_file) as f:
560 data_lines = [l for l in f if l.strip() and not l.strip().startswith('#')]
561 if len(data_lines) == 0:
562 invalid_files.append(f'{sasa_file} (no data lines)')
564 # Check energy files
565 for level in ['gb', 'pb']:
566 energy_file = self.path / f'{system}_{level}.mdout'
567 if not energy_file.exists():
568 missing_files.append(str(energy_file))
569 elif energy_file.stat().st_size == 0:
570 empty_files.append(str(energy_file))
571 else:
572 # Verify it has BOND lines (energy data)
573 with open(energy_file) as f:
574 content = f.read()
575 if ' BOND' not in content:
576 invalid_files.append(f'{energy_file} (no energy data)')
578 errors = []
579 if missing_files:
580 errors.append(f"Missing files: {missing_files}")
581 if empty_files:
582 errors.append(f"Empty files: {empty_files}")
583 if invalid_files:
584 errors.append(f"Invalid files: {invalid_files}")
586 if errors:
587 raise RuntimeError(f"Output verification failed: {'; '.join(errors)}")
590class OutputAnalyzer:
591 """
592 Analyzes the outputs from an MM-PBSA run. Stores data in a Polars dataframe
593 internally, and writes out data in the form of json/plain text.
594 """
595 def __init__(self,
596 path: PathLike,
597 surface_tension: float=0.0072,
598 sasa_offset: float=0.,
599 _tolerance: float = 0.005,
600 log: bool=True):
601 self.path = Path(path)
602 self.surften = surface_tension
603 self.offset = sasa_offset
604 self.tolerance = _tolerance
605 self.log = log
607 self.free_energy = None
609 self.systems = ['receptor', 'ligand', 'complex']
610 self.levels = ['gb', 'pb']
612 self.solvent_contributions = ['EGB', 'ESURF', 'EPB', 'ECAVITY']
614 def parse_outputs(self) -> None:
615 """
616 Parse all the output files.
618 Returns:
619 None
620 """
621 self.gb = pl.DataFrame()
622 self.pb = pl.DataFrame()
624 for system in self.systems:
625 E_sasa = self.read_sasa(self.path / f'{system}_surf.dat')
626 E_gb = self.read_GB(self.path / f'{system}_gb.mdout', system)
627 E_pb = self.read_PB(self.path / f'{system}_pb.mdout', system)
629 E_gb = E_gb.drop('ESURF').with_columns(E_sasa)
631 self.gb = pl.concat([self.gb, E_gb], how='vertical')
632 self.pb = pl.concat([self.pb, E_pb], how='vertical')
634 all_cols = list(set(self.gb.columns + self.pb.columns))
635 self.contributions = {
636 'G gas': [col for col in all_cols
637 if col not in self.solvent_contributions],
638 'G solv': [col for col in all_cols
639 if col in self.solvent_contributions]
640 }
642 self.check_bonded_terms()
643 self.generate_summary()
644 self.compute_dG()
646 def read_sasa(self,
647 _file: PathLike) -> np.ndarray:
648 """
649 Reads in the results of the cpptraj SASA calculation and returns the
650 per-frame SASA scaled by a hardcoded value for surface tension that is
651 a mostly undocumented heuristic.
653 Arguments:
654 _file (PathLike): Path to a file containing the SASA data.
656 Returns:
657 (np.ndarray): A numpy array of the per-frame rescaled SASA energies.
658 """
659 df = pd.read_csv(_file, sep='\s+') # read in dataframe
660 sasa = df.iloc[:, -1].to_numpy(dtype=float) * self.surften + self.offset
662 return pl.Series('ESURF', sasa)
664 def read_GB(self,
665 _file: PathLike,
666 system: str) -> pl.DataFrame:
667 """
668 Read in the GB mdout files and returns a Polars dataframe of the values
669 for each term for every frame. Also adds a `system` label to more easily
670 compute summary statistics later.
672 Arguments:
673 _file (PathLike): Energy data file path.
674 system (str): String label for which system we are processing (e.g. complex).
676 Returns:
677 (pl.DataFrame): Polars dataframe containing the parsed energy data.
678 """
679 gb_terms = ['BOND', 'ANGLE', 'DIHED', 'VDWAALS', 'EEL',
680 'EGB', '1-4 VDW', '1-4 EEL', 'RESTRAINT', 'ESURF']
681 data = {gb_term: [] for gb_term in gb_terms}
683 lines = open(_file, 'r').readlines()
685 return self.parse_energy_file(lines, data, system)
687 def read_PB(self,
688 _file: PathLike,
689 system: str) -> pl.DataFrame:
690 """
691 Read in the PB mdout files and returns a Polars dataframe of the values
692 for each term for every frame. Also adds a `system` label to more easily
693 compute summary statistics later.
695 Arguments:
696 _file (PathLike): Energy data file path.
697 system (str): String label for which system we are processing (e.g. complex).
699 Returns:
700 (pl.DataFrame): Polars dataframe containing the parsed energy data.
701 """
702 pb_terms = ['BOND', 'ANGLE', 'DIHED', 'VDWAALS', 'EEL',
703 'EPB', '1-4 VDW', '1-4 EEL', 'RESTRAINT',
704 'ECAVITY', 'EDISPER']
705 data = {pb_term: [] for pb_term in pb_terms}
707 lines = open(_file, 'r').readlines()
709 return self.parse_energy_file(lines, data, system)
711 def parse_energy_file(self, file_contents: list[str],
712 data: dict[str, list], system: str) -> pl.DataFrame:
713 """Parse energy file contents."""
714 for line in file_contents:
715 if '=' in line and any(key in line for key in data.keys()):
716 parsed = self.parse_line(line)
717 for key, val in parsed:
718 if key in data:
719 data[key].append(val)
721 df = pl.DataFrame(data)
722 df = df.with_columns(pl.lit(system).alias('system'))
723 return df
725 def parse_energy_file_OLD(self,
726 file_contents: list[str],
727 data: dict[str, list],
728 system: str) -> pl.DataFrame:
729 """
730 Parses the contents of an energy calculation using a dictionary of
731 energy terms to extract theory-level observables (e.g. EGB vs EPB).
733 Arguments:
734 file_contents (list[str]): A list of each line from an energy calculation.
735 data (dict[str, list]): The relevant energy terms to be scraped from input.
736 system (str): The name of the system which will be included as an additional
737 kv pair in the returned dataframe. This ensures we can track which portion
738 of the calculation we are accounting for (e.g. complex, receptor, ligand).
739 Returns:
740 (pl.DataFrame): A Polars dataframe of shape (n_frames, n_calculations + system).
741 """
742 idx = 0
743 n_frames = 0
744 while idx < len(file_contents):
745 if file_contents[idx].startswith(' BOND'):
746 for _ in range(4): # number of lines to read. DO NOT CHANGE!!!
747 line = file_contents[idx]
748 parsed = self.parse_line(line)
749 for key, val in parsed:
750 data[key].append(val)
752 idx += 1
754 if 'Processing frame' in file_contents[idx]:
755 n_frames = int(file_contents[idx].strip().split()[-1])
757 idx +=1
759 data['system'] = [system] * n_frames
761 return pl.DataFrame(
762 {key: np.array(val) for key, val in data.items()}
763 )
765 def check_bonded_terms(self) -> None:
766 """
767 Performs a sanity check on the bonded terms which should perfectly cancel out
768 (e.g. complex = receptor + ligand). If this is not the case something horrible
769 has happened and we can't trust the non-bonded energies either. Additionally
770 sets a few terms we will need later such as the number of frames as given by
771 the dataframe height and sqrt(n_frames).
773 Returns:
774 None
775 """
776 bonded = ['BOND', 'ANGLE', 'DIHED', '1-4 VDW', '1-4 EEL']
778 for theory_level in (self.gb, self.pb):
779 a = theory_level.filter(pl.col('system') == 'receptor')
780 b = theory_level.filter(pl.col('system') == 'ligand')
781 c = theory_level.filter(pl.col('system') == 'complex')
783 a = a.select(pl.col([col for col in a.columns if col in bonded])).to_numpy()
784 b = b.select(pl.col([col for col in b.columns if col in bonded])).to_numpy()
785 c = c.select(pl.col([col for col in c.columns if col in bonded])).to_numpy()
787 diffs = np.array(c - b - a)
788 if np.where(diffs >= self.tolerance)[0].size > 0:
789 raise ValueError('Bonded terms for receptor + ligand != complex!')
791 remove = ['RESTRAINT', 'EDISPER']
792 self.gb = self.gb.select(
793 pl.col([col for col in self.gb.columns if col not in remove])
794 )
795 self.pb = self.pb.select(
796 pl.col([col for col in self.pb.columns if col not in remove])
797 )
799 self.n_frames = self.gb.height
800 self.square_root_N = np.sqrt(self.n_frames)
802 def generate_summary(self) -> None:
803 """
804 Summarizes all processed energy data into a single polars dataframe
805 and dumps it to a json file.
807 Returns:
808 None
809 """
810 full_statistics = {sys: {} for sys in self.systems}
811 for theory, level in zip([self.gb, self.pb], self.levels):
812 for system in self.systems:
813 sys = theory.filter(pl.col('system') == system).drop('system')
815 stats = {}
816 for col in sys.columns:
817 mean = sys.select(pl.mean(col)).item()
818 stdev = sys.select(pl.std(col)).item()
820 stats[col] = {'mean': mean,
821 'std': stdev,
822 'err': stdev / self.square_root_N}
824 for energy, contributors in self.contributions.items():
825 pooled_data = sys.select(
826 pl.col([col for col in sys.columns if col in contributors])
827 ).to_numpy().flatten()
829 stats[energy] = {'mean': np.mean(pooled_data),
830 'std': np.std(pooled_data),
831 'err': np.std(pooled_data) / self.square_root_N}
833 total_data = sys.to_numpy().flatten()
834 stats['total'] = {'mean': np.mean(total_data),
835 'std': np.std(total_data),
836 'err': np.std(total_data) / self.square_root_N}
838 full_statistics[system][level] = stats
840 with open('statistics.json', 'w') as fout:
841 json.dump(full_statistics, fout, indent=4)
843 def compute_dG(self) -> None:
844 """
845 For each energy dataframe (GB/PB) compute the ∆G of binding by subtracting out
846 relevant contributions in accordance with how this is done under the hood of the
847 MMPBSA code.
849 Returns:
850 None
851 """
852 differences = []
853 for theory, level in zip([self.gb, self.pb], self.levels):
854 diff_cols = [col for col in theory.columns if col != 'system']
855 diff_arr = theory.filter(pl.col('system') == 'complex').drop('system').to_numpy()
856 for system in self.systems[:2]:
857 diff_arr -= theory.filter(pl.col('system') == system).drop('system').to_numpy()
859 means = np.mean(diff_arr, axis=0)
860 stds = np.std(diff_arr, axis=0)
861 errs = stds / self.square_root_N
863 gas_solv_phase = []
864 for energy, contributors in self.contributions.items():
865 indices = [i for i, diff_col in enumerate(diff_cols)
866 if diff_col in contributors]
867 contribution = np.sum(diff_arr[:, indices], axis=1)
868 gas_solv_phase.append(contribution)
870 diff_cols.append(energy)
871 means = np.concatenate((means, [np.mean(contribution)]))
872 stds = np.concatenate((stds, [np.std(contribution)]))
873 errs = np.concatenate((errs, [np.std(contribution) / self.square_root_N]))
875 diff_cols.append('∆G Binding')
876 total = np.sum(np.vstack(gas_solv_phase), axis=0)
878 means = np.concatenate((means, [np.mean(total)]))
879 stds = np.concatenate((stds, [np.std(total)]))
880 errs = np.concatenate((errs, [np.std(total) / self.square_root_N]))
882 data = np.vstack((means, stds, errs))
884 differences.append(pl.DataFrame(
885 {diff_cols[i]: data[:,i] for i in range(len(diff_cols))}
886 ))
888 self.pretty_print(differences)
890 def pretty_print(self,
891 dfs: list[pl.DataFrame]) -> None:
892 """
893 Ingests a list of Polars dataframes for GB and PB and prints their contents
894 in a human-readable form to STDIN. Also saves out the energies to a plain
895 text file called `deltaG.txt`.
897 Arguments:
898 dfs (list[pl.DataFrame]): List of dataframes for GB and PB.
900 Returns:
901 None
902 """
903 print_statement = []
904 log_statement = []
905 for df, level in zip(dfs, ['Generalized Born ', 'Poisson Boltzmann']):
906 print_statement += [
907 f'{" ":<20}=========================',
908 f'{" ":<20}=== {level} ===',
909 f'{" ":<20}=========================',
910 'Energy Component Average Std. Dev. Std. Err. of Mean',
911 '---------------------------------------------------------------------'
912 ]
913 for col in df.columns:
914 mean, std, err = [x.item() for x in df.select(pl.col(col)).to_numpy()]
915 report = f'{col:<20}{mean:<16.3f}{std:<16.3f}{err:<16.3f}'
916 if abs(mean) <= self.tolerance:
917 continue
919 if col in ['G gas', '∆G Binding']:
920 print_statement.append('')
922 if col == '∆G Binding':
923 log_statement.append(f'{level.strip()}:')
924 log_statement.append(report)
926 if level == 'Poisson Boltzmann':
927 self.free_energy = [mean, std]
929 print_statement.append(report)
931 print_statement = '\n'.join(print_statement)
932 with open(self.path / 'deltaG.txt', 'w') as fout:
933 fout.write(print_statement)
935 if self.log:
936 for statement in log_statement:
937 logging.info(statement)
938 else:
939 print(print_statement)
941 @staticmethod
942 def parse_line(line) -> tuple[list[str], list[float]]:
943 """
944 Parses a line from mmpbsa_energy to get the various energy terms and values.
946 Returns:
947 (tuple[list[str], list[float]]): A tuple containing the list of energy
948 term names and corresponding energy values.
949 """
950 eq_split = line.split('=')
952 if len(eq_split) == 2:
953 splits = [eq_spl.strip() for eq_spl in eq_split]
954 else:
955 splits = [eq_split[0].strip()]
957 for i in range(1, len(eq_split) - 1):
958 splits += [spl.strip() for spl in eq_split[i].strip().split(' ')]
960 splits += [eq_split[-1].strip()]
962 keys = splits[::2]
963 vals = np.array(splits[1::2], dtype=float)
965 return zip(keys, vals)
968class FileHandler:
969 """
970 Performs preprocessing for MM-PBSA runs and manages the pathing to all file
971 inputs. Additionally used to write out various cpptraj input files by the
972 MMPBSA class.
973 """
974 def __init__(self,
975 top: Path,
976 traj: Path,
977 path: Path,
978 sels: list[str],
979 first: int,
980 last: int,
981 stride: int,
982 cpptraj_binary: PathLike,
983 n_chunks: int=1):
984 self.top = top
985 self.traj = traj
986 self.path = path
987 self.selections = sels
988 self.ff = first
989 self.lf = last
990 self.stride = stride
991 self.cpptraj = cpptraj_binary
992 self.n_chunks = n_chunks
994 self.prepare_topologies()
995 self.prepare_trajectories()
997 self.trajectory_chunks = {}
999 if n_chunks > 1:
1000 self._count_frames()
1001 self._split_trajectories()
1002 else:
1003 for system, traj in zip(['complex', 'receptor', 'ligand'], self.trajectories):
1004 self.trajectory_chunks[system] = [traj]
1006 def prepare_topologies(self) -> None:
1007 """
1008 Slices out each sub-topology for the desolvated complex, receptor and
1009 ligand using cpptraj due to the difficulty of working with AMBER FF
1010 files otherwise (including PARMED).
1012 Returns:
1013 None
1014 """
1015 self.topologies = [
1016 self.path / 'complex.prmtop',
1017 self.path / 'receptor.prmtop',
1018 self.path / 'ligand.prmtop'
1019 ]
1021 cpptraj_in = [
1022 f'parm {self.top}',
1023 'parmstrip :Na+,Cl-,WAT',
1024 'parmbox nobox',
1025 f'parmwrite out {self.topologies[0]}',
1026 'run',
1027 'clear all',
1028 f'parm {self.topologies[0]}',
1029 f'parmstrip {self.selections[0]}',
1030 f'parmwrite out {self.topologies[1]}',
1031 'run',
1032 'clear all',
1033 f'parm {self.topologies[0]}',
1034 f'parmstrip {self.selections[1]}',
1035 f'parmwrite out {self.topologies[2]}',
1036 'run',
1037 'quit'
1038 ]
1040 script = self.path / 'cpptraj.in'
1041 self.write_file('\n'.join(cpptraj_in), script)
1042 subprocess.call(f'{self.cpptraj} -i {script}', shell=True, cwd=str(self.path),
1043 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
1044 script.unlink()
1046 def prepare_trajectories(self) -> None:
1047 """
1048 Converts DCD trajectory to AMBER CRD format which is explicitly
1049 required by MM-G(P)BSA.
1051 Returns:
1052 None
1053 """
1054 self.trajectories = [path.with_suffix('.crd') for path in self.topologies]
1055 self.pdbs = [path.with_suffix('.pdb') for path in self.topologies]
1057 frame_control = f'start {self.ff}'
1059 if self.lf > -1:
1060 frame_control += f' stop {self.lf}'
1062 frame_control += f' offset {self.stride}'
1064 cpptraj_in = [
1065 f'parm {self.top}',
1066 f'trajin {self.traj}',
1067 f'trajout {self.traj.with_suffix(".crd")} crd {frame_control}',
1068 'run',
1069 'clear all',
1070 ]
1072 self.traj = self.traj.with_suffix('.crd')
1074 cpptraj_in += [
1075 f'parm {self.top}',
1076 f'trajin {self.traj}',
1077 'strip :WAT,Na+,Cl*',
1078 'autoimage',
1079 f'rmsd !(:WAT,Cl*,CIO,Cs+,IB,K*,Li+,MG*,Na+,Rb+,CS,RB,NA,F,CL) mass first',
1080 f'trajout {self.trajectories[0]} crd nobox',
1081 f'trajout {self.pdbs[0]} pdb onlyframes 1',
1082 'run',
1083 'clear all',
1084 f'parm {self.topologies[0]}',
1085 f'trajin {self.trajectories[0]}',
1086 f'strip {self.selections[0]}',
1087 f'trajout {self.trajectories[1]} crd',
1088 f'trajout {self.pdbs[1]} pdb onlyframes 1',
1089 'run',
1090 'clear all',
1091 f'parm {self.topologies[0]}',
1092 f'trajin {self.trajectories[0]}',
1093 f'strip {self.selections[1]}',
1094 f'trajout {self.trajectories[2]} crd',
1095 f'trajout {self.pdbs[2]} pdb onlyframes 1',
1096 'run',
1097 'quit'
1098 ]
1100 name = self.path / 'mdcrd.in'
1101 self.write_file('\n'.join(cpptraj_in), name)
1102 subprocess.call(f'{self.cpptraj} -i {name}', shell=True, cwd=str(self.path),
1103 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
1105 name.unlink()
1107 def _count_frames(self) -> None:
1108 """Count the total number of frames in the processed trajectory."""
1109 # Use cpptraj to count frames
1110 count_script = self.path / 'count_frames.in'
1111 count_out = self.path / 'frame_count.dat'
1113 script_content = [
1114 f'parm {self.topologies[0]}',
1115 f'trajin {self.trajectories[0]}',
1116 f'trajinfo {self.trajectories[0]} name myinfo',
1117 'run',
1118 'quit'
1119 ]
1120 self.write_file('\n'.join(script_content), count_script)
1122 result = subprocess.run(
1123 f'{self.cpptraj} -i {count_script}',
1124 shell=True, cwd=str(self.path),
1125 capture_output=True, text=True
1126 )
1128 # Parse frame count from cpptraj output
1129 for line in result.stdout.split('\n'):
1130 if 'frames' in line.lower() and 'total' in line.lower():
1131 parts = line.split()
1132 for i, part in enumerate(parts):
1133 if part.isdigit():
1134 self.total_frames = int(part)
1135 break
1137 # Fallback: count lines in trajectory (for ASCII formats) or use file size heuristic
1138 if not hasattr(self, 'total_frames'):
1139 # Try to infer from first trajectory file
1140 self.total_frames = self._estimate_frames()
1142 count_script.unlink(missing_ok=True)
1143 logger.debug(f'Total frames in trajectory: {self.total_frames}')
1145 def _estimate_frames(self) -> int:
1146 """Estimate frame count by running a quick cpptraj analysis."""
1147 script = self.path / 'estimate_frames.in'
1148 script_content = [
1149 f'parm {self.topologies[0]}',
1150 f'trajin {self.trajectories[0]}',
1151 'run',
1152 'quit'
1153 ]
1154 self.write_file('\n'.join(script_content), script)
1156 result = subprocess.run(
1157 f'{self.cpptraj} -i {script}',
1158 shell=True, cwd=str(self.path),
1159 capture_output=True, text=True
1160 )
1162 script.unlink(missing_ok=True)
1164 # Parse output for frame count
1165 for line in result.stdout.split('\n'):
1166 if 'frames' in line.lower():
1167 for word in line.split():
1168 if word.isdigit():
1169 return int(word)
1171 # Default fallback
1172 return 100
1174 def _split_trajectories(self) -> None:
1175 """Split trajectories into chunks for parallel processing."""
1176 actual_chunks = min(self.n_chunks, self.total_frames) # in case we have fewer frames than resources
1178 if actual_chunks < self.n_chunks:
1179 logger.warning(
1180 f'Requested {self.n_chunks} chunks but only {self.total_frames} frames.'
1181 f'Using {actual_chunks} chunks (some CPUs will be idle)'
1182 )
1184 frames_per_chunk = max(1, self.total_frames // actual_chunks)
1185 self.trajectory_chunks = {system: [] for system in ['complex', 'receptor', 'ligand']}
1187 for i, (top, traj, system) in enumerate(zip(
1188 self.topologies,
1189 self.trajectories,
1190 ['complex', 'receptor', 'ligand']
1191 )):
1192 for chunk_idx in range(actual_chunks):
1193 start_frame = chunk_idx * frames_per_chunk + 1 # cpptraj is 1-indexed
1195 if chunk_idx == actual_chunks - 1:
1196 # Last chunk gets remaining frames
1197 end_frame = self.total_frames
1198 else:
1199 end_frame = (chunk_idx + 1) * frames_per_chunk
1201 chunk_traj = self.path / f'{system}_chunk{chunk_idx}.crd'
1203 # Create chunk trajectory
1204 split_script = self.path / f'split_{system}_{chunk_idx}.in'
1205 script_content = [
1206 f'parm {top}',
1207 f'trajin {traj} {start_frame} {end_frame}',
1208 f'trajout {chunk_traj} crd',
1209 'run',
1210 'quit'
1211 ]
1212 self.write_file('\n'.join(script_content), split_script)
1214 subprocess.run(
1215 f'{self.cpptraj} -i {split_script}',
1216 shell=True, cwd=str(self.path),
1217 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
1218 )
1220 split_script.unlink()
1221 self.trajectory_chunks[system].append(chunk_traj)
1223 logger.debug(f'Split trajectories into {actual_chunks} chunks of ~{frames_per_chunk} frames each.')
1225 @property
1226 def files(self) -> tuple[list[str]]:
1227 """
1228 Returns a zip generator containing the output paths, topologies,
1229 trajectories and pdbs for each system. This is done to ensure we
1230 have the correct order for housekeeping reasons.
1232 Returns:
1233 (tuple[list[str]]): System order, topologies, trajectories and pdbs.
1234 """
1235 _order = [self.path / prefix for prefix in ['complex', 'receptor', 'ligand']]
1236 return zip(_order, self.topologies, self.trajectories, self.pdbs)
1238 @property
1239 def files_chunked(self) -> list[tuple]:
1240 """Returns file info with chunked trajectories for parallel processing."""
1241 result = []
1242 for system, top, pdb in zip(
1243 ['complex', 'receptor', 'ligand'],
1244 self.topologies,
1245 self.pdbs
1246 ):
1247 prefix = self.path / system
1248 traj_chunks = self.trajectory_chunks[system]
1249 result.append((prefix, top, traj_chunks, pdb))
1251 return result
1253 @staticmethod
1254 def write_file(lines: list[str],
1255 filepath: PathLike) -> None:
1256 """
1257 Given an input of either a list of strings or a single string,
1258 write input to file. If a list, join by newline characters.
1260 Arguments:
1261 lines (list[str]): Input to be written to file.
1262 filepath (PathLike): Path to the file to be written.
1264 Returns:
1265 None
1266 """
1267 if isinstance(lines, list):
1268 lines = '\n'.join(lines)
1269 with open(str(filepath), 'w') as f:
1270 f.write(lines)