Coverage for src / tracekit / analyzers / statistics / trend.py: 99%

164 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Trend detection and analysis for signal data. 

2 

3This module provides linear trend detection, drift analysis, and 

4detrending functions for identifying systematic changes in signals. 

5 

6 

7Example: 

8 >>> from tracekit.analyzers.statistics.trend import ( 

9 ... detect_trend, detrend, moving_average 

10 ... ) 

11 >>> result = detect_trend(trace) 

12 >>> print(f"Slope: {result['slope']:.2e} V/s") 

13 >>> detrended = detrend(trace) 

14 

15References: 

16 Montgomery, D. C. (2012). Introduction to Statistical Quality Control 

17 NIST Engineering Statistics Handbook 

18""" 

19 

20from __future__ import annotations 

21 

22from dataclasses import dataclass 

23from typing import TYPE_CHECKING, Any, Literal 

24 

25import numpy as np 

26from scipy import stats 

27 

28from tracekit.core.types import WaveformTrace 

29 

30if TYPE_CHECKING: 

31 from numpy.typing import NDArray 

32 

33 

34@dataclass 

35class TrendResult: 

36 """Result of trend analysis. 

37 

38 Attributes: 

39 slope: Trend slope (units per second). 

40 intercept: Trend intercept (at t=0). 

41 r_squared: Coefficient of determination. 

42 p_value: Statistical significance (p < 0.05 is significant). 

43 std_error: Standard error of slope estimate. 

44 is_significant: Whether trend is statistically significant. 

45 trend_line: Fitted trend values at each sample. 

46 """ 

47 

48 slope: float 

49 intercept: float 

50 r_squared: float 

51 p_value: float 

52 std_error: float 

53 is_significant: bool 

54 trend_line: NDArray[np.float64] 

55 

56 

57def detect_trend( 

58 trace: WaveformTrace | NDArray[np.floating[Any]], 

59 *, 

60 significance_level: float = 0.05, 

61 sample_rate: float | None = None, 

62) -> TrendResult: 

63 """Detect linear trend in signal data. 

64 

65 Fits a linear regression and tests for statistical significance. 

66 Reports slope, R-squared, and whether drift is significant. 

67 

68 Args: 

69 trace: Input trace or numpy array. 

70 significance_level: P-value threshold for significance (default 0.05). 

71 sample_rate: Sample rate in Hz (required for array input). 

72 

73 Returns: 

74 TrendResult with trend analysis. 

75 

76 Raises: 

77 ValueError: If trace is array and sample_rate is not provided. 

78 

79 Example: 

80 >>> result = detect_trend(trace) 

81 >>> if result.is_significant: 

82 ... print(f"Significant drift: {result.slope:.2e} V/s") 

83 ... print(f"R-squared: {result.r_squared:.4f}") 

84 

85 References: 

86 NIST Engineering Statistics Handbook Section 6.6 

87 """ 

88 if isinstance(trace, WaveformTrace): 

89 data = trace.data 

90 fs = trace.metadata.sample_rate 

91 else: 

92 data = trace 

93 if sample_rate is None: 

94 raise ValueError("sample_rate required when trace is array") 

95 fs = sample_rate 

96 

97 n = len(data) 

98 

99 if n < 3: 

100 return TrendResult( 

101 slope=np.nan, 

102 intercept=np.nan, 

103 r_squared=np.nan, 

104 p_value=np.nan, 

105 std_error=np.nan, 

106 is_significant=False, 

107 trend_line=np.full(n, np.nan, dtype=np.float64), 

108 ) 

109 

110 # Time axis in seconds 

111 t = np.arange(n) / fs 

112 

113 # Linear regression 

114 result = stats.linregress(t, data) 

115 

116 slope = float(result.slope) 

117 intercept = float(result.intercept) 

118 r_squared = float(result.rvalue**2) 

119 p_value = float(result.pvalue) 

120 std_error = float(result.stderr) 

121 is_significant = p_value < significance_level 

122 

123 # Compute trend line 

124 trend_line = intercept + slope * t 

125 

126 return TrendResult( 

127 slope=slope, 

128 intercept=intercept, 

129 r_squared=r_squared, 

130 p_value=p_value, 

131 std_error=std_error, 

132 is_significant=is_significant, 

133 trend_line=trend_line.astype(np.float64), 

134 ) 

135 

136 

137def detrend( 

138 trace: WaveformTrace | NDArray[np.floating[Any]], 

139 *, 

140 method: Literal["linear", "constant", "polynomial"] = "linear", 

141 order: int = 1, 

142 return_trend: bool = False, 

143 sample_rate: float | None = None, 

144) -> NDArray[np.float64] | tuple[NDArray[np.float64], NDArray[np.float64]]: 

145 """Remove trend from signal data. 

