Coverage for /Users/Newville/Codes/xraylarch/larch/fitting/__init__.py: 66%
247 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
3from copy import copy, deepcopy
4import random
5import numpy as np
6from scipy.stats import f
8import lmfit
9from lmfit import Parameter
10from lmfit import (Parameters, Minimizer, conf_interval,
11 ci_report, conf_interval2d)
13from lmfit.minimizer import MinimizerResult
14from lmfit.model import (ModelResult, save_model, load_model,
15 save_modelresult, load_modelresult)
16from lmfit.confidence import f_compare
17from uncertainties import ufloat, correlated_values
18from uncertainties import wrap as un_wrap
20from ..symboltable import Group, isgroup
23def isParameter(x):
24 return (isinstance(x, Parameter) or
25 x.__class__.__name__ == 'Parameter')
27def param_value(val):
28 "get param value -- useful for 3rd party code"
29 while isinstance(val, Parameter):
30 val = val.value
31 return val
33def f_test(ndata, nvars, chisquare, chisquare0, nfix=1):
34 """return the F-test value for the following input values:
35 f = f_test(ndata, nparams, chisquare, chisquare0, nfix=1)
37 nfix = the number of fixed parameters.
38 """
39 return f_compare(ndata, nvars, chisquare, chisquare0, nfix=1)
41def confidence_report(conf_vals, **kws):
42 """return a formatted report of confidence intervals calcualted
43 by confidence_intervals
44 """
45 return ci_report(conf_vals)
49def asteval_with_uncertainties(*vals, **kwargs):
50 """Calculate object value, given values for variables.
52 This is used by the uncertainties package to calculate the
53 uncertainty in an object even with a complicated expression.
55 """
56 _obj = kwargs.get('_obj', None)
57 _pars = kwargs.get('_pars', None)
58 _names = kwargs.get('_names', None)
59 _asteval = _pars._asteval
60 if (_obj is None or _pars is None or _names is None or
61 _asteval is None or _obj._expr_ast is None):
62 return 0
63 for val, name in zip(vals, _names):
64 _asteval.symtable[name] = val
66 # re-evaluate all constraint parameters to
67 # force the propagation of uncertainties
68 [p._getval() for p in _pars.values()]
69 return _asteval.eval(_obj._expr_ast)
72wrap_ueval = un_wrap(asteval_with_uncertainties)
75def eval_stderr(obj, uvars, _names, _pars):
76 """Evaluate uncertainty and set ``.stderr`` for a parameter `obj`.
78 Given the uncertain values `uvars` (list of `uncertainties.ufloats`),
79 a list of parameter names that matches `uvars`, and a dictionary of
80 parameter objects, keyed by name.
82 This uses the uncertainties package wrapped function to evaluate the
83 uncertainty for an arbitrary expression (in ``obj._expr_ast``) of
84 parameters.
86 """
87 if not isinstance(obj, Parameter) or getattr(obj, '_expr_ast', None) is None:
88 return
89 uval = wrap_ueval(*uvars, _obj=obj, _names=_names, _pars=_pars)
90 try:
91 obj.stderr = uval.std_dev
92 except Exception:
93 obj.stderr = 0
96class ParameterGroup(Group):
97 """
98 Group for Fitting Parameters
99 """
100 def __init__(self, name=None, **kws):
101 if name is not None:
102 self.__name__ = name
103 if '_larch' in kws:
104 kws.pop('_larch')
105 self.__params__ = Parameters()
106 Group.__init__(self)
107 self.__exprsave__ = {}
108 for key, val in kws.items():
109 expr = getattr(val, 'expr', None)
110 if expr is not None:
111 self.__exprsave__[key] = expr
112 val.expr = None
113 setattr(self, key, val)
115 for key, val in self.__exprsave__.items():
116 self.__params__[key].expr = val
119 def __repr__(self):
120 return '<Param Group {:s}>'.format(self.__name__)
122 def __setattr__(self, name, val):
123 if isParameter(val):
124 if val.name != name:
125 # allow 'a=Parameter(2, ..)' to mean Parameter(name='a', value=2, ...)
126 nval = None
127 try:
128 nval = float(val.name)
129 except (ValueError, TypeError):
130 pass
131 if nval is not None:
132 val.value = nval
133 skip = getattr(val, 'skip', None)
134 self.__params__.add(name, value=val.value, vary=val.vary, min=val.min,
135 max=val.max, expr=val.expr, brute_step=val.brute_step)
136 val = self.__params__[name]
138 val.skip = skip
139 elif hasattr(self, '__params__') and not name.startswith('__'):
140 self.__params__._asteval.symtable[name] = val
141 self.__dict__[name] = val
143 def __delattr__(self, name):
144 self.__dict__.pop(name)
145 if name in self.__params__:
146 self.__params__.pop(name)
148 def __add(self, name, value=None, vary=True, min=-np.inf, max=np.inf,
149 expr=None, stderr=None, correl=None, brute_step=None, skip=None):
150 if expr is None and isinstance(value, str):
151 expr = value
152 value = None
153 if self.__params__ is not None:
154 self.__params__.add(name, value=value, vary=vary, min=min, max=max,
155 expr=expr, brute_step=brute_step)
156 self.__params__[name].stderr = stderr
157 self.__params__[name].correl = correl
158 self.__params__[name].skip = skip
159 self.__dict__[name] = self.__params__[name]
162def param_group(**kws):
163 "create a parameter group"
164 return ParameterGroup(**kws)
166def randstr(n):
167 return ''.join([chr(random.randint(97, 122)) for i in range(n)])
169class unnamedParameter(Parameter):
170 """A Parameter that can be nameless"""
171 def __init__(self, name=None, value=None, vary=True, min=-np.inf, max=np.inf,
172 expr=None, brute_step=None, user_data=None, skip=None):
173 if name is None:
174 name = randstr(8)
175 self.name = name
176 self.user_data = user_data
177 self.init_value = value
178 self.min = min
179 self.max = max
180 self.brute_step = brute_step
181 self.vary = vary
182 self.skip = skip
183 self._expr = expr
184 self._expr_ast = None
185 self._expr_eval = None
186 self._expr_deps = []
187 self._delay_asteval = False
188 self.stderr = None
189 self.correl = None
190 self.from_internal = lambda val: val
191 self._val = value
192 self._init_bounds()
193 Parameter.__init__(self, name, value=value, vary=vary,
194 min=min, max=max, expr=expr,
195 brute_step=brute_step,
196 user_data=user_data)
198def param(*args, **kws):
199 "create a fitting Parameter as a Variable"
200 if len(args) > 0:
201 a0 = args[0]
202 if isinstance(a0, str):
203 kws.update({'expr': a0})
204 elif isinstance(a0, (int, float)):
205 kws.update({'value': a0})
206 else:
207 raise ValueError("first argument to param() must be string or number")
208 args = args[1:]
209 if '_larch' in kws:
210 kws.pop('_larch')
211 if 'vary' not in kws:
212 kws['vary'] = False
214 return unnamedParameter(*args, **kws)
216def guess(value, **kws):
217 """create a fitting Parameter as a Variable.
218 A minimum or maximum value for the variable value can be given:
219 x = guess(10, min=0)
220 y = guess(1.2, min=1, max=2)
221 """
222 kws.update({'vary':True})
223 return param(value, **kws)
225def is_param(obj):
226 """return whether an object is a Parameter"""
227 return isParameter(obj)
229def dict2params(pars):
230 """sometimes we get a plain dict of Parameters,
231 with vals that are Parameter objects"""
232 if isinstance(pars, Parameters):
233 return pars
234 out = Parameters()
235 for key, val in pars.items():
236 if isinstance(val, Parameter):
237 out[key] = val
238 return out
240def group2params(paramgroup):
241 """take a Group of Parameter objects (and maybe other things)
242 and put them into a lmfit.Parameters, ready for use in fitting
243 """
244 if isinstance(paramgroup, Parameters):
245 return paramgroup
246 if isinstance(paramgroup, dict):
247 params = Parameters()
248 for key, val in paramgroup.items():
249 if isinstance(val, Parameter):
250 params[key] = val
251 return params
254 if isinstance(paramgroup, ParameterGroup):
255 return paramgroup.__params__
257 params = Parameters()
258 if paramgroup is not None:
259 for name in dir(paramgroup):
260 par = getattr(paramgroup, name)
261 if getattr(par, 'skip', None) not in (False, None):
262 continue
263 if isParameter(par):
264 params.add(name, value=par.value, vary=par.vary,
265 min=par.min, max=par.max,
266 brute_step=par.brute_step)
267 else:
268 params._asteval.symtable[name] = par
270 # now set any expression (that is, after all symbols are defined)
271 for name in dir(paramgroup):
272 par = getattr(paramgroup, name)
273 if isParameter(par) and par.expr is not None:
274 params[name].expr = par.expr
276 return params
278def params2group(params, paramgroup):
279 """fill Parameter objects in paramgroup with
280 values from lmfit.Parameters
281 """
282 _params = getattr(paramgroup, '__params__', None)
283 for name, param in params.items():
284 this = getattr(paramgroup, name, None)
285 if isParameter(this):
286 if _params is not None:
287 _params[name] = this
288 for attr in ('value', 'vary', 'stderr', 'min', 'max', 'expr',
289 'name', 'correl', 'brute_step', 'user_data'):
290 setattr(this, attr, getattr(param, attr, None))
291 if this.stderr is not None:
292 try:
293 this.uvalue = ufloat(this.value, this.stderr)
294 except:
295 pass
298def minimize(fcn, paramgroup, method='leastsq', args=None, kws=None,
299 scale_covar=True, iter_cb=None, reduce_fcn=None, nan_polcy='omit',
300 **fit_kws):
301 """
302 wrapper around lmfit minimizer for Larch
303 """
304 if isinstance(paramgroup, ParameterGroup):
305 params = paramgroup.__params__
306 elif isgroup(paramgroup):
307 params = group2params(paramgroup)
308 elif isinstance(Parameters):
309 params = paramgroup
310 else:
311 raise ValueError('minimize takes ParamterGroup or Group as first argument')
313 if args is None:
314 args = ()
315 if kws is None:
316 kws = {}
318 def _residual(params):
319 params2group(params, paramgroup)
320 return fcn(paramgroup, *args, **kws)
322 fitter = Minimizer(_residual, params, iter_cb=iter_cb,
323 reduce_fcn=reduce_fcn, nan_policy='omit', **fit_kws)
325 result = fitter.minimize(method=method)
326 params2group(result.params, paramgroup)
328 out = Group(name='minimize results', fitter=fitter, fit_details=result,
329 chi_square=result.chisqr, chi_reduced=result.redchi)
331 for attr in ('aic', 'bic', 'covar', 'params', 'nvarys',
332 'nfree', 'ndata', 'var_names', 'nfev', 'success',
333 'errorbars', 'message', 'lmdif_message', 'residual'):
334 setattr(out, attr, getattr(result, attr, None))
335 return out
337def fit_report(fit_result, modelpars=None, show_correl=True, min_correl=0.1,
338 sort_pars=True, **kws):
339 """generate a report of fitting results
340 wrapper around lmfit.fit_report
342 The report contains the best-fit values for the parameters and their
343 uncertainties and correlations.
345 Parameters
346 ----------
347 fit_result : result from fit
348 Fit Group output from fit, or lmfit.MinimizerResult returned from a fit.
349 modelpars : Parameters, optional
350 Known Model Parameters.
351 show_correl : bool, optional
352 Whether to show list of sorted correlations (default is True).
353 min_correl : float, optional
354 Smallest correlation in absolute value to show (default is 0.1).
355 sort_pars : bool or callable, optional
356 Whether to show parameter names sorted in alphanumerical order. If
357 False, then the parameters will be listed in the order they
358 were added to the Parameters dictionary. If callable, then this (one
359 argument) function is used to extract a comparison key from each
360 list element.
362 Returns
363 -------
364 string
365 Multi-line text of fit report.
368 """
369 result = getattr(fit_result, 'fit_details', fit_result)
370 if isinstance(result, MinimizerResult):
371 return lmfit.fit_report(result, modelpars=modelpars,
372 show_correl=show_correl,
373 min_correl=min_correl, sort_pars=sort_pars)
374 elif isinstance(result, ModelResult):
375 return result.fit_report(modelpars=modelpars,
376 show_correl=show_correl,
377 min_correl=min_correl, sort_pars=sort_pars)
378 else:
379 result = getattr(fit_result, 'params', fit_result)
380 if isinstance(result, Parameters):
381 return lmfit.fit_report(result, modelpars=modelpars,
382 show_correl=show_correl,
383 min_correl=min_correl, sort_pars=sort_pars)
384 else:
385 try:
386 result = group2params(fit_result)
387 return lmfit.fit_report(result, modelpars=modelpars,
388 show_correl=show_correl,
389 min_correl=min_correl, sort_pars=sort_pars)
390 except (ValueError, AttributeError):
391 pass
392 return "Cannot make fit report with %s" % repr(fit_result)
395def confidence_intervals(fit_result, sigmas=(1, 2, 3), **kws):
396 """calculate the confidence intervals from a fit
397 for supplied sigma values
399 wrapper around lmfit.conf_interval
400 """
401 fitter = getattr(fit_result, 'fitter', None)
402 result = getattr(fit_result, 'fit_details', None)
403 return conf_interval(fitter, result, sigmas=sigmas, **kws)
405def chi2_map(fit_result, xname, yname, nx=11, ny=11, sigma=3, **kws):
406 """generate a confidence map for any two parameters for a fit
408 Arguments
409 ==========
410 minout output of minimize() fit (must be run first)
411 xname name of variable parameter for x-axis
412 yname name of variable parameter for y-axis
413 nx number of steps in x [11]
414 ny number of steps in y [11]
415 sigma scale for uncertainty range [3]
417 Returns
418 =======
419 xpts, ypts, map
421 Notes
422 =====
423 1. sigma sets the extent of values to explore:
424 param.value +/- sigma * param.stderr
425 """
426 #
427 fitter = getattr(fit_result, 'fitter', None)
428 result = getattr(fit_result, 'fit_details', None)
429 if fitter is None or result is None:
430 raise ValueError("chi2_map needs valid fit result as first argument")
432 c2_scale = fit_result.chi_square / result.chisqr
434 def scaled_chisqr(ndata, nparas, new_chi, best_chi, nfix=1.):
435 """return scaled chi-sqaure, instead of probability"""
436 return new_chi * c2_scale
438 x = result.params[xname]
439 y = result.params[yname]
440 xrange = (x.value + sigma * x.stderr, x.value - sigma * x.stderr)
441 yrange = (y.value + sigma * y.stderr, y.value - sigma * y.stderr)
443 return conf_interval2d(fitter, result, xname, yname,
444 limits=(xrange, yrange),
445 prob_func=scaled_chisqr,
446 nx=nx, ny=ny, **kws)
448_larch_name = '_math'
449exports = {'param': param,
450 'guess': guess,
451 'param_group': param_group,
452 'confidence_intervals': confidence_intervals,
453 'confidence_report': confidence_report,
454 'f_test': f_test, 'chi2_map': chi2_map,
455 'is_param': isParameter,
456 'isparam': isParameter,
457 'minimize': minimize,
458 'ufloat': ufloat,
459 'fit_report': fit_report,
460 'Parameters': Parameters,
461 'Parameter': Parameter,
462 'lm_minimize': minimize,
463 'lm_save_model': save_model,
464 'lm_load_model': load_model,
465 'lm_save_modelresult': save_modelresult,
466 'lm_load_modelresult': load_modelresult,
467 }
469for name in ('BreitWignerModel', 'ComplexConstantModel',
470 'ConstantModel', 'DampedHarmonicOscillatorModel',
471 'DampedOscillatorModel', 'DoniachModel',
472 'ExponentialGaussianModel', 'ExponentialModel',
473 'ExpressionModel', 'GaussianModel', 'Interpreter',
474 'LinearModel', 'LognormalModel', 'LorentzianModel',
475 'MoffatModel', 'ParabolicModel', 'Pearson7Model',
476 'PolynomialModel', 'PowerLawModel',
477 'PseudoVoigtModel', 'QuadraticModel',
478 'RectangleModel', 'SkewedGaussianModel',
479 'StepModel', 'StudentsTModel', 'VoigtModel'):
480 val = getattr(lmfit.models, name, None)
481 if val is not None:
482 exports[name] = val
484_larch_builtins = {'_math': exports}