Coverage for src / tracekit / optimization / search.py: 99%

157 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Parameter optimization via grid search and random search. 

2 

3This module provides tools for finding optimal analysis parameters through 

4systematic or random search of the parameter space. 

5""" 

6 

7from __future__ import annotations 

8 

9import itertools 

10from collections.abc import Callable 

11from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed 

12from dataclasses import dataclass 

13from typing import TYPE_CHECKING, Any, Literal 

14 

15import numpy as np 

16import pandas as pd 

17 

18from tracekit.analyzers.waveform.spectral import thd as compute_thd 

19from tracekit.core.exceptions import AnalysisError 

20 

21if TYPE_CHECKING: 

22 from numpy.typing import NDArray 

23 

24 from tracekit.core.types import WaveformTrace 

25 

26 ScoringFunction = Callable[[WaveformTrace, dict[str, Any]], float] 

27else: 

28 ScoringFunction = Callable 

29 

30 

31@dataclass 

32class SearchResult: 

33 """Result from parameter search. 

34 

35 Attributes: 

36 best_params: Dictionary of best parameters found. 

37 best_score: Best score achieved. 

38 all_results: DataFrame with all parameter combinations and scores. 

39 cv_scores: Cross-validation scores if CV was used. 

40 

41 Example: 

42 >>> result = search.fit(traces) 

43 >>> print(f"Best params: {result.best_params}") 

44 >>> print(f"Best score: {result.best_score}") 

45 

46 References: 

47 API-014: Parameter Grid Search 

48 """ 

49 

50 best_params: dict[str, Any] 

51 best_score: float 

52 all_results: pd.DataFrame 

53 cv_scores: NDArray[np.float64] | None = None 

54 

55 

56def _default_snr_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float: 

57 """Default SNR scoring function. 

58 

59 Args: 

60 trace: Waveform trace to analyze. 

61 params: Parameters to apply (not used in basic SNR). 

62 

63 Returns: 

64 Signal-to-noise ratio in dB. 

65 """ 

66 # Simple SNR: signal power / noise power 

67 # Assume first half is signal, second half is noise (oversimplified) 

68 data = trace.data 

69 mid = len(data) // 2 

70 signal_power = np.mean(data[:mid] ** 2) 

71 noise_power = np.mean((data[mid:] - np.mean(data[mid:])) ** 2) 

72 

73 if noise_power == 0: 

74 return float("inf") 

75 

76 snr = signal_power / noise_power 

77 return float(10 * np.log10(snr)) 

78 

79 

80def _default_thd_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float: 

81 """Default THD scoring function. 

82 

83 Args: 

84 trace: Waveform trace to analyze. 

85 params: Parameters to apply (not used in basic THD). 

86 

87 Returns: 

88 Negative THD percentage (negative because lower THD is better, but we maximize scores). 

89 """ 

90 # Compute THD and return negative value (lower THD = better = higher score) 

91 thd_value = compute_thd(trace) 

92 return float(-thd_value) 

93 

94 

95class GridSearchCV: 

96 """Grid search over parameter space with optional cross-validation. 

97 

98 Systematically evaluates all combinations of parameters to find the 

99 optimal configuration. 

100 

101 Example: 

102 >>> from tracekit.optimization.search import GridSearchCV 

103 >>> param_grid = { 

104 ... 'cutoff': [1e5, 5e5, 1e6], 

105 ... 'order': [2, 4, 6] 

106 ... } 

107 >>> search = GridSearchCV( 

108 ... param_grid=param_grid, 

109 ... scoring='snr', 

110 ... cv=3 

111 ... ) 

112 >>> result = search.fit(traces, apply_filter) 

113 >>> print(result.best_params) 

114 

115 References: 

116 API-014: Parameter Grid Search 

117 """ 

118 

119 def __init__( 

120 self, 

121 param_grid: dict[str, list[Any]], 

122 scoring: Literal["snr", "thd"] | ScoringFunction = "snr", 

123 cv: int | None = None, 

124 *, 

125 parallel: bool = True, 

126 max_workers: int | None = None, 

127 use_threads: bool = True, 

128 ) -> None: 

129 """Initialize grid search. 

130 

131 Args: 

132 param_grid: Dictionary mapping parameter names to lists of values. 

133 scoring: Scoring function. Built-in: 'snr', 'thd', or custom callable. 

134 cv: Number of cross-validation folds. None for no CV. 

135 parallel: Enable parallel evaluation. 

136 max_workers: Maximum parallel workers. 

