Coverage for src / tracekit / analyzers / statistics / trend.py: 99%
164 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Trend detection and analysis for signal data.
3This module provides linear trend detection, drift analysis, and
4detrending functions for identifying systematic changes in signals.
7Example:
8 >>> from tracekit.analyzers.statistics.trend import (
9 ... detect_trend, detrend, moving_average
10 ... )
11 >>> result = detect_trend(trace)
12 >>> print(f"Slope: {result['slope']:.2e} V/s")
13 >>> detrended = detrend(trace)
15References:
16 Montgomery, D. C. (2012). Introduction to Statistical Quality Control
17 NIST Engineering Statistics Handbook
18"""
20from __future__ import annotations
22from dataclasses import dataclass
23from typing import TYPE_CHECKING, Any, Literal
25import numpy as np
26from scipy import stats
28from tracekit.core.types import WaveformTrace
30if TYPE_CHECKING:
31 from numpy.typing import NDArray
34@dataclass
35class TrendResult:
36 """Result of trend analysis.
38 Attributes:
39 slope: Trend slope (units per second).
40 intercept: Trend intercept (at t=0).
41 r_squared: Coefficient of determination.
42 p_value: Statistical significance (p < 0.05 is significant).
43 std_error: Standard error of slope estimate.
44 is_significant: Whether trend is statistically significant.
45 trend_line: Fitted trend values at each sample.
46 """
48 slope: float
49 intercept: float
50 r_squared: float
51 p_value: float
52 std_error: float
53 is_significant: bool
54 trend_line: NDArray[np.float64]
57def detect_trend(
58 trace: WaveformTrace | NDArray[np.floating[Any]],
59 *,
60 significance_level: float = 0.05,
61 sample_rate: float | None = None,
62) -> TrendResult:
63 """Detect linear trend in signal data.
65 Fits a linear regression and tests for statistical significance.
66 Reports slope, R-squared, and whether drift is significant.
68 Args:
69 trace: Input trace or numpy array.
70 significance_level: P-value threshold for significance (default 0.05).
71 sample_rate: Sample rate in Hz (required for array input).
73 Returns:
74 TrendResult with trend analysis.
76 Raises:
77 ValueError: If trace is array and sample_rate is not provided.
79 Example:
80 >>> result = detect_trend(trace)
81 >>> if result.is_significant:
82 ... print(f"Significant drift: {result.slope:.2e} V/s")
83 ... print(f"R-squared: {result.r_squared:.4f}")
85 References:
86 NIST Engineering Statistics Handbook Section 6.6
87 """
88 if isinstance(trace, WaveformTrace):
89 data = trace.data
90 fs = trace.metadata.sample_rate
91 else:
92 data = trace
93 if sample_rate is None:
94 raise ValueError("sample_rate required when trace is array")
95 fs = sample_rate
97 n = len(data)
99 if n < 3:
100 return TrendResult(
101 slope=np.nan,
102 intercept=np.nan,
103 r_squared=np.nan,
104 p_value=np.nan,
105 std_error=np.nan,
106 is_significant=False,
107 trend_line=np.full(n, np.nan, dtype=np.float64),
108 )
110 # Time axis in seconds
111 t = np.arange(n) / fs
113 # Linear regression
114 result = stats.linregress(t, data)
116 slope = float(result.slope)
117 intercept = float(result.intercept)
118 r_squared = float(result.rvalue**2)
119 p_value = float(result.pvalue)
120 std_error = float(result.stderr)
121 is_significant = p_value < significance_level
123 # Compute trend line
124 trend_line = intercept + slope * t
126 return TrendResult(
127 slope=slope,
128 intercept=intercept,
129 r_squared=r_squared,
130 p_value=p_value,
131 std_error=std_error,
132 is_significant=is_significant,
133 trend_line=trend_line.astype(np.float64),
134 )
137def detrend(
138 trace: WaveformTrace | NDArray[np.floating[Any]],
139 *,
140 method: Literal["linear", "constant", "polynomial"] = "linear",
141 order: int = 1,
142 return_trend: bool = False,
143 sample_rate: float | None = None,
144) -> NDArray[np.float64] | tuple[NDArray[np.float64], NDArray[np.float64]]:
145 """Remove trend from signal data.
147 Subtracts fitted trend to isolate fluctuations around baseline.
149 Args:
150 trace: Input trace or numpy array.
151 method: Detrending method:
152 - "constant": Remove mean (DC offset)
153 - "linear": Remove linear trend (default)
154 - "polynomial": Remove polynomial trend
155 order: Polynomial order (for method="polynomial").
156 return_trend: If True, also return the removed trend.
157 sample_rate: Sample rate in Hz (required for array input, only for linear).
159 Returns:
160 Detrended data array.
161 If return_trend=True, returns (detrended, trend).
163 Raises:
164 ValueError: If method is not recognized.
166 Example:
167 >>> detrended = detrend(trace, method="linear")
168 >>> # Or get the trend too
169 >>> detrended, trend = detrend(trace, return_trend=True)
170 """
171 if isinstance(trace, WaveformTrace):
172 data = trace.data.astype(np.float64)
173 fs = trace.metadata.sample_rate
174 else:
175 data = np.array(trace, dtype=np.float64)
176 fs = sample_rate if sample_rate else 1.0
178 n = len(data)
180 if method == "constant":
181 trend = np.full(n, np.mean(data), dtype=np.float64)
183 elif method == "linear":
184 result = detect_trend(trace, sample_rate=fs)
185 trend = result.trend_line
187 elif method == "polynomial":
188 t = np.arange(n)
189 coeffs = np.polyfit(t, data, order)
190 trend = np.polyval(coeffs, t)
192 else:
193 raise ValueError(f"Unknown method: {method}")
195 detrended = data - trend
197 if return_trend:
198 return detrended, trend.astype(np.float64)
199 return detrended
202def moving_average(
203 trace: WaveformTrace | NDArray[np.floating[Any]],
204 *,
205 window_size: int,
206 method: Literal["simple", "exponential", "weighted"] = "simple",
207 alpha: float = 0.1,
208) -> NDArray[np.float64]:
209 """Compute moving average of signal.
211 Smooths signal by averaging over sliding window.
213 Args:
214 trace: Input trace or numpy array.
215 window_size: Size of averaging window in samples.
216 method: Averaging method:
217 - "simple": Simple moving average (default)
218 - "exponential": Exponential moving average
219 - "weighted": Linearly weighted moving average
220 alpha: Smoothing factor for exponential method (0-1).
222 Returns:
223 Smoothed signal array (same length as input).
225 Raises:
226 ValueError: If method is not recognized.
228 Example:
229 >>> smoothed = moving_average(trace, window_size=10)
230 >>> # Exponential smoothing
231 >>> ema = moving_average(trace, window_size=10, method="exponential", alpha=0.2)
232 """
233 if isinstance(trace, WaveformTrace):
234 data = trace.data.astype(np.float64)
235 else:
236 data = np.array(trace, dtype=np.float64)
238 n = len(data)
240 window_size = min(window_size, n)
242 if window_size < 1:
243 return data.copy()
245 if method == "simple":
246 # Simple moving average using convolution
247 kernel = np.ones(window_size) / window_size
248 # Pad for same output length
249 padded = np.pad(data, (window_size - 1, 0), mode="edge")
250 result = np.convolve(padded, kernel, mode="valid")
252 elif method == "exponential":
253 # Exponential moving average
254 result = np.zeros(n, dtype=np.float64)
255 result[0] = data[0]
256 for i in range(1, n):
257 result[i] = alpha * data[i] + (1 - alpha) * result[i - 1]
259 elif method == "weighted":
260 # Linearly weighted moving average
261 weights = np.arange(1, window_size + 1, dtype=np.float64)
262 weights = weights / np.sum(weights)
264 padded = np.pad(data, (window_size - 1, 0), mode="edge")
265 result = np.convolve(padded, weights, mode="valid")
267 else:
268 raise ValueError(f"Unknown method: {method}")
270 return result.astype(np.float64)
273def detect_drift_segments(
274 trace: WaveformTrace | NDArray[np.floating[Any]],
275 *,
276 segment_size: int,
277 threshold_slope: float | None = None,
278 sample_rate: float | None = None,
279) -> list[dict]: # type: ignore[type-arg]
280 """Detect segments with significant drift.
282 Divides signal into segments and identifies those with
283 statistically significant linear trends.
285 Args:
286 trace: Input trace or numpy array.
287 segment_size: Size of each segment in samples.
288 threshold_slope: Minimum slope magnitude to flag (units/second).
289 If None, uses statistical significance.
290 sample_rate: Sample rate in Hz (required for array input).
292 Returns:
293 List of dictionaries describing drift segments:
294 - start_sample: Start index of segment
295 - end_sample: End index of segment
296 - start_time: Start time in seconds
297 - end_time: End time in seconds
298 - slope: Trend slope
299 - r_squared: Coefficient of determination
301 Raises:
302 ValueError: If trace is array and sample_rate is not provided.
304 Example:
305 >>> segments = detect_drift_segments(trace, segment_size=1000)
306 >>> for seg in segments:
307 ... print(f"Drift at {seg['start_time']:.3f}s: {seg['slope']:.2e} V/s")
308 """
309 if isinstance(trace, WaveformTrace):
310 data = trace.data
311 fs = trace.metadata.sample_rate
312 else:
313 data = trace
314 if sample_rate is None:
315 raise ValueError("sample_rate required when trace is array")
316 fs = sample_rate
318 n = len(data)
319 drift_segments = []
321 for start in range(0, n, segment_size):
322 end = min(start + segment_size, n)
324 if end - start < 10: # Need minimum points for regression
325 continue
327 segment_data = data[start:end]
328 segment_trace = segment_data # Array
330 result = detect_trend(segment_trace, sample_rate=fs)
332 # Check if drift is significant
333 is_drift = result.is_significant
334 if threshold_slope is not None:
335 is_drift = is_drift and abs(result.slope) >= threshold_slope
337 if is_drift:
338 drift_segments.append(
339 {
340 "start_sample": start,
341 "end_sample": end,
342 "start_time": start / fs,
343 "end_time": end / fs,
344 "slope": result.slope,
345 "r_squared": result.r_squared,
346 "p_value": result.p_value,
347 }
348 )
350 return drift_segments
353def change_point_detection(
354 trace: WaveformTrace | NDArray[np.floating[Any]],
355 *,
356 min_segment_size: int = 10,
357 penalty: float | None = None,
358) -> list[int]:
359 """Detect change points in signal level or trend.
361 Identifies locations where the signal characteristics change
362 significantly, using a simple CUSUM-based approach.
364 Args:
365 trace: Input trace or numpy array.
366 min_segment_size: Minimum samples between change points.
367 penalty: Penalty for adding change points (controls sensitivity).
368 If None, auto-selected based on signal variance.
370 Returns:
371 List of sample indices where changes occur.
373 Example:
374 >>> change_points = change_point_detection(trace)
375 >>> for cp in change_points:
376 ... print(f"Change at sample {cp}")
377 """
378 data = trace.data if isinstance(trace, WaveformTrace) else np.array(trace, dtype=np.float64)
380 n = len(data)
382 if n < 2 * min_segment_size:
383 return []
385 # Auto-select penalty if not provided
386 if penalty is None:
387 penalty = np.var(data) * 2
389 # Simple binary segmentation using mean-shift cost
390 change_points = []
391 segments = [(0, n)]
393 while segments:
394 start, end = segments.pop(0)
395 segment = data[start:end]
396 seg_len = len(segment)
398 if seg_len < 2 * min_segment_size: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true
399 continue
401 # Find best split point
402 best_cost_reduction = -np.inf
403 best_split = None
405 for split in range(min_segment_size, seg_len - min_segment_size):
406 left = segment[:split]
407 right = segment[split:]
409 # Cost = sum of squared deviations from segment mean
410 cost_whole = np.sum((segment - np.mean(segment)) ** 2)
411 cost_left = np.sum((left - np.mean(left)) ** 2)
412 cost_right = np.sum((right - np.mean(right)) ** 2)
414 cost_reduction = cost_whole - (cost_left + cost_right) - penalty
416 if cost_reduction > best_cost_reduction:
417 best_cost_reduction = cost_reduction
418 best_split = split
420 # If significant cost reduction, add change point
421 if best_split is not None and best_cost_reduction > 0:
422 cp = start + best_split
423 change_points.append(cp)
425 # Add new segments to process
426 segments.append((start, cp))
427 segments.append((cp, end))
429 change_points.sort()
430 return change_points
433def piecewise_linear_fit(
434 trace: WaveformTrace | NDArray[np.floating[Any]],
435 *,
436 n_segments: int = 3,
437 sample_rate: float | None = None,
438) -> dict: # type: ignore[type-arg]
439 """Fit piecewise linear model to signal.
441 Divides signal into segments and fits linear trends to each.
443 Args:
444 trace: Input trace or numpy array.
445 n_segments: Number of segments to fit.
446 sample_rate: Sample rate in Hz (required for array input).
448 Returns:
449 Dictionary with fit results:
450 - breakpoints: Sample indices of segment boundaries
451 - segments: List of (slope, intercept) for each segment
452 - fitted: Full fitted signal
453 - residuals: Fitting residuals
455 Raises:
456 ValueError: If trace is array and sample_rate is not provided.
458 Example:
459 >>> result = piecewise_linear_fit(trace, n_segments=4)
460 >>> print(f"Breakpoints: {result['breakpoints']}")
461 """
462 if isinstance(trace, WaveformTrace):
463 data = trace.data
464 fs = trace.metadata.sample_rate
465 else:
466 data = np.array(trace, dtype=np.float64)
467 if sample_rate is None:
468 raise ValueError("sample_rate required when trace is array")
469 fs = sample_rate
471 n = len(data)
473 # Calculate segment boundaries
474 segment_size = n // n_segments
475 breakpoints = [i * segment_size for i in range(1, n_segments)]
476 breakpoints = [0, *breakpoints, n]
478 # Fit each segment
479 segments = []
480 fitted = np.zeros(n, dtype=np.float64)
482 for i in range(len(breakpoints) - 1):
483 start = breakpoints[i]
484 end = breakpoints[i + 1]
486 segment_data = data[start:end]
487 t = np.arange(len(segment_data)) / fs
489 if len(t) >= 2: 489 ↛ 482line 489 didn't jump to line 482 because the condition on line 489 was always true
490 slope, intercept = np.polyfit(t, segment_data, 1)
491 fitted[start:end] = intercept + slope * t
492 segments.append(
493 {
494 "slope": float(slope),
495 "intercept": float(intercept),
496 "start": start,
497 "end": end,
498 }
499 )
501 residuals = data - fitted
503 return {
504 "breakpoints": breakpoints,
505 "segments": segments,
506 "fitted": fitted,
507 "residuals": residuals,
508 "rmse": float(np.sqrt(np.mean(residuals**2))),
509 }
512__all__ = [
513 "TrendResult",
514 "change_point_detection",
515 "detect_drift_segments",
516 "detect_trend",
517 "detrend",
518 "moving_average",
519 "piecewise_linear_fit",
520]