Coverage for /Users/Newville/Codes/xraylarch/larch/wxxas/xas_controller.py: 12%

375 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1import os 

2import copy 

3import time 

4import shutil 

5from glob import glob 

6import numpy as np 

7from copy import deepcopy 

8 

9import wx 

10 

11import larch 

12from larch import Group, Journal, Entry 

13from larch.larchlib import read_config, save_config 

14from larch.utils import (group2dict, unique_name, fix_varname, get_cwd, 

15 asfloat, get_sessionid, mkdir) 

16from larch.wxlib.plotter import last_cursor_pos 

17from larch.wxlib import ExceptionPopup 

18from larch.io import fix_varname, save_session 

19from larch.site_config import home_dir, user_larchdir 

20 

21from .config import XASCONF, CONF_FILE, OLDCONF_FILE 

22 

23class XASController(): 

24 """ 

25 class holding the Larch session and doing the processing work for XAS GUI 

26 """ 

27 def __init__(self, wxparent=None, _larch=None): 

28 self.wxparent = wxparent 

29 self.filelist = None 

30 self.group = None 

31 self.groupname = None 

32 self.plot_erange = None 

33 self.report_frame = None 

34 self.recentfiles = [] 

35 self.larch = _larch 

36 if _larch is None: 

37 self.larch = larch.Interpreter() 

38 self.larix_folder = os.path.join(user_larchdir, 'larix') 

39 self.config_file = os.path.join(self.larix_folder, CONF_FILE) 

40 self.init_larch_session() 

41 self.init_workdir() 

42 

43 def init_larch_session(self): 

44 self.symtable = self.larch.symtable 

45 self.file_groups = self.symtable._xasgroups = {} 

46 

47 config = {} 

48 config.update(XASCONF) 

49 # may migrate old 'xas_viewer' folder to 'larix' folder 

50 xasv_folder = os.path.join(user_larchdir, 'xas_viewer') 

51 if (os.path.exists(xasv_folder) and not os.path.exists(self.larix_folder)): 

52 print("Migrating xas_viewer to larix folder") 

53 shutil.move(xasv_folder, self.larix_folder) 

54 

55 if not os.path.exists(self.larix_folder): 

56 try: 

57 mkdir(self.larix_folder) 

58 except: 

59 title = "Cannot create Larix folder" 

60 message = [f"Cannot create directory {larix_folder}"] 

61 ExceptionPopup(self, title, message) 

62 

63 

64 # may migrate old 'xas_viewer.conf' file to 'larix.conf' 

65 old_config_file = os.path.join(self.larix_folder, OLDCONF_FILE) 

66 if (os.path.exists(old_config_file) 

67 and not os.path.exists(self.config_file)): 

68 shutil.move(old_config_file, self.config_file) 

69 

70 if os.path.exists(self.config_file): 

71 user_config = read_config(self.config_file) 

72 if user_config is not None: 

73 for sname in config: 

74 if sname in user_config: 

75 val = user_config[sname] 

76 if isinstance(val, dict): 

77 for k, v in val.items(): 

78 config[sname][k] = v 

79 else: 

80 config[sname] = val 

81 

82 self.config = self.larch.symtable._sys.larix_config = config 

83 self.larch.symtable._sys.wx.plotopts = config['plot'] 

84 self.clean_autosave_sessions() 

85 

86 

87 def install_group(self, groupname, filename, source=None, journal=None): 

88 """add groupname / filename to list of available data groups""" 

89 

90 try: 

91 thisgroup = getattr(self.symtable, groupname) 

92 except AttributeError: 

93 thisgroup = self.symtable.new_group(groupname) 

94 

95 # file /group may already exist in list 

96 if filename in self.file_groups: 

97 fbase, i = filename, 0 

98 while i < 50000 and filename in self.file_groups: 

99 filename = f"{fbase}_{i}" 

100 i += 1 

101 if i >= 50000: 

102 raise ValueError(f"Too many repeated filenames: {fbase}") 

103 

104 filename = filename.strip() 

105 if source is None: 