137 use_threads: Use threads instead of processes. 

138 

139 Raises: 

140 AnalysisError: If scoring function is invalid. 

141 

142 Example: 

143 >>> param_grid = {'cutoff': [1e6, 2e6], 'order': [4, 6]} 

144 >>> search = GridSearchCV(param_grid, scoring='snr', cv=3) 

145 """ 

146 self.param_grid = param_grid 

147 self.cv = cv 

148 self.parallel = parallel 

149 self.max_workers = max_workers 

150 self.use_threads = use_threads 

151 

152 # Set scoring function 

153 if scoring == "snr": 

154 self.scoring_fn = _default_snr_scorer 

155 elif scoring == "thd": 

156 self.scoring_fn = _default_thd_scorer 

157 elif callable(scoring): 

158 self.scoring_fn = scoring # type: ignore[assignment] 

159 else: 

160 raise AnalysisError(f"Unknown scoring function: {scoring}") 

161 

162 self.best_params_: dict[str, Any] | None = None 

163 self.best_score_: float | None = None 

164 self.results_df_: pd.DataFrame | None = None 

165 

166 def fit( 

167 self, 

168 traces: list[WaveformTrace] | WaveformTrace, 

169 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace], 

170 ) -> SearchResult: 

171 """Fit grid search on traces. 

172 

173 Evaluates all parameter combinations and finds the best. 

174 

175 Args: 

176 traces: Trace or list of traces to evaluate on. 

177 transform_fn: Function that applies parameters to trace. 

178 Should accept (trace, **params) and return transformed trace. 

179 

180 Returns: 

181 SearchResult with best parameters and all results. 

182 

183 Example: 

184 >>> def apply_filter(trace, cutoff, order): 

185 ... return lowpass_filter(trace, cutoff=cutoff, order=order) 

186 >>> result = search.fit(traces, apply_filter) 

187 

188 References: 

189 API-014: Parameter Grid Search 

190 """ 

191 # Convert single trace to list 

192 if not isinstance(traces, list): 

193 traces = [traces] 

194 

195 # Generate all parameter combinations 

196 param_combinations = self._generate_combinations() 

197 

198 # Evaluate each combination 

199 results = self._evaluate_combinations(param_combinations, traces, transform_fn) 

200 

201 # Convert to DataFrame 

202 self.results_df_ = pd.DataFrame(results) 

203 

204 # Find best 

205 best_idx = self.results_df_["mean_score"].idxmax() 

206 best_row = self.results_df_.iloc[best_idx] 

207 

208 self.best_params_ = {k: best_row[k] for k in self.param_grid} 

209 self.best_score_ = float(best_row["mean_score"]) 

210 

211 # Collect CV scores if available 

212 cv_scores = None 

213 if self.cv: 

214 cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")] 

215 if cv_cols: 215 ↛ 218line 215 didn't jump to line 218 because the condition on line 215 was always true

216 cv_scores = self.results_df_.loc[best_idx, cv_cols].values 

217 

218 return SearchResult( 

219 best_params=self.best_params_, 

220 best_score=self.best_score_, 

221 all_results=self.results_df_, 

222 cv_scores=cv_scores, 

223 ) 

224 

225 def _generate_combinations(self) -> list[dict[str, Any]]: 

226 """Generate all parameter combinations from grid. 

227 

228 Returns: 

229 List of parameter dictionaries. 

230 """ 

231 keys = list(self.param_grid.keys()) 

232 values = [self.param_grid[k] for k in keys] 

233 

234 combinations = [] 

235 for combo in itertools.product(*values): 

236 combinations.append(dict(zip(keys, combo, strict=False))) 

237 

238 return combinations 

239 

240 def _evaluate_combinations( 

241 self, 

242 param_combinations: list[dict[str, Any]], 

243 traces: list[WaveformTrace], 

244 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace], 

245 ) -> list[dict[str, Any]]: 

246 """Evaluate all parameter combinations. 

247 

248 Args: 

249 param_combinations: List of parameter dicts to evaluate. 

250 traces: Traces to evaluate on. 

251 transform_fn: Transformation function. 

252 

253 Returns: 

254 List of result dictionaries. 

255 """ 

256 if self.parallel: 

257 return self._evaluate_parallel(param_combinations, traces, transform_fn) 

258 else: 

259 return self._evaluate_sequential(param_combinations, traces, transform_fn) 

260 

261 def _evaluate_one( 

262 self, 

263 params: dict[str, Any], 

264 traces: list[WaveformTrace], 

265 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace], 

266 ) -> dict[str, Any]: 

267 """Evaluate one parameter combination. 

