Coverage for /Users/Newville/Codes/xraylarch/larch/wxxas/pca_panel.py: 14%

280 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2""" 

3Linear Combination panel 

4""" 

5import os 

6import time 

7import wx 

8import wx.lib.scrolledpanel as scrolled 

9import wx.dataview as dv 

10 

11import numpy as np 

12 

13from functools import partial 

14 

15from larch import Group 

16from larch.math import index_of 

17from larch.math.lincombo_fitting import get_arrays 

18from larch.utils import get_cwd, gformat 

19 

20from larch.xafs import etok, ktoe 

21from larch.wxlib import (BitmapButton, FloatCtrl, get_icon, SimpleText, 

22 pack, Button, HLine, Choice, Check, CEN, RIGHT, 

23 LEFT, Font, FileSave, FileOpen, DataTableGrid) 

24 

25from .taskpanel import TaskPanel, autoset_fs_increment 

26from .config import Linear_ArrayChoices 

27 

28np.seterr(all='ignore') 

29 

30DVSTYLE = dv.DV_SINGLE|dv.DV_VERT_RULES|dv.DV_ROW_LINES 

31 

32# max number of *reported* PCA weights after fit 

33MAX_COMPS = 30 

34 

35class PCAPanel(TaskPanel): 

36 """PCA Panel""" 

37 

38 def __init__(self, parent, controller, **kws): 

39 TaskPanel.__init__(self, parent, controller, panel='pca', **kws) 

40 self.result = None 

41 

42 def process(self, dgroup, **kws): 

43 """ handle PCA processing""" 

44 if self.skip_process: 

45 return 

46 form = self.read_form() 

47 

48 def build_display(self): 

49 panel = self.panel 

50 wids = self.wids 

51 self.skip_process = True 

52 

53 wids['fitspace'] = Choice(panel, choices=list(Linear_ArrayChoices.keys()), 

54 action=self.onFitSpace, size=(175, -1)) 

55 wids['fitspace'].SetSelection(0) 

56 

57 add_text = self.add_text 

58 

59 opts = dict(digits=2, increment=1.0) 

60 defaults = self.get_defaultconfig() 

61 self.make_fit_xspace_widgets(elo=defaults['elo_rel'], ehi=defaults['ehi_rel']) 

62 

63 w_wmin = self.add_floatspin('weight_min', digits=4, 

64 value=defaults['weight_min'], 

65 increment=0.0005, with_pin=False, 

66 min_val=0, max_val=0.5, 

67 action=self.onSet_WeightMin) 

68 

69 self.wids['weight_auto'] = Check(panel, default=True, label='auto?') 

70 

71 w_mcomps = self.add_floatspin('max_components', digits=0, 

72 value=defaults['max_components'], 

73 increment=1, with_pin=False, min_val=0) 

74 

75 wids['build_model'] = Button(panel, 'Build Model With Selected Groups', 

76 size=(250, -1), action=self.onBuildPCAModel) 

77 

78 wids['plot_model'] = Button(panel, 'Plot Components and Statistics', 

79 size=(250, -1), action=self.onPlotPCAModel) 

80 

81 wids['fit_group'] = Button(panel, 'Test Current Group with Model', size=(250, -1), 

82 action=self.onFitGroup) 

83 wids['fit_selected'] = Button(panel, 'Test Selected Groups with Model', size=(250, -1), 

84 action=self.onFitSelected) 

85 

86 wids['save_model'] = Button(panel, 'Save PCA Model', size=(125, -1), 

87 action=self.onSavePCAModel) 

88 wids['load_model'] = Button(panel, 'Load PCA Model', size=(125, -1), 

89 action=self.onLoadPCAModel) 

90 

91 wids['fit_group'].Disable() 

92 wids['fit_selected'].Disable() 

93 wids['load_model'].Enable() 

94 wids['save_model'].Disable() 

95 

96 collabels = [' Variance ', ' IND value ', 'IND/IND_Best'] 

97 colsizes = [125, 125, 125] 

98 coltypes = ['float:12,6', 'float:12,6', 'float:12,5'] 

99 coldefs = [0.0, 0.0, 1.0] 

100 

101 wids['pca_table'] = DataTableGrid(panel, nrows=MAX_COMPS, 

102 collabels=collabels, 

103 datatypes=coltypes, 

104 defaults=coldefs, 

105 colsizes=colsizes, rowlabelsize=60) 

