Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2""" 

3Cross-validation iterators for GAM 

4 

5Author: Luca Puggini 

6 

7""" 

8 

9from abc import ABCMeta, abstractmethod 

10from statsmodels.compat.python import with_metaclass 

11import numpy as np 

12 

13 

14class BaseCrossValidator(with_metaclass(ABCMeta)): 

15 """ 

16 The BaseCrossValidator class is a base class for all the iterators that 

17 split the data in train and test as for example KFolds or LeavePOut 

18 """ 

19 def __init__(self): 

20 pass 

21 

22 @abstractmethod 

23 def split(self): 

24 pass 

25 

26 

27class KFold(BaseCrossValidator): 

28 """ 

29 K-Folds cross validation iterator: 

30 Provides train/test indexes to split data in train test sets 

31 

32 Parameters 

33 ---------- 

34 k: int 

35 number of folds 

36 shuffle : bool 

37 If true, then the index is shuffled before splitting into train and 

38 test indices. 

39 

40 Notes 

41 ----- 

42 All folds except for last fold have size trunc(n/k), the last fold has 

43 the remainder. 

44 """ 

45 

46 def __init__(self, k_folds, shuffle=False): 

47 self.nobs = None 

48 self.k_folds = k_folds 

49 self.shuffle = shuffle 

50 

51 def split(self, X, y=None, label=None): 

52 """yield index split into train and test sets 

53 """ 

54 # TODO: X and y are redundant, we only need nobs 

55 

56 nobs = X.shape[0] 

57 index = np.array(range(nobs)) 

58 

59 if self.shuffle: 

60 np.random.shuffle(index) 

61 

62 folds = np.array_split(index, self.k_folds) 

63 for fold in folds: 

64 test_index = np.zeros(nobs, dtype=np.bool_) 

65 test_index[fold] = True 

66 train_index = np.logical_not(test_index) 

67 yield train_index, test_index