Coverage for /Users/Newville/Codes/xraylarch/larch/io/rixsdata.py: 0%

192 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3""" 

4RIXS data object 

5================ 

6""" 

7import numpy as np 

8import copy 

9import time 

10from itertools import cycle 

11from scipy.interpolate import griddata 

12from silx.io.dictdump import dicttoh5, h5todict 

13 

14from larch.math.gridxyz import gridxyz 

15from larch.xafs.xafsutils import guess_energy_units 

16from larch.utils.logging import getLogger 

17 

18_logger = getLogger(__name__) #: module logger 

19_logger.setLevel("INFO") 

20 

21 

22def _tostr(arr): 

23 """Numpy array to string""" 

24 try: 

25 return np.array_str(arr) 

26 except Exception: 

27 return arr 

28 

29 

30def _restore_from_array(dictin): 

31 """restore str/float from a nested dictionary of numpy.ndarray (when using silx.io.dictdump.h5todict) 

32 

33 Note: discussed here https://github.com/silx-kit/silx/issues/3633 

34 """ 

35 for k, v in dictin.items(): 

36 if isinstance(v, dict): 

37 _restore_from_array(v) 

38 else: 

39 if isinstance(v[()], np.str_): 

40 dictin[k] = np.array_str(v) 

41 if isinstance(v[()], (np.float64, np.float32)): 

42 dictin[k] = copy.deepcopy(v.item()) 

43 

44 

45class CycleColors: 

46 """Utility for setting the line colors of the RIXS map cuts""" 

47 

48 DEFAULT_PALETTE = ( 

49 "#1F77B4", 

50 "#AEC7E8", 

51 "#FF7F0E", 

52 "#FFBB78", 

53 "#2CA02C", 

54 "#98DF8A", 

55 "#D62728", 

56 "#FF9896", 

57 "#9467BD", 

58 "#C5B0D5", 

59 "#8C564B", 

60 "#C49C94", 

61 "#E377C2", 

62 "#F7B6D2", 

63 "#7F7F7F", 

64 "#C7C7C7", 

65 "#BCBD22", 

66 "#DBDB8D", 

67 "#17BECF", 

68 "#9EDAE5", 

69 ) 

70 

71 def __init__(self) -> None: 

72 self.colors = cycle(self.DEFAULT_PALETTE) 

73 

74 def get_color(self) -> None: 

75 return next(self.colors) 

76 

77 

78class RixsData(object): 

79 """RIXS plane object""" 

80 

81 def __init__(self, name=None, logger=None): 

82 """Constructor""" 

83 

84 self.__name__ = "RixsData_{0}".format(hex(id(self))) 

85 self.name = name or self.__name__ 

86 self.label = self.name 

87 

88 self._logger = logger or _logger 

89 self._palette = CycleColors() 

90 self._no_save = ("_logger", "_palette") 

91 

92 self.sample_name = "UnknownSample" 

93 self.counter_all, self.counter_signal, self.counter_norm = None, None, None 

94 self._x, self._y, self._z = None, None, None 

95 self.ene_in, self.ene_out, self.rixs_map = None, None, None 

96 self.ene_et, self.rixs_et_map = None, None 

97 self.ene_grid, self.ene_unit, self.grid_method = None, None, None 

98 self.line_cuts = {} 

99 self.datatype = "rixs" 

100 

101 def set_energy_unit(self, unit=None): 

102 """set the energy unit to eV""" 

103 if unit is not None: 

104 self.ene_unit = unit 

105 if self.ene_unit is None: 

106 self.ene_unit = guess_energy_units(self._x) 

107 if self.ene_unit == "keV": 

108 self._logger.info(f"Energy unit is {self.ene_unit} -> converting to eV") 

109 self._x *= 1000 

110 self._y *= 1000 

111 self.ene_grid = 0.1 

112 self.reset() 

113 self.ene_unit = "eV" 

114 assert ( 

115 self.ene_unit == "eV" 

116 ), f"energy unit is {self.set_energy_unit} -> must be eV" 

117 

118 def load_from_dict(self, rxdict): 

119 """Load RIXS data from a dictionary 

120 

121 Parameters 

122 ---------- 

123 rxdict : dict 

124 Minimal required structure 

125 { 

126 'writer_version': '1.5.0', 

127 'sample_name': str, 

128 '_x': 1D array, #: energy in 

129 '_y': 1D array, #: energy out 

130 '_z': 1D array, #: signal 

131 } 

132 

133 Return 

134 ------ 

135 None, set attributes: self.* 

136 """ 

