Coverage for /Users/Newville/Codes/xraylarch/larch/plot/plot_rixsdata.py: 0%

281 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4"""Plot RIXS data sets (2D maps) 

5================================ 

6""" 

7 

8import copy 

9import numpy as np 

10import matplotlib.pyplot as plt 

11 

12 

13from matplotlib import gridspec 

14from matplotlib import cm 

15from matplotlib.ticker import MaxNLocator, AutoLocator 

16from sympy import EX 

17 

18from larch.utils.logging import getLogger 

19 

20_logger = getLogger(__name__) #: module logger, used as self._logger if not given 

21 

22 

23def plot_rixs( 

24 rd, 

25 et=True, 

26 fig_name="plot_rixs_fig", 

27 fig_size=(10, 10), 

28 fig_dpi=75, 

29 fig_title=None, 

30 x_label=None, 

31 y_label=None, 

32 x_nticks=0, 

33 y_nticks=0, 

34 x_min=None, 

35 x_max=None, 

36 y_min=None, 

37 y_max=None, 

38 cbar_show=True, 

39 cbar_pos="vertical", 

40 cbar_nticks=0, 

41 cbar_label="Signal intensity", 

42 cbar_norm0=False, 

43 cont_nlevels=50, 

44 cont_imshow=True, 

45 cmap=cm.gist_heat_r, 

46 cmap2=cm.RdBu, 

47 cmap_linlog="linear", 

48 cont_type="line", 

49 cont_lwidths=0.25, 

50 cont_labels=None, 

51 cont_labelformat="%.3f", 

52 origin="lower", 

53): 

54 """RIXS map plotter 

55 

56 Parameters 

57 ---------- 

58 

59 rd : RixsData 

60 

61 cbar_norm0 : boolean, optional [False] 

62 Normalize color bar around 0 

63 

64 cont_levels : int, optional [50] 

65 number of contour lines 

66 

67 cont_imshow : boolean, optional [True] 

68 use plt.imshow instead of plt.contourf 

69 

70 """ 

71 if not "RixsData" in str(type(rd)): 

72 _logger.error('only "RixsData" objects can be plotted!') 

73 return 

74 

75 if fig_title is None: 

76 fig_title = rd.label 

77 

78 if x_label is None: 

79 x_label = "Incoming energy (eV)" 

80 

81 if et: 

82 try: 

83 x = rd.ene_in 

84 y = rd.ene_et 

85 zz = rd.rixs_et_map 

86 if y_label is None: 

87 y_label = "Energy transfer (eV)" 

88 except Exception: 

89 _logger.error("`ene_in/ene_et/rixs_et_map` arrays missing") 

90 return 

91 else: 

92 try: 

93 x = rd.ene_in 

94 y = rd.ene_out 

95 zz = rd.rixs_map 

96 if y_label is None: 

97 y_label = "Emitted energy (eV)" 

98 except Exception: 

99 _logger.error("`ene_in/ene_out/rixs_map` arrays missing") 

100 return 

101 

102 plt.close(fig_name) 

103 fig = plt.figure(num=fig_name, figsize=fig_size, dpi=fig_dpi) 

104 

105 # NOTE: np.nanmin/np.nanmax fails with masked arrays! better 

106 # to work with MaskedArray for zz 

107 

108 # if not 'MaskedArray' in str(type(zz)): 

109 # zz = np.ma.masked_where(zz == np.nan, zz) 

110 

111 # NOTE2: even with masked arrays min()/max() fail!!! I do a 

112 # manual check against 'nan' instead of the masked 

113 # array solution 

114 

115 try: 

116 zzmin, zzmax = np.nanmin(zz), np.nanmax(zz) 

117 except: 

118 zzmin, zzmax = np.min(zz), np.max(zz) 

119 

120 if cbar_norm0: 

121 # normalize colors around 0 

122 if abs(zzmin) > abs(zzmax): 

123 vnorm = abs(zzmin) 

124 else: 

125 vnorm = abs(zzmax) 

126 norm = cm.colors.Normalize(vmin=-vnorm, vmax=vnorm) 

127 else: 

128 # normalize colors from min to max 

129 norm = cm.colors.Normalize(vmin=zzmin, vmax=zzmax) 

130 

131 extent = (x.min(), x.max(), y.min(), y.max()) 

132 levels = np.linspace(zzmin, zzmax, cont_nlevels) 

133 

134 ### FIGURE LAYOUT ### 

135 plane = fig.add_subplot(111) 

