Coverage for /Users/Newville/Codes/xraylarch/larch/math/spline.py: 19%

32 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2""" 

3Splines for fitting to data within Larch 

4 

5""" 

6from scipy.interpolate import splrep, splev 

7 

8from .. import Group, isgroup 

9from ..fitting import Parameter, isParameter 

10 

11 

12def spline_rep(x, y, group=None, name='spl1'): 

13 """create a spline representation for an (x, y) data set to be 

14 evaluated with spline_eval(), with 

15 

16 pars = spline_rep(x, y) 

17 

18 pars = group() 

19 spline_rep(x, y, group=pars) 

20 

21 ynew = spline_eval(xnew, pars) 

22 

23 arguments: 

24 ------------ 

25 x 1-d array for x 

26 y 1-d array for y 

27 name name for spline params and subgroup ['spl1'] 

28 group optional group to use to hold spline parameters 

29 

30 returns: 

31 -------- 

32 group containing spline representation, which will include 

33 len(x)+2 parameters (named 'spl1_c0' ... 'spl1_cN+1') and 

34 a subgroup 'spl1_details' 

35 

36 notes: 

37 ------ 

38 

39 in order to hold multiple splines in a single parameter group, 

40 the ``name`` argument must be different for each spline 

41 representation, and passed to spline_eval() 

42 """ 

43 if group is None: 

44 group = Group() 

45 

46 knots, coefs, order = splrep(x, y) 

47 dgroup = Group(knots=knots, order=order, coefs=coefs) 

48 setattr(group, "{:s}_details".format(name), dgroup) 

49 

50 for i, val in enumerate(coefs[2:-2]): 

51 pname = "{:s}_c{:d}".format(name, i) 

52 p = Parameter(value=val, name=pname, vary=True) 

53 setattr(group, pname, p) 

54 return group 

55 

56def spline_eval(x, group, name='spl1'): 

57 """evaluate spline at specified x values 

58 

59 arguments: 

60 ------------ 

61 x input 1-d array for absicca 

62 group Group containing spline representation, 

63 as defined by spline_rep() 

64 name name for spline params and subgroups ['spl1'] 

65 

66 returns: 

67 -------- 

68 1-d array with interpolated values 

69 """ 

70 sgroup = getattr(group, "{:s}_details".format(name), None) 

71 if sgroup is None or not isgroup(sgroup): 

72 raise Warning("spline_eval: subgroup '{:s}' not found".format(name)) 

73 

74 knots = getattr(sgroup, 'knots') 

75 order = getattr(sgroup, 'order') 

76 coefs = getattr(sgroup, 'coefs') 

77 for i, val in enumerate(coefs[2:-2]): 

78 pname = "{:s}_c{:d}".format(name, i) 

79 cval = getattr(group, pname, None) 

80 if cval is None: 

81 raise Warning("spline_eval: param'{:s}' not found".format(pname)) 

82 if isParameter(cval): 

83 cval = cval.value 

84 coefs[2+i] = cval 

85 setattr(sgroup, 'coefs', coefs) 

86 return splev(x, [knots, coefs, order])