Coverage for emd/tests/test_cycles.py: 100%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.11, created at 2025-03-08 15:44 +0000

1"""Tests for single-cycle analyses in emd.cycles.""" 

2 

3import unittest 

4 

5import numpy as np 

6 

7 

8class TestCycles(unittest.TestCase): 

9 """Basic cycle identification tests.""" 

10 

11 def setUp(self): 

12 """Initialise cycles for testing.""" 

13 self.sample_rate = 1000 

14 self.seconds = 2 

15 self.pad_time = .1 

16 nsamples = int((self.sample_rate * self.seconds) + (2*self.pad_time*self.sample_rate)) 

17 self.time_vect = np.linspace(-self.pad_time, 

18 self.seconds+self.pad_time, 

19 nsamples) 

20 self.signal = np.sin(2 * np.pi * 10 * self.time_vect)[:, None] 

21 

22 def cycle_generator(self, f, phase=0, distort=None): 

23 """Return cycles to be tested.""" 

24 from ..cycles import get_cycle_vector 

25 from ..spectra import frequency_transform 

26 

27 x = np.sin(2 * np.pi * f * self.time_vect + phase)[:, None] 

28 

29 # Add a wobble 

30 if distort is not None: 

31 x[distort - 25:distort + 25, 0] += np.linspace(-.1, .1, 50) 

32 

33 # This is a perfect sin so we can use normal hilbert 

34 IP, IF, IA = frequency_transform(x, self.sample_rate, 'hilbert') 

35 # Find good cycles 

36 cycles = get_cycle_vector(IP, return_good=True)[:, 0] 

37 

38 return cycles 

39 

40 def test_simple_cycle_counting(self): 

41 """Test that correct number of cycles are identified.""" 

42 # Test basic cycle detection 

43 uni_cycles = np.unique(self.cycle_generator(4)) 

44 assert(np.all(uni_cycles == np.arange(-1, 8))) 

45 

46 uni_cycles = np.unique(self.cycle_generator(5, phase=1.5 * np.pi)) 

47 print(uni_cycles) 

48 assert(np.all(uni_cycles == np.arange(-1, 10))) 

49 

50 def test_cycle_count_with_bad_start_and_end(self): 

51 """Test that cycles clipped at edges of signal are dropped.""" 

52 # Test basic cycle detection 

53 cycles = self.cycle_generator(4, phase=0) 

54 uni_cycles = np.unique(cycles) 

55 assert(np.all(uni_cycles == np.arange(-1, 8))) 

56 assert(cycles[50] == -1) 

57 assert(cycles[2150] == -1) 

58 

59 cycles = self.cycle_generator(5, phase=0) 

60 uni_cycles = np.unique(cycles) 

61 assert(np.all(uni_cycles == np.arange(-1, 10))) 

62 assert(cycles[50] == -1) 

63 assert(cycles[2150] == -1) 

64 

65 def test_cycle_count_with_bad_in_middle(self): 

66 """Test that bad cycles in continuous signal are dropped.""" 

67 cycles = self.cycle_generator(4, phase=1.5 * np.pi, distort=1100) 

68 uni_cycles = np.unique(cycles) 

69 assert(np.all(uni_cycles == np.arange(-1, 7))) 

70 assert(cycles[1100] == -1) 

71 

72 def test_cycle_control_points(self): 

73 """Test that cycle control points are correctly identified.""" 

74 from ..cycles import get_control_points 

75 

76 x = np.sin(2*np.pi*np.linspace(0, 1, 1280)) 

77 cycles = np.ones_like(x, dtype=int) 

78 ctrl = get_control_points(x, cycles) 

79 

80 # We accept a 1 sample error in ctrl point location... 

81 ref = 1280*np.linspace(0, 1, 5) 

82 assert(np.abs(ctrl-ref).max()) 

83 

84 

85class TestCyclesSupport(unittest.TestCase): 

86 """Test functionality in emd._cycles_support module.""" 

87 

88 def setUp(self): 

89 """Initialise cycles for testing.""" 

90 from .._cycles_support import get_cycle_stat_from_samples 

91 from ..cycles import (get_chain_vector, get_cycle_vector, 

92 get_subset_vector) 

93 from ..spectra import frequency_transform 

94 

95 X = np.sin(2*np.pi*10*np.linspace(0, 2, 512)) 

96 X = X * (2-np.cos(2*np.pi*1*np.linspace(0, 2, 512))) 

97 IP, IF, IA = frequency_transform(X, 512, 'hilbert') 

98 

99 self.cycle_vect = get_cycle_vector(IP[:, 0], return_good=False) 

100 self.max_amps = get_cycle_stat_from_samples(IA[:, 0], self.cycle_vect, np.max) 

101 

102 valids = self.max_amps > 1.5 

103 self.subset_vect = get_subset_vector(valids) 

104 self.chain_vect = get_chain_vector(self.subset_vect) 

105 

106 def test_cycle_maps(self): 

107 """Ensure that mapping between samples, cycles, subsets and chains are working.""" 

108 from .._cycles_support import map_cycle_to_samples, map_sample_to_cycle 

109 

110 # Test 1 - q2 should contain 350 

111 q1 = map_sample_to_cycle(self.cycle_vect, 350) 

112 q2 = map_cycle_to_samples(self.cycle_vect, q1[0]) 

