Coverage for emd/support.py: 66%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-09 10:07 +0000

1#!/usr/bin/python 

2 

3# vim: set expandtab ts=4 sw=4: 

4 

5""" 

6Helper functions for interacting with an EMD install and ensuring array sizes. 

7 

8Main Routines: 

9 get_install_dir 

10 get_installed_version 

11 run_tests 

12 

13Ensurance Routines: 

14 ensure_equal_dims 

15 ensure_vector 

16 ensure_1d_with_singleton 

17 ensure_2d 

18 

19Errors: 

20 EMDSiftCovergeError 

21 

22""" 

23 

24import logging 

25import os 

26import pathlib 

27 

28import numpy as np 

29 

30# Housekeeping for logging 

31logger = logging.getLogger(__name__) 

32 

33 

34def get_install_dir(): 

35 """Get directory path of currently installed & imported emd.""" 

36 #return os.path.dirname(sift.__file__) 

37 return str(pathlib.Path(__file__).parent.resolve()) 

38 

39 

40def get_installed_version(): 

41 """Read version of currently installed & imported emd. 

42 

43 Version is determined according to local setup.py. If a user has made local 

44 changes this version may not be exactly the same as the online package. 

45 

46 """ 

47 # The directory containing this file 

48 #HERE = pathlib.Path(__file__).parent.parent 

49 #version = (HERE / 'emd' / '_version.py').read_text().split(' = ')[1].rstrip('\n').strip("'") 

50 from importlib.metadata import version 

51 return version('emd') 

52 

53 

54def run_tests(): 

55 """Run tests in directly from python. 

56 

57 Useful for people without a dev-install to run tests perhaps. 

58 https://docs.pytest.org/en/latest/usage.html#calling-pytest-from-python-code 

59 

60 """ 

61 import pytest 

62 inst_dir = get_install_dir() 

63 

64 if os.path.exists(os.path.join(inst_dir, 'tests')) is False: 

65 logger.info('Test directory not found in: {0}'.format(inst_dir)) 

66 logger.info('(this is normal for PyPI/pip EMD installs)') 

67 else: 

68 logger.info('Running EMD package tests from: {0}'.format(inst_dir)) 

69 out = pytest.main(['-x', inst_dir]) 

70 

71 if out.value != 0: 

72 logger.warning('EMD package tests FAILED - EMD may not behave as expected') 

73 else: 

74 logger.info('EMD package tests passed') 

75 

76 

77# Parallel processing 

78 

79def run_parallel(pfunc, args, nprocesses=1): 

80 """Run set of processes in serial or parallel.""" 

81 from joblib import Parallel, delayed 

82 

83 if nprocesses > 1: 

84 with Parallel(n_jobs=nprocesses) as parallel: 

85 res = parallel(delayed(pfunc)(*aa) for aa in args) 

86 else: 

87 res = [pfunc(*aa) for aa in args] 

88 

89 return res 

90 

91 

92# Ensurance Department 

93 

94 

95def ensure_equal_dims(to_check, names, func_name, dim=None): 

96 """Check that a set of arrays all have the same dimension. 

97 

98 Raises an error with details if not. 

99 

100 Parameters 

101 ---------- 

102 to_check : list of arrays 

103 List of arrays to check for equal dimensions 

104 names : list 

105 List of variable names for arrays in to_check 

106 func_name : str 

107 Name of function calling ensure_equal_dims 

108 dim : int 

109 Integer index of specific axes to ensure shape of, default is to compare all dims 

110 

111 Raises 

112 ------ 

113 ValueError 

114 If any of the inputs in to_check have differing shapes 

115 

116 """ 

117 if dim is None: 

118 dim = np.arange(to_check[0].ndim) 

119 else: 

120 dim = [dim] 

121 

122 all_dims = [tuple(np.array(x.shape)[dim]) for x in to_check] 

123 check = [True] + [all_dims[0] == all_dims[ii + 1] for ii in range(len(all_dims[1:]))] 

124 

125 if np.all(check) == False: # noqa: E712 

126 msg = 'Checking {0} inputs - Input dim mismatch'.format(func_name) 

127 logger.error(msg) 

128 msg = "Mismatch between inputs: " 

129 for ii in range(len(to_check)): 

130 msg += "'{0}': {1}, ".format(names[ii], to_check[ii].shape) 

131 logger.error(msg) 

132 raise ValueError(msg) 

133 

134 

135def ensure_vector(to_check, names, func_name): 