146 

147 Subtracts fitted trend to isolate fluctuations around baseline. 

148 

149 Args: 

150 trace: Input trace or numpy array. 

151 method: Detrending method: 

152 - "constant": Remove mean (DC offset) 

153 - "linear": Remove linear trend (default) 

154 - "polynomial": Remove polynomial trend 

155 order: Polynomial order (for method="polynomial"). 

156 return_trend: If True, also return the removed trend. 

157 sample_rate: Sample rate in Hz (required for array input, only for linear). 

158 

159 Returns: 

160 Detrended data array. 

161 If return_trend=True, returns (detrended, trend). 

162 

163 Raises: 

164 ValueError: If method is not recognized. 

165 

166 Example: 

167 >>> detrended = detrend(trace, method="linear") 

168 >>> # Or get the trend too 

169 >>> detrended, trend = detrend(trace, return_trend=True) 

170 """ 

171 if isinstance(trace, WaveformTrace): 

172 data = trace.data.astype(np.float64) 

173 fs = trace.metadata.sample_rate 

174 else: 

175 data = np.array(trace, dtype=np.float64) 

176 fs = sample_rate if sample_rate else 1.0 

177 

178 n = len(data) 

179 

180 if method == "constant": 

181 trend = np.full(n, np.mean(data), dtype=np.float64) 

182 

183 elif method == "linear": 

184 result = detect_trend(trace, sample_rate=fs) 

185 trend = result.trend_line 

186 

187 elif method == "polynomial": 

188 t = np.arange(n) 

189 coeffs = np.polyfit(t, data, order) 

190 trend = np.polyval(coeffs, t) 

191 

192 else: 

193 raise ValueError(f"Unknown method: {method}") 

194 

195 detrended = data - trend 

196 

197 if return_trend: 

198 return detrended, trend.astype(np.float64) 

199 return detrended 

200 

201 

202def moving_average( 

203 trace: WaveformTrace | NDArray[np.floating[Any]], 

204 *, 

205 window_size: int, 

206 method: Literal["simple", "exponential", "weighted"] = "simple", 

207 alpha: float = 0.1, 

208) -> NDArray[np.float64]: 

209 """Compute moving average of signal. 

210 

211 Smooths signal by averaging over sliding window. 

212 

213 Args: 

214 trace: Input trace or numpy array. 

215 window_size: Size of averaging window in samples. 

216 method: Averaging method: 

217 - "simple": Simple moving average (default) 

218 - "exponential": Exponential moving average 

219 - "weighted": Linearly weighted moving average 

220 alpha: Smoothing factor for exponential method (0-1). 

221 

222 Returns: 

223 Smoothed signal array (same length as input). 

224 

225 Raises: 

226 ValueError: If method is not recognized. 

227 

228 Example: 

229 >>> smoothed = moving_average(trace, window_size=10) 

230 >>> # Exponential smoothing 

231 >>> ema = moving_average(trace, window_size=10, method="exponential", alpha=0.2) 

232 """ 

233 if isinstance(trace, WaveformTrace): 

234 data = trace.data.astype(np.float64) 

235 else: 

236 data = np.array(trace, dtype=np.float64) 

237 

238 n = len(data) 

239 

240 window_size = min(window_size, n) 

241 

242 if window_size < 1: 

243 return data.copy() 

244 

245 if method == "simple": 

246 # Simple moving average using convolution 

247 kernel = np.ones(window_size) / window_size 

248 # Pad for same output length 

249 padded = np.pad(data, (window_size - 1, 0), mode="edge") 

250 result = np.convolve(padded, kernel, mode="valid") 

251 

252 elif method == "exponential": 

253 # Exponential moving average 

254 result = np.zeros(n, dtype=np.float64) 

255 result[0] = data[0] 

256 for i in range(1, n): 

257 result[i] = alpha * data[i] + (1 - alpha) * result[i - 1] 

258 

259 elif method == "weighted": 

260 # Linearly weighted moving average 

261 weights = np.arange(1, window_size + 1, dtype=np.float64) 

262 weights = weights / np.sum(weights) 

263 

264 padded = np.pad(data, (window_size - 1, 0), mode="edge") 

265 result = np.convolve(padded, weights, mode="valid") 

266 

267 else: 

268 raise ValueError(f"Unknown method: {method}") 

269 

270 return result.astype(np.float64) 

271 

272 

273def detect_drift_segments( 

274 trace: WaveformTrace | NDArray[np.floating[Any]], 

275 *, 

276 segment_size: int, 

277 threshold_slope: float | None = None, 

278 sample_rate: float | None = None, 

279) -> list[dict]: # type: ignore[type-arg] 

280 """Detect segments with significant drift. 