106 source = filename 

107 

108 jopts = f"source='{source}'" 

109 if isinstance(journal, dict): 

110 jnl = {'source': f"{source}"} 

111 jnl.update(journal) 

112 jopts = ', '.join([f"{k}='{v}'" for k, v in jnl.items()]) 

113 elif isinstance(journal, (list, Journal)): 

114 jopts = repr(journal) 

115 

116 cmds = [f"{groupname:s}.groupname = '{groupname:s}'", 

117 f"{groupname:s}.filename = '{filename:s}'"] 

118 needs_config = not hasattr(thisgroup, 'config') 

119 if needs_config: 

120 cmds.append(f"{groupname:s}.config = group(__name__='larix config')") 

121 

122 cmds.append(f"{groupname:s}.journal = journal({jopts:s})") 

123 

124 datatype = getattr(thisgroup, 'datatype', 'raw') 

125 if datatype == 'xas': 

126 cmds.append(f"{groupname:s}.energy_orig = {groupname:s}.energy[:]") 

127 array_labels = getattr(thisgroup, 'array_labels', []) 

128 if len(array_labels) > 2 and getattr(thisgroup, 'data', None) is not None: 

129 for i0name in ('i0', 'i_0', 'monitor'): 

130 if i0name in array_labels: 

131 i0x = array_labels.index(i0name) 

132 cmds.append(f"{groupname:s}.i0 = {groupname:s}.data[{i0x}, :]") 

133 

134 self.larch.eval('\n'.join(cmds)) 

135 

136 if needs_config: 

137 self.init_group_config(thisgroup) 

138 

139 self.file_groups[filename] = groupname 

140 self.filelist.Append(filename) 

141 self.filelist.SetStringSelection(filename) 

142 self.sync_xasgroups() 

143 return filename 

144 

145 def sync_xasgroups(self): 

146 "make sure `_xasgroups` is identical to file_groups" 

147 if self.file_groups != self.symtable._xasgroups: 

148 self.symtable._xasgroups = self.file_groups 

149 

150 def get_config(self, key, default=None): 

151 "get top-level, program-wide configuration setting" 

152 if key not in self.config: 

153 return default 

154 return deepcopy(self.config[key]) 

155 

156 def init_group_config(self, dgroup): 

157 """set up 'config' group with values from self.config""" 

158 if not hasattr(dgroup, 'config'): 

159 dgroup.config = larch.Group(__name__='larix config') 

160 

161 for sect in ('exafs', 'feffit', 'lincombo', 'pca', 'prepeaks', 

162 'regression', 'xasnorm'): 

163 setattr(dgroup.config, sect, deepcopy(self.config[sect])) 

164 

165 def get_plot_conf(self): 

166 """get basic plot options to pass to plot() ** not window sizes **""" 

167 dx = {'linewidth': 3, 'markersize': 4, 

168 'show_grid': True, 'show_fullbox': True, 'theme': 'light'} 

169 pconf = self.config['plot'] 

170 out = {} 

171 for attr, val in dx.items(): 

172 out[attr] = pconf.get(attr, val) 

173 return out 

174 

175 def save_config(self): 

176 """save configuration""" 

177 save_config(self.config_file, self.config) 

178 

179 def chdir_on_fileopen(self): 

180 return self.config['main']['chdir_on_fileopen'] 

181 

182 def set_workdir(self): 

183 self.config['main']['workdir'] = get_cwd() 

184 

185 def save_workdir(self): 

186 """save last workdir and recent session files""" 

187 try: 

188 with open(os.path.join(self.larix_folder, 'workdir.txt'), 'w') as fh: 

189 fh.write("%s\n" % get_cwd()) 

190 except: 

191 pass 

192 

193 buffer = [] 

194 rfiles = [] 

195 for tstamp, fname in sorted(self.recentfiles, key=lambda x: x[0], reverse=True)[:10]: 

196 if fname not in rfiles: 

197 buffer.append(f"{tstamp:.1f} {fname:s}") 

198 rfiles.append(fname) 

199 buffer.append('') 

