Coverage for /Users/Newville/Codes/xraylarch/larch/io/mergegroups.py: 12%

146 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2""" 

3merge groups, interpolating if necessary 

4""" 

5import os 

6import numpy as np 

7from larch import Group 

8from larch.math import interp, interp1d, index_of, remove_dups 

9from larch.utils.logging import getLogger 

10 

11_logger = getLogger("larch.io.mergegroups") 

12 

13def merge_groups(grouplist, master=None, xarray='energy', yarray='mu', 

14 kind='cubic', trim=True, calc_yerr=True): 

15 """merge arrays from a list of groups. 

16 

17 Arguments 

18 --------- 

19 grouplist list of groups to merge 

20 master group to use for common x arrary [None -> 1st group] 

21 xarray name of x-array for merge ['energy'] 

22 yarray name of y-array for merge ['mu'] 

23 kind interpolation kind ['cubic'] 

24 trim whether to trim to the shortest energy range [True] 

25 calc_yerr whether to use the variance in the input as yerr [True] 

26 

27 Returns 

28 -------- 

29 group with x-array and y-array containing merged data. 

30 

31 """ 

32 if master is None: 

33 master = grouplist[0] 

34 

35 xout = remove_dups(getattr(master, xarray)) 

36 xmins = [min(xout)] 

37 xmaxs = [max(xout)] 

38 yvals = [] 

39 

40 for g in grouplist: 

41 x = getattr(g, xarray) 

42 y = getattr(g, yarray) 

43 yvals.append(interp(x, y, xout, kind=kind)) 

44 xmins.append(min(x)) 

45 xmaxs.append(max(x)) 

46 

47 yvals = np.array(yvals) 

48 yave = yvals.mean(axis=0) 

49 ystd = yvals.std(axis=0) 

50 

51 if trim: 

52 xmin = min(xmins) 

53 xmax = min(xmaxs) 

54 ixmin = index_of(xout, xmin) 

55 ixmax = index_of(xout, xmax) 

56 xout = xout[ixmin:ixmax] 

57 yave = yave[ixmin:ixmax] 

58 ystd = ystd[ixmin:ixmax] 

59 

60 grp = Group() 

61 setattr(grp, xarray, xout) 

62 setattr(grp, yarray, yave) 

63 setattr(grp, yarray + '_std', ystd) 

64 

65 if kind == 'cubic': 

66 y0 = getattr(master, yarray) 

67 # if the derivative gets much worse, use linear interpolation 

68 if max(np.diff(yave)) > 50*max(np.diff(y0)): 

69 grp = merge_groups(grouplist, master=master, xarray=xarray, 

70 yarray=yarray, trim=trim, 

71 calc_yerr=calc_yerr, kind='linear') 

72 return grp 

73 

74def imin(arr, debug=False): 

75 """index of minimum value""" 

76 _im = np.argmin(arr) 

77 if debug: 

78 _logger.debug("Check: {0} = {1}".format(np.min(arr), arr[_im])) 

79 return _im 

80 

81 

82def imax(arr, debug=False): 

83 """index of maximum value""" 

84 _im = np.argmax(arr) 

85 if debug: 

86 _logger.debug("Check: {0} = {1}".format(np.max(arr), arr[_im])) 

87 return _im 

88 

89 

90def lists_to_matrix(data, axis=None, **kws): 

91 """Convert two lists of 1D arrays to a 2D matrix 

92 

93 Parameters 

94 ---------- 

95 data : list of lists of 1D arrays 

96 [ 

97 [x1, ... xN] 

98 [z1, ... zN] 

99 ] 

100 axis : None or array 1D, optional 

101 a reference axis used for the interpolation [None -> xdats[0]] 

102 **kws : optional 

103 keyword arguments for scipy.interpolate.interp1d() 

104 

105 Returns 

106 ------- 

107 axis, outmat : arrays 

108 """ 

109 assert len(data) == 2, "'data' should be a list of two lists" 

110 xdats, zdats = data 

111 assert isinstance(xdats, list), "'xdats' not a list" 

112 assert isinstance(zdats, list), "'zdats' not a list" 

113 assert len(xdats) == len(zdats), "lists of data not of the same length" 

114 assert all(isinstance(z, np.ndarray) for z in zdats), "data in list must be arrays" 

115 if axis is None: 

116 axis = xdats[0] 

117 assert isinstance(axis, np.ndarray), "axis must be array" 

118 if all(z.size == axis.size for z in zdats): 

119 #: all same size 

120 return axis, np.array(zdats) 

121 else: 

