Source code for ase2sprkkr.common.init_tests

""" Routines and classes used in tests """

import sys
import os
from pathlib import Path
import unittest
import numpy as np
import inspect
import sys
import asyncio
from ase.cell import Cell
from contextlib import contextmanager
[docs] def patch_package(package, name): """ Set the package name for the tests, to make the relative imports working. Usage: ..code:: if __package__: from .init_tests import TestCase, patch_package else: from init_tests import TestCase, patch_package __package__, __name__ = patch_package(__package__, __name__) """ if package and package.count('.') >= 3: return package, name frame=inspect.stack()[1] path = frame.filename file = Path(path).resolve() current = file.parents[0] while file.name != 'ase2sprkkr': file = file.parent top = str(file.parent) sys.path.append(top) package=str(current)[len(top)+1:].replace('/','.') return package, package + '.' + name.rsplit('.', 1)[-1]
__package__, __name__ = patch_package(__package__,__name__)
[docs] class TestCase(unittest.TestCase): """ A testcase class with some usefull assertions and a better numpy arrays comparison """
[docs] def assertAsyncEqual(self, a, b): return self.assertEqual(a, self.runAsync(b))
[docs] def assertAsyncRaises(self, a, b): return self.assertRaises(a, self.runAsync(b))
[docs] def assertAlmostEqual(self, a, b, **kwargs): np.testing.assert_almost_equal(a,b, **kwargs)
[docs] @staticmethod def runAsync(corr): return asyncio.run(corr)
[docs] @classmethod @contextmanager def almost_equal_precision(cls, **kwargs): tmp = cls._almost_equal_precision cls._almost_equal_precision = kwargs yield cls._almost_equal_precision = tmp
_almost_equal_precision = {}
[docs] def assertDictEqual(self, a, b, msg=''): if a.__class__ != b.__class__: if msg: msg+='\n' raise ValueError(msg + f'Classes {a.__class__} and {b.__class__} are not equal') if len(a) != len(b): super().assertDictEqual(dict(a), dict(b), msg) for (ai, av),(bi, bv) in zip(a.items(), b.items()): self.assertEqual(ai, bi, 'Dict keys are not equal') self.assertEqual(av, bv, 'Dict values are not equal')
[docs] def setUp(self): """ Register numpy array for the equality """ def testfce(fce, msg='', **kwargs): def np_array_equal(a, b, msg=msg, **kwar): try: fce(a,b, **kwargs, **kwar) except AssertionError as e: if msg: msg = msg + '\n' + str(e) raise self.failureException(msg) else: raise return np_array_equal assert_equals = testfce(np.testing.assert_equal) assert_almost_equals = testfce(np.testing.assert_almost_equal) def arr_testfce(a,b,msg, **kwargs): """ assert_almost_equal does not work for non-numeric dtypes """ if a.dtype == 'O': return assert_equals(a,b,msg) if a.dtype.names: for i in range(len(a.dtype)): if a.dtype[1] == 'O': return assert_equals(a,b,msg) kwargs.update(self._almost_equal_precision) return assert_almost_equals(a,b,msg, **kwargs) self.addTypeEqualityFunc( np.ndarray, arr_testfce ) self.addTypeEqualityFunc( dict, self.assertDictEqual ) self.addTypeEqualityFunc( Cell, lambda a,b,msg: arr_testfce(np.array(a), np.array(b), msg) )