268 

269 Args: 

270 params: Parameters to evaluate. 

271 traces: Traces to evaluate on. 

272 transform_fn: Transformation function. 

273 

274 Returns: 

275 Result dictionary with scores. 

276 """ 

277 scores: list[float] = [] 

278 

279 if self.cv: 

280 # Cross-validation - split traces into folds 

281 fold_size = len(traces) // self.cv 

282 for i in range(self.cv): 

283 # Select fold 

284 start = i * fold_size 

285 end = start + fold_size if i < self.cv - 1 else len(traces) 

286 fold_traces = traces[start:end] 

287 

288 # Evaluate on fold 

289 fold_scores = [] 

290 for trace in fold_traces: 

291 transformed = transform_fn(trace, **params) # type: ignore[call-arg] 

292 score = self.scoring_fn(transformed, params) 

293 fold_scores.append(score) 

294 

295 scores.append(float(np.mean(fold_scores))) 

296 

297 else: 

298 # No CV - evaluate on all traces 

299 for trace in traces: 

300 transformed = transform_fn(trace, **params) # type: ignore[call-arg] 

301 score = self.scoring_fn(transformed, params) 

302 scores.append(score) 

303 

304 # Build result 

305 result = params.copy() 

306 result["mean_score"] = float(np.mean(scores)) 

307 result["std_score"] = float(np.std(scores)) 

308 

309 if self.cv: 

310 for i, score in enumerate(scores): 

311 result[f"cv_{i}"] = float(score) 

312 

313 return result 

314 

315 def _evaluate_sequential( 

316 self, 

317 param_combinations: list[dict[str, Any]], 

318 traces: list[WaveformTrace], 

319 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace], 

320 ) -> list[dict[str, Any]]: 

321 """Evaluate combinations sequentially. 

322 

323 Args: 

324 param_combinations: Parameter combinations. 

325 traces: Traces to evaluate on. 

326 transform_fn: Transformation function. 

327 

328 Returns: 

329 List of results. 

330 """ 

331 results = [] 

332 for params in param_combinations: 

333 result = self._evaluate_one(params, traces, transform_fn) 

334 results.append(result) 

335 return results 

336 

337 def _evaluate_parallel( 

338 self, 

339 param_combinations: list[dict[str, Any]], 

340 traces: list[WaveformTrace], 

341 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace], 

342 ) -> list[dict[str, Any]]: 

343 """Evaluate combinations in parallel. 

344 

345 Args: 

346 param_combinations: Parameter combinations. 

347 traces: Traces to evaluate on. 

348 transform_fn: Transformation function. 

349 

350 Returns: 

351 List of results. 

352 """ 

353 executor_class = ThreadPoolExecutor if self.use_threads else ProcessPoolExecutor 

354 

355 with executor_class(max_workers=self.max_workers) as executor: 

356 futures = { 

357 executor.submit(self._evaluate_one, params, traces, transform_fn): params 

358 for params in param_combinations 

359 } 

360 

361 results = [] 

362 for future in as_completed(futures): 

363 result = future.result() 

364 results.append(result) 

365 

366 return results 

367 

368 

369class RandomizedSearchCV: 

370 """Random search over parameter distributions. 

371 

372 Samples random combinations from parameter distributions rather than 

373 exhaustively evaluating all combinations. 

374 

375 Example: 

376 >>> from tracekit.optimization.search import RandomizedSearchCV 

377 >>> import numpy as np 

378 >>> param_distributions = { 

379 ... 'cutoff': lambda: np.random.uniform(1e5, 1e7), 

380 ... 'order': lambda: np.random.choice([2, 4, 6, 8]) 

381 ... } 

382 >>> search = RandomizedSearchCV( 

383 ... param_distributions=param_distributions, 

384 ... n_iter=20, 

385 ... scoring='snr' 

386 ... ) 

387 >>> result = search.fit(traces, apply_filter) 

388 

389 References: 

390 API-014: Parameter Grid Search 

391 """ 

392 

393 def __init__( 

394 self, 

395 param_distributions: dict[str, Callable[[], Any]], 

396 n_iter: int = 10, 

397 scoring: Literal["snr", "thd"] | ScoringFunction = "snr", 

398 cv: int | None = None, 

399 *, 

400 parallel: bool = True, 

401 max_workers: int | None = None, 

402 use_threads: bool = True, 

403 random_state: int | None = None, 

404 ) -> None: 

405 """Initialize randomized search. 