113 assert(350 in q2) 

114 

115 from .._cycles_support import map_cycle_to_subset, map_subset_to_cycle 

116 

117 # Test 2 - should recover 9 

118 q3 = map_subset_to_cycle(self.subset_vect, 9) 

119 q4 = map_cycle_to_subset(self.subset_vect, q3) 

120 assert(q4[0] == 9) 

121 

122 from .._cycles_support import (map_sample_to_subset, 

123 map_subset_to_sample) 

124 

125 # Test 3 - should recover 350 

126 q5 = map_sample_to_subset(self.subset_vect, self.cycle_vect, 350) 

127 q6 = map_subset_to_sample(self.subset_vect, self.cycle_vect, q5[0]) 

128 assert(350 in q6) 

129 

130 from .._cycles_support import map_chain_to_subset, map_subset_to_chain 

131 

132 # Test 4 - Should recover 7 

133 q7 = map_subset_to_chain(self.chain_vect, 7) 

134 q8 = map_chain_to_subset(self.chain_vect, q7) 

135 assert(7 in q8) 

136 

137 from .._cycles_support import map_cycle_to_chain 

138 

139 # Test 5 - check that third cycle with -1 in subset doesn't have a chain 

140 q9 = map_cycle_to_chain(self.chain_vect, self.subset_vect, np.where(self.subset_vect == -1)[0][3]) 

141 assert(q9 is None) 

142 

143 

144class TestCyclesObject(unittest.TestCase): 

145 """Ensure that cycle object is working as expected.""" 

146 

147 def setUp(self): 

148 """Initialise cycles for testing.""" 

149 from ..cycles import Cycles 

150 from ..spectra import frequency_transform 

151 

152 X = np.sin(2*np.pi*10*np.linspace(0, 2, 512)) 

153 self.X = X * (2-np.cos(2*np.pi*1*np.linspace(0, 2, 512))) 

154 self.IP, self.IF, self.IA = frequency_transform(X, 512, 'hilbert') 

155 

156 self.C = Cycles(self.IP[:, 0]) 

157 

158 def test_cycle_object_metrics(self): 

159 """Ensure that cycle metric computation and storage are working.""" 

160 from ..cycles import cf_ascending_zero_sample 

161 

162 self.C.compute_cycle_metric('max_amp', self.IA[:, 0], np.max) 

163 self.C.compute_cycle_timings() 

164 

165 self.C.compute_cycle_metric('asc_samp', self.X, cf_ascending_zero_sample, mode='augmented') 

166 

167 xx = np.arange(self.C.ncycles) 

168 self.C.add_cycle_metric('range', xx) 

169 

170 conditions = ['max_amp>0.75'] 

171 self.C.pick_cycle_subset(conditions) 

172 self.C.compute_chain_timings() 

173 

174 df = self.C.get_metric_dataframe() 

175 assert(len(df['max_amp']) == self.C.ncycles) 

176 df = self.C.get_metric_dataframe(subset=True) 

177 assert(len(df['max_amp']) == 20) 

178 conditions = ['max_amp>0.75', 'range>5'] 

179 df = self.C.get_metric_dataframe(conditions=conditions) 

180 assert(len(df['max_amp']) == 14) 

181 

182 def test_cycle_object_iteration(self): 

183 """Ensure that cycle iteration is working.""" 

184 from ..cycles import phase_align 

185 pa, phasex = phase_align(self.IP, self.IF, self.C) 

186 

187 

188class TestKDTreeMatch(unittest.TestCase): 

189 """Ensure that KD-Tree matching is working.""" 

190 

191 def test_kdt(self): 

192 """Ensure that KD-Tree matching is working.""" 

193 x = np.linspace(0, 1) 

194 y = np.linspace(0, 1, 10) 

195 

196 from ..cycles import kdt_match 

197 x_inds, y_inds = kdt_match(x, y, K=2) 

198 

199 assert(all(y_inds == np.arange(10))) 

200 

201 xx = np.array([0, 5, 11, 16, 22, 27, 33, 38, 44, 49]) 

202 assert(all(x_inds == xx)) 

203 

204 

205class TestCycleStats(): 

206 """Ensure that cycle stats are computed as expected.""" 

207 

208 def test_get_cycle_stat(self): 

209 """Ensure that cycle stats are computed as expected.""" 

210 from ..cycles import get_cycle_stat 

211 

212 x = np.array([-1, 0, 0, 0, 0, 1, 1, 2, 2, 2, -1]) 

213 y = np.ones_like(x) 

214 

215 # Compute the average of y within bins of x 

216 bin_avg = get_cycle_stat(x, y) 

217 print(bin_avg) 

218 assert(np.all(bin_avg == [1., 1., 1.])) 

219 

220 # Compute sum of y within bins of x and return full vector 

221 bin_avg = get_cycle_stat(x, y, out='samples', func=np.sum) 

222 assert(np.allclose(bin_avg, np.array([np.nan, 4., 4., 4., 4., 2., 2., 3., 3., 3., np.nan]), equal_nan=True)) 

223 

224 # Compute the sum of y within bins of x 

225 bin_counts = get_cycle_stat(x, y, func=np.sum) 

226 print(bin_counts) 

227 assert(np.all(bin_counts == [4, 2, 3]))