122 #: interpolate 

123 outmat = np.zeros((len(zdats), axis.size)) 

124 for idat, (xdat, zdat) in enumerate(zip(xdats, zdats)): 

125 fdat = interp1d(xdat, zdat, **kws) 

126 znew = fdat(axis) 

127 outmat[idat] = znew 

128 return axis, outmat 

129 

130 

131def curves_to_matrix(curves, axis=None, **kws): 

132 """Convert a list of curves to a 2D data matrix 

133 

134 Parameters 

135 ---------- 

136 curves : list of lists 

137 Curves format is the following: 

138 [ 

139 [x1, y1, label1, info1], 

140 ... 

141 [xN, yN, labelN, infoN] 

142 ] 

143 axis : None or array 1D, optional 

144 a reference axis used for the interpolation [None -> curves[0][0]] 

145 **kws : optional 

146 keyword arguments for func:`scipy.interpolate.interp1d` 

147 

148 Returns 

149 ------- 

150 axis, outmat : arrays 

151 """ 

152 assert isinstance(curves, list), "'curves' not a list" 

153 assert all( 

154 (isinstance(curve, list) and len(curve) == 4) for curve in curves 

155 ), "curves should be lists of four elements" 

156 if axis is None: 

157 axis = curves[0][0] 

158 assert isinstance(axis, np.ndarray), "axis must be array" 

159 outmat = np.zeros((len(curves), axis.size)) 

160 for icurve, curve in enumerate(curves): 

161 assert len(curve) == 4, "wrong curve format, should contain four elements" 

162 x, y, label, info = curve 

163 try: 

164 assert isinstance(x, np.ndarray), "not array!" 

165 assert isinstance(y, np.ndarray), "not array!" 

166 except AssertionError: 

167 _logger.error( 

168 "[curve_to_matrix] Curve %d (%s) not containing arrays -> ADDING ZEROS", 

169 icurve, 

170 label, 

171 ) 

172 continue 

173 if (x.size == axis.size) and (y.size == axis.size): 

174 #: all same length 

175 outmat[icurve] = y 

176 else: 

177 #: interpolate 

178 fdat = interp1d(x, y, **kws) 

179 ynew = fdat(axis) 

180 outmat[icurve] = ynew 

181 _logger.debug("[curve_to_matrix] Curve %d (%s) interpolated", icurve, label) 

182 return axis, outmat 

183 

184 

185def sum_arrays_1d(data, axis=None, **kws): 

186 """Sum list of 1D arrays or curves by interpolation on a reference axis 

187 

188 Parameters 

189 ---------- 

190 data : lists of lists 

191 data_fmt : str 

192 define data format 

193 - "curves" -> :func:`curves_to_matrix` 

194 - "lists" -> :func:`curves_to_matrix` 

195 

196 Returns 

197 ------- 

198 axis, zsum : 1D arrays 

199 """ 

200 data_fmt = kws.pop("data_fmt", "curves") 

201 if data_fmt == "curves": 

202 ax, mat = curves_to_matrix(data, axis=axis, **kws) 

203 elif data_fmt == "lists": 

204 ax, mat = lists_to_matrix(data, axis=axis, **kws) 

205 else: 

206 raise NameError("'data_fmt' not understood") 

207 return ax, np.sum(mat, 0) 

208 

209 

210def avg_arrays_1d(data, axis=None, weights=None, **kws): 

211 """Average list of 1D arrays or curves by interpolation on a reference axis 

212 

213 Parameters 

214 ---------- 

215 data : lists of lists 

216 data_fmt : str 

217 define data format 

218 - "curves" -> :func:`curves_to_matrix` 

219 - "lists" -> :func:`lists_to_matrix` 

220 weights : None or array 

221 weights for the average 

222 

223 Returns 

224 ------- 

225 axis, zavg : 1D arrays 

226 np.average(zdats) 

227 """ 

228 data_fmt = kws.pop("data_fmt", "curves") 

229 if data_fmt == "curves": 

230 ax, mat = curves_to_matrix(data, axis=axis, **kws) 

231 elif data_fmt == "lists": 

232 ax, mat = lists_to_matrix(data, axis=axis, **kws) 

233 else: 

234 raise NameError("'data_fmt' not understood") 

235 return ax, np.average(mat, axis=0, weights=weights) 

236 

237 

238def merge_arrays_1d(data, method="average", axis=None, weights=None, **kws): 