136 plane.set_title(fig_title) 

137 plane.set_xlabel(x_label) 

138 plane.set_ylabel(y_label) 

139 if x_min and x_max: 

140 plane.set_xlim(x_min, x_max) 

141 if y_min and y_max: 

142 plane.set_ylim(y_min, y_max) 

143 

144 # contour mode: 'contf' or 'imshow' 

145 if cont_imshow: 

146 contf = plane.imshow(zz, origin="lower", extent=extent, cmap=cmap, norm=norm) 

147 else: 

148 contf = plane.contourf( 

149 x, y, zz, levels, cmap=cm.get_cmap(cmap, len(levels) - 1), norm=norm 

150 ) 

151 

152 if "line" in cont_type.lower(): 

153 cont = plane.contour( 

154 x, y, zz, levels, colors="k", hold="on", linewidths=cont_lwidths 

155 ) 

156 if x_nticks: 

157 plane.xaxis.set_major_locator(MaxNLocator(int(x_nticks))) 

158 else: 

159 plane.xaxis.set_major_locator(AutoLocator()) 

160 if y_nticks: 

161 plane.yaxis.set_major_locator(MaxNLocator(int(y_nticks))) 

162 else: 

163 plane.yaxis.set_major_locator(AutoLocator()) 

164 

165 # colorbar 

166 if cbar_show: 

167 xyratio = y.shape[0] / x.shape[0] 

168 cbar = fig.colorbar( 

169 contf, 

170 use_gridspec=True, 

171 orientation=cbar_pos, 

172 fraction=0.046 * xyratio, 

173 pad=0.04, 

174 ) 

175 if cbar_nticks: 

176 cbar.set_ticks(MaxNLocator(int(y_nticks))) 

177 else: 

178 cbar.set_ticks(AutoLocator()) 

179 cbar.set_label(cbar_label) 

180 

181 fig.tight_layout() 

182 return fig 

183 

184 

185def plot_rixs_cuts(rd, et=True, fig_name="plot_rixs_cuts", fig_size=(8, 10), fig_dpi=75): 

186 """plot RIXS line cuts""" 

187 assert len(rd.line_cuts.keys()) >= 1, "no line cuts are present" 

188 plt.close(fig_name) 

189 fig, axs = plt.subplots(nrows=3, num=fig_name, figsize=fig_size, dpi=fig_dpi) 

190 

191 for ax in axs: 

192 ax.set_axis_off() 

193 

194 y_label = "Signal intensity" 

195 

196 for key, val in rd.line_cuts.items(): 

197 x, y, info = val["x"], val["y"], val["info"] 

198 mode = info["mode"] 

199 label = info["label"] 

200 color = info["color"] 

201 if mode == "CEE": 

202 ax = axs[0] 

203 ax.set_axis_on() 

204 x_label = "Incoming energy (eV)" 

205 elif mode == "CIE": 

206 ax = axs[1] 

207 ax.set_axis_on() 

208 if et: 

209 x = info['enecut'] - x 

210 x_label = "Energy transfer (eV)" 

211 else: 

212 x_label = "Emitted energy (eV)" 

213 elif mode == "CET": 

214 ax = axs[2] 

215 ax.set_axis_on() 

216 x_label = "Incoming energy (eV)" 

217 else: 

218 _logger.error(f"wrong mode: {mode}") 

219 continue 

220 ax.set_title(mode) 

221 ax.set_xlabel(x_label) 

222 ax.set_ylabel(y_label) 

223 ax.plot(x, y, label=label, color=color) 

224 ax.legend() 

225 fig.tight_layout() 

226 return fig 

227 

228 

229class RixsDataPlotter(object): 

230 """plotter for a RixsData object""" 

231 

232 def __init__(self, rd): 

233 "initialize with keyword arguments dictionaries" 

234 if not "RixsData" in str(type(rd)): 

235 _logger.error('I can only plot "RixsData" objects!') 

236 return 

237 try: 

238 self.kwsd = copy.deepcopy(rd.kwsd["plot"]) 

239 except Exception: 

240 self.kwsd = self.get_plot_kwsd() 

241 self.rd = rd 

242 

243 def get_plot_kwsd(self): 

244 """return a dictionary of dictionaries with default keywords arguments""" 

