geronimo.models

Geronimo Models Module.

The models module defines the base classes and interfaces for creating machine learning models within Geronimo. It handles standardizing the interface for training, evaluation, and inference.

Key components:

  • Model: The base class that all user models must inherit from.
  • HyperParams: A Pydantic model for defining and validating hyperparameters.

By subclassing Model, users can plug their custom logic (using PyTorch, Scikit-Learn, etc.) into the Geronimo ecosystem.

 1"""Geronimo Models Module.
 2
 3The models module defines the base classes and interfaces for creating machine
 4learning models within Geronimo. It handles standardizing the interface for
 5training, evaluation, and inference.
 6
 7Key components:
 8- Model: The base class that all user models must inherit from.
 9- HyperParams: A Pydantic model for defining and validating hyperparameters.
10
11By subclassing `Model`, users can plug their custom logic (using PyTorch,
12Scikit-Learn, etc.) into the Geronimo ecosystem.
13"""
14
15from geronimo.models.base import Model
16from geronimo.models.params import HyperParams
17
18__all__ = ["Model", "HyperParams"]
19
20__docformat__ = "google"
class Model(abc.ABC):
 14class Model(ABC):
 15    """Base class for Geronimo models.
 16
 17    Provides a standardized interface for training and inference
 18    with integrated artifact management.
 19
 20    Example:
 21        ```python
 22        from geronimo.models import Model, HyperParams
 23        from geronimo.features import FeatureSet
 24        from sklearn.ensemble import RandomForestClassifier
 25
 26        class CreditRiskModel(Model):
 27            name = "credit-risk"
 28            version = "1.2.0"
 29
 30            def train(self, X, y, params: HyperParams):
 31                self.estimator = RandomForestClassifier(
 32                    n_estimators=params.n_estimators,
 33                    max_depth=params.max_depth,
 34                )
 35                self.estimator.fit(X, y)
 36
 37            def predict(self, X):
 38                return self.estimator.predict_proba(X)
 39        ```
 40    """
 41
 42    # Override in subclass
 43    name: str = "unnamed"
 44    """The name of the model."""
 45
 46    version: str = "1.0.0"
 47    """The version of the model."""
 48
 49    features: Optional["FeatureSet"] = None
 50    """The associated feature set."""
 51
 52    estimator: Any
 53    """The underlying model estimator (e.g., sklearn object)."""
 54
 55    _is_fitted: bool
 56    """Internal flag tracking whether the model has been trained."""
 57
 58    def __init__(self):
 59        """Initialize model."""
 60        self.estimator: Any = None
 61        self._is_fitted: bool = False
 62
 63    @abstractmethod
 64    def train(self, X, y, params: HyperParams) -> None:
 65        """Train the model.
 66
 67        Args:
 68            X: Feature matrix.
 69            y: Target vector.
 70            params: Hyperparameters.
 71        """
 72        pass
 73
 74    @abstractmethod
 75    def predict(self, X) -> Any:
 76        """Generate predictions.
 77
 78        Args:
 79            X: Feature matrix.
 80
 81        Returns:
 82            Predictions.
 83        """
 84        pass
 85
 86    def save(self, store: "ArtifactStore") -> None:
 87        """Save model to artifact store.
 88
 89        Args:
 90            store: ArtifactStore instance.
 91        """
 92        if self.estimator is None:
 93            raise ValueError("No estimator to save. Train the model first.")
 94
 95        store.save("model", self.estimator, artifact_type="estimator")
 96
 97        # Save features if present
 98        if self.features is not None and hasattr(self.features, "save"):
 99            self.features.save(store)
100
101    def load(self, store: "ArtifactStore") -> None:
102        """Load model from artifact store.
103
104        Args:
105            store: ArtifactStore instance.
106        """
107        self.estimator = store.get("model")
108        self._is_fitted = True
109
110        # Load features if present
111        if self.features is not None and hasattr(self.features, "load"):
112            self.features.load(store)
113
114    @property
115    def is_fitted(self) -> bool:
116        """Check if model has been trained."""
117        return self._is_fitted
118
119    def __repr__(self) -> str:
120        status = "fitted" if self._is_fitted else "not fitted"
121        return f"{self.__class__.__name__}({self.name}@{self.version}, {status})"

Base class for Geronimo models.

Provides a standardized interface for training and inference with integrated artifact management.

Example:
from geronimo.models import Model, HyperParams
from geronimo.features import FeatureSet
from sklearn.ensemble import RandomForestClassifier

class CreditRiskModel(Model):
    name = "credit-risk"
    version = "1.2.0"

    def train(self, X, y, params: HyperParams):
        self.estimator = RandomForestClassifier(
            n_estimators=params.n_estimators,
            max_depth=params.max_depth,
        )
        self.estimator.fit(X, y)

    def predict(self, X):
        return self.estimator.predict_proba(X)
Model()
58    def __init__(self):
59        """Initialize model."""
60        self.estimator: Any = None
61        self._is_fitted: bool = False

Initialize model.

name: str = 'unnamed'

The name of the model.

version: str = '1.0.0'

The version of the model.

features: Optional[geronimo.features.FeatureSet] = None

The associated feature set.

estimator: Any

The underlying model estimator (e.g., sklearn object).

@abstractmethod
def train(self, X, y, params: HyperParams) -> None:
63    @abstractmethod
64    def train(self, X, y, params: HyperParams) -> None:
65        """Train the model.
66
67        Args:
68            X: Feature matrix.
69            y: Target vector.
70            params: Hyperparameters.
71        """
72        pass

Train the model.

Arguments:
  • X: Feature matrix.
  • y: Target vector.
  • params: Hyperparameters.
