Coverage for /Users/Newville/Codes/xraylarch/larch/fitting/__init__.py: 66%

247 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2 

3from copy import copy, deepcopy 

4import random 

5import numpy as np 

6from scipy.stats import f 

7 

8import lmfit 

9from lmfit import Parameter 

10from lmfit import (Parameters, Minimizer, conf_interval, 

11 ci_report, conf_interval2d) 

12 

13from lmfit.minimizer import MinimizerResult 

14from lmfit.model import (ModelResult, save_model, load_model, 

15 save_modelresult, load_modelresult) 

16from lmfit.confidence import f_compare 

17from uncertainties import ufloat, correlated_values 

18from uncertainties import wrap as un_wrap 

19 

20from ..symboltable import Group, isgroup 

21 

22 

23def isParameter(x): 

24 return (isinstance(x, Parameter) or 

25 x.__class__.__name__ == 'Parameter') 

26 

27def param_value(val): 

28 "get param value -- useful for 3rd party code" 

29 while isinstance(val, Parameter): 

30 val = val.value 

31 return val 

32 

33def f_test(ndata, nvars, chisquare, chisquare0, nfix=1): 

34 """return the F-test value for the following input values: 

35 f = f_test(ndata, nparams, chisquare, chisquare0, nfix=1) 

36 

37 nfix = the number of fixed parameters. 

38 """ 

39 return f_compare(ndata, nvars, chisquare, chisquare0, nfix=1) 

40 

41def confidence_report(conf_vals, **kws): 

42 """return a formatted report of confidence intervals calcualted 

43 by confidence_intervals 

44 """ 

45 return ci_report(conf_vals) 

46 

47 

48 

49def asteval_with_uncertainties(*vals, **kwargs): 

50 """Calculate object value, given values for variables. 

51 

52 This is used by the uncertainties package to calculate the 

53 uncertainty in an object even with a complicated expression. 

54 

55 """ 

56 _obj = kwargs.get('_obj', None) 

57 _pars = kwargs.get('_pars', None) 

58 _names = kwargs.get('_names', None) 

59 _asteval = _pars._asteval 

60 if (_obj is None or _pars is None or _names is None or 

61 _asteval is None or _obj._expr_ast is None): 

62 return 0 

63 for val, name in zip(vals, _names): 

64 _asteval.symtable[name] = val 

65 

66 # re-evaluate all constraint parameters to 

67 # force the propagation of uncertainties 

68 [p._getval() for p in _pars.values()] 

69 return _asteval.eval(_obj._expr_ast) 

70 

71 

72wrap_ueval = un_wrap(asteval_with_uncertainties) 

73 

74 

75def eval_stderr(obj, uvars, _names, _pars): 

76 """Evaluate uncertainty and set ``.stderr`` for a parameter `obj`. 

77 

78 Given the uncertain values `uvars` (list of `uncertainties.ufloats`), 

79 a list of parameter names that matches `uvars`, and a dictionary of 

80 parameter objects, keyed by name. 

81 

82 This uses the uncertainties package wrapped function to evaluate the 

83 uncertainty for an arbitrary expression (in ``obj._expr_ast``) of 

84 parameters. 

85 

86 """ 

87 if not isinstance(obj, Parameter) or getattr(obj, '_expr_ast', None) is None: 

88 return 

89 uval = wrap_ueval(*uvars, _obj=obj, _names=_names, _pars=_pars) 

90 try: 

91 obj.stderr = uval.std_dev 

92 except Exception: 

93 obj.stderr = 0 

94 

95 

96class ParameterGroup(Group): 

97 """ 

98 Group for Fitting Parameters 

99 """ 

100 def __init__(self, name=None, **kws): 

101 if name is not None: 

102 self.__name__ = name 

103 if '_larch' in kws: 

104 kws.pop('_larch') 

105 self.__params__ = Parameters() 

106 Group.__init__(self) 

107 self.__exprsave__ = {} 

108 for key, val in kws.items(): 

109 expr = getattr(val, 'expr', None) 

110 if expr is not None: 

111 self.__exprsave__[key] = expr 

112 val.expr = None 

113 setattr(self, key, val) 

114 

115 for key, val in self.__exprsave__.items(): 

116 self.__params__[key].expr = val 

117 

118 

119 def __repr__(self): 

120 return '<Param Group {:s}>'.format(self.__name__) 

121 

122 def __setattr__(self, name, val): 

123 if isParameter(val): 

124 if val.name != name: 

125 # allow 'a=Parameter(2, ..)' to mean Parameter(name='a', value=2, ...) 

126 nval = None 

127 try: 

128 nval = float(val.name) 

