Coverage for src / tracekit / utils / progressive.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Progressive resolution analysis for memory-constrained scenarios. 

2 

3This module provides multi-pass analysis capabilities: preview first, 

4then zoom into regions of interest for detailed analysis. 

5 

6 

7Example: 

8 >>> from tracekit.utils.progressive import create_preview, analyze_roi 

9 >>> preview = create_preview(trace, downsample_factor=10) 

10 >>> # User inspects preview, selects ROI 

11 >>> roi_result = analyze_roi(trace, start_time=0.001, end_time=0.002) 

12 

13References: 

14 Multi-resolution analysis techniques 

15""" 

16 

17from __future__ import annotations 

18 

19from dataclasses import dataclass 

20from typing import TYPE_CHECKING, Any 

21 

22import numpy as np 

23 

24if TYPE_CHECKING: 

25 from collections.abc import Callable 

26 

27 from numpy.typing import NDArray 

28 

29 from tracekit.core.types import WaveformTrace 

30 

31 

32@dataclass 

33class PreviewResult: 

34 """Result of preview analysis. 

35 

36 

37 Attributes: 

38 downsampled_data: Downsampled waveform data. 

39 downsample_factor: Downsampling factor applied. 

40 original_length: Length of original signal. 

41 preview_length: Length of preview signal. 

42 sample_rate: Sample rate of preview (original / factor). 

43 time_vector: Time axis for preview. 

44 basic_stats: Basic statistics from preview. 

45 """ 

46 

47 downsampled_data: NDArray[np.float64] 

48 downsample_factor: int 

49 original_length: int 

50 preview_length: int 

51 sample_rate: float 

52 time_vector: NDArray[np.float64] 

53 basic_stats: dict[str, float] 

54 

55 

56@dataclass 

57class ROISelection: 

58 """Region of interest selection. 

59 

60 

61 Attributes: 

62 start_time: Start time in seconds. 

63 end_time: End time in seconds. 

64 start_index: Start sample index in original signal. 

65 end_index: End sample index in original signal. 

66 duration: Duration in seconds. 

67 num_samples: Number of samples in ROI. 

68 """ 

69 

70 start_time: float 

71 end_time: float 

72 start_index: int 

73 end_index: int 

74 duration: float 

75 num_samples: int 

76 

77 

78def create_preview( 

79 trace: WaveformTrace, 

80 *, 

81 downsample_factor: int | None = None, 

82 max_samples: int = 10_000, 

83 apply_antialiasing: bool = True, 

84) -> PreviewResult: 

85 """Create downsampled preview of waveform for quick inspection. 

86 

87 

88 Args: 

89 trace: Input waveform trace. 

90 downsample_factor: Downsampling factor (auto-computed if None). 

91 max_samples: Target maximum samples in preview. 

92 apply_antialiasing: Apply anti-aliasing lowpass filter before decimation. 

93 

94 Returns: 

95 PreviewResult with downsampled data and metadata. 

96 

97 Example: 

98 >>> preview = create_preview(large_trace, downsample_factor=10) 

99 >>> print(f"Preview: {preview.preview_length} samples (factor {preview.downsample_factor}x)") 

100 >>> # Inspect preview.basic_stats 

