Source code for mlcalcdriver.calculators.ase_calculators.asespkcalculator

from ase.calculators.calculator import Calculator, all_changes
from mlcalcdriver.calculators import SchnetPackCalculator
from mlcalcdriver.base import Posinp, Job
from copy import deepcopy
import numpy as np


[docs]class AseSpkCalculator(Calculator): r""" Wrapper :class:`Calculator` class around the :class:`SchnetPackCalculator` class to use directly inside ASE funtions. """ def __init__(self, model_dir, available_properties=None, device="cpu", md=False, **kwargs): r""" Parameters ---------- model_dir : str Same as :class:`SchnetPackCalculator`. available_properties : str or list of str Same as :class:`SchnetPackCalculator`. device : str Same as :class:`SchnetPackCalculator`. md : bool Default is False. Should be set to True if the calculator is used for molecular dynamics. units : dict Same as :class:`SchnetPackCalculator`. """ Calculator.__init__(self, **kwargs) self.schnetpackcalculator = SchnetPackCalculator( model_dir=model_dir, available_properties=available_properties, device=device, md=md, ) self.implemented_properties = self.schnetpackcalculator._get_available_properties() if "energy" in self.implemented_properties and "forces" not in self.implemented_properties: self.implemented_properties.append("forces")
[docs] def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): r""" This method will be called by ASE functions. """ if self.calculation_required(atoms, properties): Calculator.calculate(self, atoms) posinp = Posinp.from_ase(atoms) job = Job(posinp=posinp, calculator=self.schnetpackcalculator) for prop in properties: job.run(prop) results = {} for prop, result in zip(job.results.keys(), job.results.values()): results[prop] = np.squeeze(result) self.results = results