Coverage for emd/support.py: 59%
108 statements
« prev ^ index » next coverage.py v7.6.11, created at 2025-03-08 15:44 +0000
« prev ^ index » next coverage.py v7.6.11, created at 2025-03-08 15:44 +0000
1#!/usr/bin/python
3# vim: set expandtab ts=4 sw=4:
5"""
6Helper functions for interacting with an EMD install and ensuring array sizes.
8Main Routines:
9 get_install_dir
10 get_installed_version
11 run_tests
13Ensurance Routines:
14 ensure_equal_dims
15 ensure_vector
16 ensure_1d_with_singleton
17 ensure_2d
19Errors:
20 EMDSiftCovergeError
22"""
24import logging
25import os
26import pathlib
28import numpy as np
30# Housekeeping for logging
31logger = logging.getLogger(__name__)
34def get_install_dir():
35 """Get directory path of currently installed & imported emd."""
36 #return os.path.dirname(sift.__file__)
37 return str(pathlib.Path(__file__).parent.resolve())
40def get_installed_version():
41 """Read version of currently installed & imported emd.
43 Version is determined according to local setup.py. If a user has made local
44 changes this version may not be exactly the same as the online package.
46 """
47 # The directory containing this file
48 #HERE = pathlib.Path(__file__).parent.parent
49 #version = (HERE / 'emd' / '_version.py').read_text().split(' = ')[1].rstrip('\n').strip("'")
50 from importlib.metadata import version
51 return version('emd')
54def create_readthedocs_requirements():
55 import tomllib
56 with open(pathlib.Path(get_install_dir()).parent / 'pyproject.toml', "rb") as f:
57 data = tomllib.load(f)
59 with open(pathlib.Path(get_install_dir()).parent / 'requirements_rtd.txt', 'w') as f:
60 for dep in data['project']['dependencies']:
61 f.write(dep + '\n')
62 for opt in data['project']['optional-dependencies']:
63 if opt == 'full':
64 continue
65 for dep in data['project']['optional-dependencies'][opt]:
66 f.write(dep + '\n')
67 f.write('-e .\n')
70def run_tests():
71 """Run tests in directly from python.
73 Useful for people without a dev-install to run tests perhaps.
74 https://docs.pytest.org/en/latest/usage.html#calling-pytest-from-python-code
76 """
77 import pytest
78 inst_dir = get_install_dir()
80 if os.path.exists(os.path.join(inst_dir, 'tests')) is False:
81 logger.info('Test directory not found in: {0}'.format(inst_dir))
82 logger.info('(this is normal for PyPI/pip EMD installs)')
83 else:
84 logger.info('Running EMD package tests from: {0}'.format(inst_dir))
85 out = pytest.main(['-x', inst_dir])
87 if out.value != 0:
88 logger.warning('EMD package tests FAILED - EMD may not behave as expected')
89 else:
90 logger.info('EMD package tests passed')
93# Parallel processing
95def run_parallel(pfunc, args, nprocesses=1):
96 """Run set of processes in serial or parallel."""
97 from joblib import Parallel, delayed
99 if nprocesses > 1:
100 with Parallel(n_jobs=nprocesses) as parallel:
101 res = parallel(delayed(pfunc)(*aa) for aa in args)
102 else:
103 res = [pfunc(*aa) for aa in args]
105 return res
108# Ensurance Department
111def ensure_equal_dims(to_check, names, func_name, dim=None):
112 """Check that a set of arrays all have the same dimension.
114 Raises an error with details if not.
116 Parameters
117 ----------
118 to_check : list of arrays
119 List of arrays to check for equal dimensions
120 names : list
121 List of variable names for arrays in to_check
122 func_name : str
123 Name of function calling ensure_equal_dims
124 dim : int
125 Integer index of specific axes to ensure shape of, default is to compare all dims
127 Raises
128 ------
129 ValueError
130 If any of the inputs in to_check have differing shapes
132 """
133 if dim is None:
134 dim = np.arange(to_check[0].ndim)
135 else:
136 dim = [dim]
138 all_dims = [tuple(np.array(x.shape)[dim]) for x in to_check]
139 check = [True] + [all_dims[0] == all_dims[ii + 1] for ii in range(len(all_dims[1:]))]
141 if np.all(check) == False: # noqa: E712
142 msg = 'Checking {0} inputs - Input dim mismatch'.format(func_name)
143 logger.error(msg)
144 msg = "Mismatch between inputs: "
145 for ii in range(len(to_check)):
146 msg += "'{0}': {1}, ".format(names[ii], to_check[ii].shape)
147 logger.error(msg)
148 raise ValueError(msg)
151def ensure_vector(to_check, names, func_name):
152 """Check that a set of arrays are all vectors with only 1-dimension.
154 Arrays with singleton second dimensions will be trimmed and an error will
155 be raised for non-singleton 2d or greater than 2d inputs.
157 Parameters
158 ----------
159 to_check : list of arrays
160 List of arrays to check for equal dimensions
161 names : list
162 List of variable names for arrays in to_check
163 func_name : str
164 Name of function calling ensure_equal_dims
166 Returns
167 -------
168 out
169 Copy of arrays in to_check with 1d shape.
171 Raises
172 ------
173 ValueError
174 If any input is a 2d or greater array
176 """
177 out_args = list(to_check)
178 for idx, xx in enumerate(to_check):
180 if (xx.ndim > 1) and (xx.shape[1] == 1):
181 msg = "Checking {0} inputs - trimming singleton from input '{1}'"
182 msg = msg.format(func_name, names[idx])
183 out_args[idx] = out_args[idx][:, 0]
184 logger.warning(msg)
185 elif (xx.ndim > 1) and (xx.shape[1] != 1):
186 msg = "Checking {0} inputs - Input '{1}' {2} must be a vector or 2d with singleton second dim"
187 msg = msg.format(func_name, names[idx], xx.shape)
188 logger.error(msg)
189 raise ValueError(msg)
190 elif xx.ndim > 2:
191 msg = "Checking {0} inputs - Shape of input '{1}' {2} must be a vector."
192 msg = msg.format(func_name, names[idx], xx.shape)
193 logger.error(msg)
194 raise ValueError(msg)
196 if len(out_args) == 1:
197 return out_args[0]
198 else:
199 return out_args
202def ensure_1d_with_singleton(to_check, names, func_name):
203 """Check that a set of arrays are all vectors with singleton second dimensions.
205 1d arrays will have a singleton second dimension added and an error will be
206 raised for non-singleton 2d or greater than 2d inputs.
208 Parameters
209 ----------
210 to_check : list of arrays
211 List of arrays to check for equal dimensions
212 names : list
213 List of variable names for arrays in to_check
214 func_name : str
215 Name of function calling ensure_equal_dims
217 Returns
218 -------
219 out
220 Copy of arrays in to_check with '1d with singleton' shape.
222 Raises
223 ------
224 ValueError
225 If any input is a 2d or greater array
227 """
228 out_args = list(to_check)
229 for idx, xx in enumerate(to_check):
231 if (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])):
232 # nd input where all trailing are ones
233 msg = "Checking {0} inputs - Trimming trailing singletons from input '{1}' (input size {2})"
234 logger.debug(msg.format(func_name, names[idx], xx.shape))
235 out_args[idx] = np.squeeze(xx)[:, np.newaxis]
236 elif (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])) == False: # noqa: E712
237 # nd input where some trailing are not one
238 msg = "Checking {0} inputs - trailing dims of input '{1}' {2} must be singletons (length=1)"
239 logger.error(msg.format(func_name, names[idx], xx.shape))
240 raise ValueError(msg)
241 elif xx.ndim == 1:
242 # Vector input - add a dummy dimension
243 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'"
244 logger.debug(msg.format(func_name, names[idx]))
245 out_args[idx] = out_args[idx][:, np.newaxis]
247 if len(out_args) == 1:
248 return out_args[0]
249 else:
250 return out_args
253def ensure_2d(to_check, names, func_name):
254 """Check that a set of arrays are all arrays with 2 dimensions.
256 1d arrays will have a singleton second dimension added.
258 Parameters
259 ----------
260 to_check : list of arrays
261 List of arrays to check for equal dimensions
262 names : list
263 List of variable names for arrays in to_check
264 func_name : str
265 Name of function calling ensure_equal_dims
267 Returns
268 -------
269 out
270 Copy of arrays in to_check with 2d shape.
272 """
273 out_args = list(to_check)
274 for idx in range(len(to_check)):
276 if to_check[idx].ndim == 1:
277 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'"
278 logger.debug(msg.format(func_name, names[idx]))
279 out_args[idx] = out_args[idx][:, np.newaxis]
281 if len(out_args) == 1:
282 return out_args[0]
283 else:
284 return out_args
287# Exceptions & Errors
289class EMDSiftCovergeError(Exception):
290 """Exception raised for errors in the input.
292 Attributes
293 ----------
294 expression -- input expression in which the error occurred
295 message -- explanation of the error
297 """
299 def __init__(self, message):
300 """Raise error indicating that sift has failed to converge."""
301 self.message = message
302 logger.exception(self.message)