406 

407 Args: 

408 param_distributions: Dict mapping parameter names to sampling functions. 

409 n_iter: Number of parameter combinations to sample. 

410 scoring: Scoring function. 

411 cv: Number of cross-validation folds. 

412 parallel: Enable parallel evaluation. 

413 max_workers: Maximum parallel workers. 

414 use_threads: Use threads instead of processes. 

415 random_state: Random seed for reproducibility. 

416 

417 Raises: 

418 AnalysisError: If scoring function is invalid. 

419 

420 Example: 

421 >>> param_dist = {'cutoff': lambda: np.random.uniform(1e5, 1e7)} 

422 >>> search = RandomizedSearchCV(param_dist, n_iter=50) 

423 """ 

424 self.param_distributions = param_distributions 

425 self.n_iter = n_iter 

426 self.cv = cv 

427 self.parallel = parallel 

428 self.max_workers = max_workers 

429 self.use_threads = use_threads 

430 

431 if random_state is not None: 

432 np.random.seed(random_state) 

433 

434 # Set scoring function 

435 if scoring == "snr": 

436 self.scoring_fn = _default_snr_scorer 

437 elif scoring == "thd": 

438 self.scoring_fn = _default_thd_scorer 

439 elif callable(scoring): 

440 self.scoring_fn = scoring # type: ignore[assignment] 

441 else: 

442 raise AnalysisError(f"Unknown scoring function: {scoring}") 

443 

444 self.best_params_: dict[str, Any] | None = None 

445 self.best_score_: float | None = None 

446 self.results_df_: pd.DataFrame | None = None 

447 

448 def fit( 

449 self, 

450 traces: list[WaveformTrace] | WaveformTrace, 

451 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace], 

452 ) -> SearchResult: 

453 """Fit randomized search on traces. 

454 

455 Args: 

456 traces: Trace or list of traces to evaluate on. 

457 transform_fn: Function that applies parameters to trace. 

458 

459 Returns: 

460 SearchResult with best parameters. 

461 

462 Example: 

463 >>> result = search.fit(traces, apply_filter) 

464 >>> print(f"Best cutoff: {result.best_params['cutoff']:.2e}") 

465 

466 References: 

467 API-014: Parameter Grid Search 

468 """ 

469 # Convert single trace to list 

470 if not isinstance(traces, list): 

471 traces = [traces] 

472 

473 # Sample parameter combinations 

474 param_combinations = self._sample_combinations() 

475 

476 # Reuse grid search evaluation logic 

477 grid_search = GridSearchCV( 

478 param_grid={}, # Not used 

479 scoring=self.scoring_fn, 

480 cv=self.cv, 

481 parallel=self.parallel, 

482 max_workers=self.max_workers, 

483 use_threads=self.use_threads, 

484 ) 

485 

486 results = grid_search._evaluate_combinations(param_combinations, traces, transform_fn) 

487 

488 # Convert to DataFrame 

489 self.results_df_ = pd.DataFrame(results) 

490 

491 # Find best 

492 best_idx = self.results_df_["mean_score"].idxmax() 

493 best_row = self.results_df_.iloc[best_idx] 

494 

495 self.best_params_ = {k: best_row[k] for k in self.param_distributions} 

496 self.best_score_ = float(best_row["mean_score"]) 

497 

498 # Collect CV scores if available 

499 cv_scores = None 

500 if self.cv: 

501 cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")] 

502 if cv_cols: 502 ↛ 505line 502 didn't jump to line 505 because the condition on line 502 was always true

503 cv_scores = self.results_df_.loc[best_idx, cv_cols].values 

504 

505 return SearchResult( 

506 best_params=self.best_params_, 

507 best_score=self.best_score_, 

508 all_results=self.results_df_, 

509 cv_scores=cv_scores, 

510 ) 

511 

512 def _sample_combinations(self) -> list[dict[str, Any]]: 

513 """Sample random parameter combinations. 

514 

515 Returns: 

516 List of sampled parameter dictionaries. 

517 """ 

518 combinations = [] 

519 

520 for _ in range(self.n_iter): 

521 params = {key: sampler() for key, sampler in self.param_distributions.items()} 

522 combinations.append(params) 

523 

524 return combinations 

525 

526 

527__all__ = [ 

528 "GridSearchCV", 

529 "RandomizedSearchCV", 

530 "ScoringFunction", 

531 "SearchResult", 

532]