281 

282 Divides signal into segments and identifies those with 

283 statistically significant linear trends. 

284 

285 Args: 

286 trace: Input trace or numpy array. 

287 segment_size: Size of each segment in samples. 

288 threshold_slope: Minimum slope magnitude to flag (units/second). 

289 If None, uses statistical significance. 

290 sample_rate: Sample rate in Hz (required for array input). 

291 

292 Returns: 

293 List of dictionaries describing drift segments: 

294 - start_sample: Start index of segment 

295 - end_sample: End index of segment 

296 - start_time: Start time in seconds 

297 - end_time: End time in seconds 

298 - slope: Trend slope 

299 - r_squared: Coefficient of determination 

300 

301 Raises: 

302 ValueError: If trace is array and sample_rate is not provided. 

303 

304 Example: 

305 >>> segments = detect_drift_segments(trace, segment_size=1000) 

306 >>> for seg in segments: 

307 ... print(f"Drift at {seg['start_time']:.3f}s: {seg['slope']:.2e} V/s") 

308 """ 

309 if isinstance(trace, WaveformTrace): 

310 data = trace.data 

311 fs = trace.metadata.sample_rate 

312 else: 

313 data = trace 

314 if sample_rate is None: 

315 raise ValueError("sample_rate required when trace is array") 

316 fs = sample_rate 

317 

318 n = len(data) 

319 drift_segments = [] 

320 

321 for start in range(0, n, segment_size): 

322 end = min(start + segment_size, n) 

323 

324 if end - start < 10: # Need minimum points for regression 

325 continue 

326 

327 segment_data = data[start:end] 

328 segment_trace = segment_data # Array 

329 

330 result = detect_trend(segment_trace, sample_rate=fs) 

331 

332 # Check if drift is significant 

333 is_drift = result.is_significant 

334 if threshold_slope is not None: 

335 is_drift = is_drift and abs(result.slope) >= threshold_slope 

336 

337 if is_drift: 

338 drift_segments.append( 

339 { 

340 "start_sample": start, 

341 "end_sample": end, 

342 "start_time": start / fs, 

343 "end_time": end / fs, 

344 "slope": result.slope, 

345 "r_squared": result.r_squared, 

346 "p_value": result.p_value, 

347 } 

348 ) 

349 

350 return drift_segments 

351 

352 

353def change_point_detection( 

354 trace: WaveformTrace | NDArray[np.floating[Any]], 

355 *, 

356 min_segment_size: int = 10, 

357 penalty: float | None = None, 

358) -> list[int]: 

359 """Detect change points in signal level or trend. 

360 

361 Identifies locations where the signal characteristics change 

362 significantly, using a simple CUSUM-based approach. 

363 

364 Args: 

365 trace: Input trace or numpy array. 

366 min_segment_size: Minimum samples between change points. 

367 penalty: Penalty for adding change points (controls sensitivity). 

368 If None, auto-selected based on signal variance. 

369 

370 Returns: 

371 List of sample indices where changes occur. 

372 

373 Example: 

374 >>> change_points = change_point_detection(trace) 

375 >>> for cp in change_points: 

376 ... print(f"Change at sample {cp}") 

