Coverage for /Users/Newville/Codes/xraylarch/larch/math/spline.py: 19%
32 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2"""
3Splines for fitting to data within Larch
5"""
6from scipy.interpolate import splrep, splev
8from .. import Group, isgroup
9from ..fitting import Parameter, isParameter
12def spline_rep(x, y, group=None, name='spl1'):
13 """create a spline representation for an (x, y) data set to be
14 evaluated with spline_eval(), with
16 pars = spline_rep(x, y)
18 pars = group()
19 spline_rep(x, y, group=pars)
21 ynew = spline_eval(xnew, pars)
23 arguments:
24 ------------
25 x 1-d array for x
26 y 1-d array for y
27 name name for spline params and subgroup ['spl1']
28 group optional group to use to hold spline parameters
30 returns:
31 --------
32 group containing spline representation, which will include
33 len(x)+2 parameters (named 'spl1_c0' ... 'spl1_cN+1') and
34 a subgroup 'spl1_details'
36 notes:
37 ------
39 in order to hold multiple splines in a single parameter group,
40 the ``name`` argument must be different for each spline
41 representation, and passed to spline_eval()
42 """
43 if group is None:
44 group = Group()
46 knots, coefs, order = splrep(x, y)
47 dgroup = Group(knots=knots, order=order, coefs=coefs)
48 setattr(group, "{:s}_details".format(name), dgroup)
50 for i, val in enumerate(coefs[2:-2]):
51 pname = "{:s}_c{:d}".format(name, i)
52 p = Parameter(value=val, name=pname, vary=True)
53 setattr(group, pname, p)
54 return group
56def spline_eval(x, group, name='spl1'):
57 """evaluate spline at specified x values
59 arguments:
60 ------------
61 x input 1-d array for absicca
62 group Group containing spline representation,
63 as defined by spline_rep()
64 name name for spline params and subgroups ['spl1']
66 returns:
67 --------
68 1-d array with interpolated values
69 """
70 sgroup = getattr(group, "{:s}_details".format(name), None)
71 if sgroup is None or not isgroup(sgroup):
72 raise Warning("spline_eval: subgroup '{:s}' not found".format(name))
74 knots = getattr(sgroup, 'knots')
75 order = getattr(sgroup, 'order')
76 coefs = getattr(sgroup, 'coefs')
77 for i, val in enumerate(coefs[2:-2]):
78 pname = "{:s}_c{:d}".format(name, i)
79 cval = getattr(group, pname, None)
80 if cval is None:
81 raise Warning("spline_eval: param'{:s}' not found".format(pname))
82 if isParameter(cval):
83 cval = cval.value
84 coefs[2+i] = cval
85 setattr(sgroup, 'coefs', coefs)
86 return splev(x, [knots, coefs, order])