Coverage for src / molecular_simulations / simulate / multires_simulator.py: 64%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-13 01:26 -0600

1from ..build.build_calvados import CGBuilder 

2from ..build import ImplicitSolvent, ExplicitSolvent 

3from calvados import sim 

4from .omm_simulator import ImplicitSimulator, Simulator 

5from cg2all.script.convert_cg2all import main as convert 

6import openmm 

7from openmm.app import * 

8from openmm.unit import * 

9import subprocess 

10import tempfile 

11import parmed as pmd 

12import pip._vendor.tomli as tomllib # for 3.10 

13from pathlib import Path 

14from dataclasses import dataclass 

15import os 

16from typing import Union, Type, TypeVar 

17 

18_T = TypeVar('_T') 

19OptPath = Union[Path, str, None] 

20PathLike = Union[Path, str] 

21 

22@dataclass 

23class sander_min_defaults: 

24 """ 

25 Dataclass with default values for sander minimization. 

26 Creates the contents of a sander input file during init 

27 """ 

28 imin=1 # Perform energy minimization 

29 maxcyc=5000 # Maximum number of minimization cycles 

30 ncyc=2500 # Switch from steepest descent to conjugate gradient after this many steps 

31 ntb=0 # Periodic boundary conditions (constant volume) 

32 ntr=0 # No restraints 

33 cut=10.0 # Nonbonded cutoff in Angstroms 

34 ntpr=10000 # Print energy every 10000 steps (don't print it) 

35 ntwr=5000 # Write restart file every 5000 steps (only once) 

36 ntxo=1 # Output restart file format (ASCII) 

37 

38 def __init__(self): 

39 self.mdin_contents = f"""Minimization input 

40 &cntrl 

41 imin={self.imin}, 

42 maxcyc={self.maxcyc}, 

43 ncyc={self.ncyc}, 

44 ntb={self.ntb}, 

45 ntr={self.ntr}, 

46 cut={self.cut:.1f}, 

47 ntpr={self.ntpr}, 

48 ntwr={self.ntwr}, 

49 ntxo={self.ntxo}  

50 / 

51 """ 

52 

53def sander_minimize(path: Path, 

54 inpcrd_file: str, 

55 prmtop_file: str, 

56 sander_cmd: str) -> None: 

57 """ 

58 Minimize MD system with sander and output new inpcrd file. 

59  

60 Arguments: 

61 path (Path): Path to directory containing inpcrd and prmtop. New inpcrd will be 

62 written here as well. 

63 inpcrd_file (str): Name of inpcrd file in path 

64 prmtop_file (str): Name of prmtop file in path 

65 sander_cmd (str): Command for sander 

66 """ 

67 defaults = sander_min_defaults() 

68 mdin = defaults.mdin_contents 

69 with tempfile.NamedTemporaryFile(mode='w+', suffix='.in', dir=str(path)) as tmp_in: 

70 tmp_in.write(mdin) 

71 tmp_in.flush() 

72 outfile = Path(inpcrd_file).with_suffix('.min.inpcrd') 

73 with tempfile.NamedTemporaryFile(mode='w', suffix='.out', dir=str(path)) as tmp_out: 

74 command = [sander_cmd, '-O', 

75 '-i', tmp_in.name, 

76 '-o', tmp_out.name, 

77 '-p', str(path / prmtop_file), 

78 '-c', str(path / inpcrd_file), 

79 '-r', str(path / outfile), 

80 '-inf', str(path / 'min.mdinfo')] 

81 result = subprocess.run(command, shell=False, capture_output=True, text=True) 

82 if result.returncode != 0: 

83 raise RuntimeError(f'sander error!\n{result.stderr}\n{result.stdout}') 

84 

85class MultiResolutionSimulator: 

