Coverage for emd/support.py: 59%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.11, created at 2025-03-08 15:44 +0000

1#!/usr/bin/python 

2 

3# vim: set expandtab ts=4 sw=4: 

4 

5""" 

6Helper functions for interacting with an EMD install and ensuring array sizes. 

7 

8Main Routines: 

9 get_install_dir 

10 get_installed_version 

11 run_tests 

12 

13Ensurance Routines: 

14 ensure_equal_dims 

15 ensure_vector 

16 ensure_1d_with_singleton 

17 ensure_2d 

18 

19Errors: 

20 EMDSiftCovergeError 

21 

22""" 

23 

24import logging 

25import os 

26import pathlib 

27 

28import numpy as np 

29 

30# Housekeeping for logging 

31logger = logging.getLogger(__name__) 

32 

33 

34def get_install_dir(): 

35 """Get directory path of currently installed & imported emd.""" 

36 #return os.path.dirname(sift.__file__) 

37 return str(pathlib.Path(__file__).parent.resolve()) 

38 

39 

40def get_installed_version(): 

41 """Read version of currently installed & imported emd. 

42 

43 Version is determined according to local setup.py. If a user has made local 

44 changes this version may not be exactly the same as the online package. 

45 

46 """ 

47 # The directory containing this file 

48 #HERE = pathlib.Path(__file__).parent.parent 

49 #version = (HERE / 'emd' / '_version.py').read_text().split(' = ')[1].rstrip('\n').strip("'") 

50 from importlib.metadata import version 

51 return version('emd') 

52 

53 

54def create_readthedocs_requirements(): 

55 import tomllib 

56 with open(pathlib.Path(get_install_dir()).parent / 'pyproject.toml', "rb") as f: 

57 data = tomllib.load(f) 

58 

59 with open(pathlib.Path(get_install_dir()).parent / 'requirements_rtd.txt', 'w') as f: 

60 for dep in data['project']['dependencies']: 

61 f.write(dep + '\n') 

62 for opt in data['project']['optional-dependencies']: 

63 if opt == 'full': 

64 continue 

65 for dep in data['project']['optional-dependencies'][opt]: 

66 f.write(dep + '\n') 

67 f.write('-e .\n') 

68 

69 

70def run_tests(): 

71 """Run tests in directly from python. 

72 

73 Useful for people without a dev-install to run tests perhaps. 

74 https://docs.pytest.org/en/latest/usage.html#calling-pytest-from-python-code 

75 

76 """ 

77 import pytest 

78 inst_dir = get_install_dir() 

79 

80 if os.path.exists(os.path.join(inst_dir, 'tests')) is False: 

81 logger.info('Test directory not found in: {0}'.format(inst_dir)) 

82 logger.info('(this is normal for PyPI/pip EMD installs)') 

83 else: 

84 logger.info('Running EMD package tests from: {0}'.format(inst_dir)) 

85 out = pytest.main(['-x', inst_dir]) 

86 

87 if out.value != 0: 

88 logger.warning('EMD package tests FAILED - EMD may not behave as expected') 

89 else: 

90 logger.info('EMD package tests passed') 

91 

92 

93# Parallel processing 

94 

95def run_parallel(pfunc, args, nprocesses=1): 

96 """Run set of processes in serial or parallel.""" 

97 from joblib import Parallel, delayed 

98 

99 if nprocesses > 1: 

100 with Parallel(n_jobs=nprocesses) as parallel: 

101 res = parallel(delayed(pfunc)(*aa) for aa in args) 

102 else: 

103 res = [pfunc(*aa) for aa in args] 

104 

105 return res 

106 

107 

108# Ensurance Department 

109 

110 

111def ensure_equal_dims(to_check, names, func_name, dim=None): 

112 """Check that a set of arrays all have the same dimension. 

113 

114 Raises an error with details if not. 

115 

116 Parameters 

117 ---------- 

118 to_check : list of arrays 

119 List of arrays to check for equal dimensions 

120 names : list 

121 List of variable names for arrays in to_check 

122 func_name : str 

123 Name of function calling ensure_equal_dims 

124 dim : int 

125 Integer index of specific axes to ensure shape of, default is to compare all dims 

126 

127 Raises 

128 ------ 

129 ValueError 

130 If any of the inputs in to_check have differing shapes 

131 

132 """ 

133 if dim is None: 

134 dim = np.arange(to_check[0].ndim) 

135 else: 

136 dim = [dim] 

137 

138 all_dims = [tuple(np.array(x.shape)[dim]) for x in to_check] 

139 check = [True] + [all_dims[0] == all_dims[ii + 1] for ii in range(len(all_dims[1:]))] 

140 

141 if np.all(check) == False: # noqa: E712 

142 msg = 'Checking {0} inputs - Input dim mismatch'.format(func_name) 

143 logger.error(msg) 

144 msg = "Mismatch between inputs: " 

145 for ii in range(len(to_check)): 

146 msg += "'{0}': {1}, ".format(names[ii], to_check[ii].shape) 

147 logger.error(msg) 

148 raise ValueError(msg) 

149 

150 

151def ensure_vector(to_check, names, func_name): 