245 kwsd = { 

246 "replace": True, 

247 "figname": "RixsDataPlotter", 

248 "figsize": (10, 10), 

249 "figdpi": 150, 

250 "title": None, 

251 "xlabel": None, 

252 "ylabel": None, 

253 "x_nticks": 0, 

254 "y_nticks": 0, 

255 "z_nticks": 0, 

256 "xlabelE": r"Incoming Energy (eV)", 

257 "ylabelE": r"Emitted Energy (eV)", 

258 "ylabelEt": r"Energy transfer (eV)", 

259 "zlabel": r"Intensity (a.u)", 

260 "xystep": 0.01, 

261 "xmin": None, 

262 "xmax": None, 

263 "ymin": None, 

264 "ymax": None, 

265 "xshift": 0, 

266 "ystack": 0, 

267 "xscale": 1, 

268 "yscale": 1, 

269 "cbar_show": False, 

270 "cbar_pos": "vertical", 

271 "cbar_nticks": 0, 

272 "cbar_label": "Counts/s", 

273 "cbar_norm0": False, 

274 "cmap": cm.gist_heat_r, 

275 "cmap2": cm.RdBu, 

276 "cmap_linlog": "linear", 

277 "cont_imshow": True, 

278 "cont_type": "line", 

279 "cont_lwidths": 0.25, 

280 "cont_levels": 50, 

281 "cont_labels": None, 

282 "cont_labelformat": "%.3f", 

283 "origin": "lower", 

284 "lcuts": False, 

285 "xcut": None, 

286 "ycut": None, 

287 "dcut": None, 

288 "lc_dticks": 2, 

289 "lc_color": "red", 

290 "lc_lw": 3, 

291 } 

292 return kwsd 

293 

294 def plot(self, x=None, y=None, zz=None, **kws): 

295 """make the plot""" 

296 if x is None: 

297 x = self.rd.ene_in 

298 x0 = self.rd.ene_in 

299 if y is None: 

300 y = self.rd.ene_et 

301 y0 = self.rd.ene_out 

302 if zz is None: 

303 zz = self.rd.rixs_et_map 

304 zz0 = self.rd.rixs_map 

305 

306 self.kwsd.update(**kws) 

307 

308 # check if x and y are 1D or 2D arrays 

309 if (len(x.shape) == 1) and (len(y.shape) == 1): 

310 _xyshape = 1 

311 elif (len(x.shape) == 2) and (len(y.shape) == 2): 

312 _xyshape = 2 

313 

314 lcuts = kws.get("lcuts", self.kwsd["lcuts"]) 

315 xcut = kws.get("xcut", self.kwsd["xcut"]) 

316 ycut = kws.get("ycut", self.kwsd["ycut"]) 

317 dcut = kws.get("dcut", self.kwsd["dcut"]) 

318 

319 lc_dticks = kws.get("lc_dticks", self.kwsd["lc_dticks"]) 

320 lc_color = kws.get("lc_color", self.kwsd["lc_color"]) 

321 lc_lw = kws.get("lc_lw", self.kwsd["lc_lw"]) 

322 

323 replace = kws.get("replace", self.kwsd["replace"]) 

324 figname = kws.get("figname", self.kwsd["figname"]) 

325 figsize = kws.get("figsize", self.kwsd["figsize"]) 

326 figdpi = kws.get("figdpi", self.kwsd["figdpi"]) 

327 title = kws.get("title", self.kwsd["title"]) 

328 xlabel = kws.get("xlabel", self.kwsd["xlabelE"]) 

329 if y.max() / x.max() < 0.5: 

330 ylabel = kws.get("ylabel", self.kwsd["ylabelEt"]) 

331 else: 

332 ylabel = kws.get("ylabel", self.kwsd["ylabelE"]) 

333 zlabel = kws.get("zlabel", self.kwsd["zlabel"]) 

334 xmin = kws.get("xmin", self.kwsd["xmin"]) 

335 xmax = kws.get("xmax", self.kwsd["xmax"]) 

336 ymin = kws.get("ymin", self.kwsd["ymin"]) 

337 ymax = kws.get("ymax", self.kwsd["ymax"]) 

338 x_nticks = kws.get("x_nticks", self.kwsd["x_nticks"]) 

339 y_nticks = kws.get("y_nticks", self.kwsd["y_nticks"]) 

340 z_nticks = kws.get("z_nticks", self.kwsd["z_nticks"]) 

341 cmap = kws.get("cmap", self.kwsd["cmap"]) 

342 

343 cbar_show = kws.get("cbar_show", self.kwsd["cbar_show"]) 

344 cbar_pos = kws.get("cbar_pos", self.kwsd["cbar_pos"]) 

345 cbar_nticks = kws.get("cbar_nticks", self.kwsd["cbar_nticks"]) 

