Source code for scitex_ml.classification.Classifier

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-12-12 06:49:15 (ywatanabe)"
# File: ./scitex_repo/src/scitex/ai/Classifier.py

THIS_FILE = (
    "/data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/ai/Classifier.py"
)

"""
Functionality:
    * Provides a unified interface for initializing various scikit-learn classifiers
    * Supports optional preprocessing with StandardScaler

Input:
    * Classifier name as string
    * Optional class weights for imbalanced datasets
    * Optional scaler for feature preprocessing

Output:
    * Initialized classifier or pipeline with scaler

Prerequisites:
    * scikit-learn
    * Optional: CatBoost for CatBoostClassifier
"""

from typing import Dict, List, Optional, Union

from sklearn.base import BaseEstimator as _BaseEstimator
from sklearn.discriminant_analysis import (
    QuadraticDiscriminantAnalysis as _QuadraticDiscriminantAnalysis,
)
from sklearn.ensemble import AdaBoostClassifier as _AdaBoostClassifier
from sklearn.gaussian_process import (
    GaussianProcessClassifier as _GaussianProcessClassifier,
)
from sklearn.linear_model import LogisticRegression as _LogisticRegression
from sklearn.linear_model import (
    PassiveAggressiveClassifier as _PassiveAggressiveClassifier,
)
from sklearn.linear_model import Perceptron as _Perceptron
from sklearn.linear_model import RidgeClassifier as _RidgeClassifier
from sklearn.linear_model import SGDClassifier as _SGDClassifier
from sklearn.neighbors import KNeighborsClassifier as _KNeighborsClassifier
from sklearn.pipeline import Pipeline as _Pipeline
from sklearn.pipeline import make_pipeline as _make_pipeline
from sklearn.preprocessing import StandardScaler as _StandardScaler
from sklearn.svm import SVC as _SVC
from sklearn.svm import LinearSVC as _LinearSVC


[docs] class Classifier: """ Server for initializing various scikit-learn classifiers with consistent interface. Example ------- >>> clf_server = Classifier(class_weight={0: 1.0, 1: 2.0}, random_state=42) >>> clf = clf_server("SVC", scaler=_StandardScaler()) >>> print(clf_server.list) ['CatBoostClassifier', 'Perceptron', ...] Parameters ---------- class_weight : Optional[Dict[int, float]] Class weights for handling imbalanced datasets random_state : int Random seed for reproducibility """
[docs] def __init__( self, class_weight: Optional[Dict[int, float]] = None, random_state: int = 42, ): self.class_weight = class_weight self.random_state = random_state self.clf_candi = { "Perceptron": _Perceptron( penalty="l2", class_weight=self.class_weight, random_state=random_state, ), "PassiveAggressiveClassifier": _PassiveAggressiveClassifier( class_weight=self.class_weight, random_state=random_state ), "LogisticRegression": _LogisticRegression( class_weight=self.class_weight, random_state=random_state ), "SGDClassifier": _SGDClassifier( class_weight=self.class_weight, random_state=random_state ), "RidgeClassifier": _RidgeClassifier( class_weight=self.class_weight, random_state=random_state ), "QuadraticDiscriminantAnalysis": _QuadraticDiscriminantAnalysis(), "GaussianProcessClassifier": _GaussianProcessClassifier( random_state=random_state ), "KNeighborsClassifier": _KNeighborsClassifier(), "AdaBoostClassifier": _AdaBoostClassifier(random_state=random_state), "LinearSVC": _LinearSVC( class_weight=self.class_weight, random_state=random_state ), "SVC": _SVC(class_weight=self.class_weight, random_state=random_state), }
def __call__( self, clf_str: str, scaler: Optional[_BaseEstimator] = None ) -> Union[_BaseEstimator, _Pipeline]: if clf_str not in self.clf_candi: raise ValueError( f"Unknown classifier: {clf_str}. Available options: {self.list}" ) if scaler is not None: clf = _make_pipeline(scaler, self.clf_candi[clf_str]) else: clf = self.clf_candi[clf_str] return clf @property def list(self) -> List[str]: return list(self.clf_candi.keys())
if __name__ == "__main__": clf_server = Classifier() clf = clf_server("SVC", scaler=_StandardScaler())