129 except (ValueError, TypeError): 

130 pass 

131 if nval is not None: 

132 val.value = nval 

133 skip = getattr(val, 'skip', None) 

134 self.__params__.add(name, value=val.value, vary=val.vary, min=val.min, 

135 max=val.max, expr=val.expr, brute_step=val.brute_step) 

136 val = self.__params__[name] 

137 

138 val.skip = skip 

139 elif hasattr(self, '__params__') and not name.startswith('__'): 

140 self.__params__._asteval.symtable[name] = val 

141 self.__dict__[name] = val 

142 

143 def __delattr__(self, name): 

144 self.__dict__.pop(name) 

145 if name in self.__params__: 

146 self.__params__.pop(name) 

147 

148 def __add(self, name, value=None, vary=True, min=-np.inf, max=np.inf, 

149 expr=None, stderr=None, correl=None, brute_step=None, skip=None): 

150 if expr is None and isinstance(value, str): 

151 expr = value 

152 value = None 

153 if self.__params__ is not None: 

154 self.__params__.add(name, value=value, vary=vary, min=min, max=max, 

155 expr=expr, brute_step=brute_step) 

156 self.__params__[name].stderr = stderr 

157 self.__params__[name].correl = correl 

158 self.__params__[name].skip = skip 

159 self.__dict__[name] = self.__params__[name] 

160 

161 

162def param_group(**kws): 

163 "create a parameter group" 

164 return ParameterGroup(**kws) 

165 

166def randstr(n): 

167 return ''.join([chr(random.randint(97, 122)) for i in range(n)]) 

168 

169class unnamedParameter(Parameter): 

170 """A Parameter that can be nameless""" 

171 def __init__(self, name=None, value=None, vary=True, min=-np.inf, max=np.inf, 

172 expr=None, brute_step=None, user_data=None, skip=None): 

173 if name is None: 

174 name = randstr(8) 

175 self.name = name 

176 self.user_data = user_data 

177 self.init_value = value 

178 self.min = min 

179 self.max = max 

180 self.brute_step = brute_step 

181 self.vary = vary 

182 self.skip = skip 

183 self._expr = expr 

184 self._expr_ast = None 

185 self._expr_eval = None 

186 self._expr_deps = [] 

187 self._delay_asteval = False 

188 self.stderr = None 

189 self.correl = None 

190 self.from_internal = lambda val: val 

191 self._val = value 

192 self._init_bounds() 

193 Parameter.__init__(self, name, value=value, vary=vary, 

194 min=min, max=max, expr=expr, 

195 brute_step=brute_step, 

196 user_data=user_data) 

197 

198def param(*args, **kws): 

199 "create a fitting Parameter as a Variable" 

200 if len(args) > 0: 

201 a0 = args[0] 

202 if isinstance(a0, str): 

203 kws.update({'expr': a0}) 

204 elif isinstance(a0, (int, float)): 

205 kws.update({'value': a0}) 

206 else: 

207 raise ValueError("first argument to param() must be string or number") 

208 args = args[1:] 

209 if '_larch' in kws: 

210 kws.pop('_larch') 

211 if 'vary' not in kws: 

212 kws['vary'] = False 

213 

214 return unnamedParameter(*args, **kws) 

215 

216def guess(value, **kws): 

217 """create a fitting Parameter as a Variable. 

218 A minimum or maximum value for the variable value can be given: 

219 x = guess(10, min=0) 

220 y = guess(1.2, min=1, max=2) 

221 """ 

222 kws.update({'vary':True}) 

223 return param(value, **kws) 

224 

225def is_param(obj): 

226 """return whether an object is a Parameter""" 

227 return isParameter(obj) 

228 

229def dict2params(pars): 

230 """sometimes we get a plain dict of Parameters, 

231 with vals that are Parameter objects""" 

232 if isinstance(pars, Parameters): 

233 return pars 

234 out = Parameters() 

235 for key, val in pars.items(): 

236 if isinstance(val, Parameter): 

237 out[key] = val 

238 return out 

239 

240def group2params(paramgroup): 

241 """take a Group of Parameter objects (and maybe other things) 

242 and put them into a lmfit.Parameters, ready for use in fitting 

243 """ 

244 if isinstance(paramgroup, Parameters): 

245 return paramgroup 

246 if isinstance(paramgroup, dict): 

247 params = Parameters() 

248 for key, val in paramgroup.items(): 

249 if isinstance(val, Parameter): 

250 params[key] = val 

251 return params 

252 

253 

254 if isinstance(paramgroup, ParameterGroup): 

255 return paramgroup.__params__ 

256 

257 params = Parameters() 