@abstractmethod
def predict(self, X) -> Any:
74    @abstractmethod
75    def predict(self, X) -> Any:
76        """Generate predictions.
77
78        Args:
79            X: Feature matrix.
80
81        Returns:
82            Predictions.
83        """
84        pass

Generate predictions.

Arguments:
  • X: Feature matrix.
Returns:

Predictions.

def save(self, store: geronimo.artifacts.ArtifactStore) -> None:
86    def save(self, store: "ArtifactStore") -> None:
87        """Save model to artifact store.
88
89        Args:
90            store: ArtifactStore instance.
91        """
92        if self.estimator is None:
93            raise ValueError("No estimator to save. Train the model first.")
94
95        store.save("model", self.estimator, artifact_type="estimator")
96
97        # Save features if present
98        if self.features is not None and hasattr(self.features, "save"):
99            self.features.save(store)

Save model to artifact store.

Arguments:
  • store: ArtifactStore instance.
def load(self, store: geronimo.artifacts.ArtifactStore) -> None:
101    def load(self, store: "ArtifactStore") -> None:
102        """Load model from artifact store.
103
104        Args:
105            store: ArtifactStore instance.
106        """
107        self.estimator = store.get("model")
108        self._is_fitted = True
109
110        # Load features if present
111        if self.features is not None and hasattr(self.features, "load"):
112            self.features.load(store)

Load model from artifact store.

Arguments:
  • store: ArtifactStore instance.
is_fitted: bool
114    @property
115    def is_fitted(self) -> bool:
116        """Check if model has been trained."""
117        return self._is_fitted

Check if model has been trained.

class HyperParams:
 7class HyperParams:
 8    """Hyperparameter configuration with grid search support.
 9
10    Example:
11        ```python
12        from geronimo.models import HyperParams
13
14        # Define parameter space
15        params = HyperParams(
16            n_estimators=[100, 200, 500],
17            max_depth=[3, 5, 7],
18            learning_rate=0.1,  # Fixed value
19        )
20
21        # Iterate over grid
22        for combo in params.grid():
23            model.train(X, y, combo)
24
25        # Get as dict
26        params.to_dict()  # Returns first combination
27        ```
28    """
29
30    def __init__(self, **params):
31        """Initialize hyperparameters.
32
33        Args:
34            **params: Parameter name-value pairs. Values can be:
35                - Single value (fixed parameter)
36                - List of values (for grid search)
37        """
38        self._params = params
39
40    def to_dict(self) -> dict[str, Any]:
41        """Convert to dictionary with single values.
42
43        For list parameters, returns the first value.
44
45        Returns:
46            Dictionary of parameter values.
47        """
48        return {
49            k: v[0] if isinstance(v, list) else v for k, v in self._params.items()
50        }
51
52    def grid(self) -> Iterator["HyperParams"]:
53        """Generate all parameter combinations for grid search.
54
55        Yields:
56            HyperParams instance for each combination.
57        """
58        import itertools
59
60        # Separate fixed vs grid parameters
61        keys = []
62        values = []
63        for k, v in self._params.items():
64            keys.append(k)
65            values.append(v if isinstance(v, list) else [v])
66
67        # Generate cartesian product
68        for combo in itertools.product(*values):
69            yield HyperParams(**dict(zip(keys, combo)))
70
71    def __getattr__(self, name: str) -> Any:
72        """Access parameter by attribute."""
73        if name.startswith("_"):
74            return super().__getattribute__(name)
75        if name in self._params:
76            v = self._params[name]
77            return v[0] if isinstance(v, list) else v
78        raise AttributeError(f"No parameter: {name}")
79
80    def __repr__(self) -> str:
81        params_str = ", ".join(f"{k}={v}" for k, v in self._params.items())
82        return f"HyperParams({params_str})"

Hyperparameter configuration with grid search support.

Example:
from geronimo.models import HyperParams

# Define parameter space
params = HyperParams(
    n_estimators=[100, 200, 500],
    max_depth=[3, 5, 7],
    learning_rate=0.1,  # Fixed value
)

# Iterate over grid
for combo in params.grid():
    model.train(X, y, combo)

# Get as dict
params.to_dict()  # Returns first combination
HyperParams(**params)
30    def __init__(self, **params):
31        """Initialize hyperparameters.
32
33        Args:
34            **params: Parameter name-value pairs. Values can be:
35                - Single value (fixed parameter)
36                - List of values (for grid search)
37        """
38        self._params = params

Initialize hyperparameters.

Arguments:
  • **params: Parameter name-value pairs. Values can be:
    • Single value (fixed parameter)
    • List of values (for grid search)
def to_dict(self) -> dict[str, typing.Any]:
40    def to_dict(self) -> dict[str, Any]:
41        """Convert to dictionary with single values.
42
43        For list parameters, returns the first value.
44
45        Returns:
46            Dictionary of parameter values.
47        """
48        return {
49            k: v[0] if isinstance(v, list) else v for k, v in self._params.items()
50        }

Convert to dictionary with single values.

For list parameters, returns the first value.

Returns:

Dictionary of parameter values.

def grid(self) -> Iterator[HyperParams]:
52    def grid(self) -> Iterator["HyperParams"]:
53        """Generate all parameter combinations for grid search.
54
55        Yields:
56            HyperParams instance for each combination.
57        """
58        import itertools
59
60        # Separate fixed vs grid parameters
61        keys = []
62        values = []
63        for k, v in self._params.items():
64            keys.append(k)
65            values.append(v if isinstance(v, list) else [v])
66
67        # Generate cartesian product
68        for combo in itertools.product(*values):
69            yield HyperParams(**dict(zip(keys, combo)))

Generate all parameter combinations for grid search.

Yields:

HyperParams instance for each combination.