from ase.calculators.calculator import Calculator, all_changes
from mlcalcdriver.calculators import SchnetPackCalculator
from mlcalcdriver.base import Posinp, Job
from copy import deepcopy
import numpy as np
[docs]class AseSpkCalculator(Calculator):
r"""
Wrapper :class:`Calculator` class around the :class:`SchnetPackCalculator`
class to use directly inside ASE funtions.
"""
def __init__(self, model_dir, available_properties=None, device="cpu", md=False, **kwargs):
r"""
Parameters
----------
model_dir : str
Same as :class:`SchnetPackCalculator`.
available_properties : str or list of str
Same as :class:`SchnetPackCalculator`.
device : str
Same as :class:`SchnetPackCalculator`.
md : bool
Default is False. Should be set to True if the
calculator is used for molecular dynamics.
units : dict
Same as :class:`SchnetPackCalculator`.
"""
Calculator.__init__(self, **kwargs)
self.schnetpackcalculator = SchnetPackCalculator(
model_dir=model_dir,
available_properties=available_properties,
device=device,
md=md,
)
self.implemented_properties = self.schnetpackcalculator._get_available_properties()
if "energy" in self.implemented_properties and "forces" not in self.implemented_properties:
self.implemented_properties.append("forces")
[docs] def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
r"""
This method will be called by ASE functions.
"""
if self.calculation_required(atoms, properties):
Calculator.calculate(self, atoms)
posinp = Posinp.from_ase(atoms)
job = Job(posinp=posinp, calculator=self.schnetpackcalculator)
for prop in properties:
job.run(prop)
results = {}
for prop, result in zip(job.results.keys(), job.results.values()):
results[prop] = np.squeeze(result)
self.results = results