200 buffer = '\n'.join(buffer) 

201 

202 try: 

203 with open(os.path.join(self.larix_folder, 'recent_sessions.txt'), 'w') as fh: 

204 fh.write(buffer) 

205 except: 

206 pass 

207 

208 def init_workdir(self): 

209 """set initial working folder, read recent session files""" 

210 if self.config['main'].get('use_last_workdir', False): 

211 wfile = os.path.join(self.larix_folder, 'workdir.txt') 

212 if os.path.exists(wfile): 

213 try: 

214 with open(wfile, 'r') as fh: 

215 workdir = fh.readlines()[0][:-1] 

216 self.config['main']['workdir'] = workdir 

217 except: 

218 pass 

219 try: 

220 os.chdir(self.config['main']['workdir']) 

221 except: 

222 pass 

223 

224 rfile = os.path.join(self.larix_folder, 'recent_sessions.txt') 

225 if os.path.exists(rfile): 

226 with open(rfile, 'r') as fh: 

227 for line in fh.readlines(): 

228 if len(line) < 2 or line.startswith('#'): 

229 continue 

230 try: 

231 w = line[:-1].split(None, maxsplit=1) 

232 self.recentfiles.insert(0, (float(w[0]), w[1])) 

233 except: 

234 pass 

235 

236 

237 def autosave_session(self): 

238 conf = self.get_config('autosave', {}) 

239 fileroot = conf.get('fileroot', 'autosave') 

240 nhistory = max(8, int(conf.get('nhistory', 4))) 

241 

242 fname = f"{fileroot:s}_{get_sessionid():s}.larix" 

243 savefile = os.path.join(self.larix_folder, fname) 

244 for i in reversed(range(1, nhistory)): 

245 curf = savefile.replace('.larix', f'_{i:d}.larix' ) 

246 if os.path.exists(curf): 

247 newf = savefile.replace('.larix', f'_{i+1:d}.larix' ) 

248 shutil.move(curf, newf) 

249 if os.path.exists(savefile): 

250 curf = savefile.replace('.larix', '_1.larix' ) 

251 shutil.move(savefile, curf) 

252 save_session(savefile, _larch=self.larch) 

253 return savefile 

254 

255 def clean_autosave_sessions(self): 

256 conf = self.get_config('autosave', {}) 

257 fileroot = conf.get('fileroot', 'autosave') 

258 max_hist = int(conf.get('maxfiles', 10)) 

259 

260 def get_autosavefiles(): 

261 dat = [] 

262 for afile in os.listdir(self.larix_folder): 

263 ffile = os.path.join(self.larix_folder, afile) 

264 if afile.endswith('.larix'): 

265 mtime = os.stat(ffile).st_mtime 

266 words = afile.replace('.larix', '').split('_') 

267 try: 

268 version = int(words[-1]) 

269 words.pop() 

270 except: 

271 version = 0 

272 dat.append((ffile, version, mtime)) 

273 return sorted(dat, key=lambda x: x[2]) 

274 

275 dat = get_autosavefiles() 

276 nremove = max(0, len(dat) - max_hist) 

277 # first remove oldest "version > 0" files 

278 while nremove > 0 and len(dat) > 0: 

279 dfile, version, mtime = dat.pop(0) 

280 if version > 0: 

281 os.unlink(dfile) 

282 nremove -= 1 

283 

284 dat = get_autosavefiles() 

285 nremove = max(0, len(dat) - max_hist) 

286 # then remove the oldest "version 0" files 

287 

288 while nremove > 0 and len(dat) > 0: 

289 dfile, vers, mtime = dat.pop(0) 

290 if vers == 0 and abs(mtime - time.time()) > 86400: 

291 os.unlink(dfile) 

292 nremove -= 1 

293 

294 def get_recentfiles(self, max=10): 

295 return sorted(self.recentfiles, key=lambda x: x[0], reverse=True)[:max] 

296 

297 def recent_autosave_sessions(self): 

298 "return list of (timestamp, name) for most recent autosave session files" 