106 

107 wids['pca_table'].SetMinSize((500, 150)) 

108 wids['pca_table'].EnableEditing(False) 

109 

110 

111 collabels = [' Group ', ' Chi-square ', ' Scale '] 

112 colsizes = [200, 80, 80] 

113 coltypes = ['string', 'string', 'string'] 

114 coldefs = [' ', '0.0', '1.0'] 

115 for i in range(MAX_COMPS): 

116 collabels.append(f'Comp {i+1:d}') 

117 colsizes.append(80) 

118 coltypes.append('string') 

119 coldefs.append('0.0') 

120 

121 wids['fit_table'] = DataTableGrid(panel, nrows=50, 

122 collabels=collabels, 

123 datatypes=coltypes, 

124 defaults=coldefs, 

125 colsizes=colsizes, rowlabelsize=60) 

126 

127 wids['fit_table'].SetMinSize((700, 200)) 

128 wids['fit_table'].EnableEditing(False) 

129 

130 

131 wids['status'] = SimpleText(panel, ' ') 

132 

133 panel.Add(SimpleText(panel, 'Principal Component Analysis', 

134 size=(350, -1), **self.titleopts), style=LEFT, dcol=4) 

135 

136 add_text('Array to Use: ', newrow=True) 

137 panel.Add(wids['fitspace'], dcol=2) 

138 panel.Add(wids['fitspace_label'], newrow=True) 

139 panel.Add(self.elo_wids) 

140 add_text(' : ', newrow=False) 

141 panel.Add(self.ehi_wids) 

142 # panel.Add(wids['show_fitrange']) 

143 

144 panel.Add(wids['load_model'], dcol=1, newrow=True) 

145 panel.Add(wids['save_model'], dcol=1) 

146 

147 panel.Add(wids['build_model'], dcol=3, newrow=True) 

148 panel.Add(wids['plot_model'], dcol=2) 

149 

150 

151 add_text('Min Weight: ') 

152 panel.Add(w_wmin) 

153 panel.Add(wids['weight_auto'], dcol=2) 

154 panel.Add(wids['pca_table'], dcol=6, newrow=True) 

155 

156 add_text('Status: ') 

157 panel.Add(wids['status'], dcol=6) 

158 

159 

160 panel.Add(HLine(panel, size=(550, 2)), dcol=5, newrow=True) 

161 

162 add_text('Use this PCA Model : ', dcol=1, newrow=True) 

163 add_text('Max Components:', dcol=1, newrow=False) 

164 panel.Add(w_mcomps, dcol=2) 

165 

166 panel.Add(wids['fit_group'], dcol=3, newrow=True) 

167 panel.Add(wids['fit_selected'], dcol=3, newrow=False) 

168 

169 panel.Add(wids['fit_table'], dcol=6, newrow=True) 

170 

171 

172 panel.pack() 

173 sizer = wx.BoxSizer(wx.VERTICAL) 

174 sizer.Add((10, 10), 0, LEFT, 3) 

175 sizer.Add(panel, 1, LEFT, 3) 

176 pack(self, sizer) 

177 self.skip_process = False 

178 

179 def onSet_WeightMin(self, evt=None, value=None): 

180 "handle setting edge step" 

181 wmin = self.wids['weight_min'].GetValue() 

182 self.wids['weight_auto'].SetValue(0) 

183 self.update_config({'weight_min': wmin}) 

184 # autoset_fs_increment(self.wids['weight_min'], wmin) 

185 

186 def fill_form(self, dgroup): 

187 opts = self.get_config(dgroup, with_erange=True) 

188 self.dgroup = dgroup 

189 

190 self.skip_process = True 

191 wids = self.wids 

192 for attr in ('elo', 'ehi', 'weight_min'): 

193 val = opts.get(attr, None) 

194 if val is not None: 

195 wids[attr].SetValue(val) 

196 

197 for attr in ('fitspace',): 

198 if attr in opts: 

199 wids[attr].SetStringSelection(opts[attr]) 

200 

201 self.skip_process = False 

202 

203 def onPanelExposed(self, **kws): 

204 # called when notebook is selected 

205 fname = self.controller.filelist.GetStringSelection() 

206 if fname in self.controller.file_groups: 

207 gname = self.controller.file_groups[fname] 

208 dgroup = self.controller.get_group(gname) 

209 self.ensure_xas_processed(dgroup) 

210 self.fill_form(dgroup) 