137 self.__dict__.update(rxdict) 

138 self.set_energy_unit() 

139 self.grid_rixs_from_col() 

140 

141 def load_from_h5(self, filename): 

142 """Load RIXS from HDF5 file""" 

143 rxdict = h5todict(filename) 

144 _restore_from_array(rxdict) 

145 if not ("writer_version" in rxdict.keys()): 

146 self._logger.error("Key 'writer_version' not found") 

147 return 

148 if not ("1.5" in _tostr(rxdict["writer_version"])): 

149 self._logger.warning("Data format not understood") 

150 return 

151 self.load_from_dict(rxdict) 

152 self._logger.info("RIXS map loaded from file: {0}".format(filename)) 

153 

154 def load_from_ascii(self, filename, **kws): 

155 """load data from a 3 columns ASCII file assuming the format: 

156 

157 e_in(eV), e_out(eV), signal 

158 

159 """ 

160 

161 try: 

162 dat = np.loadtxt(filename) 

163 self.filename = filename 

164 self._logger.info("Loaded {0}".format(filename)) 

165 except Exception: 

166 self._logger.error("Cannot load from {0}".format(filename)) 

167 return 

168 

169 self._x = dat[:, 0] 

170 self._y = dat[:, 1] 

171 self._z = dat[:, 2] 

172 

173 self.set_energy_unit() 

174 self.reset() 

175 

176 def save_to_h5(self, filename=None): 

177 """Dump dictionary representation to HDF5 file""" 

178 if filename is None: 

179 filename = f"{self.filename.split('.')[0]}.h5" 

180 save_dict = copy.deepcopy(self.__dict__) 

181 for dkey in self._no_save: 

182 try: 

183 del save_dict[dkey] 

184 except KeyError: 

185 continue 

186 dicttoh5(save_dict, filename, update_mode="replace") 

187 self._logger.info(f"{self.name} saved to {filename}") 

188 

189 def crop(self, crop_area, et=None): 

190 """Crop the plane in a given range 

191 

192 Parameters 

193 ---------- 

194 

195 crop_area : tuple 

196 (x1, y1, x2, y2) : floats 

197 x1 < x2 (ene_in) 

198 y1 < y2 (if yet=False: ene_out, else: ene_et) 

199 

200 et: bool, 

201 if True: y1, y2 are given in energy transfer 

202 

203 """ 

204 self._crop_area = crop_area 

205 x1, y1, x2, y2 = crop_area 

206 assert x1 < x2, "wrong crop area, x1 >= x2" 

207 assert y1 < y2, "wrong crop area, y1 >= y2" 

208 

209 if et is None: 

210 if y2 < np.max(self.ene_et): 

211 self._logger.debug("crop in energy transfer") 

212 et = True 

213 else: 

214 self._logger.debug("crop in emission energy") 

215 et = False 

216 

217 _xystep = self.ene_grid or 0.1 

218 _method = self.grid_method or "linear" 

219 

220 _nxpts = int((x2 - x1) / _xystep) 

221 _xcrop = np.linspace(x1, x2, num=_nxpts) 

222 

223 if et: 

224 _etmin = y1 

225 _etmax = y2 

226 _ymin = x1 - _etmax 

227 _ymax = x2 - _etmin 

228 self._logger.debug(f"-> emission range: {_ymin:.2f}:{_ymax:.2f}") 

229 else: 

230 _ymin = y1 

231 _ymax = y2 

232 _etmin = x2 - _ymax 

233 _etmax = x1 - _ymin 

234 self._logger.debug(f"-> et range: {_etmin:.2f}:{_etmax:.2f}") 

235 

236 _netpts = int((_etmax - _etmin) / _xystep) 

237 _nypts = int((_ymax - _ymin) / _xystep) 

238 _etcrop = np.linspace(_etmin, _etmax, num=_netpts) 

239 _ycrop = np.linspace(_ymin, _ymax, num=_nypts) 

240 

241 

242 _xx, _yy = np.meshgrid(_xcrop, _ycrop) 

243 _exx, _et = np.meshgrid(_xcrop, _etcrop) 

244 self._logger.info("Gridding data...") 

245 _zzcrop = griddata((self._x, self._y), self._z, (_xx, _yy), method=_method) 

246 _ezzcrop = griddata( 

247 (self._x, self._x - self._y), self._z, (_exx, _et), method=_method 

248 ) 

