Coverage for src / molecular_simulations / simulate / mmpbsa.py: 13%

543 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-12 10:07 -0600

1from concurrent.futures import as_completed, ThreadPoolExecutor, wait, ALL_COMPLETED 

2from dataclasses import dataclass 

3import json 

4import logging 

5import os 

6import pandas as pd 

7from pathlib import Path 

8import polars as pl 

9import re 

10import subprocess 

11import time 

12from typing import Literal, Optional, Union 

13 

14# This is simply to enable higher level parallelism by parsl/academy 

15# Numpy by default allows all threads to be used and in agentic settings 

16# we have seen oversubscription of threads and some calculations fail to 

17# to write out. These settings must be set BEFORE importing numpy 

18os.environ.setdefault('OPENBLAS_NUM_THREADS', '1') 

19os.environ.setdefault('MKL_NUM_THREADS', '1') 

20os.environ.setdefault('OMP_NUM_THREADS', '1') 

21import numpy as np 

22 

23PathLike = Union[Path, str] 

24 

25logger = logging.getLogger(__name__) 

26 

27def _run_energy_calculation(args: tuple[str], 

28 max_retries: int=3) -> tuple[Path, bool, str]: 

29 """Worker function for parallel energy calculations. 

30 Must be module level for ThreadPoolExecutor pickling. 

31 

32 Args: 

33 args (tuple): Tuple of (mmpbsa_binary, mdin_path, prmtop, pdb, traj_chunk, output_path, cwd) 

34 max_retries (int): Number of retry attempts for failed calculations 

35 

36 Returns: 

37 (Path): Path to the output file. 

38 """ 

39 mmpbsa_binary, mdin, prm, pdb, trj, out, cwd = args 

40 cmd = f'{mmpbsa_binary} -O -i {mdin} -p {prm} -c {pdb} -y {trj} -o {out}' 

41 

42 expected_output = Path(cwd) / out 

43 

44 for attempt in range(max_retries): 

45 try: 

46 result = subprocess.run(cmd, shell=True, cwd=str(cwd), 

47 capture_output=True, text=True) 

48 

49 if expected_output.exists() and expected_output.stat().st_size > 0: 

50 with open(expected_output, 'r') as f: 

51 content = f.read() 

52 if ' BOND' in content: 

53 return (out, True, '') 

54 else: 

55 error = 'Output file exists but contains no energy data' 

56 else: 

57 error = 'Output file missing or empty after subprocess complete' 

58 

59 if result.returncode != 0: 

60 error = f'Return code: {result.returncode}: {result.stderr or result.stdout}' 

61 

62 except subprocess.TimeoutExpired: 

63 error = 'Calculation timed out' 

64 except Exception as e: 

65 error = f'Exception: {e}' 

66 

67 if attempt < max_retries - 1: 

68 logger.warning(f'Energy calculation {out} failed (attempt {attempt + 1}/{max_retries})') 

69 logger.warning(f'Error: {error}') 

70 time.sleep(2 ** attempt) 

71 

72 return (out, False, error) 

73 

74def _run_sasa_calculation(args: tuple[str], 

75 max_retries: int=3) -> tuple[Path, bool, str]: 

76 """Worker function for parallel SASA calculations. 

77 Must be module level for ThreadPoolExecutor pickling.""" 

78 cpptraj_binary, sasa_script, cwd = args 

79 

80 # Parse expected output from script 

81 script_path = Path(sasa_script) 

82 with open(script_path, 'r') as f: 

83 script_content = f.read() 

84 

85 match = re.search(r'molsurf\s+.*?\s+out\s+(\S+)', script_content) 

86 if match: 

87 expected_output = Path(cwd) / match.group(1) 

88 else: 

89 expected_output = None 

90 

91 for attempt in range(max_retries): 

92 try: 

93 result = subprocess.run(f'{cpptraj_binary} -i {sasa_script}', shell=True, cwd=str(cwd), 

94 capture_output=True, text=True) 

95 if expected_output and expected_output.exists(): 

96 if expected_output.stat().st_size > 0: 

97 with open(expected_output, 'r') as f: 

98 lines = f.readlines() 

99 data_lines = [l for l in lines if l.strip() and not l.strip().startswith('#')] 

100 if len(data_lines) > 0: 

101 return (sasa_script, True, '') 

102 else: 

103 error = 'Output file has no data lines' 

104 else: 

105 error = 'Expected output file is empty' 

106 else: 

107 error = f'Expected output file not found: {expected_output}' 

108 

109 if result.returncode != 0: 

110 error = f'Return code {result.returncode}: {result.stderr or result.stdout}' 

111 

112 except subprocess.TimeoutExpired: 

113 error = 'Calculation timed out' 

114 except Exception as e: 

115 error = f'Exception: {e}' 

116 

117 if attempt < max_retries - 1: 

118 logger.warning(f'SASA calculation {sasa_script} failed (attempt {attempt+1}/{max_retries})') 

119 time.sleep(2 ** attempt) 

120 

121 return (script_path, False, error) 

122 

123@dataclass 

124class MMPBSA_settings: 

125 top: PathLike 

126 dcd: PathLike 

127 selections: list[str] 

128 first_frame: int = 0 

129 last_frame: int = -1 

130 stride: int = 1 

131 n_cpus: int = 1 

132 out: str = 'mmpbsa' 

133 solvent_probe: float = 1.4 

134 offset: int = 0 

135 gb_surften: float=0.0072 

136 gb_surfoff: float=0. 

137 

138class MMPBSA(MMPBSA_settings): 

