Coverage for /Users/Newville/Codes/xraylarch/larch/plot/plot_rixsdata.py: 0%
281 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4"""Plot RIXS data sets (2D maps)
5================================
6"""
8import copy
9import numpy as np
10import matplotlib.pyplot as plt
13from matplotlib import gridspec
14from matplotlib import cm
15from matplotlib.ticker import MaxNLocator, AutoLocator
16from sympy import EX
18from larch.utils.logging import getLogger
20_logger = getLogger(__name__) #: module logger, used as self._logger if not given
23def plot_rixs(
24 rd,
25 et=True,
26 fig_name="plot_rixs_fig",
27 fig_size=(10, 10),
28 fig_dpi=75,
29 fig_title=None,
30 x_label=None,
31 y_label=None,
32 x_nticks=0,
33 y_nticks=0,
34 x_min=None,
35 x_max=None,
36 y_min=None,
37 y_max=None,
38 cbar_show=True,
39 cbar_pos="vertical",
40 cbar_nticks=0,
41 cbar_label="Signal intensity",
42 cbar_norm0=False,
43 cont_nlevels=50,
44 cont_imshow=True,
45 cmap=cm.gist_heat_r,
46 cmap2=cm.RdBu,
47 cmap_linlog="linear",
48 cont_type="line",
49 cont_lwidths=0.25,
50 cont_labels=None,
51 cont_labelformat="%.3f",
52 origin="lower",
53):
54 """RIXS map plotter
56 Parameters
57 ----------
59 rd : RixsData
61 cbar_norm0 : boolean, optional [False]
62 Normalize color bar around 0
64 cont_levels : int, optional [50]
65 number of contour lines
67 cont_imshow : boolean, optional [True]
68 use plt.imshow instead of plt.contourf
70 """
71 if not "RixsData" in str(type(rd)):
72 _logger.error('only "RixsData" objects can be plotted!')
73 return
75 if fig_title is None:
76 fig_title = rd.label
78 if x_label is None:
79 x_label = "Incoming energy (eV)"
81 if et:
82 try:
83 x = rd.ene_in
84 y = rd.ene_et
85 zz = rd.rixs_et_map
86 if y_label is None:
87 y_label = "Energy transfer (eV)"
88 except Exception:
89 _logger.error("`ene_in/ene_et/rixs_et_map` arrays missing")
90 return
91 else:
92 try:
93 x = rd.ene_in
94 y = rd.ene_out
95 zz = rd.rixs_map
96 if y_label is None:
97 y_label = "Emitted energy (eV)"
98 except Exception:
99 _logger.error("`ene_in/ene_out/rixs_map` arrays missing")
100 return
102 plt.close(fig_name)
103 fig = plt.figure(num=fig_name, figsize=fig_size, dpi=fig_dpi)
105 # NOTE: np.nanmin/np.nanmax fails with masked arrays! better
106 # to work with MaskedArray for zz
108 # if not 'MaskedArray' in str(type(zz)):
109 # zz = np.ma.masked_where(zz == np.nan, zz)
111 # NOTE2: even with masked arrays min()/max() fail!!! I do a
112 # manual check against 'nan' instead of the masked
113 # array solution
115 try:
116 zzmin, zzmax = np.nanmin(zz), np.nanmax(zz)
117 except:
118 zzmin, zzmax = np.min(zz), np.max(zz)
120 if cbar_norm0:
121 # normalize colors around 0
122 if abs(zzmin) > abs(zzmax):
123 vnorm = abs(zzmin)
124 else:
125 vnorm = abs(zzmax)
126 norm = cm.colors.Normalize(vmin=-vnorm, vmax=vnorm)
127 else:
128 # normalize colors from min to max
129 norm = cm.colors.Normalize(vmin=zzmin, vmax=zzmax)
131 extent = (x.min(), x.max(), y.min(), y.max())
132 levels = np.linspace(zzmin, zzmax, cont_nlevels)
134 ### FIGURE LAYOUT ###
135 plane = fig.add_subplot(111)
136 plane.set_title(fig_title)
137 plane.set_xlabel(x_label)
138 plane.set_ylabel(y_label)
139 if x_min and x_max:
140 plane.set_xlim(x_min, x_max)
141 if y_min and y_max:
142 plane.set_ylim(y_min, y_max)
144 # contour mode: 'contf' or 'imshow'
145 if cont_imshow:
146 contf = plane.imshow(zz, origin="lower", extent=extent, cmap=cmap, norm=norm)
147 else:
148 contf = plane.contourf(
149 x, y, zz, levels, cmap=cm.get_cmap(cmap, len(levels) - 1), norm=norm
150 )
152 if "line" in cont_type.lower():
153 cont = plane.contour(
154 x, y, zz, levels, colors="k", hold="on", linewidths=cont_lwidths
155 )
156 if x_nticks:
157 plane.xaxis.set_major_locator(MaxNLocator(int(x_nticks)))
158 else:
159 plane.xaxis.set_major_locator(AutoLocator())
160 if y_nticks:
161 plane.yaxis.set_major_locator(MaxNLocator(int(y_nticks)))
162 else:
163 plane.yaxis.set_major_locator(AutoLocator())
165 # colorbar
166 if cbar_show:
167 xyratio = y.shape[0] / x.shape[0]
168 cbar = fig.colorbar(
169 contf,
170 use_gridspec=True,
171 orientation=cbar_pos,
172 fraction=0.046 * xyratio,
173 pad=0.04,
174 )
175 if cbar_nticks:
176 cbar.set_ticks(MaxNLocator(int(y_nticks)))
177 else:
178 cbar.set_ticks(AutoLocator())
179 cbar.set_label(cbar_label)
181 fig.tight_layout()
182 return fig
185def plot_rixs_cuts(rd, et=True, fig_name="plot_rixs_cuts", fig_size=(8, 10), fig_dpi=75):
186 """plot RIXS line cuts"""
187 assert len(rd.line_cuts.keys()) >= 1, "no line cuts are present"
188 plt.close(fig_name)
189 fig, axs = plt.subplots(nrows=3, num=fig_name, figsize=fig_size, dpi=fig_dpi)
191 for ax in axs:
192 ax.set_axis_off()
194 y_label = "Signal intensity"
196 for key, val in rd.line_cuts.items():
197 x, y, info = val["x"], val["y"], val["info"]
198 mode = info["mode"]
199 label = info["label"]
200 color = info["color"]
201 if mode == "CEE":
202 ax = axs[0]
203 ax.set_axis_on()
204 x_label = "Incoming energy (eV)"
205 elif mode == "CIE":
206 ax = axs[1]
207 ax.set_axis_on()
208 if et:
209 x = info['enecut'] - x
210 x_label = "Energy transfer (eV)"
211 else:
212 x_label = "Emitted energy (eV)"
213 elif mode == "CET":
214 ax = axs[2]
215 ax.set_axis_on()
216 x_label = "Incoming energy (eV)"
217 else:
218 _logger.error(f"wrong mode: {mode}")
219 continue
220 ax.set_title(mode)
221 ax.set_xlabel(x_label)
222 ax.set_ylabel(y_label)
223 ax.plot(x, y, label=label, color=color)
224 ax.legend()
225 fig.tight_layout()
226 return fig
229class RixsDataPlotter(object):
230 """plotter for a RixsData object"""
232 def __init__(self, rd):
233 "initialize with keyword arguments dictionaries"
234 if not "RixsData" in str(type(rd)):
235 _logger.error('I can only plot "RixsData" objects!')
236 return
237 try:
238 self.kwsd = copy.deepcopy(rd.kwsd["plot"])
239 except Exception:
240 self.kwsd = self.get_plot_kwsd()
241 self.rd = rd
243 def get_plot_kwsd(self):
244 """return a dictionary of dictionaries with default keywords arguments"""
245 kwsd = {
246 "replace": True,
247 "figname": "RixsDataPlotter",
248 "figsize": (10, 10),
249 "figdpi": 150,
250 "title": None,
251 "xlabel": None,
252 "ylabel": None,
253 "x_nticks": 0,
254 "y_nticks": 0,
255 "z_nticks": 0,
256 "xlabelE": r"Incoming Energy (eV)",
257 "ylabelE": r"Emitted Energy (eV)",
258 "ylabelEt": r"Energy transfer (eV)",
259 "zlabel": r"Intensity (a.u)",
260 "xystep": 0.01,
261 "xmin": None,
262 "xmax": None,
263 "ymin": None,
264 "ymax": None,
265 "xshift": 0,
266 "ystack": 0,
267 "xscale": 1,
268 "yscale": 1,
269 "cbar_show": False,
270 "cbar_pos": "vertical",
271 "cbar_nticks": 0,
272 "cbar_label": "Counts/s",
273 "cbar_norm0": False,
274 "cmap": cm.gist_heat_r,
275 "cmap2": cm.RdBu,
276 "cmap_linlog": "linear",
277 "cont_imshow": True,
278 "cont_type": "line",
279 "cont_lwidths": 0.25,
280 "cont_levels": 50,
281 "cont_labels": None,
282 "cont_labelformat": "%.3f",
283 "origin": "lower",
284 "lcuts": False,
285 "xcut": None,
286 "ycut": None,
287 "dcut": None,
288 "lc_dticks": 2,
289 "lc_color": "red",
290 "lc_lw": 3,
291 }
292 return kwsd
294 def plot(self, x=None, y=None, zz=None, **kws):
295 """make the plot"""
296 if x is None:
297 x = self.rd.ene_in
298 x0 = self.rd.ene_in
299 if y is None:
300 y = self.rd.ene_et
301 y0 = self.rd.ene_out
302 if zz is None:
303 zz = self.rd.rixs_et_map
304 zz0 = self.rd.rixs_map
306 self.kwsd.update(**kws)
308 # check if x and y are 1D or 2D arrays
309 if (len(x.shape) == 1) and (len(y.shape) == 1):
310 _xyshape = 1
311 elif (len(x.shape) == 2) and (len(y.shape) == 2):
312 _xyshape = 2
314 lcuts = kws.get("lcuts", self.kwsd["lcuts"])
315 xcut = kws.get("xcut", self.kwsd["xcut"])
316 ycut = kws.get("ycut", self.kwsd["ycut"])
317 dcut = kws.get("dcut", self.kwsd["dcut"])
319 lc_dticks = kws.get("lc_dticks", self.kwsd["lc_dticks"])
320 lc_color = kws.get("lc_color", self.kwsd["lc_color"])
321 lc_lw = kws.get("lc_lw", self.kwsd["lc_lw"])
323 replace = kws.get("replace", self.kwsd["replace"])
324 figname = kws.get("figname", self.kwsd["figname"])
325 figsize = kws.get("figsize", self.kwsd["figsize"])
326 figdpi = kws.get("figdpi", self.kwsd["figdpi"])
327 title = kws.get("title", self.kwsd["title"])
328 xlabel = kws.get("xlabel", self.kwsd["xlabelE"])
329 if y.max() / x.max() < 0.5:
330 ylabel = kws.get("ylabel", self.kwsd["ylabelEt"])
331 else:
332 ylabel = kws.get("ylabel", self.kwsd["ylabelE"])
333 zlabel = kws.get("zlabel", self.kwsd["zlabel"])
334 xmin = kws.get("xmin", self.kwsd["xmin"])
335 xmax = kws.get("xmax", self.kwsd["xmax"])
336 ymin = kws.get("ymin", self.kwsd["ymin"])
337 ymax = kws.get("ymax", self.kwsd["ymax"])
338 x_nticks = kws.get("x_nticks", self.kwsd["x_nticks"])
339 y_nticks = kws.get("y_nticks", self.kwsd["y_nticks"])
340 z_nticks = kws.get("z_nticks", self.kwsd["z_nticks"])
341 cmap = kws.get("cmap", self.kwsd["cmap"])
343 cbar_show = kws.get("cbar_show", self.kwsd["cbar_show"])
344 cbar_pos = kws.get("cbar_pos", self.kwsd["cbar_pos"])
345 cbar_nticks = kws.get("cbar_nticks", self.kwsd["cbar_nticks"])
346 cbar_label = kws.get("cbar_label", self.kwsd["cbar_label"])
347 cbar_norm0 = kws.get("cbar_norm0", self.kwsd["cbar_norm0"])
349 cont_imshow = kws.get("cont_imshow", self.kwsd["cont_imshow"])
350 cont_type = kws.get("cont_type", self.kwsd["cont_type"])
351 cont_levels = kws.get("cont_levels", self.kwsd["cont_levels"])
352 cont_lwidths = kws.get("cont_lwidths", self.kwsd["cont_lwidths"])
354 # NOTE: np.nanmin/np.nanmax fails with masked arrays! better
355 # to work with MaskedArray for zz
357 # if not 'MaskedArray' in str(type(zz)):
358 # zz = np.ma.masked_where(zz == np.nan, zz)
360 # NOTE2: even with masked arrays min()/max() fail!!! I do a
361 # manual check against 'nan' instead of the masked
362 # array solution
364 try:
365 zzmin, zzmax = np.nanmin(zz), np.nanmax(zz)
366 except:
367 zzmin, zzmax = np.min(zz), np.max(zz)
369 if cbar_norm0:
370 # normalize colors around 0
371 if abs(zzmin) > abs(zzmax):
372 vnorm = abs(zzmin)
373 else:
374 vnorm = abs(zzmax)
375 norm = cm.colors.Normalize(vmin=-vnorm, vmax=vnorm)
376 else:
377 # normalize colors from min to max
378 norm = cm.colors.Normalize(vmin=zzmin, vmax=zzmax)
380 extent = (x.min(), x.max(), y.min(), y.max())
381 levels = np.linspace(zzmin, zzmax, cont_levels)
383 ### FIGURE LAYOUT ###
384 if replace:
385 plt.close(figname)
386 self.fig = plt.figure(num=figname, figsize=figsize, dpi=figdpi)
387 if replace:
388 self.fig.clear()
390 # 1 DATA SET WITH OR WITHOUT LINE CUTS
391 if lcuts:
392 gs = gridspec.GridSpec(3, 3) # 3x3 grid
393 self.plane = plt.subplot(gs[:, :-1]) # plane
394 self.lxcut = plt.subplot(gs[0, 2]) # cut along x-axis
395 self.ldcut = plt.subplot(gs[1, 2]) # cut along d-axis (diagonal)
396 self.lycut = plt.subplot(gs[2, 2]) # cut along y-axis
397 else:
398 self.plane = self.fig.add_subplot(111) # plot olny plane
400 # plane
401 if title:
402 self.plane.set_title(title)
403 self.plane.set_xlabel(xlabel)
404 self.plane.set_ylabel(ylabel)
405 if xmin and xmax:
406 self.plane.set_xlim(xmin, xmax)
407 if ymin and ymax:
408 self.plane.set_ylim(ymin, ymax)
410 # contour mode: 'contf' or 'imshow'
411 if cont_imshow:
412 self.contf = self.plane.imshow(
413 zz, origin="lower", extent=extent, cmap=cmap, norm=norm
414 )
415 else:
416 self.contf = self.plane.contourf(
417 x, y, zz, levels, cmap=cm.get_cmap(cmap, len(levels) - 1), norm=norm
418 )
420 if "line" in cont_type.lower():
421 self.cont = self.plane.contour(
422 x, y, zz, levels, colors="k", hold="on", linewidths=cont_lwidths
423 )
424 if x_nticks:
425 self.plane.xaxis.set_major_locator(MaxNLocator(int(x_nticks)))
426 else:
427 self.plane.xaxis.set_major_locator(AutoLocator())
428 if y_nticks:
429 self.plane.yaxis.set_major_locator(MaxNLocator(int(y_nticks)))
430 else:
431 self.plane.yaxis.set_major_locator(AutoLocator())
433 # colorbar
434 if cbar_show:
435 self.cbar = self.fig.colorbar(
436 self.contf, use_gridspec=True, orientation=cbar_pos
437 )
438 if cbar_nticks:
439 self.cbar.set_ticks(MaxNLocator(int(y_nticks)))
440 else:
441 self.cbar.set_ticks(AutoLocator())
442 self.cbar.set_label(cbar_label)
444 # xcut plot
445 if lcuts and xcut:
446 xpos = np.argmin(np.abs(xcut - x))
447 if _xyshape == 1:
448 self.lxcut.plot(
449 y, zz[:, xpos], label=str(x[xpos]), color=lc_color, linewidth=lc_lw
450 )
451 elif _xyshape == 2:
452 self.lxcut.plot(
453 y[:, xpos],
454 zz[:, xpos],
455 label=str(x[:, xpos][0]),
456 color=lc_color,
457 linewidth=lc_lw,
458 )
459 if y_nticks:
460 self.lxcut.xaxis.set_major_locator(
461 MaxNLocator(int(y_nticks / lc_dticks))
462 )
463 else:
464 self.lxcut.xaxis.set_major_locator(AutoLocator())
465 if z_nticks:
466 self.lxcut.yaxis.set_major_locator(
467 MaxNLocator(int(z_nticks / lc_dticks))
468 )
469 else:
470 self.lxcut.yaxis.set_major_locator(AutoLocator())
471 self.lxcut.set_yticklabels([])
472 self.lxcut.set_ylabel(zlabel)
473 self.lxcut.set_xlabel(ylabel)
474 if ymin and ymax:
475 self.lxcut.set_xlim(ymin, ymax)
477 # ycut plot
478 if lcuts and ycut:
479 ypos = np.argmin(np.abs(ycut - y))
480 if _xyshape == 1:
481 self.lycut.plot(
482 x, zz[ypos, :], label=str(y[ypos]), color=lc_color, linewidth=lc_lw
483 )
484 elif _xyshape == 2:
485 self.lycut.plot(
486 x[ypos, :],
487 zz[ypos, :],
488 label=str(y[ypos, :][0]),
489 color=lc_color,
490 linewidth=lc_lw,
491 )
492 if x_nticks:
493 self.lycut.xaxis.set_major_locator(
494 MaxNLocator(int(x_nticks / lc_dticks))
495 )
496 else:
497 self.lycut.xaxis.set_major_locator(AutoLocator())
498 if z_nticks:
499 self.lycut.yaxis.set_major_locator(
500 MaxNLocator(int(z_nticks / lc_dticks))
501 )
502 else:
503 self.lycut.yaxis.set_major_locator(AutoLocator())
504 self.lycut.set_yticklabels([])
505 self.lycut.set_ylabel(zlabel)
506 self.lycut.set_xlabel(xlabel)
507 if xmin and xmax:
508 self.lycut.set_xlim(xmin, xmax)
510 # dcut plot => equivalent to ycut plot for (zz0, x0, y0)
511 if lcuts and dcut:
512 ypos0 = np.argmin(np.abs(dcut - y0))
513 if _xyshape == 1:
514 self.ldcut.plot(
515 x0,
516 zz0[ypos0, :],
517 label=str(y0[ypos0]),
518 color=lc_color,
519 linewidth=lc_lw,
520 )
521 elif _xyshape == 2:
522 self.ldcut.plot(
523 x0[ypos0, :],
524 zz0[ypos0, :],
525 label=str(y0[ypos0, :][0]),
526 color=lc_color,
527 linewidth=lc_lw,
528 )
529 if x_nticks:
530 self.ldcut.xaxis.set_major_locator(
531 MaxNLocator(int(x_nticks / lc_dticks))
532 )
533 else:
534 self.ldcut.xaxis.set_major_locator(AutoLocator())
535 if z_nticks:
536 self.ldcut.yaxis.set_major_locator(
537 MaxNLocator(int(z_nticks / lc_dticks))
538 )
539 else:
540 self.ldcut.yaxis.set_major_locator(AutoLocator())
541 self.ldcut.set_yticklabels([])
542 self.ldcut.set_ylabel(zlabel)
543 self.ldcut.set_xlabel(xlabel)
544 if xmin and xmax:
545 self.ldcut.set_xlim(xmin, xmax)
546 plt.draw()
547 plt.show()
550if __name__ == "__main__":
551 pass