377 """ 

378 data = trace.data if isinstance(trace, WaveformTrace) else np.array(trace, dtype=np.float64) 

379 

380 n = len(data) 

381 

382 if n < 2 * min_segment_size: 

383 return [] 

384 

385 # Auto-select penalty if not provided 

386 if penalty is None: 

387 penalty = np.var(data) * 2 

388 

389 # Simple binary segmentation using mean-shift cost 

390 change_points = [] 

391 segments = [(0, n)] 

392 

393 while segments: 

394 start, end = segments.pop(0) 

395 segment = data[start:end] 

396 seg_len = len(segment) 

397 

398 if seg_len < 2 * min_segment_size: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true

399 continue 

400 

401 # Find best split point 

402 best_cost_reduction = -np.inf 

403 best_split = None 

404 

405 for split in range(min_segment_size, seg_len - min_segment_size): 

406 left = segment[:split] 

407 right = segment[split:] 

408 

409 # Cost = sum of squared deviations from segment mean 

410 cost_whole = np.sum((segment - np.mean(segment)) ** 2) 

411 cost_left = np.sum((left - np.mean(left)) ** 2) 

412 cost_right = np.sum((right - np.mean(right)) ** 2) 

413 

414 cost_reduction = cost_whole - (cost_left + cost_right) - penalty 

415 

416 if cost_reduction > best_cost_reduction: 

417 best_cost_reduction = cost_reduction 

418 best_split = split 

419 

420 # If significant cost reduction, add change point 

421 if best_split is not None and best_cost_reduction > 0: 

422 cp = start + best_split 

423 change_points.append(cp) 

424 

425 # Add new segments to process 

426 segments.append((start, cp)) 

427 segments.append((cp, end)) 

428 

429 change_points.sort() 

430 return change_points 

431 

432 

433def piecewise_linear_fit( 

434 trace: WaveformTrace | NDArray[np.floating[Any]], 

435 *, 

436 n_segments: int = 3, 

437 sample_rate: float | None = None, 

438) -> dict: # type: ignore[type-arg] 

439 """Fit piecewise linear model to signal. 

440 

441 Divides signal into segments and fits linear trends to each. 

442 

443 Args: 

444 trace: Input trace or numpy array. 

445 n_segments: Number of segments to fit. 

446 sample_rate: Sample rate in Hz (required for array input). 

447 

448 Returns: 

449 Dictionary with fit results: 

450 - breakpoints: Sample indices of segment boundaries 

451 - segments: List of (slope, intercept) for each segment 

452 - fitted: Full fitted signal 

453 - residuals: Fitting residuals 

454 

455 Raises: 

456 ValueError: If trace is array and sample_rate is not provided. 

457 

458 Example: 

459 >>> result = piecewise_linear_fit(trace, n_segments=4) 

460 >>> print(f"Breakpoints: {result['breakpoints']}") 

461 """ 

462 if isinstance(trace, WaveformTrace): 

463 data = trace.data 

464 fs = trace.metadata.sample_rate 

465 else: 

466 data = np.array(trace, dtype=np.float64) 

467 if sample_rate is None: 

468 raise ValueError("sample_rate required when trace is array") 

469 fs = sample_rate 

470 

471 n = len(data) 

472 

473 # Calculate segment boundaries 

474 segment_size = n // n_segments 

475 breakpoints = [i * segment_size for i in range(1, n_segments)] 

476 breakpoints = [0, *breakpoints, n] 

477 

478 # Fit each segment 

479 segments = [] 

480 fitted = np.zeros(n, dtype=np.float64) 

481 

482 for i in range(len(breakpoints) - 1): 

483 start = breakpoints[i] 

484 end = breakpoints[i + 1] 

485 

486 segment_data = data[start:end] 

487 t = np.arange(len(segment_data)) / fs 

488 

489 if len(t) >= 2: 489 ↛ 482line 489 didn't jump to line 482 because the condition on line 489 was always true

490 slope, intercept = np.polyfit(t, segment_data, 1) 

491 fitted[start:end] = intercept + slope * t 

492 segments.append( 

493 { 

494 "slope": float(slope), 

495 "intercept": float(intercept), 

496 "start": start, 

497 "end": end, 

498 } 

499 ) 

500 

501 residuals = data - fitted 

502 

503 return { 

504 "breakpoints": breakpoints, 

505 "segments": segments, 

506 "fitted": fitted, 

507 "residuals": residuals, 

508 "rmse": float(np.sqrt(np.mean(residuals**2))), 

509 } 

510 

511 

512__all__ = [ 

513 "TrendResult", 

514 "change_point_detection", 

515 "detect_drift_segments", 

516 "detect_trend", 

517 "detrend", 

518 "moving_average", 

519 "piecewise_linear_fit", 

520]