139 """ 

140 This is an experiment in patience. What follows is a reconstruction of the various 

141 pieces of code that run MM-P(G)BSA from AMBER but written in a more digestible manner 

142 with actual documentation. Herein we have un-CLI'd what should never have been a 

143 CLI and piped together the correct pieces of the ambertools ecosystem to perform 

144 MM-P(G)BSA and that alone. Your trajectory is required to be concatenated into a single 

145 continuous trajectory - or you can run this serially over each by instancing this class 

146 for each trajectory you have. In this way we have also disentangled the requirement to 

147 parallelize by use of MPI, allowing the user to choose their own parallelization/scaling 

148 scheme. 

149 

150 Arguments: 

151 top (PathLike): Input topology for a solvated system. Should match the input trajectory. 

152 dcd (PathLike): Input trajectory. Can be DCD format or MDCRD already. 

153 selections (list[str]): A list of residue ID selections for the receptor and ligand 

154 in that order. Should be formatted for cpptraj (e.g. `:1-10`). 

155 first_frame (int): Defaults to 0. The first frame of the input trajectory to begin 

156 the calculations on. 

157 last_frame (int): Defaults to -1. Optional final frame to cut trajectory at. If -1, 

158 acts as a flag to run the whole trajectory. 

159 stride (int): Defaults to 1. The number of frames to stride the trajectory by. 

160 n_cpus (int): Number of parallel processes 

161 out (str): The prefix name or path for output files. 

162 solvent_probe (float): Defaults to 1.4Å. The probe radius to use for SA calculations. 

163 offset (int): Defaults to 0Å. I don't know what this does. 

164 gb_surften (float): Defaults to 0.0072. 

165 gb_surfoff (float): Defaults to 0.0. 

166 parallel_mode (str): 'frame' for frame-level parallelization (recommended), 

167 'system' for system-level parallelization, 

168 'hybrid' for both (most aggressive) 

169 """ 

170 def __init__(self, 

171 top: PathLike, 

172 dcd: PathLike, 

173 selections: list[str], 

174 first_frame: int=0, 

175 last_frame: int=-1, 

176 stride: int=1, 

177 n_cpus: int=1, 

178 out: str='mmpbsa', 

179 solvent_probe: float=1.4, 

180 offset: int=0, 

181 gb_surften: float=0.0072, 

182 gb_surfoff: float=0., 

183 amberhome: Optional[str]=None, 

184 parallel_mode: Literal['frame', 'serial'] = 'frame', 

185 **kwargs): 

186 super().__init__(top=top, 

187 dcd=dcd, 

188 selections=selections, 

189 first_frame=first_frame, 

190 last_frame=last_frame, 

191 stride=stride, 

192 n_cpus=n_cpus, 

193 out=out, 

194 solvent_probe=solvent_probe, 

195 offset=offset, 

196 gb_surften=gb_surften, 

197 gb_surfoff=gb_surfoff) 

198 self.parallel_mode = parallel_mode 

199 self.top = Path(self.top).resolve() 

200 self.traj = Path(self.dcd).resolve() 

201 self.path = self.top.parent 

202 if out == 'mmpbsa': 

203 self.path = self.path / 'mmpbsa' 

204 else: 

205 self.path = Path(out).resolve() 

206 

207 self.path.mkdir(exist_ok=True, parents=True) 

208 

209 self.cpptraj = 'cpptraj' 

210 self.mmpbsa_py_energy = 'mmpbsa_py_energy' 

211 if amberhome is None: # we are overriding AMBERHOME or using another env's install 

212 if 'AMBERHOME' in os.environ: 

213 amberhome = os.environ['AMBERHOME'] 

214 else: 

215 raise ValueError('AMBERHOME not set in env vars!') 

216 

217 self.cpptraj = Path(amberhome) / 'bin' / self.cpptraj 

218 self.mmpbsa_py_energy = Path(amberhome) / 'bin' / self.mmpbsa_py_energy 

219 

220 self.fh = FileHandler( 

221 top=self.top, 

222 traj=self.traj, 

223 path=self.path, 

224 sels=self.selections, 

225 first=self.first_frame, 

226 last=self.last_frame, 

227 stride=self.stride, 

228 cpptraj_binary=self.cpptraj, 

229 n_chunks=self.n_cpus 

230 ) 

231 

232 self.analyzer = OutputAnalyzer( 

233 path=self.path, 

234 surface_tension=self.gb_surften, 

235 sasa_offset=self.gb_surfoff 

236 ) 

237 

238 for key, value in kwargs.items(): 

239 setattr(self, key, value) 

240 

241 

242 def run(self) -> None: 

243 """ 

244 Main logic of MM-PBSA with parallelization. 

245  

246 Depending on parallel_mode: 

247 - 'frame': Splits trajectory into chunks, processes in parallel 

248 """ 

249 logger.debug(f'Preparing MM-PBSA calculation with {self.n_cpus} CPUs (mode: {self.parallel_mode})') 

250 gb_mdin, pb_mdin = self.write_mdins() 

251 

252 if self.parallel_mode == 'frame': 

253 self._run_frame_parallel(gb_mdin, pb_mdin) 

254 else: 

255 # Fallback to serial 

256 self._run_serial(gb_mdin, pb_mdin) 

257 

258 logger.debug('Collating results.') 

259 self.analyzer.parse_outputs() 

260 

261 self.free_energy = self.analyzer.free_energy 

262 

263 def _run_serial(self, gb_mdin: Path, pb_mdin: Path) -> None: 

264 """Original serial implementation.""" 

265 for (prefix, top, traj, pdb) in self.fh.files: 

266 logger.debug(f'Computing energy terms for {prefix.name}.') 

267 self.calculate_sasa(prefix, top, traj) 

268 self.calculate_energy(prefix, top, traj, pdb, gb_mdin, 'gb') 

269 self.calculate_energy(prefix, top, traj, pdb, pb_mdin, 'pb') 

270 

271 def _run_frame_parallel(self, gb_mdin: Path, pb_mdin: Path) -> None: 

272 """ 

273 Frame-level parallelization: split trajectory into chunks, process in parallel. 

274 This provides the best speedup for long trajectories. 

275 """ 

276 # Collect all calculation tasks 

277 energy_tasks = [] 

278 sasa_tasks = [] 

279 

280 for (prefix, top, traj_chunks, pdb) in self.fh.files_chunked: 

281 system_name = prefix.name 

282 logger.debug(f'Preparing parallel energy calculations for {system_name}.') 

283 

284 # SASA calculations for each chunk 

285 for i, traj_chunk in enumerate(traj_chunks): 

286 sasa_script = self._write_sasa_script(prefix, top, traj_chunk, chunk_idx=i) 

287 sasa_tasks.append((str(self.cpptraj), str(sasa_script), str(self.path))) 

288 

289 # Energy calculations for each chunk 

290 for i, traj_chunk in enumerate(traj_chunks): 

291 # GB calculation 

292 out_gb = f'{system_name}_chunk{i}_gb.mdout' 

293 energy_tasks.append(( 

294 str(self.mmpbsa_py_energy), str(gb_mdin), str(top), 

295 str(pdb), str(traj_chunk), out_gb, str(self.path) 

296 )) 

297 # PB calculation 

298 out_pb = f'{system_name}_chunk{i}_pb.mdout' 

