Coverage for /Users/Newville/Codes/xraylarch/larch/io/export_modelresult.py: 5%
93 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2#
3# export a ModelResult
4#
5import sys
6import numpy as np
7from lmfit.model import ModelResult
8from larch.utils import gformat, getfloat_attr
10def export_modelresult(result, filename='fitresult.xdi',
11 datafile=None, ydata=None, yerr=None,
12 **kwargs):
13 """
14 export an lmfit ModelResult to an XDI data file
16 Arguments
17 ---------
18 result ModelResult, required
19 filename name of output file ['fitresult.xdi']
20 datafile name of data file [`None`]
21 ydata data array used for fit [`None`]
22 yerr data error array used for fit [`None`]
24 Notes
25 -----
26 keyword arguments should include independent variables
28 Example
29 -------
30 result = model.fit(ydata, params, x=x)
31 export_modelresult(result, 'fitresult_1.xdi', x=x,
32 datafile='XYData.txt')
33 """
34 if not isinstance(result, ModelResult):
35 raise ValueError("export_fit needs a lmfit ModelReult")
37 header = ["XDI/1.1 Lmfit Result File"]
38 hadd = header.append
39 if datafile is not None:
40 hadd(" Datafile.name: %s " % datafile)
41 else:
42 hadd(" Datafile.name: <unknnown>")
44 ndata = len(result.best_fit)
45 columns = {}
46 for aname in result.model.independent_vars:
47 val = kwargs.get(aname, None)
48 if val is not None and len(val) == ndata:
49 columns[aname] = val
51 if ydata is not None:
52 columns['ydata'] = ydata
54 if yerr is not None:
55 columns['yerr'] = yerr
57 columns['best_fit'] = result.best_fit
58 columns['init_fit'] = result.init_fit
59 delta_fit = 0.0*result.best_fit
60 if not any([p.stderr is None for p in result.params.values()]):
61 delta_fit = result.eval_uncertainty(result.params, **kwargs)
63 columns['delta_fit'] = delta_fit
64 if len(result.model.components) > 1:
65 comps = result.eval_components(result.params, **kwargs)
66 for name, val in comps.items():
67 columns[name] = val
69 clabel = []
70 for i, cname in enumerate(columns):
71 hadd(" Column.%i: %s" % (i+1, cname))
72 clabel.append('%15s ' % cname)
74 hadd("Fit.Statistics: Start here")
75 hadd(" Fit.model_name: %s" % result.model.name)
76 hadd(" Fit.method: %s" % result.method)
77 hadd(" Fit.n_function_evals: %s" % getfloat_attr(result, 'nfev'))
78 hadd(" Fit.n_data_points: %s" % getfloat_attr(result, 'ndata'))
79 hadd(" Fit.n_variables: %s" % getfloat_attr(result, 'nvarys'))
80 hadd(" Fit.chi_square: %s" % getfloat_attr(result, 'chisqr', length=11))
81 hadd(" Fit.reduced_chi_square: %s" % getfloat_attr(result, 'redchi', length=11))
82 hadd(" Fit.akaike_info_crit: %s" % getfloat_attr(result, 'aic', length=11))
83 hadd(" Fit.bayesian_info_crit: %s" % getfloat_attr(result, 'bic', length=11))
85 hadd("Param.Statistics: Start here")
86 namelen = max([len(p) for p in result.params])
87 for name, par in result.params.items():
88 space = ' '*(namelen+1-len(name))
89 nout = "Param.%s:%s" % (name, space)
90 inval = '(init= ?)'
91 if par.init_value is not None:
92 inval = '(init=% .7g)' % par.init_value
94 try:
95 sval = gformat(par.value)
96 except (TypeError, ValueError):
97 sval = 'Non Numeric Value?'
98 if par.stderr is not None:
99 serr = gformat(par.stderr, length=9)
100 sval = '%s +/-%s' % (sval, serr)
102 if par.vary:
103 bounds = "[%s: %s]" % (gformat(par.min), gformat(par.max))
104 hadd(" %s %s %s %s" % (nout, sval, bounds, inval))
105 elif par.expr is not None:
106 hadd(" %s %s == '%s'" % (nout, sval, par.expr))
107 else:
108 hadd(" %s % .7g (fixed)" % (nout, par.value))
110 hadd("//////// Fit Report ////////")
111 for r in result.fit_report().split('\n'):
112 hadd(" %s" % r)
113 hadd("-" * 77)
114 hadd("".join(clabel)[1:])
115 header[0] = "XDI/1.1 Lmfit Result File %i header lines" % (len(header))
116 dtable = []
117 for key, dat in columns.items():
118 dtable.append(dat)
120 dtable = np.array(dtable).transpose()
121 datatable = []
122 for i in range(ndata):
123 col = dtable[i, :]*1.0
124 row = []
125 for cval in col:
126 try:
127 val = gformat(cval, length=15)
128 except:
129 val = repr(cval)
130 row.append(val)
131 datatable.append(" ".join(row))
133 datatable.append('')
134 with open(filename, 'w', encoding=sys.getdefaultencoding()) as fh:
135 fh.write("\n".join(['#%s' % s for s in header]))
136 fh.write("\n")
137 fh.write("\n".join(datatable))