Coverage for src/driada/experiment/wavelet_ridge.py: 45.45%

66 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1from numba.experimental import jitclass 

2import numpy as np 

3from numba import int32, float32, boolean # import the types 

4from numba import types, typed, njit 

5 

6 

7spec = [ 

8 ('indices', types.ListType(types.float64)), 

9 ('ampls', types.ListType(types.float64)), 

10 ('birth_scale', float32), 

11 ('scales', types.ListType(types.float64)), 

12 ('wvt_times', types.ListType(types.float64)), 

13 ('terminated', boolean), 

14 ('end_scale', float32), 

15 ('length', float32), 

16 ('max_scale', float32), 

17 ('max_ampl', float32), 

18 ('start', float32), 

19 ('end', float32), 

20 ('duration', float32), 

21] 

22 

23 

24@njit() 

25def maxpos_numba(x): 

26 m = max(x) 

27 return x.index(m) 

28 

29 

30@jitclass(spec) 

31class Ridge(object): 

32 

33 def __init__(self, start_index, ampl, start_scale, wvt_time): 

34 self.indices = typed.List.empty_list(types.float64) 

35 self.indices.append(start_index) 

36 

37 self.ampls = typed.List.empty_list(types.float64) 

38 self.ampls.append(ampl) 

39 

40 self.birth_scale = start_scale 

41 

42 self.scales = typed.List.empty_list(types.float64) 

43 self.scales.append(start_scale) 

44 

45 self.wvt_times = typed.List.empty_list(types.float64) 

46 self.wvt_times.append(wvt_time) 

47 

48 self.terminated = False 

49 

50 self.end_scale = -1 

51 self.length = -1 

52 self.max_scale = -1 

53 self.max_ampl = -1 

54 self.start = -1 

55 self.end = -1 

56 self.duration = -1 

57 

58 

59 def extend(self, index, ampl, scale, wvt_time): 

60 if not self.terminated: 

61 self.scales.append(scale) 

62 self.ampls.append(ampl) 

63 self.indices.append(index) 

64 self.wvt_times.append(wvt_time) 

65 else: 

66 raise ValueError('Ridge is terminated') 

67 

68 

69 def tip(self): 

70 return self.indices[-1] 

71 

72 

73 def terminate(self): 

74 if self.terminated: 

75 pass 

76 

77 else: 

78 self.end_scale = self.scales[-1] 

79 self.length = len(self.scales) 

80 self.max_scale = self.scales[maxpos_numba(self.ampls)] 

81 self.max_ampl = max(self.ampls) 

82 self.start = self.indices[0] 

83 self.end = self.indices[-1] 

84 self.duration = np.abs(self.end-self.start) 

85 self.terminated = True 

86 

87 

88class RidgeInfoContainer(object): 

89 def __init__(self, indices, ampls, scales, wvt_times): 

90 self.indices = np.array(indices) 

91 self.ampls = np.array(ampls) 

92 self.scales = np.array(scales) 

93 self.wvt_times = np.array(wvt_times) 

94 

95 self.birth_scale = scales[0] 

96 self.end_scale = scales[-1] 

97 self.length = len(self.scales) 

98 self.max_scale = self.scales[np.argmax(self.ampls)] 

99 self.max_ampl = max(self.ampls) 

100 self.start = self.indices[0] 

101 self.end = self.indices[-1] 

102 self.duration = np.abs(self.end - self.start) 

103 

104 

105def ridges_to_containers(ridges): 

106 rcs = [RidgeInfoContainer(ridge.indices, ridge.ampls, ridge.scales, ridge.wvt_times) for ridge in ridges] 

107 return rcs