Source code for ase2sprkkr.common.init_tests

""" Routines and classes used in tests """

import sys
from pathlib import Path
import numpy as np
import inspect
import asyncio
from ase.cell import Cell
from contextlib import contextmanager
import pytest
import tempfile
import unittest
import os


[docs] def run_sprkkr(fn): """ run this test only if SPRKKR executable can be runned """ return pytest.mark.skipif(not TestCase.run_sprkkr(), reason="The test require SPRKKR running")(fn)
[docs] def patch_package(package, name): """ Set the package name for the tests, to make the relative imports working. Usage: ..code:: if __package__: from .init_tests import TestCase, patch_package else: from init_tests import TestCase, patch_package __package__, __name__ = patch_package(__package__, __name__) """ if package and package.count('.') >= 3: return package, name frame=inspect.stack()[1] path = frame.filename file = Path(path).resolve() current = file.parents[0] while file.name != 'ase2sprkkr': file = file.parent top = str(file.parent) try: import ase2sprkkr # NOQA except ImportError: sys.path.append(top) package=str(current)[len(top) + 1:].replace('/','.') return package, package + '.' + name.rsplit('.', 1)[-1]
__package__, __name__ = patch_package(__package__,__name__)
[docs] class extdict(dict): def __call__(self, **kwargs): out = self.copy() out.update(kwargs) return out
[docs] class TestCase: """ A testcase class with some usefull assertions and a better numpy arrays comparison """ _print_output = '-v' in sys.argv or '--verbose' in sys.argv _calc_args = extdict( directory = False, input_file = 'output_test_calc.inp', # empty_spheres=False, output_file = 'output_test_calc.out', potential_file ='output_test_calc.pot', print_output=_print_output, mpi = 'auto', options = {'NKTAB': 5, 'NE': 20}, empty_spheres = False )
[docs] @pytest.fixture def temporary_dir(self): with tempfile.TemporaryDirectory() as d: TestCase._calc_args['directory'] = d self.dirname = d yield TestCase._calc_args['directory'] = False del self.dirname
[docs] @classmethod def calc_args(cls, TYPE=None, **kwargs): if not kwargs: return cls._calc_args out = cls._calc_args.copy() if 'options' in kwargs: kwargs['options'].update(cls._calc_args['options']) out.update(kwargs) if TYPE == 'save_input': try: del out['print_output'] except KeyError: pass return out
[docs] @staticmethod def run_sprkkr(): return os.environ.get('DO_NOT_RUN_SPRKKR', '') == ''
[docs] def assertAsyncEqual(self, a, b): return self.assertEqual(a, self.runAsync(b))
[docs] def assertAsyncRaises(self, a, b): with pytest.raises(a): self.runAsync(b)
[docs] def assertRaises(self, a, b=None): if b is None: return pytest.raises(a) with pytest.raises(a): b()
[docs] def assertAlmostEqual(self, a, b, **kwargs): np.testing.assert_almost_equal(a,b, **kwargs)
[docs] def assertIsNone(self, a): assert a is None
[docs] @staticmethod def runAsync(corr): return asyncio.run(corr)
[docs] def assertTrue(self, val): assert val
[docs] def assertFalse(self, val): assert not val
[docs] def assertEqual(self, a, b, msg=None): assertion.assertEqual(a,b, msg=msg)
[docs] def assertNotEqual(self, a, b): with pytest.raises(AssertionError): assertion.assertEqual(a,b)
[docs] @classmethod @contextmanager def almost_equal_precision(cls, **kwargs): tmp = assertion._almost_equal_precision assertion._almost_equal_precision = kwargs yield assertion._almost_equal_precision = tmp
assertion = unittest.TestCase('__init__') assertion._almost_equal_precision = {}
[docs] def testfce(fce, msg='', **kwargs): def np_array_equal(a, b, msg=msg, **kwar): try: fce(a,b, **kwargs, **kwar) except AssertionError as e: if msg: msg = msg + '\n' + str(e) raise assertion.failureException(msg) else: raise return np_array_equal
assert_equals = testfce(np.testing.assert_equal) assert_almost_equals = testfce(np.testing.assert_almost_equal)
[docs] def arr_testfce(a,b,msg, **kwargs): """ assert_almost_equal does not work for non-numeric dtypes """ if a.dtype == 'O': return assert_equals(a,b,msg) if a.dtype.names: for i in range(len(a.dtype)): if a.dtype[1] == 'O': return assert_equals(a,b,msg) kwargs.update(assertion._almost_equal_precision) return assert_almost_equals(a,b,msg, **kwargs)
assertion.addTypeEqualityFunc( np.ndarray, arr_testfce )
[docs] def assertDictEqual(a, b, msg=''): if a.__class__ != b.__class__: if msg: msg+='\n' raise ValueError(msg + f'Classes {a.__class__} and {b.__class__} are not equal') if len(a) != len(b): assert a == b, msg or 'Dictionaries differ' for (ai, av),(bi, bv) in zip(a.items(), b.items()): assertion.assertEqual(ai, bi, 'Dict keys are not equal') assertion.assertEqual(av, bv, 'Dict values are not equal')
[docs] def assertListEqual(a, b, msg=''): def message(error): return msg + ': ' + error if msg else error assert a.__class__ is b.__class__, message('A list is expected') assert len(a) == len(b), message('The lists should have the same lengths, they have ' f'the lengths {len(a)} and {len(b)} respectivelly') for i, vals in enumerate(zip(a,b)): try: assertion.assertEqual(vals[0], vals[1]) except Exception as e: raise AssertionError(message(f'The {i}th value of the lists differs: {e}'))
assertion.addTypeEqualityFunc(list, assertListEqual) assertion.addTypeEqualityFunc(dict, assertDictEqual) assertion.addTypeEqualityFunc(Cell, lambda a,b,msg: arr_testfce(np.array(a), np.array(b), msg) )