211 self.process(dgroup=dgroup) 

212 

213 if hasattr(self.larch.symtable, 'pca_result'): 

214 self.use_model(plot=False) 

215 

216 def onCopyParam(self, name=None, evt=None): 

217 conf = self.get_config() 

218 conf.update(self.read_form()) 

219 attrs = ('elo', 'ehi', 'weight_min', 

220 'max_components', 'fitspace') 

221 

222 out = {a: conf[a] for a in attrs} 

223 for checked in self.controller.filelist.GetCheckedStrings(): 

224 groupname = self.controller.file_groups[str(checked)] 

225 dgroup = self.controller.get_group(groupname) 

226 self.update_config(out, dgroup=dgroup) 

227 

228 

229 def plot_pca_weights(self, win=2): 

230 if self.result is None or self.skip_plotting: 

231 return 

232 self.larch_eval(f"plot_pca_weights(pca_result, win={win:d})") 

233 self.controller.set_focus() 

234 

235 

236 def plot_pca_components(self, win=1): 

237 if self.result is None or self.skip_plotting: 

238 return 

239 self.larch_eval(f"plot_pca_components(pca_result, win={win:d})") 

240 self.controller.set_focus() 

241 

242 def plot_pca_fit(self, win=1): 

243 if self.result is None or self.skip_plotting: 

244 return 

245 dgroup = self.controller.get_group() 

246 if hasattr(dgroup, 'pca_result'): 

247 self.larch_eval(f"plot_pca_fit({dgroup.groupname:s}, with_components=True, win={win:d})") 

248 self.controller.set_focus() 

249 

250 def onPlot(self, event=None): 

251 form = self.read_form() 

252 if not self.skip_plotting: 

253 self.plot_pca_fit() 

254 

255 def onFitSpace(self, evt=None): 

256 fitspace = self.wids['fitspace'].GetStringSelection() 

257 self.update_config(dict(fitspace=fitspace)) 

258 arrname = Linear_ArrayChoices.get(fitspace, 'norm') 

259 self.update_fit_xspace(arrname) 

260 

261 

262 def onFitSelected(self, event=None): 

263 form = self.read_form() 

264 if self.result is None: 

265 print("need result first!") 

266 ncomps = int(form['max_components']) 

267 

268 selected_groups = self.controller.filelist.GetCheckedStrings() 

269 groups = [self.controller.file_groups[cn] for cn in selected_groups] 

270 grid_data = [] 

271 fnames = [] 

272 for gname in groups: 

273 grp = self.controller.get_group(gname) 

274 if not hasattr(grp, 'norm'): 

275 self.xasmain.process_normalization(grp) 

276 cmd = f"pca_fit({gname:s}, pca_result, ncomps={ncomps:d})" 

277 self.larch_eval(cmd) 

278 grp.journal.add('pca_fit', cmd) 

279 

280 _data = [grp.filename, 

281 gformat(grp.pca_result.chi_square), 

282 gformat(grp.pca_result.data_scale)] 

283 _data.extend([gformat(w) for w in grp.pca_result.weights]) 

284 grid_data.append(_data) 

285 fnames.append(grp.filename) 

286 

287 for row in self.wids['fit_table'].table.data: 

288 if len(row) < 2 or row[0] not in fnames: 

289 grid_data.append(row) 

290 

291 self.wids['fit_table'].table.data = grid_data 

292 self.wids['fit_table'].table.View.Refresh() 

293 self.plot_pca_fit() 

294 

295 def onFitGroup(self, event=None): 

296 form = self.read_form() 

297 if self.result is None: 

298 print("need result first!") 

299 ncomps = int(form['max_components']) 

300 gname = form['groupname'] 

301 cmd = f"pca_fit({gname:s}, pca_result, ncomps={ncomps:d})" 

302 self.larch_eval(cmd) 

303 

304 dgroup = self.controller.get_group() 

305 dgroup.update_config(form) 

306 dgroup.journal.add('pca_fit', cmd) 

307 

308 thisrow = [dgroup.filename, 

309 gformat(dgroup.pca_result.chi_square), 

310 gformat(dgroup.pca_result.data_scale)] 

311 wts = [gformat(w) for w in dgroup.pca_result.weights] 

312 thisrow.extend(wts) 

313 grid_data = [thisrow] 

314 for row in self.wids['fit_table'].table.data: 

315 if len(row) < 2 or row[0] != dgroup.filename: 