86 """ 

87 Class for performing multi-resolution simulations with switching between CG and AA  

88 representations. Utilizes CALVADOS for CG simulations and omm_simulator.py for AA 

89 simulations.  

90  

91 Arguments: 

92 path (PathLike): Path to simulation input files, also serves as output path. 

93 input_pdb (str): Input pdb for simulations, must exist in path. 

94 n_rounds (int): Number of rounds of CG/AA simulation to perform. 

95 cg_params (dict): Parameters for CG simulations. Initializes CGBuilder. 

96 aa_params (dict): Parameters for AA simulations. Initializes omm_simulator. 

97 cg2all_bin (str): Defaults to 'convert_cg2all'. Path to cg2all binary. Must 

98 be provided if cg2all is installed in a separate environment.  

99 cg2all_ckpt (OptPath): Path to cg2all checkpoint file.  

100 AMBERHOME (str | None): Defaults to None. Path to AMBERHOME (excluding bin).  

101 Used for sander and pdb4amber. If None, assumes AmberTools binaries are  

102 available in the current $PATH. 

103 

104 Usage: 

105 sim = MultiResolutionSimulator.from_toml('config.toml') 

106 sim.run() 

107 """ 

108 def __init__(self, 

109 path: PathLike, 

110 input_pdb: str, 

111 n_rounds: int, 

112 cg_params: dict, 

113 aa_params: dict, 

114 cg2all_bin: str = 'convert_cg2all', 

115 cg2all_ckpt: OptPath = None, 

116 AMBERHOME: str | None = None): 

117 self.path = Path(path) 

118 self.input_pdb = input_pdb 

119 self.n_rounds = n_rounds 

120 self.cg_params = cg_params 

121 self.aa_params = aa_params 

122 self.cg2all_bin = cg2all_bin 

123 self.cg2all_ckpt = cg2all_ckpt 

124 self.AMBERHOME = Path(AMBERHOME) if AMBERHOME is not None else None 

125 

126 @classmethod 

127 def from_toml(cls: Type[_T], config: PathLike) -> _T: 

128 """ 

129 Constructs MultiResolutionSimulator from .toml configuration file. 

130 Recommended method for instantiating MultiResolutionSimulator. 

131 """ 

132 with open(config, 'rb') as f: 

133 cfg = tomllib.load(f) 

134 settings = cfg['settings'] 

135 cg_params = cfg['cg_params'][0] 

136 aa_params = cfg['aa_params'] 

137 path = settings['path'] 

138 input_pdb = settings['input_pdb'] 

139 n_rounds = settings['n_rounds'] 

140 

141 if 'cg2all_bin' in settings: 

142 cg2all_bin = settings['cg2all_bin'] 

143 else: 

144 cg2all_bin = 'convert_cg2all' 

145 

146 if 'cg2all_ckpt' in settings: 

147 cg2all_ckpt = settings['cg2all_ckpt'] 

148 else: 

149 cg2all_ckpt = None 

150 

151 if 'AMBERHOME' in settings: 

152 AMBERHOME = Path(settings['AMBERHOME']) 

153 else: 

154 AMBERHOME = None 

155 

156 return cls(path, 

157 input_pdb, 

158 n_rounds, 

159 cg_params, 

160 aa_params, 

161 cg2all_bin = cg2all_bin, 

162 cg2all_ckpt = cg2all_ckpt, 

163 AMBERHOME = AMBERHOME) 

164 

165 @staticmethod 

166 def strip_solvent(simulation: Simulation, 

167 output_pdb: PathLike = 'protein.pdb' 

168 ) -> None: 

169 """ 

170 Use parmed to strip solvent from an openmm simulation and write out pdb 

171 """ 

172 struc = pmd.openmm.load_topology( 

173 simulation.topology, 

174 simulation.system, 

175 xyz = simulation.context.getState(getPositions=True).getPositions() 

176 ) 

177 solvent_resnames = [ 

178 'WAT', 'HOH', 'TIP3', 'TIP3P', 'SOL', 'OW', 'H2O', 

179 'NA', 'K', 'CL', 'MG', 'CA', 'ZN', 'MN', 'FE', 

180 'Na+', 'K+', 'Cl-', 'Mg2+', 'Ca2+', 'Zn2+', 'Mn2+', 'Fe2+', 'Fe3+', 

181 'SOD', 'POT', 'CLA' 

182 ] 

183 mask = ':' + ','.join(solvent_resnames) 

184 struc.strip(mask) 

185 struc.save(output_pdb) 

186 

187 def run_rounds(self) -> None: 

188 """ 

189 Main logic for running MultiResolutionSimulator. 

190 Does not currently handle restart runs (TODO). 

191 """ 

192 

193 for r in range(self.n_rounds): 