249 

250 self.ene_in = _xcrop 

251 self.ene_out = _ycrop 

252 self.ene_et = _etcrop 

253 self.rixs_map = _zzcrop 

254 self.rixs_et_map = _ezzcrop 

255 self.label = f"{self.name} [{self._crop_area}]" 

256 

257 def reset(self, **grid_kws): 

258 """resets to initial data""" 

259 self._logger.info("resetting to initial data (grid RIXS plane and line cuts)") 

260 self.grid_rixs_from_col(**grid_kws) 

261 self.line_cuts = {} 

262 self.label = self.name 

263 self._palette = None 

264 self._palette = CycleColors() 

265 

266 def grid_rixs_from_col(self, ene_grid=None, grid_method=None): 

267 """Grid RIXS map from XYZ columns""" 

268 if ene_grid is not None: 

269 self.ene_grid = ene_grid 

270 if grid_method is not None: 

271 self.grid_method = grid_method 

272 _xystep = self.ene_grid or 0.1 

273 _method = self.grid_method or "linear" 

274 self.ene_in, self.ene_out, self.rixs_map = gridxyz( 

275 self._x, self._y, self._z, xystep=_xystep, method=_method 

276 ) 

277 self._et = self._x - self._y 

278 _, self.ene_et, self.rixs_et_map = gridxyz( 

279 self._x, self._et, self._z, xystep=_xystep, method=_method 

280 ) 

281 self.ene_grid = _xystep 

282 

283 def cut(self, energy=None, mode="CEE", label=None): 

284 """cut the RIXS plane at a given energy 

285 

286 Parameters 

287 ---------- 

288 energy : float 

289 energy of the cut 

290 

291 mode : str 

292 defines the way to cut the plane: 

293 - "CEE" (constant emission energy) 

294 - "CIE" (constant incident energy) 

295 - "CET" (constant energy transfer) 

296 

297 label : str, optional [None] 

298 custom label, if None: label = 'mode_enecut' 

299 

300 Return 

301 ------ 

302 None -> adds dict(x:array, y:array, info:dict) to self.lcuts[cut_key]:dict, where 

303 

304 info = {label: str, #: 'mode_enecut' 

305 mode: str, #: as input 

306 enecut: float, #: energy cut given from the initial interpolation 

307 datatype: str, #: 'xas' or 'xes' 

308 color: str, #: color from a common palette 

309 timestamp: str, #: time stamp 

310 } 

311 """ 

312 assert energy is not None, "The energy of the cut must be given" 

313 

314 mode = mode.upper() 

315 

316 if mode == "CEE": 

317 xc = self.ene_in 

318 iy = np.abs(self.ene_out - energy).argmin() 

319 enecut = self.ene_out[iy] 

320 yc = self.rixs_map[iy, :] 

321 datatype = "xas" 

322 elif mode == "CIE": 

323 xc = self.ene_out 

324 iy = np.abs(self.ene_in - energy).argmin() 

325 enecut = self.ene_in[iy] 

326 yc = self.rixs_map[:, iy] 

327 datatype = "xes" 

328 elif mode == "CET": 

329 xc = self.ene_in 

330 iy = np.abs(self.ene_et - energy).argmin() 

331 enecut = self.ene_et[iy] 

332 yc = self.rixs_et_map[iy, :] 

333 datatype = "xas" 

334 else: 

335 self._logger.error(f"wrong mode: {mode}") 

336 return 

337 

338 if label is None: 

339 label = f"{mode}_{enecut:.1f}" 

340 

341 info = dict( 

342 label=label, 

343 mode=mode, 

344 enecut=enecut, 

345 datatype=datatype, 

346 color=self._palette.get_color(), 

347 timestamp="{0:04d}-{1:02d}-{2:02d}_{3:02d}{4:02d}".format( 

348 *time.localtime() 

349 ), 

350 ) 

351 

352 cut_key = f"{mode}_{enecut*10:.0f}" 

353 self.line_cuts[cut_key] = dict(x=xc, y=yc, info=info) 

354 self._logger.info(f"added RIXS {mode} cut: '{label}'") 

355 

356 def norm(self): 

357 """Simple map normalization to max-min""" 

358 self.rixs_map = self.rixs_map / ( 

359 np.nanmax(self.rixs_map) - np.nanmin(self.rixs_map) 

360 ) 

361 self._logger.info("rixs map normalized to max-min") 

362 

363 

364if __name__ == "__main__": 

365 pass