Coverage for /Users/Newville/Codes/xraylarch/larch/io/rixsdata.py: 0%
192 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4RIXS data object
5================
6"""
7import numpy as np
8import copy
9import time
10from itertools import cycle
11from scipy.interpolate import griddata
12from silx.io.dictdump import dicttoh5, h5todict
14from larch.math.gridxyz import gridxyz
15from larch.xafs.xafsutils import guess_energy_units
16from larch.utils.logging import getLogger
18_logger = getLogger(__name__) #: module logger
19_logger.setLevel("INFO")
22def _tostr(arr):
23 """Numpy array to string"""
24 try:
25 return np.array_str(arr)
26 except Exception:
27 return arr
30def _restore_from_array(dictin):
31 """restore str/float from a nested dictionary of numpy.ndarray (when using silx.io.dictdump.h5todict)
33 Note: discussed here https://github.com/silx-kit/silx/issues/3633
34 """
35 for k, v in dictin.items():
36 if isinstance(v, dict):
37 _restore_from_array(v)
38 else:
39 if isinstance(v[()], np.str_):
40 dictin[k] = np.array_str(v)
41 if isinstance(v[()], (np.float64, np.float32)):
42 dictin[k] = copy.deepcopy(v.item())
45class CycleColors:
46 """Utility for setting the line colors of the RIXS map cuts"""
48 DEFAULT_PALETTE = (
49 "#1F77B4",
50 "#AEC7E8",
51 "#FF7F0E",
52 "#FFBB78",
53 "#2CA02C",
54 "#98DF8A",
55 "#D62728",
56 "#FF9896",
57 "#9467BD",
58 "#C5B0D5",
59 "#8C564B",
60 "#C49C94",
61 "#E377C2",
62 "#F7B6D2",
63 "#7F7F7F",
64 "#C7C7C7",
65 "#BCBD22",
66 "#DBDB8D",
67 "#17BECF",
68 "#9EDAE5",
69 )
71 def __init__(self) -> None:
72 self.colors = cycle(self.DEFAULT_PALETTE)
74 def get_color(self) -> None:
75 return next(self.colors)
78class RixsData(object):
79 """RIXS plane object"""
81 def __init__(self, name=None, logger=None):
82 """Constructor"""
84 self.__name__ = "RixsData_{0}".format(hex(id(self)))
85 self.name = name or self.__name__
86 self.label = self.name
88 self._logger = logger or _logger
89 self._palette = CycleColors()
90 self._no_save = ("_logger", "_palette")
92 self.sample_name = "UnknownSample"
93 self.counter_all, self.counter_signal, self.counter_norm = None, None, None
94 self._x, self._y, self._z = None, None, None
95 self.ene_in, self.ene_out, self.rixs_map = None, None, None
96 self.ene_et, self.rixs_et_map = None, None
97 self.ene_grid, self.ene_unit, self.grid_method = None, None, None
98 self.line_cuts = {}
99 self.datatype = "rixs"
101 def set_energy_unit(self, unit=None):
102 """set the energy unit to eV"""
103 if unit is not None:
104 self.ene_unit = unit
105 if self.ene_unit is None:
106 self.ene_unit = guess_energy_units(self._x)
107 if self.ene_unit == "keV":
108 self._logger.info(f"Energy unit is {self.ene_unit} -> converting to eV")
109 self._x *= 1000
110 self._y *= 1000
111 self.ene_grid = 0.1
112 self.reset()
113 self.ene_unit = "eV"
114 assert (
115 self.ene_unit == "eV"
116 ), f"energy unit is {self.set_energy_unit} -> must be eV"
118 def load_from_dict(self, rxdict):
119 """Load RIXS data from a dictionary
121 Parameters
122 ----------
123 rxdict : dict
124 Minimal required structure
125 {
126 'writer_version': '1.5.0',
127 'sample_name': str,
128 '_x': 1D array, #: energy in
129 '_y': 1D array, #: energy out
130 '_z': 1D array, #: signal
131 }
133 Return
134 ------
135 None, set attributes: self.*
136 """
137 self.__dict__.update(rxdict)
138 self.set_energy_unit()
139 self.grid_rixs_from_col()
141 def load_from_h5(self, filename):
142 """Load RIXS from HDF5 file"""
143 rxdict = h5todict(filename)
144 _restore_from_array(rxdict)
145 if not ("writer_version" in rxdict.keys()):
146 self._logger.error("Key 'writer_version' not found")
147 return
148 if not ("1.5" in _tostr(rxdict["writer_version"])):
149 self._logger.warning("Data format not understood")
150 return
151 self.load_from_dict(rxdict)
152 self._logger.info("RIXS map loaded from file: {0}".format(filename))
154 def load_from_ascii(self, filename, **kws):
155 """load data from a 3 columns ASCII file assuming the format:
157 e_in(eV), e_out(eV), signal
159 """
161 try:
162 dat = np.loadtxt(filename)
163 self.filename = filename
164 self._logger.info("Loaded {0}".format(filename))
165 except Exception:
166 self._logger.error("Cannot load from {0}".format(filename))
167 return
169 self._x = dat[:, 0]
170 self._y = dat[:, 1]
171 self._z = dat[:, 2]
173 self.set_energy_unit()
174 self.reset()
176 def save_to_h5(self, filename=None):
177 """Dump dictionary representation to HDF5 file"""
178 if filename is None:
179 filename = f"{self.filename.split('.')[0]}.h5"
180 save_dict = copy.deepcopy(self.__dict__)
181 for dkey in self._no_save:
182 try:
183 del save_dict[dkey]
184 except KeyError:
185 continue
186 dicttoh5(save_dict, filename, update_mode="replace")
187 self._logger.info(f"{self.name} saved to {filename}")
189 def crop(self, crop_area, et=None):
190 """Crop the plane in a given range
192 Parameters
193 ----------
195 crop_area : tuple
196 (x1, y1, x2, y2) : floats
197 x1 < x2 (ene_in)
198 y1 < y2 (if yet=False: ene_out, else: ene_et)
200 et: bool,
201 if True: y1, y2 are given in energy transfer
203 """
204 self._crop_area = crop_area
205 x1, y1, x2, y2 = crop_area
206 assert x1 < x2, "wrong crop area, x1 >= x2"
207 assert y1 < y2, "wrong crop area, y1 >= y2"
209 if et is None:
210 if y2 < np.max(self.ene_et):
211 self._logger.debug("crop in energy transfer")
212 et = True
213 else:
214 self._logger.debug("crop in emission energy")
215 et = False
217 _xystep = self.ene_grid or 0.1
218 _method = self.grid_method or "linear"
220 _nxpts = int((x2 - x1) / _xystep)
221 _xcrop = np.linspace(x1, x2, num=_nxpts)
223 if et:
224 _etmin = y1
225 _etmax = y2
226 _ymin = x1 - _etmax
227 _ymax = x2 - _etmin
228 self._logger.debug(f"-> emission range: {_ymin:.2f}:{_ymax:.2f}")
229 else:
230 _ymin = y1
231 _ymax = y2
232 _etmin = x2 - _ymax
233 _etmax = x1 - _ymin
234 self._logger.debug(f"-> et range: {_etmin:.2f}:{_etmax:.2f}")
236 _netpts = int((_etmax - _etmin) / _xystep)
237 _nypts = int((_ymax - _ymin) / _xystep)
238 _etcrop = np.linspace(_etmin, _etmax, num=_netpts)
239 _ycrop = np.linspace(_ymin, _ymax, num=_nypts)
242 _xx, _yy = np.meshgrid(_xcrop, _ycrop)
243 _exx, _et = np.meshgrid(_xcrop, _etcrop)
244 self._logger.info("Gridding data...")
245 _zzcrop = griddata((self._x, self._y), self._z, (_xx, _yy), method=_method)
246 _ezzcrop = griddata(
247 (self._x, self._x - self._y), self._z, (_exx, _et), method=_method
248 )
250 self.ene_in = _xcrop
251 self.ene_out = _ycrop
252 self.ene_et = _etcrop
253 self.rixs_map = _zzcrop
254 self.rixs_et_map = _ezzcrop
255 self.label = f"{self.name} [{self._crop_area}]"
257 def reset(self, **grid_kws):
258 """resets to initial data"""
259 self._logger.info("resetting to initial data (grid RIXS plane and line cuts)")
260 self.grid_rixs_from_col(**grid_kws)
261 self.line_cuts = {}
262 self.label = self.name
263 self._palette = None
264 self._palette = CycleColors()
266 def grid_rixs_from_col(self, ene_grid=None, grid_method=None):
267 """Grid RIXS map from XYZ columns"""
268 if ene_grid is not None:
269 self.ene_grid = ene_grid
270 if grid_method is not None:
271 self.grid_method = grid_method
272 _xystep = self.ene_grid or 0.1
273 _method = self.grid_method or "linear"
274 self.ene_in, self.ene_out, self.rixs_map = gridxyz(
275 self._x, self._y, self._z, xystep=_xystep, method=_method
276 )
277 self._et = self._x - self._y
278 _, self.ene_et, self.rixs_et_map = gridxyz(
279 self._x, self._et, self._z, xystep=_xystep, method=_method
280 )
281 self.ene_grid = _xystep
283 def cut(self, energy=None, mode="CEE", label=None):
284 """cut the RIXS plane at a given energy
286 Parameters
287 ----------
288 energy : float
289 energy of the cut
291 mode : str
292 defines the way to cut the plane:
293 - "CEE" (constant emission energy)
294 - "CIE" (constant incident energy)
295 - "CET" (constant energy transfer)
297 label : str, optional [None]
298 custom label, if None: label = 'mode_enecut'
300 Return
301 ------
302 None -> adds dict(x:array, y:array, info:dict) to self.lcuts[cut_key]:dict, where
304 info = {label: str, #: 'mode_enecut'
305 mode: str, #: as input
306 enecut: float, #: energy cut given from the initial interpolation
307 datatype: str, #: 'xas' or 'xes'
308 color: str, #: color from a common palette
309 timestamp: str, #: time stamp
310 }
311 """
312 assert energy is not None, "The energy of the cut must be given"
314 mode = mode.upper()
316 if mode == "CEE":
317 xc = self.ene_in
318 iy = np.abs(self.ene_out - energy).argmin()
319 enecut = self.ene_out[iy]
320 yc = self.rixs_map[iy, :]
321 datatype = "xas"
322 elif mode == "CIE":
323 xc = self.ene_out
324 iy = np.abs(self.ene_in - energy).argmin()
325 enecut = self.ene_in[iy]
326 yc = self.rixs_map[:, iy]
327 datatype = "xes"
328 elif mode == "CET":
329 xc = self.ene_in
330 iy = np.abs(self.ene_et - energy).argmin()
331 enecut = self.ene_et[iy]
332 yc = self.rixs_et_map[iy, :]
333 datatype = "xas"
334 else:
335 self._logger.error(f"wrong mode: {mode}")
336 return
338 if label is None:
339 label = f"{mode}_{enecut:.1f}"
341 info = dict(
342 label=label,
343 mode=mode,
344 enecut=enecut,
345 datatype=datatype,
346 color=self._palette.get_color(),
347 timestamp="{0:04d}-{1:02d}-{2:02d}_{3:02d}{4:02d}".format(
348 *time.localtime()
349 ),
350 )
352 cut_key = f"{mode}_{enecut*10:.0f}"
353 self.line_cuts[cut_key] = dict(x=xc, y=yc, info=info)
354 self._logger.info(f"added RIXS {mode} cut: '{label}'")
356 def norm(self):
357 """Simple map normalization to max-min"""
358 self.rixs_map = self.rixs_map / (
359 np.nanmax(self.rixs_map) - np.nanmin(self.rixs_map)
360 )
361 self._logger.info("rixs map normalized to max-min")
364if __name__ == "__main__":
365 pass