299 energy_tasks.append(( 

300 str(self.mmpbsa_py_energy), str(pb_mdin), str(top), 

301 str(pdb), str(traj_chunk), out_pb, str(self.path) 

302 )) 

303 

304 # Run SASA calculations in parallel 

305 logger.debug(f'Running {len(sasa_tasks)} SASA calculations in parallel.') 

306 sasa_failures = [] 

307 with ThreadPoolExecutor(max_workers=self.n_cpus) as executor: 

308 futures = [] 

309 for task in sasa_tasks: 

310 futures.append(executor.submit(_run_sasa_calculation, task)) 

311 

312 logger.debug(f'Submitted {len(futures)} SASA futures, waiting for completion...') 

313 

314 # Wait for ALL to complete before proceeding 

315 done, _ = wait(futures, return_when=ALL_COMPLETED) 

316 

317 for future in done: 

318 script, success, error = future.result() 

319 if not success: 

320 sasa_failures.append((script, error)) 

321 logger.error(f'SASA calculation failed: {script}: {error[:300]}') 

322 

323 if sasa_failures: 

324 failed_scripts = [f[0] for f in sasa_failures] 

325 raise RuntimeError(f'{len(sasa_failures)} SASA calculations failed: {failed_scripts}') 

326 logger.debug('All SASA calculations completed successfully') 

327 

328 # Combine SASA results 

329 self._combine_sasa_chunks() 

330 

331 # Run Energy calculations in parallel 

332 logger.debug(f'Running {len(energy_tasks)} energy calculations in parallel.') 

333 energy_failures = [] 

334 with ThreadPoolExecutor(max_workers=self.n_cpus) as executor: 

335 futures = [] 

336 for task in energy_tasks: 

337 futures.append(executor.submit(_run_energy_calculation, task)) 

338 

339 logger.debug(f'Submitted {len(futures)} Energy futures, waiting for completion...') 

340 

341 # Wait for ALL to complete before proceeding 

342 done, _ = wait(futures, return_when=ALL_COMPLETED) 

343 

344 for future in done: 

345 script, success, error = future.result() 

346 if not success: 

347 energy_failures.append((script, error)) 

348 logger.error(f'Energy calculation failed: {script}: {error[:300]}') 

349 

350 if energy_failures: 

351 failed_scripts = [f[0] for f in energy_failures] 

352 raise RuntimeError(f'{len(energy_failures)} Energy calculations failed: {failed_scripts}') 

353 logger.debug('All Energy calculations completed successfully') 

354 

355 # Combine Energy results 

356 self._combine_energy_chunks() 

357 

358 self._verify_combined_outputs() 

359 

360 def _write_sasa_script(self, prefix: Path, prm: Path, trj: Path, 

361 chunk_idx: int = 0) -> Path: 

362 """Write a SASA calculation script for a trajectory chunk.""" 

363 sasa = self.path / f'sasa_{prefix.name}_chunk{chunk_idx}.in' 

364 out_file = f'{prefix.name}_chunk{chunk_idx}_surf.dat' 

365 sasa_in = [ 

366 f'parm {prm}', 

367 f'trajin {trj}', 

368 f'molsurf :* out {out_file} probe {self.solvent_probe} offset {self.offset}', 

369 'run', 

370 'quit' 

371 ] 

372 self.fh.write_file(sasa_in, sasa) 

373 return sasa 

374 

375 def _combine_sasa_chunks(self) -> None: 

376 """Combine SASA results from all chunks into single files (in correct frame order).""" 

377 def extract_chunk_idx(filepath: Path) -> int: 

378 """Extract chunk index from filename for proper numerical sorting.""" 

379 match = re.search(r'_chunk(\d+)_', filepath.name) 

380 return int(match.group(1)) if match else 0 

381 

382 for system in ['complex', 'receptor', 'ligand']: 

383 combined_data = [] 

384 chunk_files = list(self.path.glob(f'{system}_chunk*_surf.dat')) 

385 # Sort numerically by chunk index, not lexicographically 

386 chunk_files.sort(key=extract_chunk_idx) 

387 

388 for chunk_file in chunk_files: 

389 with open(chunk_file) as f: 

390 lines = f.readlines() 

391 if combined_data: 

392 # Skip header for subsequent chunks 

393 combined_data.extend(lines[1:]) 

394 else: 

395 combined_data.extend(lines) 

396 

397 # Write combined file 

398 output = self.path / f'{system}_surf.dat' 

399 with open(output, 'w') as f: 

400 f.writelines(combined_data) 

401 

402 # Clean up chunk files 

403 for chunk_file in chunk_files: 

404 chunk_file.unlink() 

405 

406 def _combine_energy_chunks(self) -> None: 

407 """Combine energy results from all chunks into single files (in correct frame order).""" 

408 def extract_chunk_idx(filepath: Path) -> int: 

409 """Extract chunk index from filename for proper numerical sorting.""" 

410 match = re.search(r'_chunk(\d+)_', filepath.name) 

411 return int(match.group(1)) if match else 0 

412 

413 for system in ['complex', 'receptor', 'ligand']: 

414 for level in ['gb', 'pb']: 

415 combined_data = [] 

416 chunk_files = list(self.path.glob(f'{system}_chunk*_{level}.mdout')) 

417 # Sort numerically by chunk index, not lexicographically 

418 chunk_files.sort(key=extract_chunk_idx) 

419 

420 for chunk_file in chunk_files: 

421 with open(chunk_file) as f: 

422 content = f.read() 

423 combined_data.append(content) 

424 

425 # Write combined file (mdout format allows concatenation of frame data) 

426 output = self.path / f'{system}_{level}.mdout' 

427 with open(output, 'w') as f: 

428 f.write('\n'.join(combined_data)) 

429 

430 # Clean up chunk files 

431 for chunk_file in chunk_files: 

432 chunk_file.unlink() 

433 

434 def calculate_sasa(self, 

435 pre: str, 

436 prm: PathLike, 

437 trj: PathLike) -> None: 

438 """ 

439 Runs the molsurf command in cpptraj to compute the SASA of a given system. 

440 

441 Arguments: 

442 pre (str): Prefix for output SASA file. 

443 prm (PathLike): Path to prmtop file. 

444 trj (PathLike): Path to CRD trajectory file. 

445 

446 Returns: 

447 None 

448 """ 

449 sasa = self.fh.path / 'sasa.in' 

