Coverage for contextualized/analysis/bootstraps.py: 90%

10 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1import numpy as np 

2from contextualized.easy.wrappers import SKLearnWrapper 

3 

4 

5def select_good_bootstraps( 

6 sklearn_wrapper: SKLearnWrapper, train_errs: np.ndarray, tol: float = 2 

7) -> SKLearnWrapper: 

8 """ 

9 Select bootstraps that are good for a given model. 

10 

11 Args: 

12 sklearn_wrapper (contextualized.easy.wrappers.SKLearnWrapper): Wrapper for the sklearn model. 

13 train_errs (np.ndarray): Training errors for each bootstrap (n_bootstraps, n_samples, n_outcomes). 

14 tol (float): Only bootstraps with mean train_errs below tol * min(train_errs) are kept. 

15 

16 Returns: 

17 contextualized.easy.wrappers.SKLearnWrapper: The input model with only selected bootstraps. 

18 """ 

19 if len(train_errs.shape) == 2: 

20 train_errs = train_errs[:, :, None] 

21 

22 train_errs_by_bootstrap = np.mean(train_errs, axis=(1, 2)) 

23 train_errs_min = np.min(train_errs_by_bootstrap) 

24 sklearn_wrapper.models = [ 

25 model 

26 for train_err, model in zip(train_errs_by_bootstrap, sklearn_wrapper.models) 

27 if train_err < train_errs_min * tol 

28 ] 

29 sklearn_wrapper.n_bootstraps = len(sklearn_wrapper.models) 

30 return sklearn_wrapper