101 """ 

102 from scipy import signal as sp_signal 

103 

104 data = trace.data 

105 original_length = len(data) 

106 sample_rate = trace.metadata.sample_rate 

107 

108 # Auto-compute downsample factor 

109 if downsample_factor is None: 

110 downsample_factor = max(1, original_length // max_samples) 

111 # Round to nearest power of 2 for efficiency 

112 downsample_factor = 2 ** int(np.ceil(np.log2(downsample_factor))) 

113 downsample_factor = max(1, downsample_factor) 

114 

115 # Apply anti-aliasing filter if requested 

116 if apply_antialiasing and downsample_factor > 1: 

117 # Lowpass filter at Nyquist frequency of downsampled rate 

118 nyquist_freq = (sample_rate / downsample_factor) / 2 

119 sos = sp_signal.butter(8, nyquist_freq, btype="low", fs=sample_rate, output="sos") 

120 filtered = sp_signal.sosfilt(sos, data) 

121 downsampled = filtered[::downsample_factor] 

122 else: 

123 # Simple decimation without filtering 

124 downsampled = data[::downsample_factor] 

125 

126 preview_length = len(downsampled) 

127 preview_sample_rate = sample_rate / downsample_factor 

128 

129 # Create time vector 

130 time_vector = np.arange(preview_length) / preview_sample_rate 

131 

132 # Compute basic statistics 

133 basic_stats = { 

134 "mean": float(np.mean(downsampled)), 

135 "std": float(np.std(downsampled)), 

136 "min": float(np.min(downsampled)), 

137 "max": float(np.max(downsampled)), 

138 "rms": float(np.sqrt(np.mean(downsampled**2))), 

139 "peak_to_peak": float(np.ptp(downsampled)), 

140 } 

141 

142 return PreviewResult( 

143 downsampled_data=downsampled, 

144 downsample_factor=downsample_factor, 

145 original_length=original_length, 

146 preview_length=preview_length, 

147 sample_rate=preview_sample_rate, 

148 time_vector=time_vector, 

149 basic_stats=basic_stats, 

150 ) 

151 

152 

153def select_roi( 

154 trace: WaveformTrace, 

155 start_time: float, 

156 end_time: float, 

157) -> ROISelection: 

158 """Create ROI selection from time range. 

159 

160 

161 Args: 

162 trace: Input waveform trace. 

163 start_time: Start time in seconds. 

164 end_time: End time in seconds. 

165 

166 Returns: 

167 ROISelection with sample indices and metadata. 

168 

169 Raises: 

170 ValueError: If time range is invalid. 

171 

172 Example: 

173 >>> roi = select_roi(trace, start_time=0.001, end_time=0.002) 

174 >>> print(f"ROI: {roi.num_samples} samples ({roi.duration*1e6:.1f} µs)") 

175 """ 

176 sample_rate = trace.metadata.sample_rate 

177 total_length = len(trace.data) 

178 total_duration = total_length / sample_rate 

179 

180 # Validate time range 

181 if start_time < 0 or end_time > total_duration: 

182 raise ValueError( 

183 f"Time range [{start_time}, {end_time}] outside signal duration [0, {total_duration}]" 

184 ) 

185 if start_time >= end_time: 

186 raise ValueError(f"start_time ({start_time}) must be < end_time ({end_time})") 

187 

188 # Convert to sample indices 

189 start_index = int(start_time * sample_rate) 

190 end_index = int(end_time * sample_rate) 

191 

192 # Clamp to valid range 

193 start_index = max(0, min(start_index, total_length - 1)) 

194 end_index = max(start_index + 1, min(end_index, total_length)) 

195 

196 duration = end_time - start_time 

197 num_samples = end_index - start_index 

198 

199 return ROISelection( 

200 start_time=start_time, 

201 end_time=end_time, 

202 start_index=start_index, 

203 end_index=end_index, 

204 duration=duration, 

205 num_samples=num_samples, 

206 ) 

207 

208 

209def analyze_roi( 

210 trace: WaveformTrace, 

211 roi: ROISelection, 

212 *, 

213 analysis_func: Callable[[WaveformTrace], Any], 

214 **analysis_kwargs: Any, 

215) -> Any: 

216 """Analyze region of interest with high resolution. 

217 

218 

219 Args: 

220 trace: Input waveform trace. 

221 roi: ROI selection. 

222 analysis_func: Analysis function to apply to ROI. 

223 **analysis_kwargs: Additional arguments for analysis function. 

224 

225 Returns: 

226 Result of analysis function on ROI. 

227 

228 Example: 

229 >>> from tracekit.analyzers.waveform.spectral import fft 

230 >>> roi = select_roi(trace, 0.001, 0.002) 

231 >>> freq, mag = analyze_roi(trace, roi, analysis_func=fft, window='hann') 

