Coverage for src / tracekit / utils / progressive.py: 100%
74 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Progressive resolution analysis for memory-constrained scenarios.
3This module provides multi-pass analysis capabilities: preview first,
4then zoom into regions of interest for detailed analysis.
7Example:
8 >>> from tracekit.utils.progressive import create_preview, analyze_roi
9 >>> preview = create_preview(trace, downsample_factor=10)
10 >>> # User inspects preview, selects ROI
11 >>> roi_result = analyze_roi(trace, start_time=0.001, end_time=0.002)
13References:
14 Multi-resolution analysis techniques
15"""
17from __future__ import annotations
19from dataclasses import dataclass
20from typing import TYPE_CHECKING, Any
22import numpy as np
24if TYPE_CHECKING:
25 from collections.abc import Callable
27 from numpy.typing import NDArray
29 from tracekit.core.types import WaveformTrace
32@dataclass
33class PreviewResult:
34 """Result of preview analysis.
37 Attributes:
38 downsampled_data: Downsampled waveform data.
39 downsample_factor: Downsampling factor applied.
40 original_length: Length of original signal.
41 preview_length: Length of preview signal.
42 sample_rate: Sample rate of preview (original / factor).
43 time_vector: Time axis for preview.
44 basic_stats: Basic statistics from preview.
45 """
47 downsampled_data: NDArray[np.float64]
48 downsample_factor: int
49 original_length: int
50 preview_length: int
51 sample_rate: float
52 time_vector: NDArray[np.float64]
53 basic_stats: dict[str, float]
56@dataclass
57class ROISelection:
58 """Region of interest selection.
61 Attributes:
62 start_time: Start time in seconds.
63 end_time: End time in seconds.
64 start_index: Start sample index in original signal.
65 end_index: End sample index in original signal.
66 duration: Duration in seconds.
67 num_samples: Number of samples in ROI.
68 """
70 start_time: float
71 end_time: float
72 start_index: int
73 end_index: int
74 duration: float
75 num_samples: int
78def create_preview(
79 trace: WaveformTrace,
80 *,
81 downsample_factor: int | None = None,
82 max_samples: int = 10_000,
83 apply_antialiasing: bool = True,
84) -> PreviewResult:
85 """Create downsampled preview of waveform for quick inspection.
88 Args:
89 trace: Input waveform trace.
90 downsample_factor: Downsampling factor (auto-computed if None).
91 max_samples: Target maximum samples in preview.
92 apply_antialiasing: Apply anti-aliasing lowpass filter before decimation.
94 Returns:
95 PreviewResult with downsampled data and metadata.
97 Example:
98 >>> preview = create_preview(large_trace, downsample_factor=10)
99 >>> print(f"Preview: {preview.preview_length} samples (factor {preview.downsample_factor}x)")
100 >>> # Inspect preview.basic_stats
101 """
102 from scipy import signal as sp_signal
104 data = trace.data
105 original_length = len(data)
106 sample_rate = trace.metadata.sample_rate
108 # Auto-compute downsample factor
109 if downsample_factor is None:
110 downsample_factor = max(1, original_length // max_samples)
111 # Round to nearest power of 2 for efficiency
112 downsample_factor = 2 ** int(np.ceil(np.log2(downsample_factor)))
113 downsample_factor = max(1, downsample_factor)
115 # Apply anti-aliasing filter if requested
116 if apply_antialiasing and downsample_factor > 1:
117 # Lowpass filter at Nyquist frequency of downsampled rate
118 nyquist_freq = (sample_rate / downsample_factor) / 2
119 sos = sp_signal.butter(8, nyquist_freq, btype="low", fs=sample_rate, output="sos")
120 filtered = sp_signal.sosfilt(sos, data)
121 downsampled = filtered[::downsample_factor]
122 else:
123 # Simple decimation without filtering
124 downsampled = data[::downsample_factor]
126 preview_length = len(downsampled)
127 preview_sample_rate = sample_rate / downsample_factor
129 # Create time vector
130 time_vector = np.arange(preview_length) / preview_sample_rate
132 # Compute basic statistics
133 basic_stats = {
134 "mean": float(np.mean(downsampled)),
135 "std": float(np.std(downsampled)),
136 "min": float(np.min(downsampled)),
137 "max": float(np.max(downsampled)),
138 "rms": float(np.sqrt(np.mean(downsampled**2))),
139 "peak_to_peak": float(np.ptp(downsampled)),
140 }
142 return PreviewResult(
143 downsampled_data=downsampled,
144 downsample_factor=downsample_factor,
145 original_length=original_length,
146 preview_length=preview_length,
147 sample_rate=preview_sample_rate,
148 time_vector=time_vector,
149 basic_stats=basic_stats,
150 )
153def select_roi(
154 trace: WaveformTrace,
155 start_time: float,
156 end_time: float,
157) -> ROISelection:
158 """Create ROI selection from time range.
161 Args:
162 trace: Input waveform trace.
163 start_time: Start time in seconds.
164 end_time: End time in seconds.
166 Returns:
167 ROISelection with sample indices and metadata.
169 Raises:
170 ValueError: If time range is invalid.
172 Example:
173 >>> roi = select_roi(trace, start_time=0.001, end_time=0.002)
174 >>> print(f"ROI: {roi.num_samples} samples ({roi.duration*1e6:.1f} µs)")
175 """
176 sample_rate = trace.metadata.sample_rate
177 total_length = len(trace.data)
178 total_duration = total_length / sample_rate
180 # Validate time range
181 if start_time < 0 or end_time > total_duration:
182 raise ValueError(
183 f"Time range [{start_time}, {end_time}] outside signal duration [0, {total_duration}]"
184 )
185 if start_time >= end_time:
186 raise ValueError(f"start_time ({start_time}) must be < end_time ({end_time})")
188 # Convert to sample indices
189 start_index = int(start_time * sample_rate)
190 end_index = int(end_time * sample_rate)
192 # Clamp to valid range
193 start_index = max(0, min(start_index, total_length - 1))
194 end_index = max(start_index + 1, min(end_index, total_length))
196 duration = end_time - start_time
197 num_samples = end_index - start_index
199 return ROISelection(
200 start_time=start_time,
201 end_time=end_time,
202 start_index=start_index,
203 end_index=end_index,
204 duration=duration,
205 num_samples=num_samples,
206 )
209def analyze_roi(
210 trace: WaveformTrace,
211 roi: ROISelection,
212 *,
213 analysis_func: Callable[[WaveformTrace], Any],
214 **analysis_kwargs: Any,
215) -> Any:
216 """Analyze region of interest with high resolution.
219 Args:
220 trace: Input waveform trace.
221 roi: ROI selection.
222 analysis_func: Analysis function to apply to ROI.
223 **analysis_kwargs: Additional arguments for analysis function.
225 Returns:
226 Result of analysis function on ROI.
228 Example:
229 >>> from tracekit.analyzers.waveform.spectral import fft
230 >>> roi = select_roi(trace, 0.001, 0.002)
231 >>> freq, mag = analyze_roi(trace, roi, analysis_func=fft, window='hann')
232 """
233 from tracekit.core.types import TraceMetadata, WaveformTrace
235 # Extract ROI data
236 roi_data = trace.data[roi.start_index : roi.end_index]
238 # Create new trace for ROI with only standard metadata fields
239 roi_trace = WaveformTrace(
240 data=roi_data,
241 metadata=TraceMetadata(
242 sample_rate=trace.metadata.sample_rate,
243 vertical_scale=trace.metadata.vertical_scale,
244 vertical_offset=trace.metadata.vertical_offset,
245 acquisition_time=trace.metadata.acquisition_time,
246 trigger_info=trace.metadata.trigger_info,
247 source_file=trace.metadata.source_file,
248 channel_name=getattr(trace.metadata, "channel_name", None),
249 ),
250 )
252 # Apply analysis function
253 return analysis_func(roi_trace, **analysis_kwargs)
256def progressive_analysis(
257 trace: WaveformTrace,
258 *,
259 analysis_func: Callable[[WaveformTrace], Any],
260 downsample_factor: int = 10,
261 roi_selector: Callable[[PreviewResult], ROISelection] | None = None,
262 **analysis_kwargs: Any,
263) -> tuple[PreviewResult, Any]:
264 """Perform progressive multi-pass analysis.
267 Workflow:
268 1. Create downsampled preview
269 2. User/algorithm selects ROI from preview
270 3. Perform high-resolution analysis on ROI only
272 Args:
273 trace: Input waveform trace.
274 analysis_func: Analysis function to apply.
275 downsample_factor: Downsampling factor for preview.
276 roi_selector: Function to select ROI from preview (if None, analyzes full trace).
277 **analysis_kwargs: Additional arguments for analysis function.
279 Returns:
280 Tuple of (preview_result, analysis_result).
282 Example:
283 >>> def select_peak_region(preview):
284 ... # Find region with highest amplitude
285 ... peak_idx = np.argmax(np.abs(preview.downsampled_data))
286 ... start_time = max(0, (peak_idx - 500) / preview.sample_rate)
287 ... end_time = min(preview.preview_length / preview.sample_rate,
288 ... (peak_idx + 500) / preview.sample_rate)
289 ... return select_roi(trace, start_time, end_time)
290 >>>
291 >>> from tracekit.analyzers.waveform.spectral import fft
292 >>> preview, result = progressive_analysis(
293 ... trace,
294 ... analysis_func=fft,
295 ... downsample_factor=10,
296 ... roi_selector=select_peak_region
297 ... )
298 """
299 # Pass 1: Create preview
300 preview = create_preview(trace, downsample_factor=downsample_factor)
302 # Pass 2: Select ROI
303 if roi_selector is not None:
304 roi = roi_selector(preview)
305 # Pass 3: Analyze ROI
306 result = analyze_roi(trace, roi, analysis_func=analysis_func, **analysis_kwargs)
307 else:
308 # No ROI selection, analyze full trace
309 result = analysis_func(trace, **analysis_kwargs)
311 return preview, result
314def estimate_optimal_preview_factor(
315 trace_length: int,
316 *,
317 target_memory: int = 100_000_000, # 100 MB
318 bytes_per_sample: int = 8,
319) -> int:
320 """Estimate optimal downsampling factor for preview.
322 Args:
323 trace_length: Number of samples in original trace.
324 target_memory: Target memory for preview (bytes).
325 bytes_per_sample: Bytes per sample (8 for float64).
327 Returns:
328 Recommended downsampling factor.
330 Example:
331 >>> factor = estimate_optimal_preview_factor(1_000_000_000) # 1B samples
332 >>> print(f"Downsample by {factor}x for preview")
333 """
334 # Calculate required factor to fit in target memory
335 current_memory = trace_length * bytes_per_sample
336 factor = max(1, int(np.ceil(current_memory / target_memory)))
338 # Round to power of 2
339 factor = 2 ** int(np.ceil(np.log2(factor)))
341 return factor # type: ignore[no-any-return]
344__all__ = [
345 "PreviewResult",
346 "ROISelection",
347 "analyze_roi",
348 "create_preview",
349 "estimate_optimal_preview_factor",
350 "progressive_analysis",
351 "select_roi",
352]