450 sasa_in = [ 

451 f'parm {prm}', 

452 f'trajin {trj}', 

453 f'molsurf :* out {pre}_surf.dat probe {self.solvent_probe} offset {self.offset}', 

454 'run', 

455 'quit' 

456 ] 

457 

458 self.fh.write_file(sasa_in, sasa) 

459 

460 subprocess.run(f'{self.cpptraj} -i {sasa}', shell=True, cwd=str(self.path), 

461 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 

462 sasa.unlink() 

463 

464 def calculate_energy(self, 

465 pre: str, 

466 prm: PathLike, 

467 trj: PathLike, 

468 pdb: PathLike, 

469 mdin: PathLike, 

470 suf: str) -> None: 

471 """ 

472 Runs mmpbsa_py_energy, an undocumented binary file which somehow mysteriously  

473 computes the energy of a system. This software is not only undocumented but is 

474 a binary which we cannot inspect ourselves. 

475  

476 Arguments: 

477 pre (str): Prefix for output file. 

478 prm (PathLike): Path to prmtop file. 

479 trj (PathLike): Path to CRD trajectory file. 

480 pdb (PathLike): Path to PDB file. 

481 mdin (PathLike): Configuration file for the program. 

482 suf (str): Suffix for output file. 

483 

484 Returns: 

485 None 

486 """ 

487 cmd = f'{self.mmpbsa_py_energy} -O -i {mdin} -p {prm} -c {pdb} -y {trj} -o {pre}_{suf}.mdout' 

488 subprocess.run(cmd, shell=True, cwd=str(self.path), 

489 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 

490 

491 def write_mdins(self) -> tuple[Path, Path]: 

492 """ 

493 Writes out the configuration files that are to be fed to mmpbsa_py_energy. 

494 These are also undocumented and I took the parameters from the location 

495 in which they are hardcoded in ambertools. 

496 

497 Returns: 

498 (tuple[Path, Path]): Tuple of paths to the GB mdin and the PB mdin. 

499 """ 

500 gb = self.fh.path / 'gb_mdin' 

501 gb_mdin = [ 

502 'GB', 

503 'igb = 2', 

504 'extdiel = 78.3', 

505 'saltcon = 0.10', 

506 f'surften = {self.gb_surften}', 

507 'rgbmax = 25.0' 

508 ] 

509 

510 self.fh.write_file(gb_mdin, gb) 

511 

512 pb = self.fh.path / 'pb_mdin' 

513 pb_mdin = [ 

514 'PB', 

515 'inp = 2', 

516 'smoothopt = 1', 

517 'radiopt = 0', 

518 'npbopt = 0', 

519 'solvopt = 1', 

520 'maxitn = 1000', 

521 'nfocus = 2', 

522 'bcopt = 5', 

523 'eneopt = 2', 

524 'fscale = 8', 

525 'epsin = 1.0', 

526 'epsout = 80.0', 

527 'istrng = 0.10', 

528 'dprob = 1.4', 

529 'iprob = 2.0', 

530 'accept = 0.001', 

531 'fillratio = 4.0', 

532 'space = 0.5', 

533 'cutnb = 0', 

534 'sprob = 0.557', 

535 'cavity_surften = 0.0378', 

536 'cavity_offset = -0.5692' 

537 ] 

538 

539 self.fh.write_file(pb_mdin, pb) 

540 

541 return gb, pb 

542 

543 def _verify_combined_outputs(self) -> None: 

544 """Verify all required output files exist and have valid content. 

545 Raises RuntimeError if any files are missing or invalid""" 

546 missing_files = [] 

547 empty_files = [] 

548 invalid_files = [] 

549 

550 for system in ['complex', 'receptor', 'ligand']: 

551 # Check SASA file 

552 sasa_file = self.path / f'{system}_surf.dat' 

553 if not sasa_file.exists(): 

554 missing_files.append(str(sasa_file)) 

555 elif sasa_file.stat().st_size == 0: 

556 empty_files.append(str(sasa_file)) 

557 else: 

558 # Verify it has data lines 

559 with open(sasa_file) as f: 

560 data_lines = [l for l in f if l.strip() and not l.strip().startswith('#')] 

561 if len(data_lines) == 0: 

562 invalid_files.append(f'{sasa_file} (no data lines)') 

563 

564 # Check energy files 

565 for level in ['gb', 'pb']: 

566 energy_file = self.path / f'{system}_{level}.mdout' 

567 if not energy_file.exists(): 

568 missing_files.append(str(energy_file)) 

569 elif energy_file.stat().st_size == 0: 

570 empty_files.append(str(energy_file)) 

571 else: 

572 # Verify it has BOND lines (energy data) 

573 with open(energy_file) as f: 

574 content = f.read() 

575 if ' BOND' not in content: 

576 invalid_files.append(f'{energy_file} (no energy data)') 

577 

578 errors = [] 

579 if missing_files: 

580 errors.append(f"Missing files: {missing_files}") 

581 if empty_files: 

582 errors.append(f"Empty files: {empty_files}") 

583 if invalid_files: 

584 errors.append(f"Invalid files: {invalid_files}") 

585 

586 if errors: 

587 raise RuntimeError(f"Output verification failed: {'; '.join(errors)}") 

588 

589 

590class OutputAnalyzer: 

591 """ 

592 Analyzes the outputs from an MM-PBSA run. Stores data in a Polars dataframe 

593 internally, and writes out data in the form of json/plain text. 

594 """ 

595 def __init__(self, 

596 path: PathLike, 

597 surface_tension: float=0.0072, 

598 sasa_offset: float=0., 

599 _tolerance: float = 0.005, 

600 log: bool=True): 

601 self.path = Path(path) 

602 self.surften = surface_tension 

603 self.offset = sasa_offset 

604 self.tolerance = _tolerance 

605 self.log = log 

606 

607 self.free_energy = None 

608 

609 self.systems = ['receptor', 'ligand', 'complex'] 

610 self.levels = ['gb', 'pb'] 

611 

612 self.solvent_contributions = ['EGB', 'ESURF', 'EPB', 'ECAVITY'] 

613 

614 def parse_outputs(self) -> None: 

615 """ 

616 Parse all the output files. 

617 

618 Returns: 

619 None 

620 """ 

621 self.gb = pl.DataFrame() 

622 self.pb = pl.DataFrame() 

623 

624 for system in self.systems: 

625 E_sasa = self.read_sasa(self.path / f'{system}_surf.dat') 

626 E_gb = self.read_GB(self.path / f'{system}_gb.mdout', system) 

627 E_pb = self.read_PB(self.path / f'{system}_pb.mdout', system) 

628 

629 E_gb = E_gb.drop('ESURF').with_columns(E_sasa) 

630 

631 self.gb = pl.concat([self.gb, E_gb], how='vertical') 

632 self.pb = pl.concat([self.pb, E_pb], how='vertical') 

633 

634 all_cols = list(set(self.gb.columns + self.pb.columns)) 

635 self.contributions = { 

636 'G gas': [col for col in all_cols 

637 if col not in self.solvent_contributions], 

638 'G solv': [col for col in all_cols 

639 if col in self.solvent_contributions] 

640 } 

641 

642 self.check_bonded_terms() 

643 self.generate_summary() 

644 self.compute_dG() 

645 

646 def read_sasa(self, 

647 _file: PathLike) -> np.ndarray: 

648 """ 

649 Reads in the results of the cpptraj SASA calculation and returns the 

650 per-frame SASA scaled by a hardcoded value for surface tension that is 

651 a mostly undocumented heuristic. 

652 

653 Arguments: 

654 _file (PathLike): Path to a file containing the SASA data. 

655 

656 Returns: 

657 (np.ndarray): A numpy array of the per-frame rescaled SASA energies. 

658 """ 

659 df = pd.read_csv(_file, sep='\s+') # read in dataframe 

660 sasa = df.iloc[:, -1].to_numpy(dtype=float) * self.surften + self.offset 

661 

662 return pl.Series('ESURF', sasa) 

663 

664 def read_GB(self, 

665 _file: PathLike, 

666 system: str) -> pl.DataFrame: 

667 """ 

668 Read in the GB mdout files and returns a Polars dataframe of the values 

669 for each term for every frame. Also adds a `system` label to more easily 

670 compute summary statistics later. 

671 

672 Arguments: 

673 _file (PathLike): Energy data file path. 

674 system (str): String label for which system we are processing (e.g. complex). 

675 

676 Returns: 

677 (pl.DataFrame): Polars dataframe containing the parsed energy data. 

678 """ 

679 gb_terms = ['BOND', 'ANGLE', 'DIHED', 'VDWAALS', 'EEL', 

680 'EGB', '1-4 VDW', '1-4 EEL', 'RESTRAINT', 'ESURF'] 

681 data = {gb_term: [] for gb_term in gb_terms} 

682 

683 lines = open(_file, 'r').readlines() 

684 

685 return self.parse_energy_file(lines, data, system) 

686 

687 def read_PB(self, 

688 _file: PathLike, 

689 system: str) -> pl.DataFrame: 

690 """ 

691 Read in the PB mdout files and returns a Polars dataframe of the values 

692 for each term for every frame. Also adds a `system` label to more easily 

693 compute summary statistics later. 

694 

695 Arguments: 

696 _file (PathLike): Energy data file path. 

697 system (str): String label for which system we are processing (e.g. complex). 

698 

699 Returns: 

700 (pl.DataFrame): Polars dataframe containing the parsed energy data. 

701 """ 

702 pb_terms = ['BOND', 'ANGLE', 'DIHED', 'VDWAALS', 'EEL', 

703 'EPB', '1-4 VDW', '1-4 EEL', 'RESTRAINT', 

704 'ECAVITY', 'EDISPER'] 

705 data = {pb_term: [] for pb_term in pb_terms} 

706 

707 lines = open(_file, 'r').readlines() 

708 

709 return self.parse_energy_file(lines, data, system) 

710 

711 def parse_energy_file(self, file_contents: list[str], 

712 data: dict[str, list], system: str) -> pl.DataFrame: 

713 """Parse energy file contents.""" 

714 for line in file_contents: 

715 if '=' in line and any(key in line for key in data.keys()): 

716 parsed = self.parse_line(line) 

717 for key, val in parsed: 

718 if key in data: 

719 data[key].append(val) 

720 

721 df = pl.DataFrame(data) 

722 df = df.with_columns(pl.lit(system).alias('system')) 

723 return df 

724 

725 def parse_energy_file_OLD(self, 

726 file_contents: list[str], 

727 data: dict[str, list], 

728 system: str) -> pl.DataFrame: 

729 """ 

730 Parses the contents of an energy calculation using a dictionary of 

731 energy terms to extract theory-level observables (e.g. EGB vs EPB). 

732 

733 Arguments: 

734 file_contents (list[str]): A list of each line from an energy calculation. 

735 data (dict[str, list]): The relevant energy terms to be scraped from input. 

736 system (str): The name of the system which will be included as an additional 

737 kv pair in the returned dataframe. This ensures we can track which portion 

738 of the calculation we are accounting for (e.g. complex, receptor, ligand). 

739 Returns: 

740 (pl.DataFrame): A Polars dataframe of shape (n_frames, n_calculations + system). 

741 """ 

742 idx = 0 

743 n_frames = 0 

744 while idx < len(file_contents): 

745 if file_contents[idx].startswith(' BOND'): 

746 for _ in range(4): # number of lines to read. DO NOT CHANGE!!! 

747 line = file_contents[idx] 

748 parsed = self.parse_line(line) 

749 for key, val in parsed: 

750 data[key].append(val) 

751 

752 idx += 1 

753 

754 if 'Processing frame' in file_contents[idx]: 

755 n_frames = int(file_contents[idx].strip().split()[-1]) 

756 

757 idx +=1 

758 

759 data['system'] = [system] * n_frames 

760 

761 return pl.DataFrame( 

762 {key: np.array(val) for key, val in data.items()} 

763 ) 

764 

765 def check_bonded_terms(self) -> None: 

766 """ 

767 Performs a sanity check on the bonded terms which should perfectly cancel out 

768 (e.g. complex = receptor + ligand). If this is not the case something horrible 

769 has happened and we can't trust the non-bonded energies either. Additionally 

770 sets a few terms we will need later such as the number of frames as given by 

771 the dataframe height and sqrt(n_frames). 

772 

773 Returns: 

774 None 

775 """ 

776 bonded = ['BOND', 'ANGLE', 'DIHED', '1-4 VDW', '1-4 EEL'] 

777 

778 for theory_level in (self.gb, self.pb): 

779 a = theory_level.filter(pl.col('system') == 'receptor') 

780 b = theory_level.filter(pl.col('system') == 'ligand') 

781 c = theory_level.filter(pl.col('system') == 'complex') 

782 

783 a = a.select(pl.col([col for col in a.columns if col in bonded])).to_numpy() 

784 b = b.select(pl.col([col for col in b.columns if col in bonded])).to_numpy() 

785 c = c.select(pl.col([col for col in c.columns if col in bonded])).to_numpy() 

786 

787 diffs = np.array(c - b - a) 

788 if np.where(diffs >= self.tolerance)[0].size > 0: 

789 raise ValueError('Bonded terms for receptor + ligand != complex!') 

790 

791 remove = ['RESTRAINT', 'EDISPER'] 

792 self.gb = self.gb.select( 

793 pl.col([col for col in self.gb.columns if col not in remove]) 

794 ) 

795 self.pb = self.pb.select( 

796 pl.col([col for col in self.pb.columns if col not in remove]) 

797 ) 

798 

799 self.n_frames = self.gb.height 

800 self.square_root_N = np.sqrt(self.n_frames) 

801 

802 def generate_summary(self) -> None: 

803 """ 

804 Summarizes all processed energy data into a single polars dataframe 

805 and dumps it to a json file. 

806 

807 Returns: 

808 None 

809 """ 

810 full_statistics = {sys: {} for sys in self.systems} 

811 for theory, level in zip([self.gb, self.pb], self.levels): 

812 for system in self.systems: 

813 sys = theory.filter(pl.col('system') == system).drop('system') 

814 

815 stats = {} 

816 for col in sys.columns: 

817 mean = sys.select(pl.mean(col)).item() 

818 stdev = sys.select(pl.std(col)).item() 

819 

820 stats[col] = {'mean': mean, 

821 'std': stdev, 

822 'err': stdev / self.square_root_N} 

823 

824 for energy, contributors in self.contributions.items(): 

825 pooled_data = sys.select( 

826 pl.col([col for col in sys.columns if col in contributors]) 

827 ).to_numpy().flatten() 

828 

829 stats[energy] = {'mean': np.mean(pooled_data), 

830 'std': np.std(pooled_data), 

831 'err': np.std(pooled_data) / self.square_root_N} 

832 

833 total_data = sys.to_numpy().flatten() 

834 stats['total'] = {'mean': np.mean(total_data), 

835 'std': np.std(total_data), 

836 'err': np.std(total_data) / self.square_root_N} 

837 

838 full_statistics[system][level] = stats 

839 

840 with open('statistics.json', 'w') as fout: 

841 json.dump(full_statistics, fout, indent=4) 

842 

843 def compute_dG(self) -> None: 

844 """ 

845 For each energy dataframe (GB/PB) compute the ∆G of binding by subtracting out 

846 relevant contributions in accordance with how this is done under the hood of the 

847 MMPBSA code. 

848 

849 Returns: 

850 None 

851 """ 

852 differences = [] 

853 for theory, level in zip([self.gb, self.pb], self.levels): 

854 diff_cols = [col for col in theory.columns if col != 'system'] 

855 diff_arr = theory.filter(pl.col('system') == 'complex').drop('system').to_numpy() 

856 for system in self.systems[:2]: 

857 diff_arr -= theory.filter(pl.col('system') == system).drop('system').to_numpy() 

858 

859 means = np.mean(diff_arr, axis=0) 

860 stds = np.std(diff_arr, axis=0) 

861 errs = stds / self.square_root_N 

862 

863 gas_solv_phase = [] 

864 for energy, contributors in self.contributions.items(): 

865 indices = [i for i, diff_col in enumerate(diff_cols) 

866 if diff_col in contributors] 

867 contribution = np.sum(diff_arr[:, indices], axis=1) 

868 gas_solv_phase.append(contribution) 

869 

870 diff_cols.append(energy) 

871 means = np.concatenate((means, [np.mean(contribution)])) 

872 stds = np.concatenate((stds, [np.std(contribution)])) 

873 errs = np.concatenate((errs, [np.std(contribution) / self.square_root_N])) 

874 

875 diff_cols.append('∆G Binding') 

876 total = np.sum(np.vstack(gas_solv_phase), axis=0) 

877 

878 means = np.concatenate((means, [np.mean(total)])) 

879 stds = np.concatenate((stds, [np.std(total)])) 

880 errs = np.concatenate((errs, [np.std(total) / self.square_root_N])) 

881 

882 data = np.vstack((means, stds, errs)) 

883 

884 differences.append(pl.DataFrame( 

885 {diff_cols[i]: data[:,i] for i in range(len(diff_cols))} 

886 )) 

887 

888 self.pretty_print(differences) 

889 

890 def pretty_print(self, 

891 dfs: list[pl.DataFrame]) -> None: 

892 """ 

893 Ingests a list of Polars dataframes for GB and PB and prints their contents 

894 in a human-readable form to STDIN. Also saves out the energies to a plain 

895 text file called `deltaG.txt`. 

896 

897 Arguments: 

898 dfs (list[pl.DataFrame]): List of dataframes for GB and PB. 

899 

900 Returns: 

901 None 

902 """ 

903 print_statement = [] 

904 log_statement = [] 

905 for df, level in zip(dfs, ['Generalized Born ', 'Poisson Boltzmann']): 

906 print_statement += [ 

907 f'{" ":<20}=========================', 

908 f'{" ":<20}=== {level} ===', 

909 f'{" ":<20}=========================', 

910 'Energy Component Average Std. Dev. Std. Err. of Mean', 

911 '---------------------------------------------------------------------' 

912 ] 

913 for col in df.columns: 

914 mean, std, err = [x.item() for x in df.select(pl.col(col)).to_numpy()] 

915 report = f'{col:<20}{mean:<16.3f}{std:<16.3f}{err:<16.3f}' 

916 if abs(mean) <= self.tolerance: 

917 continue 

918 

919 if col in ['G gas', '∆G Binding']: 

920 print_statement.append('') 

921 

922 if col == '∆G Binding': 

923 log_statement.append(f'{level.strip()}:') 

924 log_statement.append(report) 

925 

926 if level == 'Poisson Boltzmann': 

927 self.free_energy = [mean, std] 

928 

929 print_statement.append(report) 

930 

931 print_statement = '\n'.join(print_statement) 

932 with open(self.path / 'deltaG.txt', 'w') as fout: 

933 fout.write(print_statement) 

934 

935 if self.log: 

936 for statement in log_statement: 

937 logging.info(statement) 

938 else: 

939 print(print_statement) 

940 

941 @staticmethod 

942 def parse_line(line) -> tuple[list[str], list[float]]: 

943 """ 

944 Parses a line from mmpbsa_energy to get the various energy terms and values. 

945 

946 Returns: 

947 (tuple[list[str], list[float]]): A tuple containing the list of energy 

948 term names and corresponding energy values. 

949 """ 

950 eq_split = line.split('=') 

951 

952 if len(eq_split) == 2: 

953 splits = [eq_spl.strip() for eq_spl in eq_split] 

954 else: 

955 splits = [eq_split[0].strip()] 

956 

957 for i in range(1, len(eq_split) - 1): 

958 splits += [spl.strip() for spl in eq_split[i].strip().split(' ')] 

959 

960 splits += [eq_split[-1].strip()] 

961 

962 keys = splits[::2] 

963 vals = np.array(splits[1::2], dtype=float) 

964 

965 return zip(keys, vals) 

966 

967 

968class FileHandler: 

969 """ 

970 Performs preprocessing for MM-PBSA runs and manages the pathing to all file 

971 inputs. Additionally used to write out various cpptraj input files by the 

972 MMPBSA class. 

973 """ 

974 def __init__(self, 

975 top: Path, 

976 traj: Path, 

977 path: Path, 

978 sels: list[str], 

979 first: int, 

980 last: int, 

981 stride: int, 

982 cpptraj_binary: PathLike, 

983 n_chunks: int=1): 

984 self.top = top 

985 self.traj = traj 

986 self.path = path 

987 self.selections = sels 

988 self.ff = first 

989 self.lf = last 

990 self.stride = stride 

991 self.cpptraj = cpptraj_binary 

992 self.n_chunks = n_chunks 

993 

994 self.prepare_topologies() 

995 self.prepare_trajectories() 

996 

997 self.trajectory_chunks = {} 

998 

999 if n_chunks > 1: 

1000 self._count_frames() 

1001 self._split_trajectories() 

1002 else: 

1003 for system, traj in zip(['complex', 'receptor', 'ligand'], self.trajectories): 

1004 self.trajectory_chunks[system] = [traj] 

1005 

1006 def prepare_topologies(self) -> None: 

1007 """ 

1008 Slices out each sub-topology for the desolvated complex, receptor and 

1009 ligand using cpptraj due to the difficulty of working with AMBER FF 

1010 files otherwise (including PARMED). 

1011 

1012 Returns: 

1013 None 

1014 """ 

1015 self.topologies = [ 

1016 self.path / 'complex.prmtop', 

1017 self.path / 'receptor.prmtop', 

1018 self.path / 'ligand.prmtop' 

1019 ] 

1020 

1021 cpptraj_in = [ 

1022 f'parm {self.top}', 

1023 'parmstrip :Na+,Cl-,WAT', 

1024 'parmbox nobox', 

1025 f'parmwrite out {self.topologies[0]}', 

1026 'run', 

1027 'clear all', 

1028 f'parm {self.topologies[0]}', 

1029 f'parmstrip {self.selections[0]}', 

1030 f'parmwrite out {self.topologies[1]}', 

1031 'run', 

1032 'clear all', 

1033 f'parm {self.topologies[0]}', 

1034 f'parmstrip {self.selections[1]}', 

1035 f'parmwrite out {self.topologies[2]}', 

1036 'run', 

1037 'quit' 

1038 ] 

1039 

1040 script = self.path / 'cpptraj.in' 

1041 self.write_file('\n'.join(cpptraj_in), script) 

1042 subprocess.call(f'{self.cpptraj} -i {script}', shell=True, cwd=str(self.path), 

1043 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 

1044 script.unlink() 

1045 

1046 def prepare_trajectories(self) -> None: 

1047 """ 

1048 Converts DCD trajectory to AMBER CRD format which is explicitly 

1049 required by MM-G(P)BSA. 

1050 

1051 Returns: 

1052 None 

1053 """ 

1054 self.trajectories = [path.with_suffix('.crd') for path in self.topologies] 

1055 self.pdbs = [path.with_suffix('.pdb') for path in self.topologies] 

1056 

1057 frame_control = f'start {self.ff}' 

1058 

1059 if self.lf > -1: 

1060 frame_control += f' stop {self.lf}' 

1061 

1062 frame_control += f' offset {self.stride}' 

1063 

1064 cpptraj_in = [ 

1065 f'parm {self.top}', 

1066 f'trajin {self.traj}', 

1067 f'trajout {self.traj.with_suffix(".crd")} crd {frame_control}', 

1068 'run', 

1069 'clear all', 

1070 ] 

1071 

1072 self.traj = self.traj.with_suffix('.crd') 

1073 

1074 cpptraj_in += [ 

1075 f'parm {self.top}', 

1076 f'trajin {self.traj}', 

1077 'strip :WAT,Na+,Cl*', 

1078 'autoimage', 

1079 f'rmsd !(:WAT,Cl*,CIO,Cs+,IB,K*,Li+,MG*,Na+,Rb+,CS,RB,NA,F,CL) mass first', 

1080 f'trajout {self.trajectories[0]} crd nobox', 

1081 f'trajout {self.pdbs[0]} pdb onlyframes 1', 

1082 'run', 

1083 'clear all', 

1084 f'parm {self.topologies[0]}', 

1085 f'trajin {self.trajectories[0]}', 

1086 f'strip {self.selections[0]}', 

1087 f'trajout {self.trajectories[1]} crd', 

1088 f'trajout {self.pdbs[1]} pdb onlyframes 1', 

1089 'run', 

1090 'clear all', 

1091 f'parm {self.topologies[0]}', 

1092 f'trajin {self.trajectories[0]}', 

1093 f'strip {self.selections[1]}', 

1094 f'trajout {self.trajectories[2]} crd', 

1095 f'trajout {self.pdbs[2]} pdb onlyframes 1', 

1096 'run', 

1097 'quit' 

1098 ] 

1099 

1100 name = self.path / 'mdcrd.in' 

1101 self.write_file('\n'.join(cpptraj_in), name) 

1102 subprocess.call(f'{self.cpptraj} -i {name}', shell=True, cwd=str(self.path), 

1103 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 

1104 

1105 name.unlink() 

1106 

1107 def _count_frames(self) -> None: 

1108 """Count the total number of frames in the processed trajectory.""" 

1109 # Use cpptraj to count frames 

1110 count_script = self.path / 'count_frames.in' 

1111 count_out = self.path / 'frame_count.dat' 

1112 

1113 script_content = [ 

1114 f'parm {self.topologies[0]}', 

1115 f'trajin {self.trajectories[0]}', 

1116 f'trajinfo {self.trajectories[0]} name myinfo', 

1117 'run', 

1118 'quit' 

1119 ] 

1120 self.write_file('\n'.join(script_content), count_script) 

1121 

1122 result = subprocess.run( 

1123 f'{self.cpptraj} -i {count_script}', 

1124 shell=True, cwd=str(self.path), 

1125 capture_output=True, text=True 

1126 ) 

1127 

1128 # Parse frame count from cpptraj output 

1129 for line in result.stdout.split('\n'): 

1130 if 'frames' in line.lower() and 'total' in line.lower(): 

1131 parts = line.split() 

1132 for i, part in enumerate(parts): 

1133 if part.isdigit(): 

1134 self.total_frames = int(part) 

1135 break 

1136 

1137 # Fallback: count lines in trajectory (for ASCII formats) or use file size heuristic 

1138 if not hasattr(self, 'total_frames'): 

1139 # Try to infer from first trajectory file 

1140 self.total_frames = self._estimate_frames() 

1141 

1142 count_script.unlink(missing_ok=True) 

1143 logger.debug(f'Total frames in trajectory: {self.total_frames}') 

1144 

1145 def _estimate_frames(self) -> int: 

1146 """Estimate frame count by running a quick cpptraj analysis.""" 

1147 script = self.path / 'estimate_frames.in' 

1148 script_content = [ 

1149 f'parm {self.topologies[0]}', 

1150 f'trajin {self.trajectories[0]}', 

1151 'run', 

1152 'quit' 

1153 ] 

1154 self.write_file('\n'.join(script_content), script) 

1155 

1156 result = subprocess.run( 

1157 f'{self.cpptraj} -i {script}', 

1158 shell=True, cwd=str(self.path), 

1159 capture_output=True, text=True 

1160 ) 

1161 

1162 script.unlink(missing_ok=True) 

1163 

1164 # Parse output for frame count 

1165 for line in result.stdout.split('\n'): 

1166 if 'frames' in line.lower(): 

1167 for word in line.split(): 

1168 if word.isdigit(): 

1169 return int(word) 

1170 

1171 # Default fallback 

1172 return 100 

1173 

1174 def _split_trajectories(self) -> None: 

1175 """Split trajectories into chunks for parallel processing.""" 

1176 actual_chunks = min(self.n_chunks, self.total_frames) # in case we have fewer frames than resources 

1177 

1178 if actual_chunks < self.n_chunks: 

1179 logger.warning( 

1180 f'Requested {self.n_chunks} chunks but only {self.total_frames} frames.' 

1181 f'Using {actual_chunks} chunks (some CPUs will be idle)' 

1182 ) 

1183 

1184 frames_per_chunk = max(1, self.total_frames // actual_chunks) 

1185 self.trajectory_chunks = {system: [] for system in ['complex', 'receptor', 'ligand']} 

1186 

1187 for i, (top, traj, system) in enumerate(zip( 

1188 self.topologies, 

1189 self.trajectories, 

1190 ['complex', 'receptor', 'ligand'] 

1191 )): 

1192 for chunk_idx in range(actual_chunks): 

1193 start_frame = chunk_idx * frames_per_chunk + 1 # cpptraj is 1-indexed 

1194 

1195 if chunk_idx == actual_chunks - 1: 

1196 # Last chunk gets remaining frames 

1197 end_frame = self.total_frames 

1198 else: 

1199 end_frame = (chunk_idx + 1) * frames_per_chunk 

1200 

1201 chunk_traj = self.path / f'{system}_chunk{chunk_idx}.crd' 

1202 

1203 # Create chunk trajectory 

1204 split_script = self.path / f'split_{system}_{chunk_idx}.in' 

1205 script_content = [ 

1206 f'parm {top}', 

1207 f'trajin {traj} {start_frame} {end_frame}', 

1208 f'trajout {chunk_traj} crd', 

1209 'run', 

1210 'quit' 

1211 ] 

1212 self.write_file('\n'.join(script_content), split_script) 

1213 

1214 subprocess.run( 

1215 f'{self.cpptraj} -i {split_script}', 

1216 shell=True, cwd=str(self.path), 

1217 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL 

1218 ) 

1219 

1220 split_script.unlink() 

1221 self.trajectory_chunks[system].append(chunk_traj) 

1222 

1223 logger.debug(f'Split trajectories into {actual_chunks} chunks of ~{frames_per_chunk} frames each.') 

1224 

1225 @property 

1226 def files(self) -> tuple[list[str]]: 

1227 """ 

1228 Returns a zip generator containing the output paths, topologies, 

1229 trajectories and pdbs for each system. This is done to ensure we 

1230 have the correct order for housekeeping reasons. 

1231 

1232 Returns: 

1233 (tuple[list[str]]): System order, topologies, trajectories and pdbs. 

1234 """ 

1235 _order = [self.path / prefix for prefix in ['complex', 'receptor', 'ligand']] 

1236 return zip(_order, self.topologies, self.trajectories, self.pdbs) 

1237 

1238 @property 

1239 def files_chunked(self) -> list[tuple]: 

1240 """Returns file info with chunked trajectories for parallel processing.""" 

1241 result = [] 

1242 for system, top, pdb in zip( 

1243 ['complex', 'receptor', 'ligand'], 

1244 self.topologies, 

1245 self.pdbs 

1246 ): 

1247 prefix = self.path / system 

1248 traj_chunks = self.trajectory_chunks[system] 

1249 result.append((prefix, top, traj_chunks, pdb)) 

1250 

1251 return result 

1252 

1253 @staticmethod 

1254 def write_file(lines: list[str], 

1255 filepath: PathLike) -> None: 

1256 """ 

1257 Given an input of either a list of strings or a single string, 

1258 write input to file. If a list, join by newline characters. 

1259 

1260 Arguments: 

1261 lines (list[str]): Input to be written to file. 

1262 filepath (PathLike): Path to the file to be written. 

1263 

1264 Returns: 

1265 None 

1266 """ 

1267 if isinstance(lines, list): 

1268 lines = '\n'.join(lines) 

1269 with open(str(filepath), 'w') as f: 

1270 f.write(lines)