346 cbar_label = kws.get("cbar_label", self.kwsd["cbar_label"]) 

347 cbar_norm0 = kws.get("cbar_norm0", self.kwsd["cbar_norm0"]) 

348 

349 cont_imshow = kws.get("cont_imshow", self.kwsd["cont_imshow"]) 

350 cont_type = kws.get("cont_type", self.kwsd["cont_type"]) 

351 cont_levels = kws.get("cont_levels", self.kwsd["cont_levels"]) 

352 cont_lwidths = kws.get("cont_lwidths", self.kwsd["cont_lwidths"]) 

353 

354 # NOTE: np.nanmin/np.nanmax fails with masked arrays! better 

355 # to work with MaskedArray for zz 

356 

357 # if not 'MaskedArray' in str(type(zz)): 

358 # zz = np.ma.masked_where(zz == np.nan, zz) 

359 

360 # NOTE2: even with masked arrays min()/max() fail!!! I do a 

361 # manual check against 'nan' instead of the masked 

362 # array solution 

363 

364 try: 

365 zzmin, zzmax = np.nanmin(zz), np.nanmax(zz) 

366 except: 

367 zzmin, zzmax = np.min(zz), np.max(zz) 

368 

369 if cbar_norm0: 

370 # normalize colors around 0 

371 if abs(zzmin) > abs(zzmax): 

372 vnorm = abs(zzmin) 

373 else: 

374 vnorm = abs(zzmax) 

375 norm = cm.colors.Normalize(vmin=-vnorm, vmax=vnorm) 

376 else: 

377 # normalize colors from min to max 

378 norm = cm.colors.Normalize(vmin=zzmin, vmax=zzmax) 

379 

380 extent = (x.min(), x.max(), y.min(), y.max()) 

381 levels = np.linspace(zzmin, zzmax, cont_levels) 

382 

383 ### FIGURE LAYOUT ### 

384 if replace: 

385 plt.close(figname) 

386 self.fig = plt.figure(num=figname, figsize=figsize, dpi=figdpi) 

387 if replace: 

388 self.fig.clear() 

389 

390 # 1 DATA SET WITH OR WITHOUT LINE CUTS 

391 if lcuts: 

392 gs = gridspec.GridSpec(3, 3) # 3x3 grid 

393 self.plane = plt.subplot(gs[:, :-1]) # plane 

394 self.lxcut = plt.subplot(gs[0, 2]) # cut along x-axis 

395 self.ldcut = plt.subplot(gs[1, 2]) # cut along d-axis (diagonal) 

396 self.lycut = plt.subplot(gs[2, 2]) # cut along y-axis 

397 else: 

398 self.plane = self.fig.add_subplot(111) # plot olny plane 

399 

400 # plane 

401 if title: 

402 self.plane.set_title(title) 

403 self.plane.set_xlabel(xlabel) 

404 self.plane.set_ylabel(ylabel) 

405 if xmin and xmax: 

406 self.plane.set_xlim(xmin, xmax) 

407 if ymin and ymax: 

408 self.plane.set_ylim(ymin, ymax) 

409 

410 # contour mode: 'contf' or 'imshow' 

411 if cont_imshow: 

412 self.contf = self.plane.imshow( 

413 zz, origin="lower", extent=extent, cmap=cmap, norm=norm 

414 ) 

415 else: 

416 self.contf = self.plane.contourf( 

417 x, y, zz, levels, cmap=cm.get_cmap(cmap, len(levels) - 1), norm=norm 

418 ) 

419 

420 if "line" in cont_type.lower(): 

421 self.cont = self.plane.contour( 

422 x, y, zz, levels, colors="k", hold="on", linewidths=cont_lwidths 

423 ) 

424 if x_nticks: 

425 self.plane.xaxis.set_major_locator(MaxNLocator(int(x_nticks))) 

426 else: 

427 self.plane.xaxis.set_major_locator(AutoLocator()) 

428 if y_nticks: 

429 self.plane.yaxis.set_major_locator(MaxNLocator(int(y_nticks))) 

430 else: 

431 self.plane.yaxis.set_major_locator(AutoLocator()) 

432 

433 # colorbar 

434 if cbar_show: 

435 self.cbar = self.fig.colorbar( 

436 self.contf, use_gridspec=True, orientation=cbar_pos 

437 ) 

438 if cbar_nticks: 

439 self.cbar.set_ticks(MaxNLocator(int(y_nticks))) 

440 else: 