299 conf = self.get_config('autosave', {}) 

300 fileroot = conf.get('fileroot', 'autosave') 

301 max_hist = int(conf.get('maxfiles', 10)) 

302 flist = [] 

303 for afile in os.listdir(self.larix_folder): 

304 ffile = os.path.join(self.larix_folder, afile) 

305 if ffile.endswith('.larix'): 

306 mtime = os.stat(ffile).st_mtime 

307 flist.append((os.stat(ffile).st_mtime, ffile)) 

308 

309 return sorted(flist, key=lambda x: x[0], reverse=True)[:max_hist] 

310 

311 

312 def clear_session(self): 

313 self.larch.eval("clear_session()") 

314 self.filelist.Clear() 

315 self.init_larch_session() 

316 

317 

318 def write_message(self, msg, panel=0): 

319 """write a message to the Status Bar""" 

320 self.wxparent.statusbar.SetStatusText(msg, panel) 

321 

322 def close_all_displays(self): 

323 "close all displays, as at exit" 

324 self.symtable._plotter.close_all_displays() 

325 

326 def get_display(self, win=1, stacked=False): 

327 wintitle='Larch XAS Plot Window %i' % win 

328 

329 conf = self.get_config('plot') 

330 opts = dict(wintitle=wintitle, stacked=stacked, win=win) 

331 opts.update(conf) 

332 return self.symtable._plotter.get_display(**opts) 

333 

334 def set_focus(self, topwin=None): 

335 """ 

336 set wx focus to main window or selected Window, 

337 even after plot 

338 """ 

339 if topwin is None: 

340 topwin = wx.GetApp().GetTopWindow() 

341 flist = self.filelist 

342 else: 

343 flist = getattr(topwin, 'filelist', topwin) 

344 time.sleep(0.025) 

345 topwin.Raise() 

346 flist.SetFocus() 

347 

348 def get_group(self, groupname=None): 

349 if groupname is None: 

350 groupname = self.groupname 

351 if groupname is None: 

352 return None 

353 dgroup = getattr(self.symtable, groupname, None) 

354 if dgroup is None and groupname in self.file_groups: 

355 groupname = self.file_groups[groupname] 

356 dgroup = getattr(self.symtable, groupname, None) 

357 

358 if dgroup is None and len(self.file_groups) > 0: 

359 gname = list(self.file_groups.keys())[0] 

360 dgroup = getattr(self.symtable, gname, None) 

361 return dgroup 

362 

363 def filename2group(self, filename): 

364 "convert filename (as displayed) to larch group" 

365 return self.get_group(self.file_groups[str(filename)]) 

366 

367 def merge_groups(self, grouplist, master=None, yarray='mu', outgroup=None): 

368 """merge groups""" 

369 cmd = """%s = merge_groups(%s, master=%s, 

370 xarray='energy', yarray='%s', kind='cubic', trim=True) 

371 """ 

372 glist = "[%s]" % (', '.join(grouplist)) 

373 outgroup = fix_varname(outgroup.lower()) 

374 if outgroup is None: 

375 outgroup = 'merged' 

376 

377 outgroup = unique_name(outgroup, self.file_groups, max=1000) 

378 

379 cmd = cmd % (outgroup, glist, master, yarray) 

380 self.larch.eval(cmd) 

381 

382 if master is None: 

383 master = grouplist[0] 

384 this = self.get_group(outgroup) 

385 master = self.get_group(master) 

386 if not hasattr(master, 'config'): 

387 self.init_group_config(master) 

388 if not hasattr(this, 'config'): 

389 self.init_group_config(this) 

390 this.config.xasnorm.update(master.config.xasnorm) 

391 this.datatype = master.datatype 

392 this.xdat = 1.0*this.energy 

393 this.ydat = 1.0*getattr(this, yarray) 

394 this.yerr = getattr(this, 'd' + yarray, 1.0) 

395 if yarray != 'mu': 

396 this.mu = this.ydat 

397 this.plot_xlabel = 'energy' 

398 this.plot_ylabel = yarray 

399 return this 