152 """Check that a set of arrays are all vectors with only 1-dimension. 

153 

154 Arrays with singleton second dimensions will be trimmed and an error will 

155 be raised for non-singleton 2d or greater than 2d inputs. 

156 

157 Parameters 

158 ---------- 

159 to_check : list of arrays 

160 List of arrays to check for equal dimensions 

161 names : list 

162 List of variable names for arrays in to_check 

163 func_name : str 

164 Name of function calling ensure_equal_dims 

165 

166 Returns 

167 ------- 

168 out 

169 Copy of arrays in to_check with 1d shape. 

170 

171 Raises 

172 ------ 

173 ValueError 

174 If any input is a 2d or greater array 

175 

176 """ 

177 out_args = list(to_check) 

178 for idx, xx in enumerate(to_check): 

179 

180 if (xx.ndim > 1) and (xx.shape[1] == 1): 

181 msg = "Checking {0} inputs - trimming singleton from input '{1}'" 

182 msg = msg.format(func_name, names[idx]) 

183 out_args[idx] = out_args[idx][:, 0] 

184 logger.warning(msg) 

185 elif (xx.ndim > 1) and (xx.shape[1] != 1): 

186 msg = "Checking {0} inputs - Input '{1}' {2} must be a vector or 2d with singleton second dim" 

187 msg = msg.format(func_name, names[idx], xx.shape) 

188 logger.error(msg) 

189 raise ValueError(msg) 

190 elif xx.ndim > 2: 

191 msg = "Checking {0} inputs - Shape of input '{1}' {2} must be a vector." 

192 msg = msg.format(func_name, names[idx], xx.shape) 

193 logger.error(msg) 

194 raise ValueError(msg) 

195 

196 if len(out_args) == 1: 

197 return out_args[0] 

198 else: 

199 return out_args 

200 

201 

202def ensure_1d_with_singleton(to_check, names, func_name): 

203 """Check that a set of arrays are all vectors with singleton second dimensions. 

204 

205 1d arrays will have a singleton second dimension added and an error will be 

206 raised for non-singleton 2d or greater than 2d inputs. 

207 

208 Parameters 

209 ---------- 

210 to_check : list of arrays 

211 List of arrays to check for equal dimensions 

212 names : list 

213 List of variable names for arrays in to_check 

214 func_name : str 

215 Name of function calling ensure_equal_dims 

216 

217 Returns 

218 ------- 

219 out 

220 Copy of arrays in to_check with '1d with singleton' shape. 

221 

222 Raises 

223 ------ 

224 ValueError 

225 If any input is a 2d or greater array 

226 

227 """ 

228 out_args = list(to_check) 

229 for idx, xx in enumerate(to_check): 

230 

231 if (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])): 

232 # nd input where all trailing are ones 

233 msg = "Checking {0} inputs - Trimming trailing singletons from input '{1}' (input size {2})" 

234 logger.debug(msg.format(func_name, names[idx], xx.shape)) 

235 out_args[idx] = np.squeeze(xx)[:, np.newaxis] 

236 elif (xx.ndim >= 2) and np.all(xx.shape[1:] == np.ones_like(xx.shape[1:])) == False: # noqa: E712 

237 # nd input where some trailing are not one 

238 msg = "Checking {0} inputs - trailing dims of input '{1}' {2} must be singletons (length=1)" 

239 logger.error(msg.format(func_name, names[idx], xx.shape)) 

240 raise ValueError(msg) 

241 elif xx.ndim == 1: 

242 # Vector input - add a dummy dimension 

243 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'" 

244 logger.debug(msg.format(func_name, names[idx])) 

245 out_args[idx] = out_args[idx][:, np.newaxis] 

246 

247 if len(out_args) == 1: 

248 return out_args[0] 

249 else: 

250 return out_args 

251 

252 

253def ensure_2d(to_check, names, func_name): 

254 """Check that a set of arrays are all arrays with 2 dimensions. 

255 

256 1d arrays will have a singleton second dimension added. 

257 

258 Parameters 

259 ---------- 

260 to_check : list of arrays 

261 List of arrays to check for equal dimensions 

262 names : list 

263 List of variable names for arrays in to_check 

264 func_name : str 

265 Name of function calling ensure_equal_dims 

266 

267 Returns 

268 ------- 

269 out 

270 Copy of arrays in to_check with 2d shape. 

271 

272 """ 

273 out_args = list(to_check) 

274 for idx in range(len(to_check)): 

275 

276 if to_check[idx].ndim == 1: 

277 msg = "Checking {0} inputs - Adding dummy dimension to input '{1}'" 

278 logger.debug(msg.format(func_name, names[idx])) 

279 out_args[idx] = out_args[idx][:, np.newaxis] 

280 

281 if len(out_args) == 1: 

282 return out_args[0] 

283 else: 

284 return out_args 

285 

286 

287# Exceptions & Errors 

288 

289class EMDSiftCovergeError(Exception): 

290 """Exception raised for errors in the input. 

291 

292 Attributes 

293 ---------- 

294 expression -- input expression in which the error occurred 

295 message -- explanation of the error 

296 

297 """ 

298 

299 def __init__(self, message): 

300 """Raise error indicating that sift has failed to converge.""" 

301 self.message = message 

302 logger.exception(self.message)