gaitsetpy.classification.models
1from .random_forest import RandomForestModel 2from .mlp import MLPModel 3from .lstm import LSTMModel 4from .bilstm import BiLSTMModel 5from .gnn import GNNModel 6from .cnn import CNNModel 7 8def get_classification_model(name: str, **kwargs): 9 """ 10 Factory function to get a classification model by name. 11 12 Args: 13 name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. 14 **kwargs: Model-specific parameters. 15 16 Returns: 17 An instance of the requested model. 18 19 Raises: 20 ValueError: If the model name is not recognized. 21 22 Example: 23 model = get_classification_model('cnn', input_channels=20, num_classes=4) 24 """ 25 name = name.lower() 26 if name == 'random_forest': 27 return RandomForestModel(**kwargs) 28 elif name == 'mlp': 29 return MLPModel(**kwargs) 30 elif name == 'lstm': 31 return LSTMModel(**kwargs) 32 elif name == 'bilstm': 33 return BiLSTMModel(**kwargs) 34 elif name == 'gnn': 35 return GNNModel(**kwargs) 36 elif name == 'cnn': 37 return CNNModel(**kwargs) 38 else: 39 raise ValueError(f"Unknown model name: {name}. Supported: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.")
def
get_classification_model(name: str, **kwargs):
9def get_classification_model(name: str, **kwargs): 10 """ 11 Factory function to get a classification model by name. 12 13 Args: 14 name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. 15 **kwargs: Model-specific parameters. 16 17 Returns: 18 An instance of the requested model. 19 20 Raises: 21 ValueError: If the model name is not recognized. 22 23 Example: 24 model = get_classification_model('cnn', input_channels=20, num_classes=4) 25 """ 26 name = name.lower() 27 if name == 'random_forest': 28 return RandomForestModel(**kwargs) 29 elif name == 'mlp': 30 return MLPModel(**kwargs) 31 elif name == 'lstm': 32 return LSTMModel(**kwargs) 33 elif name == 'bilstm': 34 return BiLSTMModel(**kwargs) 35 elif name == 'gnn': 36 return GNNModel(**kwargs) 37 elif name == 'cnn': 38 return CNNModel(**kwargs) 39 else: 40 raise ValueError(f"Unknown model name: {name}. Supported: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.")
Factory function to get a classification model by name.
Args: name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. **kwargs: Model-specific parameters.
Returns: An instance of the requested model.
Raises: ValueError: If the model name is not recognized.
Example: model = get_classification_model('cnn', input_channels=20, num_classes=4)