258 if paramgroup is not None: 

259 for name in dir(paramgroup): 

260 par = getattr(paramgroup, name) 

261 if getattr(par, 'skip', None) not in (False, None): 

262 continue 

263 if isParameter(par): 

264 params.add(name, value=par.value, vary=par.vary, 

265 min=par.min, max=par.max, 

266 brute_step=par.brute_step) 

267 else: 

268 params._asteval.symtable[name] = par 

269 

270 # now set any expression (that is, after all symbols are defined) 

271 for name in dir(paramgroup): 

272 par = getattr(paramgroup, name) 

273 if isParameter(par) and par.expr is not None: 

274 params[name].expr = par.expr 

275 

276 return params 

277 

278def params2group(params, paramgroup): 

279 """fill Parameter objects in paramgroup with 

280 values from lmfit.Parameters 

281 """ 

282 _params = getattr(paramgroup, '__params__', None) 

283 for name, param in params.items(): 

284 this = getattr(paramgroup, name, None) 

285 if isParameter(this): 

286 if _params is not None: 

287 _params[name] = this 

288 for attr in ('value', 'vary', 'stderr', 'min', 'max', 'expr', 

289 'name', 'correl', 'brute_step', 'user_data'): 

290 setattr(this, attr, getattr(param, attr, None)) 

291 if this.stderr is not None: 

292 try: 

293 this.uvalue = ufloat(this.value, this.stderr) 

294 except: 

295 pass 

296 

297 

298def minimize(fcn, paramgroup, method='leastsq', args=None, kws=None, 

299 scale_covar=True, iter_cb=None, reduce_fcn=None, nan_polcy='omit', 

300 **fit_kws): 

301 """ 

302 wrapper around lmfit minimizer for Larch 

303 """ 

304 if isinstance(paramgroup, ParameterGroup): 

305 params = paramgroup.__params__ 

306 elif isgroup(paramgroup): 

307 params = group2params(paramgroup) 

308 elif isinstance(Parameters): 

309 params = paramgroup 

310 else: 

311 raise ValueError('minimize takes ParamterGroup or Group as first argument') 

312 

313 if args is None: 

314 args = () 

315 if kws is None: 

316 kws = {} 

317 

318 def _residual(params): 

319 params2group(params, paramgroup) 

320 return fcn(paramgroup, *args, **kws) 

321 

322 fitter = Minimizer(_residual, params, iter_cb=iter_cb, 

323 reduce_fcn=reduce_fcn, nan_policy='omit', **fit_kws) 

324 

325 result = fitter.minimize(method=method) 

326 params2group(result.params, paramgroup) 

327 

328 out = Group(name='minimize results', fitter=fitter, fit_details=result, 

329 chi_square=result.chisqr, chi_reduced=result.redchi) 

330 

331 for attr in ('aic', 'bic', 'covar', 'params', 'nvarys', 

332 'nfree', 'ndata', 'var_names', 'nfev', 'success', 

333 'errorbars', 'message', 'lmdif_message', 'residual'): 

334 setattr(out, attr, getattr(result, attr, None)) 

335 return out 

336 

337def fit_report(fit_result, modelpars=None, show_correl=True, min_correl=0.1, 

338 sort_pars=True, **kws): 

339 """generate a report of fitting results 

340 wrapper around lmfit.fit_report 

341 

342 The report contains the best-fit values for the parameters and their 

343 uncertainties and correlations. 

344 

345 Parameters 

346 ---------- 

347 fit_result : result from fit 

348 Fit Group output from fit, or lmfit.MinimizerResult returned from a fit. 

349 modelpars : Parameters, optional 

350 Known Model Parameters. 

351 show_correl : bool, optional 

352 Whether to show list of sorted correlations (default is True). 

353 min_correl : float, optional 

354 Smallest correlation in absolute value to show (default is 0.1). 

355 sort_pars : bool or callable, optional 

356 Whether to show parameter names sorted in alphanumerical order. If 

357 False, then the parameters will be listed in the order they 

358 were added to the Parameters dictionary. If callable, then this (one 

359 argument) function is used to extract a comparison key from each 

360 list element. 

361 

362 Returns 

363 ------- 

364 string 

365 Multi-line text of fit report. 

366 

367 

368 """ 

369 result = getattr(fit_result, 'fit_details', fit_result) 

370 if isinstance(result, MinimizerResult): 

371 return lmfit.fit_report(result, modelpars=modelpars, 

372 show_correl=show_correl, 

373 min_correl=min_correl, sort_pars=sort_pars) 

374 elif isinstance(result, ModelResult): 