194 aa_path = self.path / f'aa_round{r}' 

195 aa_path.mkdir() 

196 

197 if r == 0: 

198 input_pdb = str(self.path / self.input_pdb) 

199 else: 

200 input_pdb = str(self.path / f'cg_round{r-1}/last_frame.amber.pdb') 

201 

202 

203 match self.aa_params['solvation_scheme']: 

204 case 'implicit': 

205 _aa_builder = ImplicitSolvent 

206 _aa_simulator = ImplicitSimulator 

207 case 'explicit': 

208 _aa_builder = ExplicitSolvent 

209 _aa_simulator = Simulator 

210 case _: 

211 raise AttributeError("solvation_scheme must be 'implicit' or 'explicit'") 

212 

213 aa_builder = _aa_builder( 

214 aa_path, 

215 input_pdb, 

216 protein = self.aa_params['protein'], 

217 rna = self.aa_params['rna'], 

218 dna = self.aa_params['dna'], 

219 phos_protein = self.aa_params['phos_protein'], 

220 use_amber = self.aa_params['use_amber'], 

221 out = self.aa_params['out']) 

222 

223 aa_builder.build() 

224 

225 # cg2all may create clashes which OpenMM minimization does not address. 

226 # Therefore, we want to minimize all cg2all-created structures with sander instead. 

227 if self.AMBERHOME is None: 

228 sander = 'sander' 

229 else: 

230 sander = str(self.AMBERHOME / 'bin/sander') 

231 sander_minimize(aa_path, 'system.inpcrd', 'system.prmtop', sander) 

232 

233 aa_simulator = _aa_simulator( 

234 aa_path, 

235 coor_name = 'system.min.inpcrd', 

236 ff = 'amber', 

237 equil_steps = int(self.aa_params['equilibration_steps']), 

238 prod_steps = int(self.aa_params['production_steps']), 

239 n_equil_cycles = 1, 

240 device_ids = self.aa_params['device_ids']) 

241 

242 aa_simulator.run() 

243 

244 # strip solvent and output AA structure for next step (CG) 

245 self.strip_solvent(aa_simulator.simulation, 

246 str(aa_path / 'protein.pdb')) 

247 

248 # build CG 

249 cg_path = self.path / f'cg_round{r}' 

250 cg_path.mkdir() 

251 cg_params = self.cg_params 

252 cg_params['config']['path'] = str(cg_path) 

253 cg_params['config']['input_pdb'] = str(aa_path / 'protein.pdb') 

254 

255 cg_builder = CGBuilder.from_dict(cg_params) 

256 cg_builder.build() # writes config and components yamls 

257 

258 # run CG 

259 sim.run(path = str(cg_path), 

260 fconfig = 'config.yaml', 

261 fcomponents = 'components.yaml') 

262 

263 # convert CG to AA for next round using cg2all 

264 command = [self.cg2all_bin, 

265 '-p', str(cg_path / 'top.pdb'), 

266 '-d', str(cg_path / 'protein.dcd'), 

267 '-o', str(cg_path / 'traj_aa.dcd'), 

268 '-opdb', str(cg_path / 'last_frame.pdb'), 

269 '--cg', 'ResidueBasedModel', 

270 '--standard-name', 

271 '--device', 'cuda', 

272 '--proc', '1'] 

273 if self.cg2all_ckpt is not None: 

274 command += ['--ckpt', self.cg2all_ckpt] 

275 

276 result = subprocess.run(command, shell=False, capture_output=True, text=True) 

277 if result.returncode != 0: 

278 raise RuntimeError(f'cg2all error!\n{result.stderr}') 

279 

280 # use pdb4amber to fix cg2all-generated pdb 

281 if self.AMBERHOME is None: 

282 command = ['pdb4amber'] 

283 else: 

284 command = [str(self.AMBERHOME / 'bin/pdb4amber')] 

285 command += [str(cg_path / 'last_frame.pdb'), '-y'] 

286 result = subprocess.run(command, shell=False, capture_output=True, text=True) 

287 if result.returncode == 0: 

288 with open(str(cg_path / 'last_frame.amber.pdb'), 'w') as f: 

289 f.write(result.stdout) 

290 else: 

291 raise RuntimeError(f'pdb4amber error!\n{result.stderr}')