Coverage for /Users/Newville/Codes/xraylarch/larch/wxxas/regress_panel.py: 10%
418 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2"""
3Linear Combination panel
4"""
5import os
6import sys
7import time
8import wx
9import wx.grid as wxgrid
10import numpy as np
11import pickle
12import base64
13from copy import deepcopy
14from functools import partial
16from larch import Group
17from larch.math import index_of
18from larch.wxlib import (BitmapButton, TextCtrl, FloatCtrl, get_icon,
19 SimpleText, pack, Button, HLine, Choice, Check,
20 NumericCombo, CEN, LEFT, Font, FileSave, FileOpen,
21 DataTableGrid, Popup, FONTSIZE_FW, ExceptionPopup)
22from larch.io import save_groups, read_groups, read_csv
23from larch.utils.strutils import fix_varname
24from larch.utils import get_cwd, gformat
26from .taskpanel import TaskPanel
27from .config import Linear_ArrayChoices, Regress_Choices
29CSV_WILDCARDS = "CSV Files(*.csv,*.dat)|*.csv*;*.dat|All files (*.*)|*.*"
30MODEL_WILDCARDS = "Regression Model Files(*.regmod,*.dat)|*.regmod*;*.dat|All files (*.*)|*.*"
32Plot_Choices = ['Mean Spectrum + Active Energies',
33 'Spectra Stack',
34 'Predicted External Varliable']
36MAX_ROWS = 1000
38def make_steps(max=1, decades=8):
39 steps = [1.0]
40 for i in range(6):
41 steps.extend([(j*10**(-(1+i))) for j in (5, 2, 1)])
42 return steps
44class RegressionPanel(TaskPanel):
45 """Regression Panel"""
46 def __init__(self, parent, controller, **kws):
47 TaskPanel.__init__(self, parent, controller, panel='regression', **kws)
48 self.result = None
49 self.save_csvfile = 'RegressionData.csv'
50 self.save_modelfile = 'Model.regmod'
52 def process(self, dgroup, **kws):
53 """ handle processing"""
54 if self.skip_process:
55 return
56 self.skip_process = True
57 form = self.read_form()
59 def build_display(self):
60 panel = self.panel
61 wids = self.wids
62 self.skip_process = True
64 wids['fitspace'] = Choice(panel, choices=list(Linear_ArrayChoices.keys()),
65 action=self.onFitSpace, size=(175, -1))
66 wids['fitspace'].SetSelection(0)
67 # wids['plotchoice'] = Choice(panel, choices=Plot_Choices,
68 # size=(250, -1), action=self.onPlot)
70 wids['method'] = Choice(panel, choices=Regress_Choices, size=(250, -1),
71 action=self.onRegressMethod)
72 wids['method'].SetSelection(1)
73 add_text = self.add_text
75 opts = dict(digits=2, increment=1.0)
76 defaults = self.get_defaultconfig()
78 self.make_fit_xspace_widgets(elo=defaults['elo_rel'], ehi=defaults['ehi_rel'])
80 wids['alpha'] = NumericCombo(panel, make_steps(), fmt='%.6g',
81 default_val=0.01, width=100)
83 wids['auto_scale_pls'] = Check(panel, default=True, label='auto scale?')
84 wids['auto_alpha'] = Check(panel, default=False, label='auto alpha?')
86 wids['fit_intercept'] = Check(panel, default=True, label='fit intercept?')
88 wids['save_csv'] = Button(panel, 'Save CSV File', size=(150, -1),
89 action=self.onSaveCSV)
90 wids['load_csv'] = Button(panel, 'Load CSV File', size=(150, -1),
91 action=self.onLoadCSV)
93 wids['save_model'] = Button(panel, 'Save Model', size=(150, -1),
94 action=self.onSaveModel)
95 wids['save_model'].Disable()
97 wids['load_model'] = Button(panel, 'Load Model', size=(150, -1),
98 action=self.onLoadModel)
101 wids['train_model'] = Button(panel, 'Train Model From These Data',
102 size=(275, -1), action=self.onTrainModel)
104 wids['fit_group'] = Button(panel, 'Predict Variable for Selected Groups',
105 size=(275, -1), action=self.onPredictGroups)
106 wids['fit_group'].Disable()
109 w_cvfolds = self.add_floatspin('cv_folds', digits=0, with_pin=False,
110 value=0, increment=1, min_val=-1)
112 w_cvreps = self.add_floatspin('cv_repeats', digits=0, with_pin=False,
113 value=0, increment=1, min_val=-1)
115 w_ncomps = self.add_floatspin('ncomps', digits=0, with_pin=False,
116 value=3, increment=1, min_val=1)
118 wids['varname'] = wx.TextCtrl(panel, -1, 'valence', size=(150, -1))
119 wids['stat1'] = SimpleText(panel, ' - - - ')
120 wids['stat2'] = SimpleText(panel, ' - - - ')
123 collabels = [' File Group Name ', 'External Value',
124 'Predicted Value', 'Training?']
125 colsizes = [325, 110, 110, 90]
126 coltypes = ['str', 'float:12,4', 'float:12,4', 'str']
127 coldefs = ['', 0.0, 0.0, '']
129 self.font_fixedwidth = wx.Font(FONTSIZE_FW, wx.MODERN, wx.NORMAL, wx.BOLD)
131 wids['table'] = DataTableGrid(panel, nrows=MAX_ROWS,
132 collabels=collabels,
133 datatypes=coltypes,
134 defaults=coldefs,
135 colsizes=colsizes)
136 wids['table'].SetMinSize((700, 225))
137 wids['table'].SetFont(self.font_fixedwidth)
139 wids['use_selected'] = Button(panel, 'Use Selected Groups',
140 size=(150, -1), action=self.onFillTable)
142 panel.Add(SimpleText(panel, 'Feature Regression, Model Selection',
143 size=(350, -1), **self.titleopts), style=LEFT, dcol=4)
145 add_text('Array to Use: ', newrow=True)
146 panel.Add(wids['fitspace'], dcol=4)
148 panel.Add(wids['fitspace_label'], newrow=True)
149 panel.Add(self.elo_wids)
150 add_text(' : ', newrow=False)
151 panel.Add(self.ehi_wids, dcol=3)
152 add_text('Regression Method:')
153 panel.Add(wids['method'], dcol=4)
154 add_text('PLS # components: ')
155 panel.Add(w_ncomps)
156 panel.Add(wids['auto_scale_pls'], dcol=2)
157 add_text('Lasso Alpha: ')
158 panel.Add(wids['alpha'])
159 panel.Add(wids['auto_alpha'], dcol=2)
160 panel.Add(wids['fit_intercept'])
162 add_text('Cross Validation: ')
163 add_text(' # folds, # repeats: ', newrow=False)
164 panel.Add(w_cvfolds, dcol=2)
165 panel.Add(w_cvreps)
167 panel.Add(HLine(panel, size=(600, 2)), dcol=6, newrow=True)
169 add_text('Build Model: ', newrow=True)
170 panel.Add(wids['use_selected'], dcol=2)
171 add_text('Attribute Name: ', newrow=False)
172 panel.Add(wids['varname'], dcol=4)
174 add_text('Read/Save Data: ', newrow=True)
175 panel.Add(wids['load_csv'], dcol=3)
176 panel.Add(wids['save_csv'], dcol=2)
178 panel.Add(wids['table'], newrow=True, dcol=5) # , drow=3)
180 panel.Add(HLine(panel, size=(550, 2)), dcol=5, newrow=True)
181 panel.Add((5, 5), newrow=True)
182 add_text('Train Model : ')
183 panel.Add(wids['train_model'], dcol=3)
184 panel.Add(wids['load_model'])
186 add_text('Use This Model : ')
187 panel.Add(wids['fit_group'], dcol=3)
188 panel.Add(wids['save_model'])
189 add_text('Fit Statistics : ')
190 panel.Add(wids['stat1'], dcol=4)
191 panel.Add((5, 5), newrow=True)
192 panel.Add(wids['stat2'], dcol=4)
193 panel.pack()
195 sizer = wx.BoxSizer(wx.VERTICAL)
196 sizer.Add((10, 10), 0, LEFT, 3)
197 sizer.Add(panel, 1, LEFT, 3)
198 pack(self, sizer)
199 self.onRegressMethod()
200 self.skip_process = False
202 def onRegressMethod(self, evt=None):
203 meth = self.wids['method'].GetStringSelection()
204 use_lasso = meth.lower().startswith('lasso')
205 self.wids['alpha'].Enable(use_lasso)
206 self.wids['auto_alpha'].Enable(use_lasso)
207 self.wids['fit_intercept'].Enable(use_lasso)
208 self.wids['auto_scale_pls'].Enable(not use_lasso)
209 self.wids['ncomps'].Enable(not use_lasso)
211 def onFitSpace(self, evt=None):
212 fitspace = self.wids['fitspace'].GetStringSelection()
213 self.update_config(dict(fitspace=fitspace))
214 arrname = Linear_ArrayChoices.get(fitspace, 'norm')
215 self.update_fit_xspace(arrname)
218 def fill_form(self, dgroup=None, opts=None):
219 conf = deepcopy(self.get_config(dgroup=dgroup, with_erange=True))
220 if opts is None:
221 opts = {}
222 conf.update(opts)
223 self.dgroup = dgroup
224 self.skip_process = True
225 wids = self.wids
227 for attr in ('fitspace','method'):
228 if attr in conf:
229 wids[attr].SetStringSelection(conf[attr])
231 for attr in ('elo', 'ehi', 'alpha', 'varname', 'cv_folds', 'cv_repeats'):
232 val = conf.get(attr, None)
233 if val is not None:
234 if attr == 'alpha':
235 if val < 0:
236 val = 0.001
237 conf['auto_alpha'] = True
238 val = '%.6g' % val
239 if attr in wids:
240 wids[attr].SetValue(val)
242 use_lasso = conf['method'].lower().startswith('lasso')
244 for attr in ('auto_alpha', 'fit_intercept','auto_scale_pls'):
245 val = conf.get(attr, True)
246 if attr == 'auto_scale_pls':
247 val = val and not use_lasso
248 else:
249 val = val and use_lasso
250 wids[attr].SetValue(val)
251 self.onRegressMethod()
253 self.skip_process = False
255 def read_form(self):
256 dgroup = self.controller.get_group()
257 form = {'groupname': getattr(dgroup, 'groupname', 'No Group')}
259 for k in ('fitspace', 'method'):
260 form[k] = self.wids[k].GetStringSelection()
262 for k in ('elo', 'ehi', 'alpha', 'cv_folds',
263 'cv_repeats', 'ncomps', 'varname'):
264 form[k] = self.wids[k].GetValue()
266 form['alpha'] = float(form['alpha'])
267 if form['alpha'] < 0:
268 form['alpha'] = 1.e-3
270 for k in ('auto_scale_pls', 'auto_alpha', 'fit_intercept'):
271 form[k] = self.wids[k].IsChecked()
273 mname = form['method'].lower()
274 form['use_lars'] = 'lars' in mname
275 form['funcname'] = 'pls'
276 if mname.startswith('lasso'):
277 form['funcname'] = 'lasso'
278 if form['auto_alpha']:
279 form['alpha'] = None
281 return form
284 def onFillTable(self, event=None):
285 selected_groups = self.controller.filelist.GetCheckedStrings()
286 varname = fix_varname(self.wids['varname'].GetValue())
287 predname = varname + '_predicted'
288 grid_data = []
289 for fname in self.controller.filelist.GetCheckedStrings():
290 gname = self.controller.file_groups[fname]
291 grp = self.controller.get_group(gname)
292 grid_data.append([fname, getattr(grp, varname, 0.0),
293 getattr(grp, predname, 0.0), 'Yes'])
295 self.wids['table'].table.data = grid_data
296 self.wids['table'].table.View.Refresh()
298 def onTrainModel(self, event=None):
299 form = self.read_form()
300 self.update_config(form)
301 varname = form['varname']
302 predname = varname + '_predicted'
304 grid_data = self.wids['table'].table.data
305 groups = []
306 for fname, yval, pval, istrain in grid_data:
307 gname = self.controller.file_groups[fname]
308 grp = self.controller.get_group(gname)
309 setattr(grp, varname, yval)
310 setattr(grp, predname, pval)
311 groups.append(gname)
313 cmds = ['# train linear regression model',
314 'training_groups = [%s]' % ', '.join(groups)]
316 copts = ["varname='%s'" % varname, "xmin=%.4f" % form['elo'],
317 "xmax=%.4f" % form['ehi']]
319 arrname = Linear_ArrayChoices.get(form['fitspace'], 'norm')
320 copts.append("arrayname='%s'" % arrname)
322 if form['method'].lower().startswith('lasso'):
323 if form['auto_alpha']:
324 copts.append('alpha=None')
325 else:
326 copts.append('alpha=%.6g' % form['alpha'])
327 copts.append('use_lars=%s' % repr('lars' in form['method'].lower()))
328 copts.append('fit_intercept=%s' % repr(form['fit_intercept']))
329 else:
330 copts.append('ncomps=%d' % form['ncomps'])
331 copts.append('scale=%s' % repr(form['auto_scale_pls']))
333 callargs = ', '.join(copts)
335 cmds.append("reg_model = %s_train(training_groups, %s)" %
336 (form['funcname'], callargs))
338 self.larch_eval('\n'.join(cmds))
339 reg_model = self.larch_get('reg_model')
340 reg_model.form = form
341 self.use_regmodel(reg_model)
343 def use_regmodel(self, reg_model):
344 if reg_model is None:
345 return
346 opts = self.read_form()
348 if hasattr(reg_model, 'form'):
349 opts.update(reg_model.form)
351 self.write_message('Regression Model trained: %s' % opts['method'])
352 rmse_cv = reg_model.rmse_cv
353 if rmse_cv is not None:
354 rmse_cv = "%.4f" % rmse_cv
355 stat = "RMSE_CV = %s, RMSE = %.4f" % (rmse_cv, reg_model.rmse)
356 self.wids['stat1'].SetLabel(stat)
357 if opts['funcname'].startswith('lasso'):
358 stat = "Alpha = %.4f, %d active components"
359 self.wids['stat2'].SetLabel(stat % (reg_model.alpha,
360 len(reg_model.active)))
362 if opts['auto_alpha']:
363 self.wids['alpha'].add_choice(reg_model.alpha)
365 else:
366 self.wids['stat2'].SetLabel('- - - ')
367 training_groups = reg_model.groupnames
368 ntrain = len(training_groups)
369 grid_data = self.wids['table'].table.data
370 grid_new = []
371 for i in range(ntrain): # min(ntrain, len(grid_data))):
372 fname = training_groups[i]
373 istrain = 'Yes' if fname in training_groups else 'No'
374 grid_new.append( [fname, reg_model.ydat[i], reg_model.ypred[i], istrain])
375 self.wids['table'].table.data = grid_new
376 self.wids['table'].table.View.Refresh()
378 if reg_model.cv_folds not in (0, None):
379 self.wids['cv_folds'].SetValue(reg_model.cv_folds)
380 if reg_model.cv_repeats not in (0, None):
381 self.wids['cv_repeats'].SetValue(reg_model.cv_repeats)
383 self.wids['save_model'].Enable()
384 self.wids['fit_group'].Enable()
386 wx.CallAfter(self.onPlotModel, model=reg_model)
388 def onPanelExposed(self, **kws):
389 # called when notebook is selected
390 try:
391 fname = self.controller.filelist.GetStringSelection()
392 gname = self.controller.file_groups[fname]
393 dgroup = self.controller.get_group(gname)
394 self.ensure_xas_processed(dgroup)
395 self.fill_form(dgroup)
396 except:
397 pass # print(" Cannot Fill prepeak panel from group ")
399 reg_model = getattr(self.larch.symtable, 'reg_model', None)
400 if reg_model is not None:
401 self.use_regmodel(reg_model)
404 def onPredictGroups(self, event=None):
405 opts = self.read_form()
406 varname = opts['varname'] + '_predicted'
408 reg_model = self.larch_get('reg_model')
409 training_groups = reg_model.groupnames
411 grid_data = self.wids['table'].table.data
413 gent = {}
414 if len(grid_data[0][0].strip()) == 0:
415 grid_data = []
416 else:
417 for i, row in enumerate(grid_data):
418 gent[row[0]] = i
420 for fname in self.controller.filelist.GetCheckedStrings():
421 gname = self.controller.file_groups[fname]
422 grp = self.controller.get_group(gname)
423 extval = getattr(grp, opts['varname'], 0)
424 cmd = "%s.%s = %s_predict(%s, reg_model)" % (gname, varname,
425 opts['funcname'], gname)
426 self.larch_eval(cmd)
427 val = self.larch_get('%s.%s' % (gname, varname))
428 if fname in gent:
429 grid_data[gent[fname]][2] = val
430 else:
431 istrain = 'Yes' if fname in training_groups else 'No'
432 grid_data.append([fname, extval, val, istrain])
433 self.wids['table'].table.data = grid_data
434 self.wids['table'].table.View.Refresh()
436 def onSaveModel(self, event=None):
437 try:
438 reg_model = self.larch_get('reg_model')
439 except:
440 title = "No regresion model to save"
441 message = [f"Cannot get regression model to save"]
442 ExceptionPopup(self, title, message)
443 return
445 dlg = wx.FileDialog(self, message="Save Regression Model",
446 defaultDir=get_cwd(),
447 defaultFile=self.save_modelfile,
448 wildcard=MODEL_WILDCARDS,
449 style=wx.FD_SAVE)
450 fname = None
451 if dlg.ShowModal() == wx.ID_OK:
452 fname = dlg.GetPath()
453 dlg.Destroy()
454 if fname is None:
455 return
456 save_groups(fname, ['#regression model 1.0', reg_model])
457 self.write_message('Wrote Regression Model to %s ' % fname)
459 def onLoadModel(self, event=None):
460 dlg = wx.FileDialog(self, message="Load Regression Model",
461 defaultDir=get_cwd(),
462 wildcard=MODEL_WILDCARDS, style=wx.FD_OPEN)
464 fname = None
465 if dlg.ShowModal() == wx.ID_OK:
466 fname = dlg.GetPath()
467 dlg.Destroy()
468 if fname is None:
469 return
470 dat = read_groups(fname)
471 if len(dat) != 2 or not dat[0].startswith('#regression model'):
472 Popup(self, f" '{rfile}' is not a valid Regression model file",
473 "Invalid file")
475 reg_model = dat[1]
476 self.controller.symtable.reg_model = reg_model
478 self.write_message('Read Regression Model from %s ' % fname)
479 self.wids['fit_group'].Enable()
481 self.use_regmodel(reg_model)
483 def onLoadCSV(self, event=None):
484 dlg = wx.FileDialog(self, message="Load CSV Data File",
485 defaultDir=get_cwd(),
486 wildcard=CSV_WILDCARDS, style=wx.FD_OPEN)
488 fname = None
489 if dlg.ShowModal() == wx.ID_OK:
490 fname = dlg.GetPath()
491 dlg.Destroy()
492 if fname is None:
493 return
495 self.save_csvfile = os.path.split(fname)[1]
496 varname = fix_varname(self.wids['varname'].GetValue())
497 csvgroup = read_csv(fname)
498 script = []
499 grid_data = []
500 for sname, yval in zip(csvgroup.col_01, csvgroup.col_02):
501 if sname.startswith('#'):
502 continue
503 if sname in self.controller.file_groups:
504 gname = self.controller.file_groups[sname]
505 script.append('%s.%s = %f' % (gname, varname, yval))
506 grid_data.append([sname, yval, 0])
508 self.larch_eval('\n'.join(script))
509 self.wids['table'].table.data = grid_data
510 self.wids['table'].table.View.Refresh()
511 self.write_message('Read CSV File %s ' % fname)
513 def onSaveCSV(self, event=None):
514 wildcard = 'CSV file (*.csv)|*.csv|All files (*.*)|*.*'
515 fname = FileSave(self, message='Save CSV Data File',
516 wildcard=wildcard,
517 default_file=self.save_csvfile)
518 if fname is None:
519 return
520 self.save_csvfile = os.path.split(fname)[1]
521 buff = []
522 for row in self.wids['table'].table.data:
523 buff.append("%s, %s, %s" % (row[0], gformat(row[1]), gformat(row[2])))
524 buff.append('')
525 with open(fname, 'w', encoding=sys.getdefaultencoding()) as fh:
526 fh.write('\n'.join(buff))
527 self.write_message('Wrote CSV File %s ' % fname)
529 def onPlotModel(self, event=None, model=None):
530 opts = self.read_form()
531 if model is None:
532 return
533 opts.update(model.form)
535 ppanel = self.controller.get_display(win=1).panel
536 viewlims = ppanel.get_viewlimits()
537 plotcmd = ppanel.plot
539 d_ave = model.spectra.mean(axis=0)
540 d_std = model.spectra.std(axis=0)
541 ymin, ymax = (d_ave-d_std).min(), (d_ave+d_std).max()
543 if opts['funcname'].startswith('lasso'):
544 active = [int(i) for i in model.active]
545 active_coefs = (model.coefs[active])
546 active_coefs = active_coefs/max(abs(active_coefs))
547 ymin = min(active_coefs.min(), ymin)
548 ymax = max(active_coefs.max(), ymax)
550 else:
551 ymin = min(model.coefs.min(), ymin)
552 ymax = max(model.coefs.max(), ymax)
554 ymin = ymin - 0.02*(ymax-ymin)
555 ymax = ymax + 0.02*(ymax-ymin)
558 title = '%s Regression results' % (opts['method'])
560 ppanel.plot(model.x, d_ave, win=1, title=title,
561 label='mean spectra', xlabel='Energy (eV)',
562 ylabel=opts['fitspace'], show_legend=True,
563 ymin=ymin, ymax=ymax)
564 ppanel.axes.fill_between(model.x, d_ave-d_std, d_ave+d_std,
565 color='#1f77b433')
566 if opts['funcname'].startswith('lasso'):
567 ppanel.axes.bar(model.x[active], active_coefs,
568 1.0, color='#9f9f9f88',
569 label='coefficients')
570 else:
571 _, ncomps = model.coefs.shape
572 for i in range(ncomps):
573 ppanel.oplot(model.x, model.coefs[:, i], label='coef %d' % (i+1))
575 ppanel.canvas.draw()
577 ngoups = len(model.groupnames)
578 indices = np.arange(len(model.groupnames))
579 diff = model.ydat - model.ypred
580 sx = np.argsort(model.ydat)
582 ppanel = self.controller.get_display(win=2).panel
584 ppanel.plot(model.ydat[sx], indices, xlabel='valence',
585 label='experimental', linewidth=0, marker='o',
586 markersize=8, win=2, new=True, title=title)
588 ppanel.oplot(model.ypred[sx], indices, label='predicted',
589 labelxsxfontsize=7, markersize=6, marker='o',
590 linewidth=0, show_legend=True, new=False)
592 ppanel.axes.barh(indices, diff[sx], 0.5, color='#9f9f9f88')
593 ppanel.axes.set_yticks(indices)
594 ppanel.axes.set_yticklabels([model.groupnames[o] for o in sx])
595 ppanel.conf.auto_margins = False
596 ppanel.conf.set_margins(left=0.35, right=0.05, bottom=0.15, top=0.1)
597 ppanel.canvas.draw()
598 self.controller.set_focus()
601 def onCopyParam(self, name=None, evt=None):
602 conf = self.get_config()