375 return result.fit_report(modelpars=modelpars, 

376 show_correl=show_correl, 

377 min_correl=min_correl, sort_pars=sort_pars) 

378 else: 

379 result = getattr(fit_result, 'params', fit_result) 

380 if isinstance(result, Parameters): 

381 return lmfit.fit_report(result, modelpars=modelpars, 

382 show_correl=show_correl, 

383 min_correl=min_correl, sort_pars=sort_pars) 

384 else: 

385 try: 

386 result = group2params(fit_result) 

387 return lmfit.fit_report(result, modelpars=modelpars, 

388 show_correl=show_correl, 

389 min_correl=min_correl, sort_pars=sort_pars) 

390 except (ValueError, AttributeError): 

391 pass 

392 return "Cannot make fit report with %s" % repr(fit_result) 

393 

394 

395def confidence_intervals(fit_result, sigmas=(1, 2, 3), **kws): 

396 """calculate the confidence intervals from a fit 

397 for supplied sigma values 

398 

399 wrapper around lmfit.conf_interval 

400 """ 

401 fitter = getattr(fit_result, 'fitter', None) 

402 result = getattr(fit_result, 'fit_details', None) 

403 return conf_interval(fitter, result, sigmas=sigmas, **kws) 

404 

405def chi2_map(fit_result, xname, yname, nx=11, ny=11, sigma=3, **kws): 

406 """generate a confidence map for any two parameters for a fit 

407 

408 Arguments 

409 ========== 

410 minout output of minimize() fit (must be run first) 

411 xname name of variable parameter for x-axis 

412 yname name of variable parameter for y-axis 

413 nx number of steps in x [11] 

414 ny number of steps in y [11] 

415 sigma scale for uncertainty range [3] 

416 

417 Returns 

418 ======= 

419 xpts, ypts, map 

420 

421 Notes 

422 ===== 

423 1. sigma sets the extent of values to explore: 

424 param.value +/- sigma * param.stderr 

425 """ 

426 # 

427 fitter = getattr(fit_result, 'fitter', None) 

428 result = getattr(fit_result, 'fit_details', None) 

429 if fitter is None or result is None: 

430 raise ValueError("chi2_map needs valid fit result as first argument") 

431 

432 c2_scale = fit_result.chi_square / result.chisqr 

433 

434 def scaled_chisqr(ndata, nparas, new_chi, best_chi, nfix=1.): 

435 """return scaled chi-sqaure, instead of probability""" 

436 return new_chi * c2_scale 

437 

438 x = result.params[xname] 

439 y = result.params[yname] 

440 xrange = (x.value + sigma * x.stderr, x.value - sigma * x.stderr) 

441 yrange = (y.value + sigma * y.stderr, y.value - sigma * y.stderr) 

442 

443 return conf_interval2d(fitter, result, xname, yname, 

444 limits=(xrange, yrange), 

445 prob_func=scaled_chisqr, 

446 nx=nx, ny=ny, **kws) 

447 

448_larch_name = '_math' 

449exports = {'param': param, 

450 'guess': guess, 

451 'param_group': param_group, 

452 'confidence_intervals': confidence_intervals, 

453 'confidence_report': confidence_report, 

454 'f_test': f_test, 'chi2_map': chi2_map, 

455 'is_param': isParameter, 

456 'isparam': isParameter, 

457 'minimize': minimize, 

458 'ufloat': ufloat, 

459 'fit_report': fit_report, 

460 'Parameters': Parameters, 

461 'Parameter': Parameter, 

462 'lm_minimize': minimize, 

463 'lm_save_model': save_model, 

464 'lm_load_model': load_model, 

465 'lm_save_modelresult': save_modelresult, 

466 'lm_load_modelresult': load_modelresult, 

467 } 

468 

469for name in ('BreitWignerModel', 'ComplexConstantModel', 

470 'ConstantModel', 'DampedHarmonicOscillatorModel', 

471 'DampedOscillatorModel', 'DoniachModel', 

472 'ExponentialGaussianModel', 'ExponentialModel', 

473 'ExpressionModel', 'GaussianModel', 'Interpreter', 

474 'LinearModel', 'LognormalModel', 'LorentzianModel', 

475 'MoffatModel', 'ParabolicModel', 'Pearson7Model', 

476 'PolynomialModel', 'PowerLawModel', 

477 'PseudoVoigtModel', 'QuadraticModel', 

478 'RectangleModel', 'SkewedGaussianModel', 

479 'StepModel', 'StudentsTModel', 'VoigtModel'): 

480 val = getattr(lmfit.models, name, None) 

481 if val is not None: 

482 exports[name] = val 

483 

484_larch_builtins = {'_math': exports}