Coverage for contextualized/analysis/bootstraps.py: 90%
10 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1import numpy as np
2from contextualized.easy.wrappers import SKLearnWrapper
5def select_good_bootstraps(
6 sklearn_wrapper: SKLearnWrapper, train_errs: np.ndarray, tol: float = 2
7) -> SKLearnWrapper:
8 """
9 Select bootstraps that are good for a given model.
11 Args:
12 sklearn_wrapper (contextualized.easy.wrappers.SKLearnWrapper): Wrapper for the sklearn model.
13 train_errs (np.ndarray): Training errors for each bootstrap (n_bootstraps, n_samples, n_outcomes).
14 tol (float): Only bootstraps with mean train_errs below tol * min(train_errs) are kept.
16 Returns:
17 contextualized.easy.wrappers.SKLearnWrapper: The input model with only selected bootstraps.
18 """
19 if len(train_errs.shape) == 2:
20 train_errs = train_errs[:, :, None]
22 train_errs_by_bootstrap = np.mean(train_errs, axis=(1, 2))
23 train_errs_min = np.min(train_errs_by_bootstrap)
24 sklearn_wrapper.models = [
25 model
26 for train_err, model in zip(train_errs_by_bootstrap, sklearn_wrapper.models)
27 if train_err < train_errs_min * tol
28 ]
29 sklearn_wrapper.n_bootstraps = len(sklearn_wrapper.models)
30 return sklearn_wrapper