Coverage for src / tracekit / optimization / search.py: 99%
157 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Parameter optimization via grid search and random search.
3This module provides tools for finding optimal analysis parameters through
4systematic or random search of the parameter space.
5"""
7from __future__ import annotations
9import itertools
10from collections.abc import Callable
11from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
12from dataclasses import dataclass
13from typing import TYPE_CHECKING, Any, Literal
15import numpy as np
16import pandas as pd
18from tracekit.analyzers.waveform.spectral import thd as compute_thd
19from tracekit.core.exceptions import AnalysisError
21if TYPE_CHECKING:
22 from numpy.typing import NDArray
24 from tracekit.core.types import WaveformTrace
26 ScoringFunction = Callable[[WaveformTrace, dict[str, Any]], float]
27else:
28 ScoringFunction = Callable
31@dataclass
32class SearchResult:
33 """Result from parameter search.
35 Attributes:
36 best_params: Dictionary of best parameters found.
37 best_score: Best score achieved.
38 all_results: DataFrame with all parameter combinations and scores.
39 cv_scores: Cross-validation scores if CV was used.
41 Example:
42 >>> result = search.fit(traces)
43 >>> print(f"Best params: {result.best_params}")
44 >>> print(f"Best score: {result.best_score}")
46 References:
47 API-014: Parameter Grid Search
48 """
50 best_params: dict[str, Any]
51 best_score: float
52 all_results: pd.DataFrame
53 cv_scores: NDArray[np.float64] | None = None
56def _default_snr_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float:
57 """Default SNR scoring function.
59 Args:
60 trace: Waveform trace to analyze.
61 params: Parameters to apply (not used in basic SNR).
63 Returns:
64 Signal-to-noise ratio in dB.
65 """
66 # Simple SNR: signal power / noise power
67 # Assume first half is signal, second half is noise (oversimplified)
68 data = trace.data
69 mid = len(data) // 2
70 signal_power = np.mean(data[:mid] ** 2)
71 noise_power = np.mean((data[mid:] - np.mean(data[mid:])) ** 2)
73 if noise_power == 0:
74 return float("inf")
76 snr = signal_power / noise_power
77 return float(10 * np.log10(snr))
80def _default_thd_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float:
81 """Default THD scoring function.
83 Args:
84 trace: Waveform trace to analyze.
85 params: Parameters to apply (not used in basic THD).
87 Returns:
88 Negative THD percentage (negative because lower THD is better, but we maximize scores).
89 """
90 # Compute THD and return negative value (lower THD = better = higher score)
91 thd_value = compute_thd(trace)
92 return float(-thd_value)
95class GridSearchCV:
96 """Grid search over parameter space with optional cross-validation.
98 Systematically evaluates all combinations of parameters to find the
99 optimal configuration.
101 Example:
102 >>> from tracekit.optimization.search import GridSearchCV
103 >>> param_grid = {
104 ... 'cutoff': [1e5, 5e5, 1e6],
105 ... 'order': [2, 4, 6]
106 ... }
107 >>> search = GridSearchCV(
108 ... param_grid=param_grid,
109 ... scoring='snr',
110 ... cv=3
111 ... )
112 >>> result = search.fit(traces, apply_filter)
113 >>> print(result.best_params)
115 References:
116 API-014: Parameter Grid Search
117 """
119 def __init__(
120 self,
121 param_grid: dict[str, list[Any]],
122 scoring: Literal["snr", "thd"] | ScoringFunction = "snr",
123 cv: int | None = None,
124 *,
125 parallel: bool = True,
126 max_workers: int | None = None,
127 use_threads: bool = True,
128 ) -> None:
129 """Initialize grid search.
131 Args:
132 param_grid: Dictionary mapping parameter names to lists of values.
133 scoring: Scoring function. Built-in: 'snr', 'thd', or custom callable.
134 cv: Number of cross-validation folds. None for no CV.
135 parallel: Enable parallel evaluation.
136 max_workers: Maximum parallel workers.
137 use_threads: Use threads instead of processes.
139 Raises:
140 AnalysisError: If scoring function is invalid.
142 Example:
143 >>> param_grid = {'cutoff': [1e6, 2e6], 'order': [4, 6]}
144 >>> search = GridSearchCV(param_grid, scoring='snr', cv=3)
145 """
146 self.param_grid = param_grid
147 self.cv = cv
148 self.parallel = parallel
149 self.max_workers = max_workers
150 self.use_threads = use_threads
152 # Set scoring function
153 if scoring == "snr":
154 self.scoring_fn = _default_snr_scorer
155 elif scoring == "thd":
156 self.scoring_fn = _default_thd_scorer
157 elif callable(scoring):
158 self.scoring_fn = scoring # type: ignore[assignment]
159 else:
160 raise AnalysisError(f"Unknown scoring function: {scoring}")
162 self.best_params_: dict[str, Any] | None = None
163 self.best_score_: float | None = None
164 self.results_df_: pd.DataFrame | None = None
166 def fit(
167 self,
168 traces: list[WaveformTrace] | WaveformTrace,
169 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
170 ) -> SearchResult:
171 """Fit grid search on traces.
173 Evaluates all parameter combinations and finds the best.
175 Args:
176 traces: Trace or list of traces to evaluate on.
177 transform_fn: Function that applies parameters to trace.
178 Should accept (trace, **params) and return transformed trace.
180 Returns:
181 SearchResult with best parameters and all results.
183 Example:
184 >>> def apply_filter(trace, cutoff, order):
185 ... return lowpass_filter(trace, cutoff=cutoff, order=order)
186 >>> result = search.fit(traces, apply_filter)
188 References:
189 API-014: Parameter Grid Search
190 """
191 # Convert single trace to list
192 if not isinstance(traces, list):
193 traces = [traces]
195 # Generate all parameter combinations
196 param_combinations = self._generate_combinations()
198 # Evaluate each combination
199 results = self._evaluate_combinations(param_combinations, traces, transform_fn)
201 # Convert to DataFrame
202 self.results_df_ = pd.DataFrame(results)
204 # Find best
205 best_idx = self.results_df_["mean_score"].idxmax()
206 best_row = self.results_df_.iloc[best_idx]
208 self.best_params_ = {k: best_row[k] for k in self.param_grid}
209 self.best_score_ = float(best_row["mean_score"])
211 # Collect CV scores if available
212 cv_scores = None
213 if self.cv:
214 cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")]
215 if cv_cols: 215 ↛ 218line 215 didn't jump to line 218 because the condition on line 215 was always true
216 cv_scores = self.results_df_.loc[best_idx, cv_cols].values
218 return SearchResult(
219 best_params=self.best_params_,
220 best_score=self.best_score_,
221 all_results=self.results_df_,
222 cv_scores=cv_scores,
223 )
225 def _generate_combinations(self) -> list[dict[str, Any]]:
226 """Generate all parameter combinations from grid.
228 Returns:
229 List of parameter dictionaries.
230 """
231 keys = list(self.param_grid.keys())
232 values = [self.param_grid[k] for k in keys]
234 combinations = []
235 for combo in itertools.product(*values):
236 combinations.append(dict(zip(keys, combo, strict=False)))
238 return combinations
240 def _evaluate_combinations(
241 self,
242 param_combinations: list[dict[str, Any]],
243 traces: list[WaveformTrace],
244 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
245 ) -> list[dict[str, Any]]:
246 """Evaluate all parameter combinations.
248 Args:
249 param_combinations: List of parameter dicts to evaluate.
250 traces: Traces to evaluate on.
251 transform_fn: Transformation function.
253 Returns:
254 List of result dictionaries.
255 """
256 if self.parallel:
257 return self._evaluate_parallel(param_combinations, traces, transform_fn)
258 else:
259 return self._evaluate_sequential(param_combinations, traces, transform_fn)
261 def _evaluate_one(
262 self,
263 params: dict[str, Any],
264 traces: list[WaveformTrace],
265 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
266 ) -> dict[str, Any]:
267 """Evaluate one parameter combination.
269 Args:
270 params: Parameters to evaluate.
271 traces: Traces to evaluate on.
272 transform_fn: Transformation function.
274 Returns:
275 Result dictionary with scores.
276 """
277 scores: list[float] = []
279 if self.cv:
280 # Cross-validation - split traces into folds
281 fold_size = len(traces) // self.cv
282 for i in range(self.cv):
283 # Select fold
284 start = i * fold_size
285 end = start + fold_size if i < self.cv - 1 else len(traces)
286 fold_traces = traces[start:end]
288 # Evaluate on fold
289 fold_scores = []
290 for trace in fold_traces:
291 transformed = transform_fn(trace, **params) # type: ignore[call-arg]
292 score = self.scoring_fn(transformed, params)
293 fold_scores.append(score)
295 scores.append(float(np.mean(fold_scores)))
297 else:
298 # No CV - evaluate on all traces
299 for trace in traces:
300 transformed = transform_fn(trace, **params) # type: ignore[call-arg]
301 score = self.scoring_fn(transformed, params)
302 scores.append(score)
304 # Build result
305 result = params.copy()
306 result["mean_score"] = float(np.mean(scores))
307 result["std_score"] = float(np.std(scores))
309 if self.cv:
310 for i, score in enumerate(scores):
311 result[f"cv_{i}"] = float(score)
313 return result
315 def _evaluate_sequential(
316 self,
317 param_combinations: list[dict[str, Any]],
318 traces: list[WaveformTrace],
319 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
320 ) -> list[dict[str, Any]]:
321 """Evaluate combinations sequentially.
323 Args:
324 param_combinations: Parameter combinations.
325 traces: Traces to evaluate on.
326 transform_fn: Transformation function.
328 Returns:
329 List of results.
330 """
331 results = []
332 for params in param_combinations:
333 result = self._evaluate_one(params, traces, transform_fn)
334 results.append(result)
335 return results
337 def _evaluate_parallel(
338 self,
339 param_combinations: list[dict[str, Any]],
340 traces: list[WaveformTrace],
341 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
342 ) -> list[dict[str, Any]]:
343 """Evaluate combinations in parallel.
345 Args:
346 param_combinations: Parameter combinations.
347 traces: Traces to evaluate on.
348 transform_fn: Transformation function.
350 Returns:
351 List of results.
352 """
353 executor_class = ThreadPoolExecutor if self.use_threads else ProcessPoolExecutor
355 with executor_class(max_workers=self.max_workers) as executor:
356 futures = {
357 executor.submit(self._evaluate_one, params, traces, transform_fn): params
358 for params in param_combinations
359 }
361 results = []
362 for future in as_completed(futures):
363 result = future.result()
364 results.append(result)
366 return results
369class RandomizedSearchCV:
370 """Random search over parameter distributions.
372 Samples random combinations from parameter distributions rather than
373 exhaustively evaluating all combinations.
375 Example:
376 >>> from tracekit.optimization.search import RandomizedSearchCV
377 >>> import numpy as np
378 >>> param_distributions = {
379 ... 'cutoff': lambda: np.random.uniform(1e5, 1e7),
380 ... 'order': lambda: np.random.choice([2, 4, 6, 8])
381 ... }
382 >>> search = RandomizedSearchCV(
383 ... param_distributions=param_distributions,
384 ... n_iter=20,
385 ... scoring='snr'
386 ... )
387 >>> result = search.fit(traces, apply_filter)
389 References:
390 API-014: Parameter Grid Search
391 """
393 def __init__(
394 self,
395 param_distributions: dict[str, Callable[[], Any]],
396 n_iter: int = 10,
397 scoring: Literal["snr", "thd"] | ScoringFunction = "snr",
398 cv: int | None = None,
399 *,
400 parallel: bool = True,
401 max_workers: int | None = None,
402 use_threads: bool = True,
403 random_state: int | None = None,
404 ) -> None:
405 """Initialize randomized search.
407 Args:
408 param_distributions: Dict mapping parameter names to sampling functions.
409 n_iter: Number of parameter combinations to sample.
410 scoring: Scoring function.
411 cv: Number of cross-validation folds.
412 parallel: Enable parallel evaluation.
413 max_workers: Maximum parallel workers.
414 use_threads: Use threads instead of processes.
415 random_state: Random seed for reproducibility.
417 Raises:
418 AnalysisError: If scoring function is invalid.
420 Example:
421 >>> param_dist = {'cutoff': lambda: np.random.uniform(1e5, 1e7)}
422 >>> search = RandomizedSearchCV(param_dist, n_iter=50)
423 """
424 self.param_distributions = param_distributions
425 self.n_iter = n_iter
426 self.cv = cv
427 self.parallel = parallel
428 self.max_workers = max_workers
429 self.use_threads = use_threads
431 if random_state is not None:
432 np.random.seed(random_state)
434 # Set scoring function
435 if scoring == "snr":
436 self.scoring_fn = _default_snr_scorer
437 elif scoring == "thd":
438 self.scoring_fn = _default_thd_scorer
439 elif callable(scoring):
440 self.scoring_fn = scoring # type: ignore[assignment]
441 else:
442 raise AnalysisError(f"Unknown scoring function: {scoring}")
444 self.best_params_: dict[str, Any] | None = None
445 self.best_score_: float | None = None
446 self.results_df_: pd.DataFrame | None = None
448 def fit(
449 self,
450 traces: list[WaveformTrace] | WaveformTrace,
451 transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
452 ) -> SearchResult:
453 """Fit randomized search on traces.
455 Args:
456 traces: Trace or list of traces to evaluate on.
457 transform_fn: Function that applies parameters to trace.
459 Returns:
460 SearchResult with best parameters.
462 Example:
463 >>> result = search.fit(traces, apply_filter)
464 >>> print(f"Best cutoff: {result.best_params['cutoff']:.2e}")
466 References:
467 API-014: Parameter Grid Search
468 """
469 # Convert single trace to list
470 if not isinstance(traces, list):
471 traces = [traces]
473 # Sample parameter combinations
474 param_combinations = self._sample_combinations()
476 # Reuse grid search evaluation logic
477 grid_search = GridSearchCV(
478 param_grid={}, # Not used
479 scoring=self.scoring_fn,
480 cv=self.cv,
481 parallel=self.parallel,
482 max_workers=self.max_workers,
483 use_threads=self.use_threads,
484 )
486 results = grid_search._evaluate_combinations(param_combinations, traces, transform_fn)
488 # Convert to DataFrame
489 self.results_df_ = pd.DataFrame(results)
491 # Find best
492 best_idx = self.results_df_["mean_score"].idxmax()
493 best_row = self.results_df_.iloc[best_idx]
495 self.best_params_ = {k: best_row[k] for k in self.param_distributions}
496 self.best_score_ = float(best_row["mean_score"])
498 # Collect CV scores if available
499 cv_scores = None
500 if self.cv:
501 cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")]
502 if cv_cols: 502 ↛ 505line 502 didn't jump to line 505 because the condition on line 502 was always true
503 cv_scores = self.results_df_.loc[best_idx, cv_cols].values
505 return SearchResult(
506 best_params=self.best_params_,
507 best_score=self.best_score_,
508 all_results=self.results_df_,
509 cv_scores=cv_scores,
510 )
512 def _sample_combinations(self) -> list[dict[str, Any]]:
513 """Sample random parameter combinations.
515 Returns:
516 List of sampled parameter dictionaries.
517 """
518 combinations = []
520 for _ in range(self.n_iter):
521 params = {key: sampler() for key, sampler in self.param_distributions.items()}
522 combinations.append(params)
524 return combinations
527__all__ = [
528 "GridSearchCV",
529 "RandomizedSearchCV",
530 "ScoringFunction",
531 "SearchResult",
532]