441 self.cbar.set_ticks(AutoLocator()) 

442 self.cbar.set_label(cbar_label) 

443 

444 # xcut plot 

445 if lcuts and xcut: 

446 xpos = np.argmin(np.abs(xcut - x)) 

447 if _xyshape == 1: 

448 self.lxcut.plot( 

449 y, zz[:, xpos], label=str(x[xpos]), color=lc_color, linewidth=lc_lw 

450 ) 

451 elif _xyshape == 2: 

452 self.lxcut.plot( 

453 y[:, xpos], 

454 zz[:, xpos], 

455 label=str(x[:, xpos][0]), 

456 color=lc_color, 

457 linewidth=lc_lw, 

458 ) 

459 if y_nticks: 

460 self.lxcut.xaxis.set_major_locator( 

461 MaxNLocator(int(y_nticks / lc_dticks)) 

462 ) 

463 else: 

464 self.lxcut.xaxis.set_major_locator(AutoLocator()) 

465 if z_nticks: 

466 self.lxcut.yaxis.set_major_locator( 

467 MaxNLocator(int(z_nticks / lc_dticks)) 

468 ) 

469 else: 

470 self.lxcut.yaxis.set_major_locator(AutoLocator()) 

471 self.lxcut.set_yticklabels([]) 

472 self.lxcut.set_ylabel(zlabel) 

473 self.lxcut.set_xlabel(ylabel) 

474 if ymin and ymax: 

475 self.lxcut.set_xlim(ymin, ymax) 

476 

477 # ycut plot 

478 if lcuts and ycut: 

479 ypos = np.argmin(np.abs(ycut - y)) 

480 if _xyshape == 1: 

481 self.lycut.plot( 

482 x, zz[ypos, :], label=str(y[ypos]), color=lc_color, linewidth=lc_lw 

483 ) 

484 elif _xyshape == 2: 

485 self.lycut.plot( 

486 x[ypos, :], 

487 zz[ypos, :], 

488 label=str(y[ypos, :][0]), 

489 color=lc_color, 

490 linewidth=lc_lw, 

491 ) 

492 if x_nticks: 

493 self.lycut.xaxis.set_major_locator( 

494 MaxNLocator(int(x_nticks / lc_dticks)) 

495 ) 

496 else: 

497 self.lycut.xaxis.set_major_locator(AutoLocator()) 

498 if z_nticks: 

499 self.lycut.yaxis.set_major_locator( 

500 MaxNLocator(int(z_nticks / lc_dticks)) 

501 ) 

502 else: 

503 self.lycut.yaxis.set_major_locator(AutoLocator()) 

504 self.lycut.set_yticklabels([]) 

505 self.lycut.set_ylabel(zlabel) 

506 self.lycut.set_xlabel(xlabel) 

507 if xmin and xmax: 

508 self.lycut.set_xlim(xmin, xmax) 

509 

510 # dcut plot => equivalent to ycut plot for (zz0, x0, y0) 

511 if lcuts and dcut: 

512 ypos0 = np.argmin(np.abs(dcut - y0)) 

513 if _xyshape == 1: 

514 self.ldcut.plot( 

515 x0, 

516 zz0[ypos0, :], 

517 label=str(y0[ypos0]), 

518 color=lc_color, 

519 linewidth=lc_lw, 

520 ) 

521 elif _xyshape == 2: 

522 self.ldcut.plot( 

523 x0[ypos0, :], 

524 zz0[ypos0, :], 

525 label=str(y0[ypos0, :][0]), 

526 color=lc_color, 

527 linewidth=lc_lw, 

528 ) 

529 if x_nticks: 

530 self.ldcut.xaxis.set_major_locator( 

531 MaxNLocator(int(x_nticks / lc_dticks)) 

532 ) 

533 else: 

534 self.ldcut.xaxis.set_major_locator(AutoLocator()) 

535 if z_nticks: 

536 self.ldcut.yaxis.set_major_locator( 

537 MaxNLocator(int(z_nticks / lc_dticks)) 

538 ) 

539 else: 

540 self.ldcut.yaxis.set_major_locator(AutoLocator()) 

541 self.ldcut.set_yticklabels([]) 

542 self.ldcut.set_ylabel(zlabel) 

543 self.ldcut.set_xlabel(xlabel) 

544 if xmin and xmax: 

545 self.ldcut.set_xlim(xmin, xmax) 

546 plt.draw() 

547 plt.show() 

548 

549 

550if __name__ == "__main__": 

551 pass