400 

401 def set_plot_erange(self, erange): 

402 self.plot_erange = erange 

403 

404 def copy_group(self, filename, new_filename=None): 

405 """copy XAS group (by filename) to new group""" 

406 groupname = self.file_groups[filename] 

407 if not hasattr(self.larch.symtable, groupname): 

408 return 

409 

410 ogroup = self.get_group(groupname) 

411 ngroup = larch.Group(datatype=ogroup.datatype, copied_from=groupname) 

412 

413 for attr in dir(ogroup): 

414 val = getattr(ogroup, attr, None) 

415 if val is not None: 

416 setattr(ngroup, attr, copy.deepcopy(val)) 

417 

418 if new_filename is None: 

419 new_filename = filename + '_1' 

420 ngroup.filename = unique_name(new_filename, self.file_groups.keys()) 

421 ngroup.groupname = unique_name(groupname, self.file_groups.values()) 

422 ngroup.journal.add('source_desc', f"copied from '{filename:s}'") 

423 setattr(self.larch.symtable, ngroup.groupname, ngroup) 

424 return ngroup 

425 

426 def get_cursor(self, win=None): 

427 """get last cursor from selected window""" 

428 return last_cursor_pos(win=win, _larch=self.larch) 

429 

430 def plot_group(self, groupname=None, title=None, plot_yarrays=None, 

431 new=True, **kws): 

432 ppanel = self.get_display(stacked=False).panel 

433 newplot = ppanel.plot 

434 oplot = ppanel.oplot 

435 plotcmd = oplot 

436 viewlims = ppanel.get_viewlimits() 

437 if new: 

438 plotcmd = newplot 

439 

440 dgroup = self.get_group(groupname) 

441 if not hasattr(dgroup, 'xdat'): 

442 print("Cannot plot group ", groupname) 

443 

444 if ((getattr(dgroup, 'plot_yarrays', None) is None or 

445 getattr(dgroup, 'energy', None) is None or 

446 getattr(dgroup, 'mu', None) is None)): 

447 self.process(dgroup) 

448 

449 if plot_yarrays is None and hasattr(dgroup, 'plot_yarrays'): 

450 plot_yarrays = dgroup.plot_yarrays 

451 

452 popts = kws 

453 path, fname = os.path.split(dgroup.filename) 

454 if not 'label' in popts: 

455 popts['label'] = dgroup.plot_ylabel 

456 

457 popts['xlabel'] = dgroup.plot_xlabel 

458 popts['ylabel'] = dgroup.plot_ylabel 

459 if getattr(dgroup, 'plot_y2label', None) is not None: 

460 popts['y2label'] = dgroup.plot_y2label 

461 

462 plot_extras = None 

463 if new: 

464 if title is None: 

465 title = fname 

466 plot_extras = getattr(dgroup, 'plot_extras', None) 

467 

468 popts['title'] = title 

469 

470 narr = len(plot_yarrays) - 1 

471 for i, pydat in enumerate(plot_yarrays): 

472 yaname, yopts, yalabel = pydat 

473 popts.update(yopts) 

474 if yalabel is not None: 

475 popts['label'] = yalabel 

476 popts['delay_draw'] = (i != narr) 

477 

478 plotcmd(dgroup.xdat, getattr(dgroup, yaname), **popts) 

479 plotcmd = oplot 

480 

481 if plot_extras is not None: 

482 axes = ppanel.axes 

483 for etype, x, y, opts in plot_extras: 

484 if etype == 'marker': 

485 popts = {'marker': 'o', 'markersize': 4, 

486 'label': '_nolegend_', 

487 'markerfacecolor': 'red', 

488 'markeredgecolor': '#884444'} 

489 popts.update(opts) 

490 axes.plot([x], [y], **popts) 

491 elif etype == 'vline': 

492 popts = {'ymin': 0, 'ymax': 1.0, 

493 'color': '#888888'} 

494 popts.update(opts) 

495 axes.axvline(x, **popts) 

496 ppanel.canvas.draw() 

497 self.set_focus()