136 """Check that a set of arrays are all vectors with only 1-dimension. 

137 

138 Arrays with singleton second dimensions will be trimmed and an error will 

139 be raised for non-singleton 2d or greater than 2d inputs. 

140 

141 Parameters 

142 ---------- 

143 to_check : list of arrays 

144 List of arrays to check for equal dimensions 

145 names : list 

146 List of variable names for arrays in to_check 

147 func_name : str 

148 Name of function calling ensure_equal_dims 

149 

150 Returns 

151 ------- 

152 out 

153 Copy of arrays in to_check with 1d shape. 

154 

155 Raises 

156 ------ 

157 ValueError 

158 If any input is a 2d or greater array 

159 

160 """ 

161 out_args = list(to_check) 

162 for idx, xx in enumerate(to_check): 

163 

164 if (xx.ndim > 1) and (xx.shape[1] == 1): 

165 msg = "Checking {0} inputs - trimming singleton from input '{1}'" 

166 msg = msg.format(func_name, names[idx]) 

167 out_args[idx] = out_args[idx][:, 0] 

168 logger.warning(msg) 

169 elif (xx.ndim > 1) and (xx.shape[1] != 1): 

170 msg = "Checking {0} inputs - Input '{1}' {2} must be a vector or 2d with singleton second dim" 

171 msg = msg.format(func_name, names[idx], xx.shape) 

172 logger.error(msg) 

173 raise ValueError(msg) 

174 elif xx.ndim > 2: 

175 msg = "Checking {0} inputs - Shape of input '{1}' {2} must be a vector." 

176 msg = msg.format(func_name, names[idx], xx.shape) 

177 logger.error(msg) 

178 raise ValueError(msg) 

179 

180 if len(out_args) == 1: 

181 return out_args[0] 

182 else: 

183 return out_args 

184 

185 

186def ensure_1d_with_singleton(to_check, names, func_name): 

187 """Check that a set of arrays are all vectors with singleton second dimensions. 

188 

189 1d arrays will have a singleton second dimension added and an error will be 

190 raised for non-singleton 2d or greater than 2d inputs. 

191 

192 Parameters 

193 ---------- 

194 to_check : list of arrays 

195 List of arrays to check for equal dimensions 

196 names : list 

197 List of variable names for arrays in to_check 

198 func_name : str 

199 Name of function calling ensure_equal_dims 

200 

201 Returns 

202 ------- 

203 out 

204 Copy of arrays in to_check with '1d with singleton' shape. 

205 

206 Raises 

207 ------ 

208 ValueError 

209 If any input is a 2d or greater array 

210 

211 """ 

212 out_args = list(to_check) 

213 for idx, xx in enumerate(to_check): 

214 

215 if (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])): 

216 # nd input where all trailing are ones 

217 msg = "Checking {0} inputs - Trimming trailing singletons from input '{1}' (input size {2})" 

218 logger.debug(msg.format(func_name, names[idx], xx.shape)) 

219 out_args[idx] = np.squeeze(xx)[:, np.newaxis] 

220 elif (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])) == False: # noqa: E712 

221 # nd input where some trailing are not one 

222 msg = "Checking {0} inputs - trailing dims of input '{1}' {2} must be singletons (length=1)" 

223 logger.error(msg.format(func_name, names[idx], xx.shape)) 

224 raise ValueError(msg) 

225 elif xx.ndim == 1: 

226 # Vector input - add a dummy dimension 

227 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'" 

228 logger.debug(msg.format(func_name, names[idx])) 

229 out_args[idx] = out_args[idx][:, np.newaxis] 

230 

231 if len(out_args) == 1: 

232 return out_args[0] 

233 else: 

234 return out_args 

235 

236 

237def ensure_2d(to_check, names, func_name): 

238 """Check that a set of arrays are all arrays with 2 dimensions. 

239 

240 1d arrays will have a singleton second dimension added. 

241 

242 Parameters 

243 ---------- 

244 to_check : list of arrays 

245 List of arrays to check for equal dimensions 

246 names : list 

247 List of variable names for arrays in to_check 

248 func_name : str 

249 Name of function calling ensure_equal_dims 

250 

251 Returns 

252 ------- 

253 out 

254 Copy of arrays in to_check with 2d shape. 

255 

256 """ 

257 out_args = list(to_check) 

258 for idx in range(len(to_check)): 

259 

260 if to_check[idx].ndim == 1: 

261 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'" 

262 logger.debug(msg.format(func_name, names[idx])) 

263 out_args[idx] = out_args[idx][:, np.newaxis] 

264 

265 if len(out_args) == 1: 

266 return out_args[0] 

267 else: 

268 return out_args 

269 

270 

271# Exceptions & Errors 

272 

273class EMDSiftCovergeError(Exception): 

274 """Exception raised for errors in the input. 

275 

276 Attributes 

277 ---------- 

278 expression -- input expression in which the error occurred 

279 message -- explanation of the error 

280 

281 """ 

282 

283 def __init__(self, message): 

284 """Raise error indicating that sift has failed to converge.""" 

285 self.message = message 

286 logger.exception(self.message)