gaitsetpy.classification.models

 1from .random_forest import RandomForestModel
 2from .mlp import MLPModel
 3from .lstm import LSTMModel
 4from .bilstm import BiLSTMModel
 5from .gnn import GNNModel
 6from .cnn import CNNModel
 7
 8def get_classification_model(name: str, **kwargs):
 9    """
10    Factory function to get a classification model by name.
11
12    Args:
13        name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.
14        **kwargs: Model-specific parameters.
15
16    Returns:
17        An instance of the requested model.
18
19    Raises:
20        ValueError: If the model name is not recognized.
21
22    Example:
23        model = get_classification_model('cnn', input_channels=20, num_classes=4)
24    """
25    name = name.lower()
26    if name == 'random_forest':
27        return RandomForestModel(**kwargs)
28    elif name == 'mlp':
29        return MLPModel(**kwargs)
30    elif name == 'lstm':
31        return LSTMModel(**kwargs)
32    elif name == 'bilstm':
33        return BiLSTMModel(**kwargs)
34    elif name == 'gnn':
35        return GNNModel(**kwargs)
36    elif name == 'cnn':
37        return CNNModel(**kwargs)
38    else:
39        raise ValueError(f"Unknown model name: {name}. Supported: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.")
def get_classification_model(name: str, **kwargs):
 9def get_classification_model(name: str, **kwargs):
10    """
11    Factory function to get a classification model by name.
12
13    Args:
14        name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.
15        **kwargs: Model-specific parameters.
16
17    Returns:
18        An instance of the requested model.
19
20    Raises:
21        ValueError: If the model name is not recognized.
22
23    Example:
24        model = get_classification_model('cnn', input_channels=20, num_classes=4)
25    """
26    name = name.lower()
27    if name == 'random_forest':
28        return RandomForestModel(**kwargs)
29    elif name == 'mlp':
30        return MLPModel(**kwargs)
31    elif name == 'lstm':
32        return LSTMModel(**kwargs)
33    elif name == 'bilstm':
34        return BiLSTMModel(**kwargs)
35    elif name == 'gnn':
36        return GNNModel(**kwargs)
37    elif name == 'cnn':
38        return CNNModel(**kwargs)
39    else:
40        raise ValueError(f"Unknown model name: {name}. Supported: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.")

Factory function to get a classification model by name.

Args: name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. **kwargs: Model-specific parameters.

Returns: An instance of the requested model.

Raises: ValueError: If the model name is not recognized.

Example: model = get_classification_model('cnn', input_channels=20, num_classes=4)