239 """Merge a list of 1D arrays by interpolation on a reference axis 

240 

241 Parameters 

242 ---------- 

243 data : lists of lists 

244 data_fmt : str 

245 define data format 

246 - "curves" -> :func:`curves_to_matrix` 

247 - "lists" -> :func:`curves_to_matrix` 

248 axis : None or array 1D, optional 

249 a reference axis used for the interpolation [None -> xdats[0]] 

250 method : str, optional 

251 method used to merge, available methods are: 

252 - "average" : uses np.average() 

253 - "sum" : uses np.sum() 

254 weights : None or array 1D, optional 

255 used if method == "average" 

256 

257 Returns 

258 ------- 

259 axis, zmrg : 1D arrays 

260 merge(zdats) 

261 """ 

262 if method == "sum": 

263 return sum_arrays_1d(data, axis=axis, **kws) 

264 elif method == "average": 

265 return avg_arrays_1d(data, axis=axis, weights=weights, **kws) 

266 else: 

267 raise NameError("wrong 'method': %s" % method) 

268 

269 

270def rebin_piecewise_constant(x1, y1, x2): 

271 """Rebin histogram values y1 from old bin edges x1 to new edges x2. 

272 

273 Code taken from: https://github.com/jhykes/rebin/blob/master/rebin.py 

274 

275 It follows the procedure described in Figure 18.13 (chapter 18.IV.B. 

276 Spectrum Alignment, page 703) of Knoll [1] 

277 

278 References 

279 ---------- 

280 [1] Glenn Knoll, Radiation Detection and Measurement, third edition, 

281 Wiley, 2000. 

282 

283 Parameters 

284 ---------- 

285 - x1 : m+1 array of old bin edges. 

286 - y1 : m array of old histogram values. 

287 This is the total number in each bin, not an average. 

288 - x2 : n+1 array of new bin edges. 

289 

290 Returns 

291 ------- 

292 - y2 : n array of rebinned histogram values. 

293 """ 

294 x1 = np.asarray(x1) 

295 y1 = np.asarray(y1) 

296 x2 = np.asarray(x2) 

297 

298 # the fractional bin locations of the new bins in the old bins 

299 i_place = np.interp(x2, x1, np.arange(len(x1))) 

300 

301 cum_sum = np.r_[[0], np.cumsum(y1)] 

302 

303 # calculate bins where lower and upper bin edges span 

304 # greater than or equal to one original bin. 

305 # This is the contribution from the 'intact' bins (not including the 

306 # fractional start and end parts. 

307 whole_bins = np.floor(i_place[1:]) - np.ceil(i_place[:-1]) >= 1.0 

308 start = cum_sum[np.ceil(i_place[:-1]).astype(int)] 

309 finish = cum_sum[np.floor(i_place[1:]).astype(int)] 

310 

311 y2 = np.where(whole_bins, finish - start, 0.0) 

312 

313 bin_loc = np.clip(np.floor(i_place).astype(int), 0, len(y1) - 1) 

314 

315 # fractional contribution for bins where the new bin edges are in the same 

316 # original bin. 

317 same_cell = np.floor(i_place[1:]) == np.floor(i_place[:-1]) 

318 frac = i_place[1:] - i_place[:-1] 

319 contrib = frac * y1[bin_loc[:-1]] 

320 y2 += np.where(same_cell, contrib, 0.0) 

321 

322 # fractional contribution for bins where the left and right bin edges are in 

323 # different original bins. 

324 different_cell = np.floor(i_place[1:]) > np.floor(i_place[:-1]) 

325 frac_left = np.ceil(i_place[:-1]) - i_place[:-1] 

326 contrib = frac_left * y1[bin_loc[:-1]] 

327 

328 frac_right = i_place[1:] - np.floor(i_place[1:]) 

329 contrib += frac_right * y1[bin_loc[1:]] 

330 

331 y2 += np.where(different_cell, contrib, 0.0) 

332 

333 return y2 

334 

335 

336def reject_outliers(data, m=5.189, return_ma=False): 

337 """Reject outliers 

338 

339 Modified from: https://stackoverflow.com/questions/11686720/is-there-a-numpy-builtin-to-reject-outliers-from-a-list 

340 See also: https://www.itl.nist.gov/div898/handbook/eda/section3/eda35h.htm 

341 """ 

342 if not isinstance(data, np.ndarray): 

343 data = np.array(data) 

344 d = np.abs(data - np.median(data)) 

345 mdev = np.median(d) 

346 s = d / (mdev if mdev else 1.0) 

347 mask = s < m 

348 if return_ma: 

349 imask = s > m 

350 return np.ma.masked_array(data=data, mask=imask) 

351 else: 

352 return data[mask]