Source code for WORC.classification.ObjectSampler

#!/usr/bin/env python

# Copyright 2016-2019 Biomedical Imaging Group Rotterdam, Departments of
# Medical Informatics and Radiology, Erasmus MC, Rotterdam, The Netherlands
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from imblearn import over_sampling, under_sampling, combine
import numpy as np
from sklearn.utils import check_random_state
import WORC.addexceptions as ae


[docs]class ObjectSampler(object): """ Samples objects for learning based on various under-, over- and combined sampling methods. The choice of included methods is largely based on: He, Haibo, and Edwardo A. Garcia. "Learning from imbalanced data." IEEE Transactions on Knowledge & Data Engineering 9 (2008): 1263-1284. """
[docs] def __init__(self, method, sampling_strategy='auto', SMOTE_ratio=1, SMOTE_neighbors=5, n_jobs=1, n_neighbors=3, ): # Initialize a random state self.random_seed = np.random.randint(5000) self.random_state = check_random_state(random_seed) # Initialize all objects as Nones: overriden when required by functions self.sampling_strategy = None self.object = None self.n_neighbors = None self.n_jobs = None if method == 'RandomUnderSampling': self.init_RandomUnderSampling(sampling_strategy) elif method == 'NearMiss': self.init_NearMiss(sampling_strategy, n_neighbors, n_jobs) elif method == 'NeigbourhoodCleaningRule': self.init_NeigbourhoodCleaningRule() elif method == 'RandomOverSampling': self.init_RandomOverSampling(sampling_strategy) elif method == 'ADASYN': self.init_ADASYN() elif method == 'BorderlineSMOTE': self.init_BorderlineSMOTE() elif method == 'SMOTEENN': self.init_SMOTEENN() elif method == 'SMOTETomek': self.init_SMOTETomek() else: raise ae.WORCKeyError(f'{method} is not a valid sampling method!')
[docs] def init_RandomUnderSampling(self, sampling_strategy): self.object = under_sampling.RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=self.random_state) self.sampling_strategy = sampling_strategy
[docs] def init_NearMiss(self, sampling_strategy, n_neighbors, n_jobs): self.object = under_sampling.NearMiss(sampling_strategy=sampling_strategy, random_state=self.random_state, n_neighbors=n_neighbors, n_jobs=n_jobs) self.sampling_strategy = sampling_strategy self.n_neighbors = n_neighbors self.n_jobs = n_jobs
[docs] def init_RandomOverSampling(self, sampling_strategy): self.object = over_sampling.RandomOverSampler(sampling_strategy=sampling_strategy, random_state=self.random_state) self.sampling_strategy = sampling_strategy
[docs] def init_SMOTE(self): sm = SMOTE(random_state=None, ratio=para_estimator['SampleProcessing_SMOTE_ratio'], m_neighbors=para_estimator['SampleProcessing_SMOTE_neighbors'], kind='borderline1', n_jobs=para_estimator['SampleProcessing_SMOTE_n_cores']) self.object = sm
def fit(self, **kwargs): self.object.fit(**kwargs)
[docs] def fit(self, **kwargs): self.object.fit(**kwargs)