Coverage for emd/support.py: 66%
95 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-09 10:07 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-09 10:07 +0000
1#!/usr/bin/python
3# vim: set expandtab ts=4 sw=4:
5"""
6Helper functions for interacting with an EMD install and ensuring array sizes.
8Main Routines:
9 get_install_dir
10 get_installed_version
11 run_tests
13Ensurance Routines:
14 ensure_equal_dims
15 ensure_vector
16 ensure_1d_with_singleton
17 ensure_2d
19Errors:
20 EMDSiftCovergeError
22"""
24import logging
25import os
26import pathlib
28import numpy as np
30# Housekeeping for logging
31logger = logging.getLogger(__name__)
34def get_install_dir():
35 """Get directory path of currently installed & imported emd."""
36 #return os.path.dirname(sift.__file__)
37 return str(pathlib.Path(__file__).parent.resolve())
40def get_installed_version():
41 """Read version of currently installed & imported emd.
43 Version is determined according to local setup.py. If a user has made local
44 changes this version may not be exactly the same as the online package.
46 """
47 # The directory containing this file
48 #HERE = pathlib.Path(__file__).parent.parent
49 #version = (HERE / 'emd' / '_version.py').read_text().split(' = ')[1].rstrip('\n').strip("'")
50 from importlib.metadata import version
51 return version('emd')
54def run_tests():
55 """Run tests in directly from python.
57 Useful for people without a dev-install to run tests perhaps.
58 https://docs.pytest.org/en/latest/usage.html#calling-pytest-from-python-code
60 """
61 import pytest
62 inst_dir = get_install_dir()
64 if os.path.exists(os.path.join(inst_dir, 'tests')) is False:
65 logger.info('Test directory not found in: {0}'.format(inst_dir))
66 logger.info('(this is normal for PyPI/pip EMD installs)')
67 else:
68 logger.info('Running EMD package tests from: {0}'.format(inst_dir))
69 out = pytest.main(['-x', inst_dir])
71 if out.value != 0:
72 logger.warning('EMD package tests FAILED - EMD may not behave as expected')
73 else:
74 logger.info('EMD package tests passed')
77# Parallel processing
79def run_parallel(pfunc, args, nprocesses=1):
80 """Run set of processes in serial or parallel."""
81 from joblib import Parallel, delayed
83 if nprocesses > 1:
84 with Parallel(n_jobs=nprocesses) as parallel:
85 res = parallel(delayed(pfunc)(*aa) for aa in args)
86 else:
87 res = [pfunc(*aa) for aa in args]
89 return res
92# Ensurance Department
95def ensure_equal_dims(to_check, names, func_name, dim=None):
96 """Check that a set of arrays all have the same dimension.
98 Raises an error with details if not.
100 Parameters
101 ----------
102 to_check : list of arrays
103 List of arrays to check for equal dimensions
104 names : list
105 List of variable names for arrays in to_check
106 func_name : str
107 Name of function calling ensure_equal_dims
108 dim : int
109 Integer index of specific axes to ensure shape of, default is to compare all dims
111 Raises
112 ------
113 ValueError
114 If any of the inputs in to_check have differing shapes
116 """
117 if dim is None:
118 dim = np.arange(to_check[0].ndim)
119 else:
120 dim = [dim]
122 all_dims = [tuple(np.array(x.shape)[dim]) for x in to_check]
123 check = [True] + [all_dims[0] == all_dims[ii + 1] for ii in range(len(all_dims[1:]))]
125 if np.all(check) == False: # noqa: E712
126 msg = 'Checking {0} inputs - Input dim mismatch'.format(func_name)
127 logger.error(msg)
128 msg = "Mismatch between inputs: "
129 for ii in range(len(to_check)):
130 msg += "'{0}': {1}, ".format(names[ii], to_check[ii].shape)
131 logger.error(msg)
132 raise ValueError(msg)
135def ensure_vector(to_check, names, func_name):
136 """Check that a set of arrays are all vectors with only 1-dimension.
138 Arrays with singleton second dimensions will be trimmed and an error will
139 be raised for non-singleton 2d or greater than 2d inputs.
141 Parameters
142 ----------
143 to_check : list of arrays
144 List of arrays to check for equal dimensions
145 names : list
146 List of variable names for arrays in to_check
147 func_name : str
148 Name of function calling ensure_equal_dims
150 Returns
151 -------
152 out
153 Copy of arrays in to_check with 1d shape.
155 Raises
156 ------
157 ValueError
158 If any input is a 2d or greater array
160 """
161 out_args = list(to_check)
162 for idx, xx in enumerate(to_check):
164 if (xx.ndim > 1) and (xx.shape[1] == 1):
165 msg = "Checking {0} inputs - trimming singleton from input '{1}'"
166 msg = msg.format(func_name, names[idx])
167 out_args[idx] = out_args[idx][:, 0]
168 logger.warning(msg)
169 elif (xx.ndim > 1) and (xx.shape[1] != 1):
170 msg = "Checking {0} inputs - Input '{1}' {2} must be a vector or 2d with singleton second dim"
171 msg = msg.format(func_name, names[idx], xx.shape)
172 logger.error(msg)
173 raise ValueError(msg)
174 elif xx.ndim > 2:
175 msg = "Checking {0} inputs - Shape of input '{1}' {2} must be a vector."
176 msg = msg.format(func_name, names[idx], xx.shape)
177 logger.error(msg)
178 raise ValueError(msg)
180 if len(out_args) == 1:
181 return out_args[0]
182 else:
183 return out_args
186def ensure_1d_with_singleton(to_check, names, func_name):
187 """Check that a set of arrays are all vectors with singleton second dimensions.
189 1d arrays will have a singleton second dimension added and an error will be
190 raised for non-singleton 2d or greater than 2d inputs.
192 Parameters
193 ----------
194 to_check : list of arrays
195 List of arrays to check for equal dimensions
196 names : list
197 List of variable names for arrays in to_check
198 func_name : str
199 Name of function calling ensure_equal_dims
201 Returns
202 -------
203 out
204 Copy of arrays in to_check with '1d with singleton' shape.
206 Raises
207 ------
208 ValueError
209 If any input is a 2d or greater array
211 """
212 out_args = list(to_check)
213 for idx, xx in enumerate(to_check):
215 if (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])):
216 # nd input where all trailing are ones
217 msg = "Checking {0} inputs - Trimming trailing singletons from input '{1}' (input size {2})"
218 logger.debug(msg.format(func_name, names[idx], xx.shape))
219 out_args[idx] = np.squeeze(xx)[:, np.newaxis]
220 elif (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])) == False: # noqa: E712
221 # nd input where some trailing are not one
222 msg = "Checking {0} inputs - trailing dims of input '{1}' {2} must be singletons (length=1)"
223 logger.error(msg.format(func_name, names[idx], xx.shape))
224 raise ValueError(msg)
225 elif xx.ndim == 1:
226 # Vector input - add a dummy dimension
227 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'"
228 logger.debug(msg.format(func_name, names[idx]))
229 out_args[idx] = out_args[idx][:, np.newaxis]
231 if len(out_args) == 1:
232 return out_args[0]
233 else:
234 return out_args
237def ensure_2d(to_check, names, func_name):
238 """Check that a set of arrays are all arrays with 2 dimensions.
240 1d arrays will have a singleton second dimension added.
242 Parameters
243 ----------
244 to_check : list of arrays
245 List of arrays to check for equal dimensions
246 names : list
247 List of variable names for arrays in to_check
248 func_name : str
249 Name of function calling ensure_equal_dims
251 Returns
252 -------
253 out
254 Copy of arrays in to_check with 2d shape.
256 """
257 out_args = list(to_check)
258 for idx in range(len(to_check)):
260 if to_check[idx].ndim == 1:
261 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'"
262 logger.debug(msg.format(func_name, names[idx]))
263 out_args[idx] = out_args[idx][:, np.newaxis]
265 if len(out_args) == 1:
266 return out_args[0]
267 else:
268 return out_args
271# Exceptions & Errors
273class EMDSiftCovergeError(Exception):
274 """Exception raised for errors in the input.
276 Attributes
277 ----------
278 expression -- input expression in which the error occurred
279 message -- explanation of the error
281 """
283 def __init__(self, message):
284 """Raise error indicating that sift has failed to converge."""
285 self.message = message
286 logger.exception(self.message)