316 grid_data.append(row) 

317 

318 

319 self.wids['fit_table'].table.data = grid_data 

320 self.wids['fit_table'].table.View.Refresh() 

321 self.plot_pca_fit() 

322 

323 def onBuildPCAModel(self, event=None): 

324 self.wids['status'].SetLabel(" training model...") 

325 form = self.read_form() 

326 selected_groups = self.controller.filelist.GetCheckedStrings() 

327 groups = [self.controller.file_groups[cn] for cn in selected_groups] 

328 for gname in groups: 

329 grp = self.controller.get_group(gname) 

330 if not hasattr(grp, 'norm'): 

331 self.xasmain.process_normalization(grp) 

332 

333 groups = ', '.join(groups) 

334 opts = dict(groups=groups, arr='norm', elo=form['elo'], ehi=form['ehi']) 

335 

336 opts['arr'] = Linear_ArrayChoices.get(form['fitspace'], 'norm') 

337 

338 cmd = "pca_result = pca_train([{groups}], arrayname='{arr}', xmin={elo:.2f}, xmax={ehi:.2f})" 

339 

340 self.larch_eval(cmd.format(**opts)) 

341 self.use_model('pca_result') 

342 

343 def use_model(self, modelname='pca_result', plot=True): 

344 form = self.read_form() 

345 r = self.result = self.larch_get(modelname) 

346 ncomps = len(r.components) 

347 wmin = form['weight_min'] 

348 if self.wids['weight_auto'].GetValue(): 

349 nsig = int(r.nsig) 

350 wmin = r.variances[nsig-1] 

351 if nsig <= len(r.variances): 

352 wmin = (r.variances[nsig] + r.variances[nsig-1])/2.0 

353 self.wids['weight_min'].SetValue(wmin) 

354 else: 

355 nsig = len(np.where(r.variances > wmin)[0]) 

356 

357 

358 status = " Model built, %d of %d components have weight > %.4f" 

359 self.wids['status'].SetLabel(status % (nsig, ncomps, wmin)) 

360 self.wids['max_components'].SetValue(nsig+1) 

361 

362 for b in ('fit_group', 'fit_selected', 'save_model'): 

363 self.wids[b].Enable() 

364 

365 grid_data = [] 

366 ind = [i for i in r.ind] 

367 ind.extend([0,0,0]) 

368 ind_best = ind[nsig] 

369 for i, var in enumerate(r.variances): 

370 grid_data.append([var, ind[i+1], ind[i+1]/ind_best]) 

371 self.wids['pca_table'].table.data = grid_data 

372 self.wids['pca_table'].table.View.Refresh() 

373 self.wids['fit_table'].table.data = [] 

374 self.wids['fit_table'].table.Clear() 

375 self.wids['fit_table'].table.View.Refresh() 

376 if plot: 

377 self.plot_pca_components() 

378 self.plot_pca_weights() 

379 

380 def onPlotPCAModel(self, event=None): 

381 self.plot_pca_components() 

382 self.plot_pca_weights() 

383 

384 def onSavePCAModel(self, event=None): 

385 form = self.read_form() 

386 if self.result is None: 

387 print("need result first!") 

388 retrun 

389 wildcard = 'Larch PCA Model (*.pcamod)|*.pcamod|All files (*.*)|*.*' 

390 fname = time.strftime('%Y%b%d_%H%M.pcamod') 

391 path = FileSave(self, message='Save PCA Model', 

392 wildcard=wildcard, 

393 default_file=fname) 

394 if path is not None: 

395 self.larch.eval(f"save_pca_model(pca_result, '{path:s}')") 

396 self.write_message("Saved PCA Model to '%s'" % path, 0) 

397 

398 def onLoadPCAModel(self, event=None): 

399 form = self.read_form() 

400 wildcard = 'Larch PCA Model (*.pcamod)|*.pcamod|All files (*.*)|*.*' 

401 path = FileOpen(self, message="Read PCA Model", 

402 wildcard=wildcard, default_file='a.pcamod') 

403 if path is None: 

404 return 

405 

406 if hasattr(self.larch.symtable, 'pca_result'): 

407 self.larch.eval("old_pca_result = copy_group(pca_result)") 

408 

409 self.larch.eval(f"pca_result = read_pca_model('{path:s}')") 

410 self.write_message("Read PCA Model from '%s'" % path, 0) 

411 self.use_model()