232 """ 

233 from tracekit.core.types import TraceMetadata, WaveformTrace 

234 

235 # Extract ROI data 

236 roi_data = trace.data[roi.start_index : roi.end_index] 

237 

238 # Create new trace for ROI with only standard metadata fields 

239 roi_trace = WaveformTrace( 

240 data=roi_data, 

241 metadata=TraceMetadata( 

242 sample_rate=trace.metadata.sample_rate, 

243 vertical_scale=trace.metadata.vertical_scale, 

244 vertical_offset=trace.metadata.vertical_offset, 

245 acquisition_time=trace.metadata.acquisition_time, 

246 trigger_info=trace.metadata.trigger_info, 

247 source_file=trace.metadata.source_file, 

248 channel_name=getattr(trace.metadata, "channel_name", None), 

249 ), 

250 ) 

251 

252 # Apply analysis function 

253 return analysis_func(roi_trace, **analysis_kwargs) 

254 

255 

256def progressive_analysis( 

257 trace: WaveformTrace, 

258 *, 

259 analysis_func: Callable[[WaveformTrace], Any], 

260 downsample_factor: int = 10, 

261 roi_selector: Callable[[PreviewResult], ROISelection] | None = None, 

262 **analysis_kwargs: Any, 

263) -> tuple[PreviewResult, Any]: 

264 """Perform progressive multi-pass analysis. 

265 

266 

267 Workflow: 

268 1. Create downsampled preview 

269 2. User/algorithm selects ROI from preview 

270 3. Perform high-resolution analysis on ROI only 

271 

272 Args: 

273 trace: Input waveform trace. 

274 analysis_func: Analysis function to apply. 

275 downsample_factor: Downsampling factor for preview. 

276 roi_selector: Function to select ROI from preview (if None, analyzes full trace). 

277 **analysis_kwargs: Additional arguments for analysis function. 

278 

279 Returns: 

280 Tuple of (preview_result, analysis_result). 

281 

282 Example: 

283 >>> def select_peak_region(preview): 

284 ... # Find region with highest amplitude 

285 ... peak_idx = np.argmax(np.abs(preview.downsampled_data)) 

286 ... start_time = max(0, (peak_idx - 500) / preview.sample_rate) 

287 ... end_time = min(preview.preview_length / preview.sample_rate, 

288 ... (peak_idx + 500) / preview.sample_rate) 

289 ... return select_roi(trace, start_time, end_time) 

290 >>> 

291 >>> from tracekit.analyzers.waveform.spectral import fft 

292 >>> preview, result = progressive_analysis( 

293 ... trace, 

294 ... analysis_func=fft, 

295 ... downsample_factor=10, 

296 ... roi_selector=select_peak_region 

297 ... ) 

298 """ 

299 # Pass 1: Create preview 

300 preview = create_preview(trace, downsample_factor=downsample_factor) 

301 

302 # Pass 2: Select ROI 

303 if roi_selector is not None: 

304 roi = roi_selector(preview) 

305 # Pass 3: Analyze ROI 

306 result = analyze_roi(trace, roi, analysis_func=analysis_func, **analysis_kwargs) 

307 else: 

308 # No ROI selection, analyze full trace 

309 result = analysis_func(trace, **analysis_kwargs) 

310 

311 return preview, result 

312 

313 

314def estimate_optimal_preview_factor( 

315 trace_length: int, 

316 *, 

317 target_memory: int = 100_000_000, # 100 MB 

318 bytes_per_sample: int = 8, 

319) -> int: 

320 """Estimate optimal downsampling factor for preview. 

321 

322 Args: 

323 trace_length: Number of samples in original trace. 

324 target_memory: Target memory for preview (bytes). 

325 bytes_per_sample: Bytes per sample (8 for float64). 

326 

327 Returns: 

328 Recommended downsampling factor. 

329 

330 Example: 

331 >>> factor = estimate_optimal_preview_factor(1_000_000_000) # 1B samples 

332 >>> print(f"Downsample by {factor}x for preview") 

333 """ 

334 # Calculate required factor to fit in target memory 

335 current_memory = trace_length * bytes_per_sample 

336 factor = max(1, int(np.ceil(current_memory / target_memory))) 

337 

338 # Round to power of 2 

339 factor = 2 ** int(np.ceil(np.log2(factor))) 

340 

341 return factor # type: ignore[no-any-return] 

342 

343 

344__all__ = [ 

345 "PreviewResult", 

346 "ROISelection", 

347 "analyze_roi", 

348 "create_preview", 

349 "estimate_optimal_preview_factor", 

350 "progressive_analysis", 

351 "select_roi", 

352]