Edit on GitHub

insurance_gam.anam

insurance_gam.anam — Actuarial Neural Additive Model subpackage.

Re-exports the full public API of the original insurance-anam package.

 1"""
 2insurance_gam.anam — Actuarial Neural Additive Model subpackage.
 3
 4Re-exports the full public API of the original insurance-anam package.
 5"""
 6
 7from .api import ANAM
 8from .feature_network import CategoricalFeatureNetwork, FeatureNetwork
 9from .interaction_network import InteractionNetwork
10from .losses import (
11    bernoulli_deviance,
12    gamma_deviance,
13    l1_sparsity_penalty,
14    l2_ridge_penalty,
15    poisson_deviance,
16    smoothness_penalty,
17    tweedie_deviance,
18)
19from .model import ANAMModel, FeatureConfig, InteractionConfig
20from .shapes import ShapeFunction, extract_shape_functions, plot_all_shapes
21from .trainer import ANAMTrainer, TrainingConfig, TrainingHistory
22from .utils import (
23    StandardScaler,
24    compare_shapes_to_glm,
25    compute_deviance_stat,
26    select_interactions_correlation,
27    select_interactions_residual,
28    shapes_to_relativity_table,
29)
30
31__all__ = [
32    "ANAM",
33    "ANAMModel",
34    "FeatureConfig",
35    "InteractionConfig",
36    "FeatureNetwork",
37    "CategoricalFeatureNetwork",
38    "InteractionNetwork",
39    "ANAMTrainer",
40    "TrainingConfig",
41    "TrainingHistory",
42    "poisson_deviance",
43    "gamma_deviance",
44    "tweedie_deviance",
45    "bernoulli_deviance",
46    "smoothness_penalty",
47    "l1_sparsity_penalty",
48    "l2_ridge_penalty",
49    "ShapeFunction",
50    "extract_shape_functions",
51    "plot_all_shapes",
52    "StandardScaler",
53    "select_interactions_correlation",
54    "select_interactions_residual",
55    "shapes_to_relativity_table",
56    "compare_shapes_to_glm",
57    "compute_deviance_stat",
58]
class ANAM:
 43class ANAM:
 44    """sklearn-compatible Actuarial Neural Additive Model.
 45
 46    Parameters
 47    ----------
 48    feature_configs:
 49        List of FeatureConfig objects defining features. If None, all
 50        features are treated as continuous with no constraints.
 51    feature_names:
 52        Feature names (used when feature_configs is None). Must match
 53        columns of X passed to fit().
 54    categorical_features:
 55        List of feature names that are categorical. Only used when
 56        feature_configs is None.
 57    monotone_increasing:
 58        Feature names to constrain as monotone increasing.
 59    monotone_decreasing:
 60        Feature names to constrain as monotone decreasing.
 61    link:
 62        Link function: 'log' (Poisson/Tweedie), 'identity' (Gaussian),
 63        'logit' (binary).
 64    loss:
 65        Distributional loss: 'poisson', 'tweedie', 'gamma', 'mse'.
 66    tweedie_p:
 67        Tweedie power parameter (only for loss='tweedie').
 68    interaction_pairs:
 69        List of (feature_i, feature_j) tuples for interaction subnetworks.
 70    hidden_sizes:
 71        Default hidden layer sizes for subnetworks.
 72    n_epochs:
 73        Maximum training epochs.
 74    batch_size:
 75        Mini-batch size.
 76    learning_rate:
 77        Adam learning rate.
 78    lambda_smooth:
 79        Smoothness regularisation weight.
 80    lambda_l2:
 81        L2 ridge weight.
 82    lambda_l1:
 83        L1 sparsity weight.
 84    patience:
 85        Early stopping patience (epochs).
 86    normalize:
 87        If True, standardise continuous features before training. The
 88        scaler is stored and applied automatically during predict().
 89    verbose:
 90        Training verbosity. 0 = silent.
 91    device:
 92        'cpu', 'cuda', or None (auto).
 93    """
 94
 95    def __init__(
 96        self,
 97        feature_configs: Optional[List[FeatureConfig]] = None,
 98        feature_names: Optional[List[str]] = None,
 99        categorical_features: Optional[List[str]] = None,
100        monotone_increasing: Optional[List[str]] = None,
101        monotone_decreasing: Optional[List[str]] = None,
102        link: Literal["log", "identity", "logit"] = "log",
103        loss: Literal["poisson", "tweedie", "gamma", "mse"] = "poisson",
104        tweedie_p: float = 1.5,
105        interaction_pairs: Optional[List[Tuple[str, str]]] = None,
106        hidden_sizes: Optional[List[int]] = None,
107        n_epochs: int = 100,
108        batch_size: int = 512,
109        learning_rate: float = 1e-3,
110        lambda_smooth: float = 1e-4,
111        lambda_l2: float = 1e-4,
112        lambda_l1: float = 0.0,
113        patience: int = 15,
114        normalize: bool = True,
115        verbose: int = 0,
116        device: Optional[str] = None,
117    ) -> None:
118        self.feature_configs = feature_configs
119        self.feature_names = feature_names
120        # Store exactly as passed — no coercion to [] here.
121        # sklearn clone() checks that get_params() round-trips through __init__
122        # exactly. Coercing None->[] would break that invariant.
123        # Defaults are applied at point-of-use in fit() and _build_feature_configs().
124        self.categorical_features = categorical_features
125        self.monotone_increasing = monotone_increasing
126        self.monotone_decreasing = monotone_decreasing
127        self.link = link
128        self.loss = loss
129        self.tweedie_p = tweedie_p
130        self.interaction_pairs = interaction_pairs
131        self.hidden_sizes = hidden_sizes
132        self.n_epochs = n_epochs
133        self.batch_size = batch_size
134        self.learning_rate = learning_rate
135        self.lambda_smooth = lambda_smooth
136        self.lambda_l2 = lambda_l2
137        self.lambda_l1 = lambda_l1
138        self.patience = patience
139        self.normalize = normalize
140        self.verbose = verbose
141        self.device = device
142
143        # Set after fit()
144        self.model_: Optional[ANAMModel] = None
145        self.scaler_: Optional[StandardScaler] = None
146        self.history_: Optional[TrainingHistory] = None
147        self.feature_names_in_: Optional[List[str]] = None
148        self._continuous_col_indices: List[int] = []
149        # P0-1 fix: track both the cached shapes and the n_points used to
150        # build them so we can detect stale cache on different n_points.
151        self._shapes_cache: Optional[Dict[str, ShapeFunction]] = None
152        self._shapes_cache_n_points: Optional[int] = None
153        self._X_train_scaled: Optional[np.ndarray] = None
154        # P0-2 fix: per-feature category remapping tables set during fit().
155        # Maps feature name -> array where remap[original_code] = new_code.
156        self._cat_remap_: Dict[str, np.ndarray] = {}
157
158    def fit(
159        self,
160        X: Union[np.ndarray, pl.DataFrame],
161        y: Union[np.ndarray, pl.Series],
162        sample_weight: Optional[Union[np.ndarray, pl.Series]] = None,
163    ) -> "ANAM":
164        """Fit the ANAM model.
165
166        Parameters
167        ----------
168        X:
169            Feature matrix, shape (n, p). Polars DataFrame accepted.
170        y:
171            Target vector (claim counts, rates, or severities).
172        sample_weight:
173            Exposure weights. For frequency models this is policy duration
174            (e.g. years on risk). If None, uniform exposure assumed.
175
176        Returns
177        -------
178        self
179        """
180        X_arr, y_arr, w_arr, names = self._validate_input(X, y, sample_weight)
181        self.feature_names_in_ = names
182
183        # Build feature configs if not provided. Also populates self._cat_remap_
184        # for any categorical columns that are not zero-indexed consecutive ints.
185        feat_configs = self.feature_configs or self._build_feature_configs(names, X_arr)
186
187        # P0-2 fix: apply category remapping to training data so that the
188        # trainer and scaler see zero-indexed codes, matching the embedding
189        # table sizes that _build_feature_configs computed.
190        if self._cat_remap_:
191            X_arr = self._apply_cat_remap(X_arr)
192
193        # Identify continuous column indices for normalisation
194        self._continuous_col_indices = [
195            i for i, cfg in enumerate(feat_configs)
196            if cfg.feature_type == "continuous"
197        ]
198
199        # Normalise continuous features
200        if self.normalize and self._continuous_col_indices:
201            self.scaler_ = StandardScaler()
202            cont_cols = np.array(self._continuous_col_indices)
203            X_arr_norm = X_arr.copy()
204            X_arr_norm[:, cont_cols] = self.scaler_.fit(
205                X_arr[:, cont_cols], feature_names=[names[i] for i in cont_cols]
206            ).transform(X_arr[:, cont_cols])
207        else:
208            X_arr_norm = X_arr
209
210        self._X_train_scaled = X_arr_norm
211
212        # Build interaction configs
213        interaction_configs: List[InteractionConfig] = []
214        for fi, fj in (self.interaction_pairs or []):
215            interaction_configs.append(
216                InteractionConfig(feature_i=fi, feature_j=fj)
217            )
218
219        # Build model
220        self.model_ = ANAMModel(
221            feature_configs=feat_configs,
222            link=self.link,
223            interaction_configs=interaction_configs,
224            hidden_sizes=self.hidden_sizes or [64, 32],
225        )
226
227        # Build trainer
228        train_cfg = TrainingConfig(
229            loss=self.loss,
230            tweedie_p=self.tweedie_p,
231            n_epochs=self.n_epochs,
232            batch_size=self.batch_size,
233            learning_rate=self.learning_rate,
234            lambda_smooth=self.lambda_smooth,
235            lambda_l2=self.lambda_l2,
236            lambda_l1=self.lambda_l1,
237            patience=self.patience,
238            verbose=self.verbose,
239            device=self.device,
240        )
241        trainer = ANAMTrainer(self.model_, train_cfg)
242        self.history_ = trainer.fit(X_arr_norm, y_arr, exposure=w_arr)
243        self.model_ = trainer.model  # may have moved to device
244
245        return self
246
247    def predict(
248        self,
249        X: Union[np.ndarray, pl.DataFrame],
250        exposure: Optional[Union[np.ndarray, pl.Series]] = None,
251    ) -> np.ndarray:
252        """Predict expected mean (mu) for each observation.
253
254        Parameters
255        ----------
256        X:
257            Feature matrix.
258        exposure:
259            Exposure for each row. If None, exposure=1 assumed (predicts
260            per-policy-year rate for log-link models).
261
262        Returns
263        -------
264        np.ndarray
265            Predicted means, shape (n,).
266        """
267        self._check_fitted()
268
269        X_arr = self._to_array(X)
270        X_arr = self._apply_cat_remap(X_arr)
271        X_arr = self._apply_scaling(X_arr)
272
273        if exposure is not None:
274            if isinstance(exposure, pl.Series):
275                exposure = exposure.to_numpy()
276            exp_arr = np.asarray(exposure, dtype=np.float32)
277            log_exp = torch.tensor(np.log(exp_arr.clip(min=1e-8)), dtype=torch.float32)
278        else:
279            log_exp = None
280
281        X_t = torch.tensor(X_arr, dtype=torch.float32)
282
283        self.model_.eval()
284        with torch.no_grad():
285            mu = self.model_(X_t, log_exposure=log_exp)
286
287        return mu.cpu().numpy()
288
289    def score(
290        self,
291        X: Union[np.ndarray, pl.DataFrame],
292        y: Union[np.ndarray, pl.Series],
293        sample_weight: Optional[Union[np.ndarray, pl.Series]] = None,
294    ) -> float:
295        """Return negative mean deviance (higher = better).
296
297        Follows sklearn convention where score() returns a value where
298        higher is better. We negate deviance so that model.score() can
299        be maximised by hyperparameter search.
300        """
301        self._check_fitted()
302
303        y_arr = np.asarray(y if not isinstance(y, pl.Series) else y.to_numpy(), dtype=np.float32)
304        y_pred = self.predict(X, exposure=sample_weight)
305
306        w = None
307        if sample_weight is not None:
308            w = np.asarray(
309                sample_weight if not isinstance(sample_weight, pl.Series)
310                else sample_weight.to_numpy(),
311                dtype=np.float32,
312            )
313
314        dev = compute_deviance_stat(
315            y_arr, y_pred, exposure=w,
316            loss=self.loss, tweedie_p=self.tweedie_p
317        )
318        return -dev
319
320    def shape_functions(
321        self,
322        n_points: int = 200,
323        category_labels: Optional[Dict[str, Dict[int, str]]] = None,
324    ) -> Dict[str, ShapeFunction]:
325        """Extract and cache shape functions for all features.
326
327        Returns a dict mapping feature name to ShapeFunction. Shape
328        functions are evaluated over the observed training data range.
329
330        The cache is keyed on n_points. Calling with a different n_points
331        than the cached value discards the stale cache and recomputes.
332        """
333        self._check_fitted()
334        assert self._X_train_scaled is not None
335
336        # P0-1 fix: invalidate cache when n_points changes.
337        if self._shapes_cache is None or self._shapes_cache_n_points != n_points:
338            self._shapes_cache = extract_shape_functions(
339                self.model_,
340                self._X_train_scaled,
341                n_points=n_points,
342                category_labels=category_labels,
343            )
344            self._shapes_cache_n_points = n_points
345        return self._shapes_cache
346
347    def feature_importance(self) -> pl.DataFrame:
348        """Feature importance as subnetwork weight norms.
349
350        Returns a Polars DataFrame with columns [feature, importance],
351        sorted by importance descending.
352        """
353        self._check_fitted()
354        imp = self.model_.feature_importance()
355        return pl.DataFrame(
356            {"feature": list(imp.keys()), "importance": list(imp.values())}
357        ).sort("importance", descending=True)
358
359    def get_params(self, deep: bool = True) -> Dict[str, Any]:
360        """sklearn get_params for grid search compatibility.
361
362        Returns ALL constructor parameters — including structural ones
363        (feature_configs, interaction_pairs, etc.) — so that
364        sklearn.base.clone() can reconstruct a functionally equivalent
365        model.
366        """
367        # P0-4 fix: include all constructor parameters, not just numerical ones.
368        return {
369            "feature_configs": self.feature_configs,
370            "feature_names": self.feature_names,
371            "categorical_features": self.categorical_features,
372            "monotone_increasing": self.monotone_increasing,
373            "monotone_decreasing": self.monotone_decreasing,
374            "link": self.link,
375            "loss": self.loss,
376            "tweedie_p": self.tweedie_p,
377            "interaction_pairs": self.interaction_pairs,
378            "hidden_sizes": self.hidden_sizes,
379            "n_epochs": self.n_epochs,
380            "batch_size": self.batch_size,
381            "learning_rate": self.learning_rate,
382            "lambda_smooth": self.lambda_smooth,
383            "lambda_l2": self.lambda_l2,
384            "lambda_l1": self.lambda_l1,
385            "patience": self.patience,
386            "normalize": self.normalize,
387            "verbose": self.verbose,
388            "device": self.device,
389        }
390
391    def set_params(self, **params: Any) -> "ANAM":
392        """sklearn set_params for grid search compatibility."""
393        for key, val in params.items():
394            setattr(self, key, val)
395        return self
396
397    # ------------------------------------------------------------------
398    # Internal helpers
399    # ------------------------------------------------------------------
400
401    def _check_fitted(self) -> None:
402        if self.model_ is None:
403            raise RuntimeError("Call fit() before predict() or score().")
404
405    def _to_array(self, X: Union[np.ndarray, pl.DataFrame]) -> np.ndarray:
406        if isinstance(X, pl.DataFrame):
407            return X.to_numpy()
408        return np.asarray(X, dtype=np.float64)
409
410    def _apply_scaling(self, X_arr: np.ndarray) -> np.ndarray:
411        if self.scaler_ is None or not self._continuous_col_indices:
412            return X_arr.astype(np.float32)
413        X_norm = X_arr.copy().astype(np.float32)
414        cont_cols = np.array(self._continuous_col_indices)
415        X_norm[:, cont_cols] = self.scaler_.transform(X_arr[:, cont_cols].astype(np.float64)).astype(np.float32)
416        return X_norm
417
418    def _apply_cat_remap(self, X_arr: np.ndarray) -> np.ndarray:
419        """Apply stored category remappings.
420
421        P0-2 fix: if any categorical features were remapped during fit
422        (because they were not zero-indexed consecutive integers), apply
423        the same remapping here. Called both during fit() (on training
424        data) and during predict() (on new data) so that codes always
425        align with the trained embedding table.
426        """
427        if not self._cat_remap_:
428            return X_arr
429        if self.feature_names_in_ is None:
430            return X_arr
431        X_out = X_arr.copy()
432        for feat_name, remap in self._cat_remap_.items():
433            if feat_name in self.feature_names_in_:
434                col_idx = self.feature_names_in_.index(feat_name)
435                codes = X_out[:, col_idx].astype(int)
436                X_out[:, col_idx] = remap[codes].astype(X_out.dtype)
437        return X_out
438
439    def _validate_input(
440        self,
441        X: Union[np.ndarray, pl.DataFrame],
442        y: Union[np.ndarray, pl.Series],
443        w: Optional[Union[np.ndarray, pl.Series]],
444    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], List[str]]:
445        """Convert inputs to numpy and extract feature names."""
446        if isinstance(X, pl.DataFrame):
447            names = X.columns
448            X_arr = X.to_numpy().astype(np.float64)
449        else:
450            X_arr = np.asarray(X, dtype=np.float64)
451            names = self.feature_names or [f"f{i}" for i in range(X_arr.shape[1])]
452
453        if isinstance(y, pl.Series):
454            y_arr = y.to_numpy().astype(np.float32)
455        else:
456            y_arr = np.asarray(y, dtype=np.float32)
457
458        if w is not None:
459            if isinstance(w, pl.Series):
460                w_arr: Optional[np.ndarray] = w.to_numpy().astype(np.float32)
461            else:
462                w_arr = np.asarray(w, dtype=np.float32)
463        else:
464            w_arr = None
465
466        return X_arr, y_arr, w_arr, list(names)
467
468    def _build_feature_configs(
469        self, names: List[str], X_arr: np.ndarray
470    ) -> List[FeatureConfig]:
471        """Auto-construct FeatureConfig list from names and constraints.
472
473        P0-2 fix: for categorical features, validate that column values are
474        zero-indexed consecutive integers. If not, auto-remap and warn. The
475        remap is stored in self._cat_remap_ and applied to the training
476        array in fit() before the scaler and trainer receive it.
477        """
478        configs: List[FeatureConfig] = []
479        self._cat_remap_ = {}
480
481        for i, name in enumerate(names):
482            if name in (self.categorical_features or []):
483                col = X_arr[:, i].astype(int)
484                unique_vals = np.unique(col)
485                expected = np.arange(len(unique_vals))
486                if not np.array_equal(unique_vals, expected):
487                    # Not zero-indexed consecutive integers. Auto-remap.
488                    warnings.warn(
489                        f"Categorical feature '{name}' has values {unique_vals.tolist()} "
490                        f"which are not zero-indexed consecutive integers. "
491                        f"Auto-remapping to 0..{len(unique_vals) - 1}. "
492                        f"The same remapping will be applied at predict time. "
493                        f"If you rely on specific category indices in post-hoc analysis, "
494                        f"remap the column yourself before passing to fit().",
495                        UserWarning,
496                        stacklevel=3,
497                    )
498                    # Build a lookup array: remap[original_code] = new_code.
499                    max_orig = int(unique_vals.max())
500                    remap = np.full(max_orig + 1, -1, dtype=int)
501                    for new_code, orig_code in enumerate(unique_vals):
502                        remap[orig_code] = new_code
503                    self._cat_remap_[name] = remap
504                    col = remap[col]
505
506                n_cats = int(col.max()) + 1
507                configs.append(
508                    FeatureConfig(
509                        name=name,
510                        feature_type="categorical",
511                        n_categories=n_cats,
512                    )
513                )
514            else:
515                mono: Literal["increasing", "decreasing", "none"] = "none"
516                if name in (self.monotone_increasing or []):
517                    mono = "increasing"
518                elif name in (self.monotone_decreasing or []):
519                    mono = "decreasing"
520                configs.append(
521                    FeatureConfig(
522                        name=name,
523                        feature_type="continuous",
524                        monotonicity=mono,
525                    )
526                )
527        return configs

sklearn-compatible Actuarial Neural Additive Model.

Parameters

feature_configs: List of FeatureConfig objects defining features. If None, all features are treated as continuous with no constraints. feature_names: Feature names (used when feature_configs is None). Must match columns of X passed to fit(). categorical_features: List of feature names that are categorical. Only used when feature_configs is None. monotone_increasing: Feature names to constrain as monotone increasing. monotone_decreasing: Feature names to constrain as monotone decreasing. link: Link function: 'log' (Poisson/Tweedie), 'identity' (Gaussian), 'logit' (binary). loss: Distributional loss: 'poisson', 'tweedie', 'gamma', 'mse'. tweedie_p: Tweedie power parameter (only for loss='tweedie'). interaction_pairs: List of (feature_i, feature_j) tuples for interaction subnetworks. hidden_sizes: Default hidden layer sizes for subnetworks. n_epochs: Maximum training epochs. batch_size: Mini-batch size. learning_rate: Adam learning rate. lambda_smooth: Smoothness regularisation weight. lambda_l2: L2 ridge weight. lambda_l1: L1 sparsity weight. patience: Early stopping patience (epochs). normalize: If True, standardise continuous features before training. The scaler is stored and applied automatically during predict(). verbose: Training verbosity. 0 = silent. device: 'cpu', 'cuda', or None (auto).

ANAM( feature_configs: Optional[List[FeatureConfig]] = None, feature_names: Optional[List[str]] = None, categorical_features: Optional[List[str]] = None, monotone_increasing: Optional[List[str]] = None, monotone_decreasing: Optional[List[str]] = None, link: Literal['log', 'identity', 'logit'] = 'log', loss: Literal['poisson', 'tweedie', 'gamma', 'mse'] = 'poisson', tweedie_p: float = 1.5, interaction_pairs: Optional[List[Tuple[str, str]]] = None, hidden_sizes: Optional[List[int]] = None, n_epochs: int = 100, batch_size: int = 512, learning_rate: float = 0.001, lambda_smooth: float = 0.0001, lambda_l2: float = 0.0001, lambda_l1: float = 0.0, patience: int = 15, normalize: bool = True, verbose: int = 0, device: Optional[str] = None)
 95    def __init__(
 96        self,
 97        feature_configs: Optional[List[FeatureConfig]] = None,
 98        feature_names: Optional[List[str]] = None,
 99        categorical_features: Optional[List[str]] = None,
100        monotone_increasing: Optional[List[str]] = None,
101        monotone_decreasing: Optional[List[str]] = None,
102        link: Literal["log", "identity", "logit"] = "log",
103        loss: Literal["poisson", "tweedie", "gamma", "mse"] = "poisson",
104        tweedie_p: float = 1.5,
105        interaction_pairs: Optional[List[Tuple[str, str]]] = None,
106        hidden_sizes: Optional[List[int]] = None,
107        n_epochs: int = 100,
108        batch_size: int = 512,
109        learning_rate: float = 1e-3,
110        lambda_smooth: float = 1e-4,
111        lambda_l2: float = 1e-4,
112        lambda_l1: float = 0.0,
113        patience: int = 15,
114        normalize: bool = True,
115        verbose: int = 0,
116        device: Optional[str] = None,
117    ) -> None:
118        self.feature_configs = feature_configs
119        self.feature_names = feature_names
120        # Store exactly as passed — no coercion to [] here.
121        # sklearn clone() checks that get_params() round-trips through __init__
122        # exactly. Coercing None->[] would break that invariant.
123        # Defaults are applied at point-of-use in fit() and _build_feature_configs().
124        self.categorical_features = categorical_features
125        self.monotone_increasing = monotone_increasing
126        self.monotone_decreasing = monotone_decreasing
127        self.link = link
128        self.loss = loss
129        self.tweedie_p = tweedie_p
130        self.interaction_pairs = interaction_pairs
131        self.hidden_sizes = hidden_sizes
132        self.n_epochs = n_epochs
133        self.batch_size = batch_size
134        self.learning_rate = learning_rate
135        self.lambda_smooth = lambda_smooth
136        self.lambda_l2 = lambda_l2
137        self.lambda_l1 = lambda_l1
138        self.patience = patience
139        self.normalize = normalize
140        self.verbose = verbose
141        self.device = device
142
143        # Set after fit()
144        self.model_: Optional[ANAMModel] = None
145        self.scaler_: Optional[StandardScaler] = None
146        self.history_: Optional[TrainingHistory] = None
147        self.feature_names_in_: Optional[List[str]] = None
148        self._continuous_col_indices: List[int] = []
149        # P0-1 fix: track both the cached shapes and the n_points used to
150        # build them so we can detect stale cache on different n_points.
151        self._shapes_cache: Optional[Dict[str, ShapeFunction]] = None
152        self._shapes_cache_n_points: Optional[int] = None
153        self._X_train_scaled: Optional[np.ndarray] = None
154        # P0-2 fix: per-feature category remapping tables set during fit().
155        # Maps feature name -> array where remap[original_code] = new_code.
156        self._cat_remap_: Dict[str, np.ndarray] = {}
feature_configs
feature_names
categorical_features
monotone_increasing
monotone_decreasing
loss
tweedie_p
interaction_pairs
hidden_sizes
n_epochs
batch_size
learning_rate
lambda_smooth
lambda_l2
lambda_l1
patience
normalize
verbose
device
model_: Optional[ANAMModel]
scaler_: Optional[StandardScaler]
history_: Optional[TrainingHistory]
feature_names_in_: Optional[List[str]]
def fit( self, X: Union[numpy.ndarray, polars.dataframe.frame.DataFrame], y: Union[numpy.ndarray, polars.series.series.Series], sample_weight: Union[numpy.ndarray, polars.series.series.Series, NoneType] = None) -> ANAM:
158    def fit(
159        self,
160        X: Union[np.ndarray, pl.DataFrame],
161        y: Union[np.ndarray, pl.Series],
162        sample_weight: Optional[Union[np.ndarray, pl.Series]] = None,
163    ) -> "ANAM":
164        """Fit the ANAM model.
165
166        Parameters
167        ----------
168        X:
169            Feature matrix, shape (n, p). Polars DataFrame accepted.
170        y:
171            Target vector (claim counts, rates, or severities).
172        sample_weight:
173            Exposure weights. For frequency models this is policy duration
174            (e.g. years on risk). If None, uniform exposure assumed.
175
176        Returns
177        -------
178        self
179        """
180        X_arr, y_arr, w_arr, names = self._validate_input(X, y, sample_weight)
181        self.feature_names_in_ = names
182
183        # Build feature configs if not provided. Also populates self._cat_remap_
184        # for any categorical columns that are not zero-indexed consecutive ints.
185        feat_configs = self.feature_configs or self._build_feature_configs(names, X_arr)
186
187        # P0-2 fix: apply category remapping to training data so that the
188        # trainer and scaler see zero-indexed codes, matching the embedding
189        # table sizes that _build_feature_configs computed.
190        if self._cat_remap_:
191            X_arr = self._apply_cat_remap(X_arr)
192
193        # Identify continuous column indices for normalisation
194        self._continuous_col_indices = [
195            i for i, cfg in enumerate(feat_configs)
196            if cfg.feature_type == "continuous"
197        ]
198
199        # Normalise continuous features
200        if self.normalize and self._continuous_col_indices:
201            self.scaler_ = StandardScaler()
202            cont_cols = np.array(self._continuous_col_indices)
203            X_arr_norm = X_arr.copy()
204            X_arr_norm[:, cont_cols] = self.scaler_.fit(
205                X_arr[:, cont_cols], feature_names=[names[i] for i in cont_cols]
206            ).transform(X_arr[:, cont_cols])
207        else:
208            X_arr_norm = X_arr
209
210        self._X_train_scaled = X_arr_norm
211
212        # Build interaction configs
213        interaction_configs: List[InteractionConfig] = []
214        for fi, fj in (self.interaction_pairs or []):
215            interaction_configs.append(
216                InteractionConfig(feature_i=fi, feature_j=fj)
217            )
218
219        # Build model
220        self.model_ = ANAMModel(
221            feature_configs=feat_configs,
222            link=self.link,
223            interaction_configs=interaction_configs,
224            hidden_sizes=self.hidden_sizes or [64, 32],
225        )
226
227        # Build trainer
228        train_cfg = TrainingConfig(
229            loss=self.loss,
230            tweedie_p=self.tweedie_p,
231            n_epochs=self.n_epochs,
232            batch_size=self.batch_size,
233            learning_rate=self.learning_rate,
234            lambda_smooth=self.lambda_smooth,
235            lambda_l2=self.lambda_l2,
236            lambda_l1=self.lambda_l1,
237            patience=self.patience,
238            verbose=self.verbose,
239            device=self.device,
240        )
241        trainer = ANAMTrainer(self.model_, train_cfg)
242        self.history_ = trainer.fit(X_arr_norm, y_arr, exposure=w_arr)
243        self.model_ = trainer.model  # may have moved to device
244
245        return self

Fit the ANAM model.

Parameters

X: Feature matrix, shape (n, p). Polars DataFrame accepted. y: Target vector (claim counts, rates, or severities). sample_weight: Exposure weights. For frequency models this is policy duration (e.g. years on risk). If None, uniform exposure assumed.

Returns

self

def predict( self, X: Union[numpy.ndarray, polars.dataframe.frame.DataFrame], exposure: Union[numpy.ndarray, polars.series.series.Series, NoneType] = None) -> numpy.ndarray:
247    def predict(
248        self,
249        X: Union[np.ndarray, pl.DataFrame],
250        exposure: Optional[Union[np.ndarray, pl.Series]] = None,
251    ) -> np.ndarray:
252        """Predict expected mean (mu) for each observation.
253
254        Parameters
255        ----------
256        X:
257            Feature matrix.
258        exposure:
259            Exposure for each row. If None, exposure=1 assumed (predicts
260            per-policy-year rate for log-link models).
261
262        Returns
263        -------
264        np.ndarray
265            Predicted means, shape (n,).
266        """
267        self._check_fitted()
268
269        X_arr = self._to_array(X)
270        X_arr = self._apply_cat_remap(X_arr)
271        X_arr = self._apply_scaling(X_arr)
272
273        if exposure is not None:
274            if isinstance(exposure, pl.Series):
275                exposure = exposure.to_numpy()
276            exp_arr = np.asarray(exposure, dtype=np.float32)
277            log_exp = torch.tensor(np.log(exp_arr.clip(min=1e-8)), dtype=torch.float32)
278        else:
279            log_exp = None
280
281        X_t = torch.tensor(X_arr, dtype=torch.float32)
282
283        self.model_.eval()
284        with torch.no_grad():
285            mu = self.model_(X_t, log_exposure=log_exp)
286
287        return mu.cpu().numpy()

Predict expected mean (mu) for each observation.

Parameters

X: Feature matrix. exposure: Exposure for each row. If None, exposure=1 assumed (predicts per-policy-year rate for log-link models).

Returns

np.ndarray Predicted means, shape (n,).

def score( self, X: Union[numpy.ndarray, polars.dataframe.frame.DataFrame], y: Union[numpy.ndarray, polars.series.series.Series], sample_weight: Union[numpy.ndarray, polars.series.series.Series, NoneType] = None) -> float:
289    def score(
290        self,
291        X: Union[np.ndarray, pl.DataFrame],
292        y: Union[np.ndarray, pl.Series],
293        sample_weight: Optional[Union[np.ndarray, pl.Series]] = None,
294    ) -> float:
295        """Return negative mean deviance (higher = better).
296
297        Follows sklearn convention where score() returns a value where
298        higher is better. We negate deviance so that model.score() can
299        be maximised by hyperparameter search.
300        """
301        self._check_fitted()
302
303        y_arr = np.asarray(y if not isinstance(y, pl.Series) else y.to_numpy(), dtype=np.float32)
304        y_pred = self.predict(X, exposure=sample_weight)
305
306        w = None
307        if sample_weight is not None:
308            w = np.asarray(
309                sample_weight if not isinstance(sample_weight, pl.Series)
310                else sample_weight.to_numpy(),
311                dtype=np.float32,
312            )
313
314        dev = compute_deviance_stat(
315            y_arr, y_pred, exposure=w,
316            loss=self.loss, tweedie_p=self.tweedie_p
317        )
318        return -dev

Return negative mean deviance (higher = better).

Follows sklearn convention where score() returns a value where higher is better. We negate deviance so that model.score() can be maximised by hyperparameter search.

def shape_functions( self, n_points: int = 200, category_labels: Optional[Dict[str, Dict[int, str]]] = None) -> Dict[str, ShapeFunction]:
320    def shape_functions(
321        self,
322        n_points: int = 200,
323        category_labels: Optional[Dict[str, Dict[int, str]]] = None,
324    ) -> Dict[str, ShapeFunction]:
325        """Extract and cache shape functions for all features.
326
327        Returns a dict mapping feature name to ShapeFunction. Shape
328        functions are evaluated over the observed training data range.
329
330        The cache is keyed on n_points. Calling with a different n_points
331        than the cached value discards the stale cache and recomputes.
332        """
333        self._check_fitted()
334        assert self._X_train_scaled is not None
335
336        # P0-1 fix: invalidate cache when n_points changes.
337        if self._shapes_cache is None or self._shapes_cache_n_points != n_points:
338            self._shapes_cache = extract_shape_functions(
339                self.model_,
340                self._X_train_scaled,
341                n_points=n_points,
342                category_labels=category_labels,
343            )
344            self._shapes_cache_n_points = n_points
345        return self._shapes_cache

Extract and cache shape functions for all features.

Returns a dict mapping feature name to ShapeFunction. Shape functions are evaluated over the observed training data range.

The cache is keyed on n_points. Calling with a different n_points than the cached value discards the stale cache and recomputes.

def feature_importance(self) -> polars.dataframe.frame.DataFrame:
347    def feature_importance(self) -> pl.DataFrame:
348        """Feature importance as subnetwork weight norms.
349
350        Returns a Polars DataFrame with columns [feature, importance],
351        sorted by importance descending.
352        """
353        self._check_fitted()
354        imp = self.model_.feature_importance()
355        return pl.DataFrame(
356            {"feature": list(imp.keys()), "importance": list(imp.values())}
357        ).sort("importance", descending=True)

Feature importance as subnetwork weight norms.

Returns a Polars DataFrame with columns [feature, importance], sorted by importance descending.

def get_params(self, deep: bool = True) -> Dict[str, Any]:
359    def get_params(self, deep: bool = True) -> Dict[str, Any]:
360        """sklearn get_params for grid search compatibility.
361
362        Returns ALL constructor parameters — including structural ones
363        (feature_configs, interaction_pairs, etc.) — so that
364        sklearn.base.clone() can reconstruct a functionally equivalent
365        model.
366        """
367        # P0-4 fix: include all constructor parameters, not just numerical ones.
368        return {
369            "feature_configs": self.feature_configs,
370            "feature_names": self.feature_names,
371            "categorical_features": self.categorical_features,
372            "monotone_increasing": self.monotone_increasing,
373            "monotone_decreasing": self.monotone_decreasing,
374            "link": self.link,
375            "loss": self.loss,
376            "tweedie_p": self.tweedie_p,
377            "interaction_pairs": self.interaction_pairs,
378            "hidden_sizes": self.hidden_sizes,
379            "n_epochs": self.n_epochs,
380            "batch_size": self.batch_size,
381            "learning_rate": self.learning_rate,
382            "lambda_smooth": self.lambda_smooth,
383            "lambda_l2": self.lambda_l2,
384            "lambda_l1": self.lambda_l1,
385            "patience": self.patience,
386            "normalize": self.normalize,
387            "verbose": self.verbose,
388            "device": self.device,
389        }

sklearn get_params for grid search compatibility.

Returns ALL constructor parameters — including structural ones (feature_configs, interaction_pairs, etc.) — so that sklearn.base.clone() can reconstruct a functionally equivalent model.

def set_params(self, **params: Any) -> ANAM:
391    def set_params(self, **params: Any) -> "ANAM":
392        """sklearn set_params for grid search compatibility."""
393        for key, val in params.items():
394            setattr(self, key, val)
395        return self

sklearn set_params for grid search compatibility.

class ANAMModel(torch.nn.modules.module.Module):
 85class ANAMModel(nn.Module):
 86    """Actuarial Neural Additive Model.
 87
 88    Orchestrates one subnetwork per feature plus optional pairwise
 89    interaction networks. The output is:
 90
 91        eta = bias + sum_i f_i(x_i) + sum_{(i,j)} g_{ij}(x_i, x_j)
 92        mu  = link_inverse(eta + log_offset)
 93
 94    Parameters
 95    ----------
 96    feature_configs:
 97        Ordered list of FeatureConfig objects. The column order in X arrays
 98        passed to forward() must match this list.
 99    link:
100        Link function. 'log' for Poisson/Tweedie/Gamma, 'identity' for
101        Gaussian, 'logit' for binary.
102    interaction_configs:
103        Optional list of InteractionConfig for pairwise terms.
104    hidden_sizes:
105        Default hidden sizes for all subnetworks (overridden per-feature
106        by FeatureConfig.hidden_sizes).
107    dropout:
108        Dropout rate applied within subnetworks.
109    """
110
111    def __init__(
112        self,
113        feature_configs: List[FeatureConfig],
114        link: LinkFunction = "log",
115        interaction_configs: Optional[List[InteractionConfig]] = None,
116        hidden_sizes: Optional[List[int]] = None,
117        dropout: float = 0.0,
118    ) -> None:
119        super().__init__()
120
121        self.feature_configs = feature_configs
122        self.link = link
123        self.interaction_configs = interaction_configs or []
124
125        # Map feature names to column indices
126        self.feature_name_to_idx: Dict[str, int] = {
127            cfg.name: i for i, cfg in enumerate(feature_configs)
128        }
129
130        default_hidden = hidden_sizes or [64, 32]
131
132        # Build per-feature subnetworks
133        feature_nets: Dict[str, nn.Module] = {}
134        for cfg in feature_configs:
135            net_hidden = cfg.hidden_sizes or default_hidden
136            if cfg.feature_type == "continuous":
137                feature_nets[cfg.name] = FeatureNetwork(
138                    hidden_sizes=net_hidden,
139                    monotonicity=cfg.monotonicity,
140                    dropout=dropout,
141                )
142            elif cfg.feature_type == "categorical":
143                if cfg.n_categories is None:
144                    raise ValueError(
145                        f"Feature '{cfg.name}' is categorical but n_categories is not set."
146                    )
147                feature_nets[cfg.name] = CategoricalFeatureNetwork(
148                    n_categories=cfg.n_categories,
149                    embedding_dim=cfg.embedding_dim,
150                    hidden_sizes=net_hidden,
151                    dropout=dropout,
152                )
153            else:
154                raise ValueError(f"Unknown feature_type: {cfg.feature_type!r}")
155
156        self.feature_nets = nn.ModuleDict(feature_nets)
157
158        # Build interaction subnetworks
159        interaction_nets: Dict[str, nn.Module] = {}
160        for icfg in self.interaction_configs:
161            key = f"{icfg.feature_i}_x_{icfg.feature_j}"
162            i_idx = self.feature_name_to_idx[icfg.feature_i]
163            j_idx = self.feature_name_to_idx[icfg.feature_j]
164            i_hidden = icfg.hidden_sizes or [32, 16]
165            interaction_nets[key] = InteractionNetwork(
166                feature_indices=(i_idx, j_idx),
167                hidden_sizes=i_hidden,
168                dropout=dropout,
169            )
170
171        self.interaction_nets = nn.ModuleDict(interaction_nets)
172
173        # Scalar bias (learnable)
174        self.bias = nn.Parameter(torch.zeros(1))
175
176    def forward(
177        self,
178        X: torch.Tensor,
179        log_exposure: Optional[torch.Tensor] = None,
180    ) -> torch.Tensor:
181        """Forward pass through all subnetworks.
182
183        Parameters
184        ----------
185        X:
186            Feature matrix, shape (batch, n_features). Continuous features
187            should be pre-normalised. Categorical features should be integer
188            indices (will be cast to long internally).
189        log_exposure:
190            Log of exposure (e.g. log policy duration in years), shape
191            (batch,). Added as offset to the linear predictor before the
192            link function. If None, no offset applied (equivalent to
193            exposure=1 for all observations).
194
195        Returns
196        -------
197        torch.Tensor
198            Predicted means mu, shape (batch,).
199        """
200        batch_size = X.shape[0]
201
202        # Accumulate linear predictor starting from bias
203        eta = self.bias.expand(batch_size)
204
205        # Feature subnetwork contributions
206        for i, cfg in enumerate(self.feature_configs):
207            x_i = X[:, i]
208            net = self.feature_nets[cfg.name]
209
210            if cfg.feature_type == "categorical":
211                contrib = net(x_i.long()).squeeze(-1)
212            else:
213                contrib = net(x_i.float()).squeeze(-1)
214
215            eta = eta + contrib
216
217        # Interaction subnetwork contributions
218        for icfg in self.interaction_configs:
219            key = f"{icfg.feature_i}_x_{icfg.feature_j}"
220            i_idx = self.feature_name_to_idx[icfg.feature_i]
221            j_idx = self.feature_name_to_idx[icfg.feature_j]
222
223            x_i = X[:, i_idx].float()
224            x_j = X[:, j_idx].float()
225
226            contrib = self.interaction_nets[key](x_i, x_j).squeeze(-1)
227            eta = eta + contrib
228
229        # Exposure offset
230        if log_exposure is not None:
231            eta = eta + log_exposure
232
233        # Apply link inverse
234        return self._link_inverse(eta)
235
236    def _link_inverse(self, eta: torch.Tensor) -> torch.Tensor:
237        """Apply inverse link function to linear predictor."""
238        if self.link == "log":
239            return torch.exp(eta)
240        elif self.link == "identity":
241            return eta
242        elif self.link == "logit":
243            return torch.sigmoid(eta)
244        else:
245            raise ValueError(f"Unknown link function: {self.link!r}")
246
247    def linear_predictor(
248        self,
249        X: torch.Tensor,
250        log_exposure: Optional[torch.Tensor] = None,
251    ) -> torch.Tensor:
252        """Return eta (linear predictor) without applying link inverse.
253
254        Useful for inspecting additive contributions before exponentiation.
255        """
256        batch_size = X.shape[0]
257        eta = self.bias.expand(batch_size)
258
259        for i, cfg in enumerate(self.feature_configs):
260            x_i = X[:, i]
261            net = self.feature_nets[cfg.name]
262            if cfg.feature_type == "categorical":
263                contrib = net(x_i.long()).squeeze(-1)
264            else:
265                contrib = net(x_i.float()).squeeze(-1)
266            eta = eta + contrib
267
268        for icfg in self.interaction_configs:
269            key = f"{icfg.feature_i}_x_{icfg.feature_j}"
270            i_idx = self.feature_name_to_idx[icfg.feature_i]
271            j_idx = self.feature_name_to_idx[icfg.feature_j]
272            contrib = self.interaction_nets[key](
273                X[:, i_idx].float(), X[:, j_idx].float()
274            ).squeeze(-1)
275            eta = eta + contrib
276
277        if log_exposure is not None:
278            eta = eta + log_exposure
279
280        return eta
281
282    def feature_contribution(
283        self, X: torch.Tensor, feature_name: str
284    ) -> torch.Tensor:
285        """Return the contribution of a single feature for each observation.
286
287        Useful for explaining individual predictions: the contribution from
288        feature i is exactly f_i(x_i), the output of that subnetwork.
289
290        Parameters
291        ----------
292        X:
293            Feature matrix (batch, n_features).
294        feature_name:
295            Name of the feature to inspect.
296
297        Returns
298        -------
299        torch.Tensor
300            Shape (batch,). Values of f_i(x_i) for each row.
301        """
302        idx = self.feature_name_to_idx[feature_name]
303        cfg = self.feature_configs[idx]
304        net = self.feature_nets[feature_name]
305        x_i = X[:, idx]
306
307        if cfg.feature_type == "categorical":
308            return net(x_i.long()).squeeze(-1)
309        else:
310            return net(x_i.float()).squeeze(-1)
311
312    def project_monotone_weights(self) -> None:
313        """Enforce monotonicity constraints on all relevant subnetworks.
314
315        Call this after optimizer.step() in the training loop.
316        Does nothing for non-monotone and categorical features.
317        """
318        for cfg in self.feature_configs:
319            if cfg.feature_type == "continuous" and cfg.monotonicity != "none":
320                net = self.feature_nets[cfg.name]
321                assert isinstance(net, FeatureNetwork)
322                net.project_weights()
323
324    def feature_importance(self) -> Dict[str, float]:
325        """Estimate feature importance as the L2 norm of output layer weights.
326
327        Larger norm = larger potential contribution from that feature. This
328        is a quick heuristic for feature selection — not a replacement for
329        proper permutation importance or SHAP values on the additive model.
330
331        Returns
332        -------
333        Dict[str, float]
334            Feature name -> importance score, sorted descending.
335        """
336        importances: Dict[str, float] = {}
337
338        for name, net in self.feature_nets.items():
339            total_norm = 0.0
340            for param in net.parameters():
341                total_norm += param.data.norm(2).item() ** 2
342            importances[name] = float(total_norm ** 0.5)
343
344        return dict(sorted(importances.items(), key=lambda x: x[1], reverse=True))
345
346    @property
347    def n_features(self) -> int:
348        """Number of features this model was built for."""
349        return len(self.feature_configs)
350
351    @property
352    def feature_names(self) -> List[str]:
353        """Ordered list of feature names."""
354        return [cfg.name for cfg in self.feature_configs]

Actuarial Neural Additive Model.

Orchestrates one subnetwork per feature plus optional pairwise interaction networks. The output is:

eta = bias + sum_i f_i(x_i) + sum_{(i,j)} g_{ij}(x_i, x_j)
mu  = link_inverse(eta + log_offset)

Parameters

feature_configs: Ordered list of FeatureConfig objects. The column order in X arrays passed to forward() must match this list. link: Link function. 'log' for Poisson/Tweedie/Gamma, 'identity' for Gaussian, 'logit' for binary. interaction_configs: Optional list of InteractionConfig for pairwise terms. hidden_sizes: Default hidden sizes for all subnetworks (overridden per-feature by FeatureConfig.hidden_sizes). dropout: Dropout rate applied within subnetworks.

ANAMModel( feature_configs: List[FeatureConfig], link: Literal['log', 'identity', 'logit'] = 'log', interaction_configs: Optional[List[InteractionConfig]] = None, hidden_sizes: Optional[List[int]] = None, dropout: float = 0.0)
111    def __init__(
112        self,
113        feature_configs: List[FeatureConfig],
114        link: LinkFunction = "log",
115        interaction_configs: Optional[List[InteractionConfig]] = None,
116        hidden_sizes: Optional[List[int]] = None,
117        dropout: float = 0.0,
118    ) -> None:
119        super().__init__()
120
121        self.feature_configs = feature_configs
122        self.link = link
123        self.interaction_configs = interaction_configs or []
124
125        # Map feature names to column indices
126        self.feature_name_to_idx: Dict[str, int] = {
127            cfg.name: i for i, cfg in enumerate(feature_configs)
128        }
129
130        default_hidden = hidden_sizes or [64, 32]
131
132        # Build per-feature subnetworks
133        feature_nets: Dict[str, nn.Module] = {}
134        for cfg in feature_configs:
135            net_hidden = cfg.hidden_sizes or default_hidden
136            if cfg.feature_type == "continuous":
137                feature_nets[cfg.name] = FeatureNetwork(
138                    hidden_sizes=net_hidden,
139                    monotonicity=cfg.monotonicity,
140                    dropout=dropout,
141                )
142            elif cfg.feature_type == "categorical":
143                if cfg.n_categories is None:
144                    raise ValueError(
145                        f"Feature '{cfg.name}' is categorical but n_categories is not set."
146                    )
147                feature_nets[cfg.name] = CategoricalFeatureNetwork(
148                    n_categories=cfg.n_categories,
149                    embedding_dim=cfg.embedding_dim,
150                    hidden_sizes=net_hidden,
151                    dropout=dropout,
152                )
153            else:
154                raise ValueError(f"Unknown feature_type: {cfg.feature_type!r}")
155
156        self.feature_nets = nn.ModuleDict(feature_nets)
157
158        # Build interaction subnetworks
159        interaction_nets: Dict[str, nn.Module] = {}
160        for icfg in self.interaction_configs:
161            key = f"{icfg.feature_i}_x_{icfg.feature_j}"
162            i_idx = self.feature_name_to_idx[icfg.feature_i]
163            j_idx = self.feature_name_to_idx[icfg.feature_j]
164            i_hidden = icfg.hidden_sizes or [32, 16]
165            interaction_nets[key] = InteractionNetwork(
166                feature_indices=(i_idx, j_idx),
167                hidden_sizes=i_hidden,
168                dropout=dropout,
169            )
170
171        self.interaction_nets = nn.ModuleDict(interaction_nets)
172
173        # Scalar bias (learnable)
174        self.bias = nn.Parameter(torch.zeros(1))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

feature_configs
interaction_configs
feature_name_to_idx: Dict[str, int]
feature_nets
interaction_nets
bias
def forward( self, X: torch.Tensor, log_exposure: Optional[torch.Tensor] = None) -> torch.Tensor:
176    def forward(
177        self,
178        X: torch.Tensor,
179        log_exposure: Optional[torch.Tensor] = None,
180    ) -> torch.Tensor:
181        """Forward pass through all subnetworks.
182
183        Parameters
184        ----------
185        X:
186            Feature matrix, shape (batch, n_features). Continuous features
187            should be pre-normalised. Categorical features should be integer
188            indices (will be cast to long internally).
189        log_exposure:
190            Log of exposure (e.g. log policy duration in years), shape
191            (batch,). Added as offset to the linear predictor before the
192            link function. If None, no offset applied (equivalent to
193            exposure=1 for all observations).
194
195        Returns
196        -------
197        torch.Tensor
198            Predicted means mu, shape (batch,).
199        """
200        batch_size = X.shape[0]
201
202        # Accumulate linear predictor starting from bias
203        eta = self.bias.expand(batch_size)
204
205        # Feature subnetwork contributions
206        for i, cfg in enumerate(self.feature_configs):
207            x_i = X[:, i]
208            net = self.feature_nets[cfg.name]
209
210            if cfg.feature_type == "categorical":
211                contrib = net(x_i.long()).squeeze(-1)
212            else:
213                contrib = net(x_i.float()).squeeze(-1)
214
215            eta = eta + contrib
216
217        # Interaction subnetwork contributions
218        for icfg in self.interaction_configs:
219            key = f"{icfg.feature_i}_x_{icfg.feature_j}"
220            i_idx = self.feature_name_to_idx[icfg.feature_i]
221            j_idx = self.feature_name_to_idx[icfg.feature_j]
222
223            x_i = X[:, i_idx].float()
224            x_j = X[:, j_idx].float()
225
226            contrib = self.interaction_nets[key](x_i, x_j).squeeze(-1)
227            eta = eta + contrib
228
229        # Exposure offset
230        if log_exposure is not None:
231            eta = eta + log_exposure
232
233        # Apply link inverse
234        return self._link_inverse(eta)

Forward pass through all subnetworks.

Parameters

X: Feature matrix, shape (batch, n_features). Continuous features should be pre-normalised. Categorical features should be integer indices (will be cast to long internally). log_exposure: Log of exposure (e.g. log policy duration in years), shape (batch,). Added as offset to the linear predictor before the link function. If None, no offset applied (equivalent to exposure=1 for all observations).

Returns

torch.Tensor Predicted means mu, shape (batch,).

def linear_predictor( self, X: torch.Tensor, log_exposure: Optional[torch.Tensor] = None) -> torch.Tensor:
247    def linear_predictor(
248        self,
249        X: torch.Tensor,
250        log_exposure: Optional[torch.Tensor] = None,
251    ) -> torch.Tensor:
252        """Return eta (linear predictor) without applying link inverse.
253
254        Useful for inspecting additive contributions before exponentiation.
255        """
256        batch_size = X.shape[0]
257        eta = self.bias.expand(batch_size)
258
259        for i, cfg in enumerate(self.feature_configs):
260            x_i = X[:, i]
261            net = self.feature_nets[cfg.name]
262            if cfg.feature_type == "categorical":
263                contrib = net(x_i.long()).squeeze(-1)
264            else:
265                contrib = net(x_i.float()).squeeze(-1)
266            eta = eta + contrib
267
268        for icfg in self.interaction_configs:
269            key = f"{icfg.feature_i}_x_{icfg.feature_j}"
270            i_idx = self.feature_name_to_idx[icfg.feature_i]
271            j_idx = self.feature_name_to_idx[icfg.feature_j]
272            contrib = self.interaction_nets[key](
273                X[:, i_idx].float(), X[:, j_idx].float()
274            ).squeeze(-1)
275            eta = eta + contrib
276
277        if log_exposure is not None:
278            eta = eta + log_exposure
279
280        return eta

Return eta (linear predictor) without applying link inverse.

Useful for inspecting additive contributions before exponentiation.

def feature_contribution(self, X: torch.Tensor, feature_name: str) -> torch.Tensor:
282    def feature_contribution(
283        self, X: torch.Tensor, feature_name: str
284    ) -> torch.Tensor:
285        """Return the contribution of a single feature for each observation.
286
287        Useful for explaining individual predictions: the contribution from
288        feature i is exactly f_i(x_i), the output of that subnetwork.
289
290        Parameters
291        ----------
292        X:
293            Feature matrix (batch, n_features).
294        feature_name:
295            Name of the feature to inspect.
296
297        Returns
298        -------
299        torch.Tensor
300            Shape (batch,). Values of f_i(x_i) for each row.
301        """
302        idx = self.feature_name_to_idx[feature_name]
303        cfg = self.feature_configs[idx]
304        net = self.feature_nets[feature_name]
305        x_i = X[:, idx]
306
307        if cfg.feature_type == "categorical":
308            return net(x_i.long()).squeeze(-1)
309        else:
310            return net(x_i.float()).squeeze(-1)

Return the contribution of a single feature for each observation.

Useful for explaining individual predictions: the contribution from feature i is exactly f_i(x_i), the output of that subnetwork.

Parameters

X: Feature matrix (batch, n_features). feature_name: Name of the feature to inspect.

Returns

torch.Tensor Shape (batch,). Values of f_i(x_i) for each row.

def project_monotone_weights(self) -> None:
312    def project_monotone_weights(self) -> None:
313        """Enforce monotonicity constraints on all relevant subnetworks.
314
315        Call this after optimizer.step() in the training loop.
316        Does nothing for non-monotone and categorical features.
317        """
318        for cfg in self.feature_configs:
319            if cfg.feature_type == "continuous" and cfg.monotonicity != "none":
320                net = self.feature_nets[cfg.name]
321                assert isinstance(net, FeatureNetwork)
322                net.project_weights()

Enforce monotonicity constraints on all relevant subnetworks.

Call this after optimizer.step() in the training loop. Does nothing for non-monotone and categorical features.

def feature_importance(self) -> Dict[str, float]:
324    def feature_importance(self) -> Dict[str, float]:
325        """Estimate feature importance as the L2 norm of output layer weights.
326
327        Larger norm = larger potential contribution from that feature. This
328        is a quick heuristic for feature selection — not a replacement for
329        proper permutation importance or SHAP values on the additive model.
330
331        Returns
332        -------
333        Dict[str, float]
334            Feature name -> importance score, sorted descending.
335        """
336        importances: Dict[str, float] = {}
337
338        for name, net in self.feature_nets.items():
339            total_norm = 0.0
340            for param in net.parameters():
341                total_norm += param.data.norm(2).item() ** 2
342            importances[name] = float(total_norm ** 0.5)
343
344        return dict(sorted(importances.items(), key=lambda x: x[1], reverse=True))

Estimate feature importance as the L2 norm of output layer weights.

Larger norm = larger potential contribution from that feature. This is a quick heuristic for feature selection — not a replacement for proper permutation importance or SHAP values on the additive model.

Returns

Dict[str, float] Feature name -> importance score, sorted descending.

n_features: int
346    @property
347    def n_features(self) -> int:
348        """Number of features this model was built for."""
349        return len(self.feature_configs)

Number of features this model was built for.

feature_names: List[str]
351    @property
352    def feature_names(self) -> List[str]:
353        """Ordered list of feature names."""
354        return [cfg.name for cfg in self.feature_configs]

Ordered list of feature names.

@dataclass
class FeatureConfig:
40@dataclass
41class FeatureConfig:
42    """Configuration for a single feature.
43
44    Parameters
45    ----------
46    name:
47        Feature name (used in output tables and plots).
48    feature_type:
49        'continuous' or 'categorical'.
50    monotonicity:
51        Monotonicity constraint. Only applies to continuous features.
52    n_categories:
53        Number of categories (required if feature_type='categorical').
54    embedding_dim:
55        Embedding dimension for categorical features.
56    hidden_sizes:
57        Hidden layer widths for this feature's subnetwork.
58    """
59
60    name: str
61    feature_type: Literal["continuous", "categorical"] = "continuous"
62    monotonicity: MonotonicityDirection = "none"
63    n_categories: Optional[int] = None
64    embedding_dim: int = 4
65    hidden_sizes: Optional[List[int]] = None

Configuration for a single feature.

Parameters

name: Feature name (used in output tables and plots). feature_type: 'continuous' or 'categorical'. monotonicity: Monotonicity constraint. Only applies to continuous features. n_categories: Number of categories (required if feature_type='categorical'). embedding_dim: Embedding dimension for categorical features. hidden_sizes: Hidden layer widths for this feature's subnetwork.

FeatureConfig( name: str, feature_type: Literal['continuous', 'categorical'] = 'continuous', monotonicity: Literal['increasing', 'decreasing', 'none'] = 'none', n_categories: Optional[int] = None, embedding_dim: int = 4, hidden_sizes: Optional[List[int]] = None)
name: str
feature_type: Literal['continuous', 'categorical'] = 'continuous'
monotonicity: Literal['increasing', 'decreasing', 'none'] = 'none'
n_categories: Optional[int] = None
embedding_dim: int = 4
hidden_sizes: Optional[List[int]] = None
@dataclass
class InteractionConfig:
68@dataclass
69class InteractionConfig:
70    """Configuration for a pairwise interaction subnetwork.
71
72    Parameters
73    ----------
74    feature_i, feature_j:
75        Names of the two interacting features. Must be in the feature list.
76    hidden_sizes:
77        Hidden layer widths for the interaction subnetwork.
78    """
79
80    feature_i: str
81    feature_j: str
82    hidden_sizes: Optional[List[int]] = None

Configuration for a pairwise interaction subnetwork.

Parameters

feature_i, feature_j: Names of the two interacting features. Must be in the feature list. hidden_sizes: Hidden layer widths for the interaction subnetwork.

InteractionConfig( feature_i: str, feature_j: str, hidden_sizes: Optional[List[int]] = None)
feature_i: str
feature_j: str
hidden_sizes: Optional[List[int]] = None
class FeatureNetwork(torch.nn.modules.module.Module):
 66class FeatureNetwork(nn.Module):
 67    """Single-feature MLP subnetwork.
 68
 69    Takes a scalar input (one feature value) and outputs a scalar
 70    contribution f_i(x_i). The contribution is offset by its mean over a
 71    representative sample so that the bias term in the full model captures
 72    the population mean prediction.
 73
 74    Parameters
 75    ----------
 76    hidden_sizes:
 77        Width of each hidden layer. E.g. [64, 32] gives two hidden layers.
 78    activation:
 79        'relu' (default) or 'exu'. ReLU is more stable; ExU more expressive.
 80        When 'exu' is specified, ExUActivation replaces the nn.Linear +
 81        nn.ReLU pair for each hidden layer (ExU fuses the linear transform
 82        into the activation).
 83    monotonicity:
 84        'increasing', 'decreasing', or 'none'. Enforced by projecting
 85        weight matrices onto the non-negative (or non-positive) orthant
 86        after each gradient step via project_weights().
 87    dropout:
 88        Dropout rate applied between hidden layers. 0.0 disables dropout.
 89    """
 90
 91    def __init__(
 92        self,
 93        hidden_sizes: list[int] | None = None,
 94        activation: Literal["relu", "exu"] = "relu",
 95        monotonicity: MonotonicityDirection = "none",
 96        dropout: float = 0.0,
 97    ) -> None:
 98        super().__init__()
 99
100        if hidden_sizes is None:
101            hidden_sizes = [64, 32]
102
103        self.hidden_sizes = hidden_sizes
104        self.activation_name = activation
105        self.monotonicity = monotonicity
106        self.dropout_rate = dropout
107
108        layers: list[nn.Module] = []
109        in_dim = 1
110
111        for i, out_dim in enumerate(hidden_sizes):
112            if activation == "relu":
113                # Standard linear + ReLU pair.
114                linear = nn.Linear(in_dim, out_dim)
115                # Weight initialisation: smaller init stabilises training with
116                # many subnetworks summing together.
117                nn.init.xavier_uniform_(linear.weight, gain=0.5)
118                nn.init.zeros_(linear.bias)
119                layers.append(linear)
120                layers.append(nn.ReLU())
121            elif activation == "exu":
122                # P0-3 fix: wire up ExUActivation properly.
123                # ExU fuses the linear transform into the activation — no
124                # preceding nn.Linear. The ExUActivation module takes
125                # (in_features, out_features) and outputs (batch, out_features).
126                layers.append(ExUActivation(in_features=in_dim, out_features=out_dim))
127            else:
128                raise ValueError(f"Unknown activation: {activation!r}")
129
130            if dropout > 0.0 and i < len(hidden_sizes) - 1:
131                layers.append(nn.Dropout(p=dropout))
132
133            in_dim = out_dim
134
135        # Output layer: single scalar
136        output_layer = nn.Linear(in_dim, 1)
137        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
138        nn.init.zeros_(output_layer.bias)
139        layers.append(output_layer)
140
141        self.network = nn.Sequential(*layers)
142
143    def forward(self, x: torch.Tensor) -> torch.Tensor:
144        """Forward pass.
145
146        Parameters
147        ----------
148        x:
149            Shape (batch,) or (batch, 1). Single feature values.
150
151        Returns
152        -------
153        torch.Tensor
154            Shape (batch, 1). Feature contribution f_i(x_i).
155        """
156        if x.dim() == 1:
157            x = x.unsqueeze(-1)
158        return self.network(x)
159
160    def project_weights(self) -> None:
161        """Enforce monotonicity by clamping weight matrices in-place.
162
163        For a ReLU network, all-non-negative weights guarantee a
164        non-decreasing function (Dykstra's projection onto the positive
165        orthant). For decreasing, clamp to non-positive.
166
167        Call this after every optimizer.step() during training.
168        """
169        if self.monotonicity == "none":
170            return
171
172        for module in self.network.modules():
173            if isinstance(module, nn.Linear):
174                if self.monotonicity == "increasing":
175                    module.weight.data.clamp_(min=0.0)
176                elif self.monotonicity == "decreasing":
177                    module.weight.data.clamp_(max=0.0)
178            elif isinstance(module, ExUActivation):
179                if self.monotonicity == "increasing":
180                    module.weights.data.clamp_(min=0.0)
181                elif self.monotonicity == "decreasing":
182                    module.weights.data.clamp_(max=0.0)
183
184    def feature_range(
185        self, x_min: float, x_max: float, n_points: int = 200
186    ) -> tuple[torch.Tensor, torch.Tensor]:
187        """Evaluate shape function over a grid.
188
189        Returns (x_grid, f_values) for plotting shape curves.
190        """
191        self.eval()
192        with torch.no_grad():
193            x_grid = torch.linspace(x_min, x_max, n_points)
194            f_values = self.forward(x_grid).squeeze(-1)
195        return x_grid, f_values

Single-feature MLP subnetwork.

Takes a scalar input (one feature value) and outputs a scalar contribution f_i(x_i). The contribution is offset by its mean over a representative sample so that the bias term in the full model captures the population mean prediction.

Parameters

hidden_sizes: Width of each hidden layer. E.g. [64, 32] gives two hidden layers. activation: 'relu' (default) or 'exu'. ReLU is more stable; ExU more expressive. When 'exu' is specified, ExUActivation replaces the nn.Linear + nn.ReLU pair for each hidden layer (ExU fuses the linear transform into the activation). monotonicity: 'increasing', 'decreasing', or 'none'. Enforced by projecting weight matrices onto the non-negative (or non-positive) orthant after each gradient step via project_weights(). dropout: Dropout rate applied between hidden layers. 0.0 disables dropout.

FeatureNetwork( hidden_sizes: list[int] | None = None, activation: Literal['relu', 'exu'] = 'relu', monotonicity: Literal['increasing', 'decreasing', 'none'] = 'none', dropout: float = 0.0)
 91    def __init__(
 92        self,
 93        hidden_sizes: list[int] | None = None,
 94        activation: Literal["relu", "exu"] = "relu",
 95        monotonicity: MonotonicityDirection = "none",
 96        dropout: float = 0.0,
 97    ) -> None:
 98        super().__init__()
 99
100        if hidden_sizes is None:
101            hidden_sizes = [64, 32]
102
103        self.hidden_sizes = hidden_sizes
104        self.activation_name = activation
105        self.monotonicity = monotonicity
106        self.dropout_rate = dropout
107
108        layers: list[nn.Module] = []
109        in_dim = 1
110
111        for i, out_dim in enumerate(hidden_sizes):
112            if activation == "relu":
113                # Standard linear + ReLU pair.
114                linear = nn.Linear(in_dim, out_dim)
115                # Weight initialisation: smaller init stabilises training with
116                # many subnetworks summing together.
117                nn.init.xavier_uniform_(linear.weight, gain=0.5)
118                nn.init.zeros_(linear.bias)
119                layers.append(linear)
120                layers.append(nn.ReLU())
121            elif activation == "exu":
122                # P0-3 fix: wire up ExUActivation properly.
123                # ExU fuses the linear transform into the activation — no
124                # preceding nn.Linear. The ExUActivation module takes
125                # (in_features, out_features) and outputs (batch, out_features).
126                layers.append(ExUActivation(in_features=in_dim, out_features=out_dim))
127            else:
128                raise ValueError(f"Unknown activation: {activation!r}")
129
130            if dropout > 0.0 and i < len(hidden_sizes) - 1:
131                layers.append(nn.Dropout(p=dropout))
132
133            in_dim = out_dim
134
135        # Output layer: single scalar
136        output_layer = nn.Linear(in_dim, 1)
137        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
138        nn.init.zeros_(output_layer.bias)
139        layers.append(output_layer)
140
141        self.network = nn.Sequential(*layers)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

hidden_sizes
activation_name
monotonicity
dropout_rate
network
def forward(self, x: torch.Tensor) -> torch.Tensor:
143    def forward(self, x: torch.Tensor) -> torch.Tensor:
144        """Forward pass.
145
146        Parameters
147        ----------
148        x:
149            Shape (batch,) or (batch, 1). Single feature values.
150
151        Returns
152        -------
153        torch.Tensor
154            Shape (batch, 1). Feature contribution f_i(x_i).
155        """
156        if x.dim() == 1:
157            x = x.unsqueeze(-1)
158        return self.network(x)

Forward pass.

Parameters

x: Shape (batch,) or (batch, 1). Single feature values.

Returns

torch.Tensor Shape (batch, 1). Feature contribution f_i(x_i).

def project_weights(self) -> None:
160    def project_weights(self) -> None:
161        """Enforce monotonicity by clamping weight matrices in-place.
162
163        For a ReLU network, all-non-negative weights guarantee a
164        non-decreasing function (Dykstra's projection onto the positive
165        orthant). For decreasing, clamp to non-positive.
166
167        Call this after every optimizer.step() during training.
168        """
169        if self.monotonicity == "none":
170            return
171
172        for module in self.network.modules():
173            if isinstance(module, nn.Linear):
174                if self.monotonicity == "increasing":
175                    module.weight.data.clamp_(min=0.0)
176                elif self.monotonicity == "decreasing":
177                    module.weight.data.clamp_(max=0.0)
178            elif isinstance(module, ExUActivation):
179                if self.monotonicity == "increasing":
180                    module.weights.data.clamp_(min=0.0)
181                elif self.monotonicity == "decreasing":
182                    module.weights.data.clamp_(max=0.0)

Enforce monotonicity by clamping weight matrices in-place.

For a ReLU network, all-non-negative weights guarantee a non-decreasing function (Dykstra's projection onto the positive orthant). For decreasing, clamp to non-positive.

Call this after every optimizer.step() during training.

def feature_range( self, x_min: float, x_max: float, n_points: int = 200) -> tuple[torch.Tensor, torch.Tensor]:
184    def feature_range(
185        self, x_min: float, x_max: float, n_points: int = 200
186    ) -> tuple[torch.Tensor, torch.Tensor]:
187        """Evaluate shape function over a grid.
188
189        Returns (x_grid, f_values) for plotting shape curves.
190        """
191        self.eval()
192        with torch.no_grad():
193            x_grid = torch.linspace(x_min, x_max, n_points)
194            f_values = self.forward(x_grid).squeeze(-1)
195        return x_grid, f_values

Evaluate shape function over a grid.

Returns (x_grid, f_values) for plotting shape curves.

class CategoricalFeatureNetwork(torch.nn.modules.module.Module):
198class CategoricalFeatureNetwork(nn.Module):
199    """Subnetwork for categorical features using an embedding layer.
200
201    Each category maps to a learned embedding (dim: embedding_dim), which
202    then passes through a small MLP. This allows the model to discover
203    structure in category space (e.g. similar vehicle groups cluster
204    together) without requiring manual one-hot encoding.
205
206    For regulatory documentation, the output for each category level can be
207    extracted as a relativity table — exactly like a GLM factor table.
208    """
209
210    def __init__(
211        self,
212        n_categories: int,
213        embedding_dim: int = 4,
214        hidden_sizes: list[int] | None = None,
215        dropout: float = 0.0,
216    ) -> None:
217        super().__init__()
218
219        if hidden_sizes is None:
220            hidden_sizes = [32]
221
222        self.n_categories = n_categories
223        self.embedding_dim = embedding_dim
224
225        self.embedding = nn.Embedding(n_categories, embedding_dim)
226        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.1)
227
228        layers: list[nn.Module] = []
229        in_dim = embedding_dim
230
231        for i, out_dim in enumerate(hidden_sizes):
232            linear = nn.Linear(in_dim, out_dim)
233            nn.init.xavier_uniform_(linear.weight, gain=0.5)
234            nn.init.zeros_(linear.bias)
235            layers.append(linear)
236            layers.append(nn.ReLU())
237            if dropout > 0.0 and i < len(hidden_sizes) - 1:
238                layers.append(nn.Dropout(p=dropout))
239            in_dim = out_dim
240
241        output_layer = nn.Linear(in_dim, 1)
242        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
243        nn.init.zeros_(output_layer.bias)
244        layers.append(output_layer)
245
246        self.network = nn.Sequential(*layers)
247
248    def forward(self, x: torch.Tensor) -> torch.Tensor:
249        """Forward pass.
250
251        Parameters
252        ----------
253        x:
254            Shape (batch,). Integer category indices.
255
256        Returns
257        -------
258        torch.Tensor
259            Shape (batch, 1). Category contribution.
260        """
261        embedded = self.embedding(x.long())  # (batch, embedding_dim)
262        return self.network(embedded)
263
264    def category_table(self) -> dict[int, float]:
265        """Extract per-category contributions as a relativity table.
266
267        Returns a dict mapping category index to scalar contribution value.
268        Useful for regulatory documentation and GLM comparison.
269        """
270        self.eval()
271        with torch.no_grad():
272            indices = torch.arange(self.n_categories)
273            contribs = self.forward(indices).squeeze(-1)
274        return {int(i): float(v) for i, v in enumerate(contribs)}

Subnetwork for categorical features using an embedding layer.

Each category maps to a learned embedding (dim: embedding_dim), which then passes through a small MLP. This allows the model to discover structure in category space (e.g. similar vehicle groups cluster together) without requiring manual one-hot encoding.

For regulatory documentation, the output for each category level can be extracted as a relativity table — exactly like a GLM factor table.

CategoricalFeatureNetwork( n_categories: int, embedding_dim: int = 4, hidden_sizes: list[int] | None = None, dropout: float = 0.0)
210    def __init__(
211        self,
212        n_categories: int,
213        embedding_dim: int = 4,
214        hidden_sizes: list[int] | None = None,
215        dropout: float = 0.0,
216    ) -> None:
217        super().__init__()
218
219        if hidden_sizes is None:
220            hidden_sizes = [32]
221
222        self.n_categories = n_categories
223        self.embedding_dim = embedding_dim
224
225        self.embedding = nn.Embedding(n_categories, embedding_dim)
226        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.1)
227
228        layers: list[nn.Module] = []
229        in_dim = embedding_dim
230
231        for i, out_dim in enumerate(hidden_sizes):
232            linear = nn.Linear(in_dim, out_dim)
233            nn.init.xavier_uniform_(linear.weight, gain=0.5)
234            nn.init.zeros_(linear.bias)
235            layers.append(linear)
236            layers.append(nn.ReLU())
237            if dropout > 0.0 and i < len(hidden_sizes) - 1:
238                layers.append(nn.Dropout(p=dropout))
239            in_dim = out_dim
240
241        output_layer = nn.Linear(in_dim, 1)
242        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
243        nn.init.zeros_(output_layer.bias)
244        layers.append(output_layer)
245
246        self.network = nn.Sequential(*layers)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

n_categories
embedding_dim
embedding
network
def forward(self, x: torch.Tensor) -> torch.Tensor:
248    def forward(self, x: torch.Tensor) -> torch.Tensor:
249        """Forward pass.
250
251        Parameters
252        ----------
253        x:
254            Shape (batch,). Integer category indices.
255
256        Returns
257        -------
258        torch.Tensor
259            Shape (batch, 1). Category contribution.
260        """
261        embedded = self.embedding(x.long())  # (batch, embedding_dim)
262        return self.network(embedded)

Forward pass.

Parameters

x: Shape (batch,). Integer category indices.

Returns

torch.Tensor Shape (batch, 1). Category contribution.

def category_table(self) -> dict[int, float]:
264    def category_table(self) -> dict[int, float]:
265        """Extract per-category contributions as a relativity table.
266
267        Returns a dict mapping category index to scalar contribution value.
268        Useful for regulatory documentation and GLM comparison.
269        """
270        self.eval()
271        with torch.no_grad():
272            indices = torch.arange(self.n_categories)
273            contribs = self.forward(indices).squeeze(-1)
274        return {int(i): float(v) for i, v in enumerate(contribs)}

Extract per-category contributions as a relativity table.

Returns a dict mapping category index to scalar contribution value. Useful for regulatory documentation and GLM comparison.

class InteractionNetwork(torch.nn.modules.module.Module):
 26class InteractionNetwork(nn.Module):
 27    """Pairwise interaction subnetwork g_{ij}(x_i, x_j).
 28
 29    Takes two feature values and learns their joint effect. The network
 30    is intentionally shallow to avoid overfitting and to keep the
 31    interaction contribution interpretable as a 2D surface.
 32
 33    Parameters
 34    ----------
 35    feature_indices:
 36        Tuple (i, j) identifying which features this network handles.
 37        Used for bookkeeping and shape function export.
 38    hidden_sizes:
 39        Width of each hidden layer. Shallower than single-feature nets
 40        is recommended — interactions should be simple corrections to
 41        the additive baseline.
 42    dropout:
 43        Dropout rate between hidden layers.
 44    """
 45
 46    def __init__(
 47        self,
 48        feature_indices: tuple[int, int],
 49        hidden_sizes: list[int] | None = None,
 50        dropout: float = 0.0,
 51    ) -> None:
 52        super().__init__()
 53
 54        if hidden_sizes is None:
 55            hidden_sizes = [32, 16]
 56
 57        self.feature_indices = feature_indices
 58        self.hidden_sizes = hidden_sizes
 59
 60        layers: list[nn.Module] = []
 61        in_dim = 2  # two features concatenated
 62
 63        for i, out_dim in enumerate(hidden_sizes):
 64            linear = nn.Linear(in_dim, out_dim)
 65            nn.init.xavier_uniform_(linear.weight, gain=0.5)
 66            nn.init.zeros_(linear.bias)
 67            layers.append(linear)
 68            layers.append(nn.ReLU())
 69            if dropout > 0.0 and i < len(hidden_sizes) - 1:
 70                layers.append(nn.Dropout(p=dropout))
 71            in_dim = out_dim
 72
 73        output_layer = nn.Linear(in_dim, 1)
 74        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
 75        nn.init.zeros_(output_layer.bias)
 76        layers.append(output_layer)
 77
 78        self.network = nn.Sequential(*layers)
 79
 80    def forward(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
 81        """Forward pass.
 82
 83        Parameters
 84        ----------
 85        x_i, x_j:
 86            Shape (batch,) or (batch, 1). Values for the two interacting
 87            features.
 88
 89        Returns
 90        -------
 91        torch.Tensor
 92            Shape (batch, 1). Interaction contribution g_{ij}(x_i, x_j).
 93        """
 94        if x_i.dim() == 1:
 95            x_i = x_i.unsqueeze(-1)
 96        if x_j.dim() == 1:
 97            x_j = x_j.unsqueeze(-1)
 98        combined = torch.cat([x_i, x_j], dim=-1)  # (batch, 2)
 99        return self.network(combined)
100
101    def interaction_grid(
102        self,
103        xi_min: float,
104        xi_max: float,
105        xj_min: float,
106        xj_max: float,
107        n_points: int = 50,
108    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109        """Evaluate interaction surface over a 2D grid.
110
111        Returns (xi_grid, xj_grid, g_values) where g_values is an
112        (n_points, n_points) tensor representing the interaction surface.
113        Useful for heatmap visualisation.
114        """
115        self.eval()
116        with torch.no_grad():
117            xi = torch.linspace(xi_min, xi_max, n_points)
118            xj = torch.linspace(xj_min, xj_max, n_points)
119            xi_grid, xj_grid = torch.meshgrid(xi, xj, indexing="ij")
120            xi_flat = xi_grid.reshape(-1)
121            xj_flat = xj_grid.reshape(-1)
122            g_flat = self.forward(xi_flat, xj_flat).squeeze(-1)
123            g_values = g_flat.reshape(n_points, n_points)
124        return xi_grid, xj_grid, g_values

Pairwise interaction subnetwork g_{ij}(x_i, x_j).

Takes two feature values and learns their joint effect. The network is intentionally shallow to avoid overfitting and to keep the interaction contribution interpretable as a 2D surface.

Parameters

feature_indices: Tuple (i, j) identifying which features this network handles. Used for bookkeeping and shape function export. hidden_sizes: Width of each hidden layer. Shallower than single-feature nets is recommended — interactions should be simple corrections to the additive baseline. dropout: Dropout rate between hidden layers.

InteractionNetwork( feature_indices: tuple[int, int], hidden_sizes: list[int] | None = None, dropout: float = 0.0)
46    def __init__(
47        self,
48        feature_indices: tuple[int, int],
49        hidden_sizes: list[int] | None = None,
50        dropout: float = 0.0,
51    ) -> None:
52        super().__init__()
53
54        if hidden_sizes is None:
55            hidden_sizes = [32, 16]
56
57        self.feature_indices = feature_indices
58        self.hidden_sizes = hidden_sizes
59
60        layers: list[nn.Module] = []
61        in_dim = 2  # two features concatenated
62
63        for i, out_dim in enumerate(hidden_sizes):
64            linear = nn.Linear(in_dim, out_dim)
65            nn.init.xavier_uniform_(linear.weight, gain=0.5)
66            nn.init.zeros_(linear.bias)
67            layers.append(linear)
68            layers.append(nn.ReLU())
69            if dropout > 0.0 and i < len(hidden_sizes) - 1:
70                layers.append(nn.Dropout(p=dropout))
71            in_dim = out_dim
72
73        output_layer = nn.Linear(in_dim, 1)
74        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
75        nn.init.zeros_(output_layer.bias)
76        layers.append(output_layer)
77
78        self.network = nn.Sequential(*layers)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

feature_indices
hidden_sizes
network
def forward(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
80    def forward(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
81        """Forward pass.
82
83        Parameters
84        ----------
85        x_i, x_j:
86            Shape (batch,) or (batch, 1). Values for the two interacting
87            features.
88
89        Returns
90        -------
91        torch.Tensor
92            Shape (batch, 1). Interaction contribution g_{ij}(x_i, x_j).
93        """
94        if x_i.dim() == 1:
95            x_i = x_i.unsqueeze(-1)
96        if x_j.dim() == 1:
97            x_j = x_j.unsqueeze(-1)
98        combined = torch.cat([x_i, x_j], dim=-1)  # (batch, 2)
99        return self.network(combined)

Forward pass.

Parameters

x_i, x_j: Shape (batch,) or (batch, 1). Values for the two interacting features.

Returns

torch.Tensor Shape (batch, 1). Interaction contribution g_{ij}(x_i, x_j).

def interaction_grid( self, xi_min: float, xi_max: float, xj_min: float, xj_max: float, n_points: int = 50) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
101    def interaction_grid(
102        self,
103        xi_min: float,
104        xi_max: float,
105        xj_min: float,
106        xj_max: float,
107        n_points: int = 50,
108    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109        """Evaluate interaction surface over a 2D grid.
110
111        Returns (xi_grid, xj_grid, g_values) where g_values is an
112        (n_points, n_points) tensor representing the interaction surface.
113        Useful for heatmap visualisation.
114        """
115        self.eval()
116        with torch.no_grad():
117            xi = torch.linspace(xi_min, xi_max, n_points)
118            xj = torch.linspace(xj_min, xj_max, n_points)
119            xi_grid, xj_grid = torch.meshgrid(xi, xj, indexing="ij")
120            xi_flat = xi_grid.reshape(-1)
121            xj_flat = xj_grid.reshape(-1)
122            g_flat = self.forward(xi_flat, xj_flat).squeeze(-1)
123            g_values = g_flat.reshape(n_points, n_points)
124        return xi_grid, xj_grid, g_values

Evaluate interaction surface over a 2D grid.

Returns (xi_grid, xj_grid, g_values) where g_values is an (n_points, n_points) tensor representing the interaction surface. Useful for heatmap visualisation.

class ANAMTrainer:
113class ANAMTrainer:
114    """Manages the training loop for ANAMModel.
115
116    Parameters
117    ----------
118    model:
119        The ANAMModel to train. Modified in-place.
120    config:
121        Training hyperparameters.
122    """
123
124    def __init__(self, model: ANAMModel, config: TrainingConfig) -> None:
125        self.model = model
126        self.config = config
127
128        if config.device is None:
129            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130        else:
131            self.device = torch.device(config.device)
132
133        self.model.to(self.device)
134        self.history = TrainingHistory()
135
136    def fit(
137        self,
138        X: np.ndarray,
139        y: np.ndarray,
140        exposure: Optional[np.ndarray] = None,
141    ) -> TrainingHistory:
142        """Train the model.
143
144        Parameters
145        ----------
146        X:
147            Feature matrix, shape (n, n_features). Continuous features
148            should be normalised before calling fit().
149        y:
150            Target vector, shape (n,). Claim counts, rates, or severities.
151        exposure:
152            Exposure vector, shape (n,). Policy years or similar. If None,
153            uniform exposure (all 1.0) is assumed.
154
155        Returns
156        -------
157        TrainingHistory
158            Training and validation loss per epoch.
159        """
160        n = len(y)
161
162        if exposure is None:
163            exposure = np.ones(n, dtype=np.float32)
164
165        # Convert to tensors
166        X_t = torch.tensor(X, dtype=torch.float32)
167        y_t = torch.tensor(y, dtype=torch.float32)
168        exp_t = torch.tensor(exposure, dtype=torch.float32)
169
170        # Train/val split
171        n_val = max(1, int(n * self.config.val_fraction))
172        perm = torch.randperm(n)
173        val_idx = perm[:n_val]
174        train_idx = perm[n_val:]
175
176        X_train, y_train, exp_train = X_t[train_idx], y_t[train_idx], exp_t[train_idx]
177        X_val, y_val, exp_val = X_t[val_idx], y_t[val_idx], exp_t[val_idx]
178
179        # Precompute feature ranges for smoothness penalty
180        feature_ranges = self._compute_feature_ranges(X_train)
181
182        # DataLoader
183        train_dataset = TensorDataset(X_train, y_train, exp_train)
184        train_loader = DataLoader(
185            train_dataset,
186            batch_size=self.config.batch_size,
187            shuffle=True,
188            drop_last=False,
189        )
190
191        # Optimiser and LR scheduler
192        optimizer = torch.optim.Adam(
193            self.model.parameters(),
194            lr=self.config.learning_rate,
195            weight_decay=0.0,  # L2 handled manually for flexibility
196        )
197        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198            optimizer, T_max=self.config.n_epochs, eta_min=self.config.learning_rate * 0.01
199        )
200
201        best_val_loss = float("inf")
202        best_state = copy.deepcopy(self.model.state_dict())
203        patience_counter = 0
204
205        for epoch in range(self.config.n_epochs):
206            t0 = time.time()
207            train_loss = self._train_epoch(
208                train_loader, optimizer, feature_ranges
209            )
210            val_loss = self._evaluate(X_val, y_val, exp_val)
211
212            scheduler.step()
213
214            self.history.train_loss.append(train_loss)
215            self.history.val_loss.append(val_loss)
216            self.history.epoch_times.append(time.time() - t0)
217
218            if self.config.verbose > 0 and (epoch + 1) % self.config.verbose == 0:
219                print(
220                    f"Epoch {epoch + 1:4d}/{self.config.n_epochs} | "
221                    f"train={train_loss:.6f} | val={val_loss:.6f}"
222                )
223
224            # Early stopping
225            if val_loss < best_val_loss - self.config.min_delta:
226                best_val_loss = val_loss
227                best_state = copy.deepcopy(self.model.state_dict())
228                self.history.best_epoch = epoch
229                patience_counter = 0
230            else:
231                patience_counter += 1
232                if patience_counter >= self.config.patience:
233                    if self.config.verbose > 0:
234                        print(f"Early stopping at epoch {epoch + 1}.")
235                    self.history.stopped_early = True
236                    break
237
238        # Restore best weights
239        self.model.load_state_dict(best_state)
240        return self.history
241
242    def _train_epoch(
243        self,
244        loader: DataLoader,
245        optimizer: torch.optim.Optimizer,
246        feature_ranges: Dict[str, Tuple[float, float]],
247    ) -> float:
248        """One pass through the training data."""
249        self.model.train()
250        total_loss = 0.0
251        total_weight = 0.0
252
253        for X_batch, y_batch, exp_batch in loader:
254            X_batch = X_batch.to(self.device)
255            y_batch = y_batch.to(self.device)
256            exp_batch = exp_batch.to(self.device)
257
258            optimizer.zero_grad()
259
260            # Log exposure offset for log-link models
261            log_exp = torch.log(exp_batch.clamp(min=1e-8)) if self.model.link == "log" else None
262
263            # Forward pass
264            y_pred = self.model(X_batch, log_exposure=log_exp)
265
266            # Distributional loss
267            loss = self._distributional_loss(y_pred, y_batch, exp_batch)
268
269            # Smoothness penalty over all continuous features
270            if self.config.lambda_smooth > 0.0:
271                for cfg in self.model.feature_configs:
272                    if cfg.feature_type == "continuous":
273                        x_min, x_max = feature_ranges[cfg.name]
274                        net = self.model.feature_nets[cfg.name]
275                        loss = loss + smoothness_penalty(
276                            net, x_min, x_max,
277                            n_points=self.config.smooth_n_points,
278                            lambda_smooth=self.config.lambda_smooth,
279                        )
280
281            # L2 ridge
282            if self.config.lambda_l2 > 0.0:
283                all_nets = list(self.model.feature_nets.values()) + list(
284                    self.model.interaction_nets.values()
285                )
286                loss = loss + l2_ridge_penalty(all_nets, self.config.lambda_l2)
287
288            # L1 sparsity
289            if self.config.lambda_l1 > 0.0:
290                all_nets = list(self.model.feature_nets.values())
291                loss = loss + l1_sparsity_penalty(all_nets, self.config.lambda_l1)
292
293            loss.backward()
294            optimizer.step()
295
296            # Monotonicity projection (Dykstra step)
297            self.model.project_monotone_weights()
298
299            batch_weight = exp_batch.sum().item()
300            total_loss += loss.item() * batch_weight
301            total_weight += batch_weight
302
303        return total_loss / max(total_weight, 1e-8)
304
305    def _evaluate(
306        self,
307        X: torch.Tensor,
308        y: torch.Tensor,
309        exposure: torch.Tensor,
310    ) -> float:
311        """Evaluate distributional loss on a dataset (no regularisation)."""
312        self.model.eval()
313        with torch.no_grad():
314            X = X.to(self.device)
315            y = y.to(self.device)
316            exposure = exposure.to(self.device)
317
318            log_exp = torch.log(exposure.clamp(min=1e-8)) if self.model.link == "log" else None
319            y_pred = self.model(X, log_exposure=log_exp)
320            loss = self._distributional_loss(y_pred, y, exposure)
321
322        return float(loss.item())
323
324    def _distributional_loss(
325        self,
326        y_pred: torch.Tensor,
327        y_true: torch.Tensor,
328        weights: torch.Tensor,
329    ) -> torch.Tensor:
330        """Compute weighted distributional loss."""
331        cfg = self.config
332        if cfg.loss == "poisson":
333            return poisson_deviance(y_pred, y_true, weights)
334        elif cfg.loss == "tweedie":
335            return tweedie_deviance(y_pred, y_true, p=cfg.tweedie_p, weights=weights)
336        elif cfg.loss == "gamma":
337            return gamma_deviance(y_pred, y_true, weights)
338        elif cfg.loss == "mse":
339            err = (y_pred - y_true) ** 2
340            return (weights * err).sum() / weights.sum().clamp(min=1e-8)
341        else:
342            raise ValueError(f"Unknown loss: {cfg.loss!r}")
343
344    def _compute_feature_ranges(
345        self, X_train: torch.Tensor
346    ) -> Dict[str, Tuple[float, float]]:
347        """Compute per-feature min/max ranges from training data."""
348        ranges: Dict[str, Tuple[float, float]] = {}
349        for i, cfg in enumerate(self.model.feature_configs):
350            if cfg.feature_type == "continuous":
351                col = X_train[:, i]
352                ranges[cfg.name] = (float(col.min().item()), float(col.max().item()))
353        return ranges

Manages the training loop for ANAMModel.

Parameters

model: The ANAMModel to train. Modified in-place. config: Training hyperparameters.

ANAMTrainer( model: ANAMModel, config: TrainingConfig)
124    def __init__(self, model: ANAMModel, config: TrainingConfig) -> None:
125        self.model = model
126        self.config = config
127
128        if config.device is None:
129            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130        else:
131            self.device = torch.device(config.device)
132
133        self.model.to(self.device)
134        self.history = TrainingHistory()
model
config
history
def fit( self, X: numpy.ndarray, y: numpy.ndarray, exposure: Optional[numpy.ndarray] = None) -> TrainingHistory:
136    def fit(
137        self,
138        X: np.ndarray,
139        y: np.ndarray,
140        exposure: Optional[np.ndarray] = None,
141    ) -> TrainingHistory:
142        """Train the model.
143
144        Parameters
145        ----------
146        X:
147            Feature matrix, shape (n, n_features). Continuous features
148            should be normalised before calling fit().
149        y:
150            Target vector, shape (n,). Claim counts, rates, or severities.
151        exposure:
152            Exposure vector, shape (n,). Policy years or similar. If None,
153            uniform exposure (all 1.0) is assumed.
154
155        Returns
156        -------
157        TrainingHistory
158            Training and validation loss per epoch.
159        """
160        n = len(y)
161
162        if exposure is None:
163            exposure = np.ones(n, dtype=np.float32)
164
165        # Convert to tensors
166        X_t = torch.tensor(X, dtype=torch.float32)
167        y_t = torch.tensor(y, dtype=torch.float32)
168        exp_t = torch.tensor(exposure, dtype=torch.float32)
169
170        # Train/val split
171        n_val = max(1, int(n * self.config.val_fraction))
172        perm = torch.randperm(n)
173        val_idx = perm[:n_val]
174        train_idx = perm[n_val:]
175
176        X_train, y_train, exp_train = X_t[train_idx], y_t[train_idx], exp_t[train_idx]
177        X_val, y_val, exp_val = X_t[val_idx], y_t[val_idx], exp_t[val_idx]
178
179        # Precompute feature ranges for smoothness penalty
180        feature_ranges = self._compute_feature_ranges(X_train)
181
182        # DataLoader
183        train_dataset = TensorDataset(X_train, y_train, exp_train)
184        train_loader = DataLoader(
185            train_dataset,
186            batch_size=self.config.batch_size,
187            shuffle=True,
188            drop_last=False,
189        )
190
191        # Optimiser and LR scheduler
192        optimizer = torch.optim.Adam(
193            self.model.parameters(),
194            lr=self.config.learning_rate,
195            weight_decay=0.0,  # L2 handled manually for flexibility
196        )
197        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198            optimizer, T_max=self.config.n_epochs, eta_min=self.config.learning_rate * 0.01
199        )
200
201        best_val_loss = float("inf")
202        best_state = copy.deepcopy(self.model.state_dict())
203        patience_counter = 0
204
205        for epoch in range(self.config.n_epochs):
206            t0 = time.time()
207            train_loss = self._train_epoch(
208                train_loader, optimizer, feature_ranges
209            )
210            val_loss = self._evaluate(X_val, y_val, exp_val)
211
212            scheduler.step()
213
214            self.history.train_loss.append(train_loss)
215            self.history.val_loss.append(val_loss)
216            self.history.epoch_times.append(time.time() - t0)
217
218            if self.config.verbose > 0 and (epoch + 1) % self.config.verbose == 0:
219                print(
220                    f"Epoch {epoch + 1:4d}/{self.config.n_epochs} | "
221                    f"train={train_loss:.6f} | val={val_loss:.6f}"
222                )
223
224            # Early stopping
225            if val_loss < best_val_loss - self.config.min_delta:
226                best_val_loss = val_loss
227                best_state = copy.deepcopy(self.model.state_dict())
228                self.history.best_epoch = epoch
229                patience_counter = 0
230            else:
231                patience_counter += 1
232                if patience_counter >= self.config.patience:
233                    if self.config.verbose > 0:
234                        print(f"Early stopping at epoch {epoch + 1}.")
235                    self.history.stopped_early = True
236                    break
237
238        # Restore best weights
239        self.model.load_state_dict(best_state)
240        return self.history

Train the model.

Parameters

X: Feature matrix, shape (n, n_features). Continuous features should be normalised before calling fit(). y: Target vector, shape (n,). Claim counts, rates, or severities. exposure: Exposure vector, shape (n,). Policy years or similar. If None, uniform exposure (all 1.0) is assumed.

Returns

TrainingHistory Training and validation loss per epoch.

@dataclass
class TrainingConfig:
48@dataclass
49class TrainingConfig:
50    """Hyperparameters for the ANAM training loop.
51
52    Parameters
53    ----------
54    loss:
55        Distributional loss type. 'poisson' for frequency, 'tweedie' for
56        pure premium, 'gamma' for severity, 'mse' for Gaussian.
57    tweedie_p:
58        Tweedie power parameter. Only used when loss='tweedie'.
59        Typical range (1.0, 2.0). Common choices: 1.5 (compound Poisson).
60    n_epochs:
61        Maximum training epochs.
62    batch_size:
63        Mini-batch size.
64    learning_rate:
65        Initial Adam learning rate.
66    lambda_smooth:
67        Smoothness penalty weight. 0.0 disables smoothness regularisation.
68    lambda_l1:
69        L1 sparsity penalty weight. 0.0 disables.
70    lambda_l2:
71        L2 ridge penalty weight (weight decay).
72    smooth_n_points:
73        Number of grid points for smoothness penalty evaluation.
74    val_fraction:
75        Fraction of training data held out for early stopping.
76    patience:
77        Number of epochs without improvement before stopping.
78    min_delta:
79        Minimum improvement in validation loss to count as progress.
80    verbose:
81        Print training progress every `verbose` epochs. 0 = silent.
82    device:
83        PyTorch device string. 'cpu' or 'cuda'. Auto-detected if None.
84    """
85
86    loss: LossType = "poisson"
87    tweedie_p: float = 1.5
88    n_epochs: int = 100
89    batch_size: int = 512
90    learning_rate: float = 1e-3
91    lambda_smooth: float = 1e-4
92    lambda_l1: float = 0.0
93    lambda_l2: float = 1e-4
94    smooth_n_points: int = 100
95    val_fraction: float = 0.1
96    patience: int = 10
97    min_delta: float = 1e-6
98    verbose: int = 10
99    device: Optional[str] = None

Hyperparameters for the ANAM training loop.

Parameters

loss: Distributional loss type. 'poisson' for frequency, 'tweedie' for pure premium, 'gamma' for severity, 'mse' for Gaussian. tweedie_p: Tweedie power parameter. Only used when loss='tweedie'. Typical range (1.0, 2.0). Common choices: 1.5 (compound Poisson). n_epochs: Maximum training epochs. batch_size: Mini-batch size. learning_rate: Initial Adam learning rate. lambda_smooth: Smoothness penalty weight. 0.0 disables smoothness regularisation. lambda_l1: L1 sparsity penalty weight. 0.0 disables. lambda_l2: L2 ridge penalty weight (weight decay). smooth_n_points: Number of grid points for smoothness penalty evaluation. val_fraction: Fraction of training data held out for early stopping. patience: Number of epochs without improvement before stopping. min_delta: Minimum improvement in validation loss to count as progress. verbose: Print training progress every verbose epochs. 0 = silent. device: PyTorch device string. 'cpu' or 'cuda'. Auto-detected if None.

TrainingConfig( loss: Literal['poisson', 'tweedie', 'gamma', 'mse'] = 'poisson', tweedie_p: float = 1.5, n_epochs: int = 100, batch_size: int = 512, learning_rate: float = 0.001, lambda_smooth: float = 0.0001, lambda_l1: float = 0.0, lambda_l2: float = 0.0001, smooth_n_points: int = 100, val_fraction: float = 0.1, patience: int = 10, min_delta: float = 1e-06, verbose: int = 10, device: Optional[str] = None)
loss: Literal['poisson', 'tweedie', 'gamma', 'mse'] = 'poisson'
tweedie_p: float = 1.5
n_epochs: int = 100
batch_size: int = 512
learning_rate: float = 0.001
lambda_smooth: float = 0.0001
lambda_l1: float = 0.0
lambda_l2: float = 0.0001
smooth_n_points: int = 100
val_fraction: float = 0.1
patience: int = 10
min_delta: float = 1e-06
verbose: int = 10
device: Optional[str] = None
@dataclass
class TrainingHistory:
102@dataclass
103class TrainingHistory:
104    """Records per-epoch training and validation losses."""
105
106    train_loss: List[float] = field(default_factory=list)
107    val_loss: List[float] = field(default_factory=list)
108    epoch_times: List[float] = field(default_factory=list)
109    best_epoch: int = 0
110    stopped_early: bool = False

Records per-epoch training and validation losses.

TrainingHistory( train_loss: List[float] = <factory>, val_loss: List[float] = <factory>, epoch_times: List[float] = <factory>, best_epoch: int = 0, stopped_early: bool = False)
train_loss: List[float]
val_loss: List[float]
epoch_times: List[float]
best_epoch: int = 0
stopped_early: bool = False
def poisson_deviance( y_pred: torch.Tensor, y_true: torch.Tensor, weights: torch.Tensor | None = None, eps: float = 1e-08) -> torch.Tensor:
36def poisson_deviance(
37    y_pred: torch.Tensor,
38    y_true: torch.Tensor,
39    weights: torch.Tensor | None = None,
40    eps: float = 1e-8,
41) -> torch.Tensor:
42    """Poisson deviance loss (mean over batch).
43
44    Deviance = 2 * w * [y * log(y / mu) - (y - mu)]
45    where mu = y_pred, y = y_true, w = observation weight.
46
47    Parameters
48    ----------
49    y_pred:
50        Predicted mean (mu), strictly positive. Shape (batch,).
51    y_true:
52        Observed values. Shape (batch,).
53    weights:
54        Observation weights (e.g. exposure). Shape (batch,). If None,
55        uses uniform weights.
56    eps:
57        Numerical floor for y_true in log computation.
58
59    Returns
60    -------
61    torch.Tensor
62        Scalar mean deviance.
63    """
64    mu = y_pred.clamp(min=eps)
65    y = y_true.clamp(min=eps)
66
67    # y * log(y/mu) term: undefined at y=0, use 0 * log(0) = 0 convention
68    log_term = torch.where(
69        y_true > eps,
70        y * (torch.log(y) - torch.log(mu)),
71        torch.zeros_like(y),
72    )
73    d = 2.0 * (log_term - (y_true - mu))
74
75    if weights is not None:
76        return (weights * d).sum() / weights.sum().clamp(min=eps)
77    return d.mean()

Poisson deviance loss (mean over batch).

Deviance = 2 * w * [y * log(y / mu) - (y - mu)] where mu = y_pred, y = y_true, w = observation weight.

Parameters

y_pred: Predicted mean (mu), strictly positive. Shape (batch,). y_true: Observed values. Shape (batch,). weights: Observation weights (e.g. exposure). Shape (batch,). If None, uses uniform weights. eps: Numerical floor for y_true in log computation.

Returns

torch.Tensor Scalar mean deviance.

def gamma_deviance( y_pred: torch.Tensor, y_true: torch.Tensor, weights: torch.Tensor | None = None, eps: float = 1e-08) -> torch.Tensor:
 80def gamma_deviance(
 81    y_pred: torch.Tensor,
 82    y_true: torch.Tensor,
 83    weights: torch.Tensor | None = None,
 84    eps: float = 1e-8,
 85) -> torch.Tensor:
 86    """Gamma deviance loss (mean over batch).
 87
 88    Deviance = 2 * w * [log(mu/y) + (y - mu)/mu]
 89    where mu = y_pred, y = y_true.
 90
 91    Used for claim severity (positive, right-skewed).
 92    """
 93    mu = y_pred.clamp(min=eps)
 94    y = y_true.clamp(min=eps)
 95
 96    d = 2.0 * (torch.log(mu / y) + (y - mu) / mu)
 97
 98    if weights is not None:
 99        return (weights * d).sum() / weights.sum().clamp(min=eps)
100    return d.mean()

Gamma deviance loss (mean over batch).

Deviance = 2 * w * [log(mu/y) + (y - mu)/mu] where mu = y_pred, y = y_true.

Used for claim severity (positive, right-skewed).

def tweedie_deviance( y_pred: torch.Tensor, y_true: torch.Tensor, p: float = 1.5, weights: torch.Tensor | None = None, eps: float = 1e-08) -> torch.Tensor:
103def tweedie_deviance(
104    y_pred: torch.Tensor,
105    y_true: torch.Tensor,
106    p: float = 1.5,
107    weights: torch.Tensor | None = None,
108    eps: float = 1e-8,
109) -> torch.Tensor:
110    """Tweedie deviance loss (mean over batch).
111
112    For p in (1, 2):
113        D(y, mu) = 2 * [y^(2-p)/((1-p)*(2-p)) - y*mu^(1-p)/(1-p) + mu^(2-p)/(2-p)]
114
115    Special cases:
116    - p=1: Poisson (use poisson_deviance for numerical stability)
117    - p=2: Gamma (use gamma_deviance)
118    - p=1.5: Inverse Gaussian-like, common for pure premium
119
120    Parameters
121    ----------
122    p:
123        Tweedie power parameter. Must not equal 1 or 2. Typical range
124        (1.0, 2.0) for compound Poisson-Gamma.
125    """
126    if abs(p - 1.0) < 1e-6:
127        return poisson_deviance(y_pred, y_true, weights, eps)
128    if abs(p - 2.0) < 1e-6:
129        return gamma_deviance(y_pred, y_true, weights, eps)
130
131    mu = y_pred.clamp(min=eps)
132    y = y_true.clamp(min=eps)
133
134    term1 = y.pow(2 - p) / ((1 - p) * (2 - p))
135    term2 = y * mu.pow(1 - p) / (1 - p)
136    term3 = mu.pow(2 - p) / (2 - p)
137
138    # Handle y=0 in term1: 0^(2-p) for p<2 is 0, so term1 -> 0
139    term1 = torch.where(y_true < eps, torch.zeros_like(term1), term1)
140
141    d = 2.0 * (term1 - term2 + term3)
142
143    if weights is not None:
144        return (weights * d).sum() / weights.sum().clamp(min=eps)
145    return d.mean()

Tweedie deviance loss (mean over batch).

For p in (1, 2): D(y, mu) = 2 * [y^(2-p)/((1-p)(2-p)) - ymu^(1-p)/(1-p) + mu^(2-p)/(2-p)]

Special cases:

  • p=1: Poisson (use poisson_deviance for numerical stability)
  • p=2: Gamma (use gamma_deviance)
  • p=1.5: Inverse Gaussian-like, common for pure premium

Parameters

p: Tweedie power parameter. Must not equal 1 or 2. Typical range (1.0, 2.0) for compound Poisson-Gamma.

def bernoulli_deviance( y_pred_logit: torch.Tensor, y_true: torch.Tensor, weights: torch.Tensor | None = None, eps: float = 1e-08) -> torch.Tensor:
148def bernoulli_deviance(
149    y_pred_logit: torch.Tensor,
150    y_true: torch.Tensor,
151    weights: torch.Tensor | None = None,
152    eps: float = 1e-8,
153) -> torch.Tensor:
154    """Binary cross-entropy deviance (logit inputs).
155
156    For binary outcomes (lapse, catastrophic event indicators).
157    y_pred_logit is the raw network output (before sigmoid).
158    """
159    d = F.binary_cross_entropy_with_logits(y_pred_logit, y_true, reduction="none")
160    if weights is not None:
161        return (weights * d).sum() / weights.sum().clamp(min=eps)
162    return d.mean()

Binary cross-entropy deviance (logit inputs).

For binary outcomes (lapse, catastrophic event indicators). y_pred_logit is the raw network output (before sigmoid).

def smoothness_penalty( feature_network: torch.nn.modules.module.Module, x_min: float, x_max: float, n_points: int = 100, lambda_smooth: float = 0.0001) -> torch.Tensor:
170def smoothness_penalty(
171    feature_network: "torch.nn.Module",
172    x_min: float,
173    x_max: float,
174    n_points: int = 100,
175    lambda_smooth: float = 1e-4,
176) -> torch.Tensor:
177    """Second-order difference penalty on a feature network's shape.
178
179    Evaluates f_i at n_points evenly spaced over [x_min, x_max] and
180    penalises second differences: sum((f_{k+2} - 2*f_{k+1} + f_k)^2).
181
182    This discourages shape functions that change direction rapidly.
183    Lambda_smooth controls the trade-off between fit and smoothness.
184    """
185    # Determine device from network parameters so the grid lands on the same
186    # device as the model (GPU or CPU). torch.linspace defaults to CPU which
187    # would cause device-mismatch errors when the network is on CUDA.
188    try:
189        device = next(feature_network.parameters()).device
190    except StopIteration:
191        device = torch.device("cpu")
192    x_grid = torch.linspace(x_min, x_max, n_points, device=device)
193    f_vals = feature_network(x_grid).squeeze(-1)  # (n_points,)
194
195    # Second-order differences: f[k+2] - 2*f[k+1] + f[k]
196    second_diff = f_vals[2:] - 2 * f_vals[1:-1] + f_vals[:-2]
197    penalty = lambda_smooth * (second_diff ** 2).sum()
198    return penalty

Second-order difference penalty on a feature network's shape.

Evaluates f_i at n_points evenly spaced over [x_min, x_max] and penalises second differences: sum((f_{k+2} - 2*f_{k+1} + f_k)^2).

This discourages shape functions that change direction rapidly. Lambda_smooth controls the trade-off between fit and smoothness.

def l1_sparsity_penalty( feature_networks: list[torch.nn.modules.module.Module], lambda_l1: float = 1e-05) -> torch.Tensor:
201def l1_sparsity_penalty(
202    feature_networks: list["torch.nn.Module"],
203    lambda_l1: float = 1e-5,
204) -> torch.Tensor:
205    """L1 penalty on output layer weights of each subnetwork.
206
207    Encourages some subnetworks to output near-zero (feature selection).
208    Applied only to the output layer to avoid over-shrinking intermediate
209    representations.
210    """
211    penalty = torch.tensor(0.0)
212    for net in feature_networks:
213        # Collect only the output layer: the last nn.Linear in the network.
214        # Penalising all weight layers would shrink intermediate representations
215        # too aggressively — the docstring is clear that only the output layer
216        # is the right target for feature-selection sparsity.
217        import torch.nn as nn
218        linear_layers = [m for m in net.modules() if isinstance(m, nn.Linear)]
219        if linear_layers:
220            output_layer = linear_layers[-1]
221            penalty = penalty + output_layer.weight.abs().sum()
222    return lambda_l1 * penalty

L1 penalty on output layer weights of each subnetwork.

Encourages some subnetworks to output near-zero (feature selection). Applied only to the output layer to avoid over-shrinking intermediate representations.

def l2_ridge_penalty( feature_networks: list[torch.nn.modules.module.Module], lambda_l2: float = 0.0001) -> torch.Tensor:
225def l2_ridge_penalty(
226    feature_networks: list["torch.nn.Module"],
227    lambda_l2: float = 1e-4,
228) -> torch.Tensor:
229    """L2 ridge penalty across all subnetwork weights.
230
231    Standard weight decay. Stabilises training especially when many
232    subnetworks sum together — without it, individual nets can grow
233    large while cancelling each other out.
234    """
235    penalty = torch.tensor(0.0)
236    for net in feature_networks:
237        for param in net.parameters():
238            penalty = penalty + (param ** 2).sum()
239    return lambda_l2 * penalty

L2 ridge penalty across all subnetwork weights.

Standard weight decay. Stabilises training especially when many subnetworks sum together — without it, individual nets can grow large while cancelling each other out.

@dataclass
class ShapeFunction:
 33@dataclass
 34class ShapeFunction:
 35    """Extracted shape function for one feature.
 36
 37    Contains the evaluated curve (x_values, f_values) and metadata for
 38    reporting. Created by ANAM.shape_functions() after fitting.
 39
 40    Attributes
 41    ----------
 42    feature_name:
 43        Feature identifier.
 44    feature_type:
 45        'continuous' or 'categorical'.
 46    x_values:
 47        Grid of feature values (continuous) or category indices (categorical).
 48    f_values:
 49        Corresponding subnetwork outputs f_i(x_i).
 50    x_label:
 51        Human-readable x-axis label for plots.
 52    monotonicity:
 53        Monotonicity constraint applied during training.
 54    category_labels:
 55        Optional mapping from category index to label string.
 56    """
 57
 58    feature_name: str
 59    feature_type: str
 60    x_values: np.ndarray
 61    f_values: np.ndarray
 62    x_label: str = ""
 63    monotonicity: str = "none"
 64    category_labels: Optional[Dict[int, str]] = None
 65
 66    def to_polars(self) -> pl.DataFrame:
 67        """Export shape function as a Polars DataFrame.
 68
 69        For continuous features: columns [x, f_x]
 70        For categorical features: columns [category_index, category_label, f_x]
 71        """
 72        if self.feature_type == "continuous":
 73            return pl.DataFrame(
 74                {
 75                    "x": self.x_values.tolist(),
 76                    "f_x": self.f_values.tolist(),
 77                    "feature": [self.feature_name] * len(self.x_values),
 78                }
 79            )
 80        else:
 81            labels = [
 82                self.category_labels.get(int(i), str(int(i)))
 83                if self.category_labels
 84                else str(int(i))
 85                for i in self.x_values
 86            ]
 87            return pl.DataFrame(
 88                {
 89                    "category_index": self.x_values.astype(int).tolist(),
 90                    "category_label": labels,
 91                    "f_x": self.f_values.tolist(),
 92                    "feature": [self.feature_name] * len(self.x_values),
 93                }
 94            )
 95
 96    def to_relativities(self, base_level: Optional[float] = None) -> pl.DataFrame:
 97        """Convert shape function to GLM-style multiplicative relativities.
 98
 99        For log-link models, the shape function is on the log scale.
100        exp(f_i(x_i)) gives the multiplicative factor. This is then
101        normalised by the value at the base level (default: median x).
102
103        Parameters
104        ----------
105        base_level:
106            Feature value to use as base (relativity = 1.0). For
107            continuous features, defaults to median. For categorical,
108            defaults to category with f_x closest to zero.
109
110        Returns
111        -------
112        pl.DataFrame
113            Columns: [x, relativity, log_relativity]
114        """
115        if self.feature_type == "continuous":
116            if base_level is None:
117                mid_idx = len(self.x_values) // 2
118                base_f = self.f_values[mid_idx]
119            else:
120                # Find nearest grid point
121                idx = np.argmin(np.abs(self.x_values - base_level))
122                base_f = self.f_values[idx]
123
124            log_rel = self.f_values - base_f
125            rel = np.exp(log_rel)
126
127            return pl.DataFrame(
128                {
129                    "x": self.x_values.tolist(),
130                    "relativity": rel.tolist(),
131                    "log_relativity": log_rel.tolist(),
132                    "feature": [self.feature_name] * len(self.x_values),
133                }
134            )
135        else:
136            if base_level is None:
137                base_f = self.f_values[np.argmin(np.abs(self.f_values))]
138            else:
139                base_f = self.f_values[int(base_level)]
140
141            log_rel = self.f_values - base_f
142            rel = np.exp(log_rel)
143            labels = [
144                self.category_labels.get(int(i), str(int(i)))
145                if self.category_labels
146                else str(int(i))
147                for i in self.x_values
148            ]
149
150            return pl.DataFrame(
151                {
152                    "category_index": self.x_values.astype(int).tolist(),
153                    "category_label": labels,
154                    "relativity": rel.tolist(),
155                    "log_relativity": log_rel.tolist(),
156                    "feature": [self.feature_name] * len(self.x_values),
157                }
158            )
159
160    def to_dict(self) -> Dict[str, Any]:
161        """Serialise to a plain dict (JSON-compatible)."""
162        return {
163            "feature_name": self.feature_name,
164            "feature_type": self.feature_type,
165            "monotonicity": self.monotonicity,
166            "x_values": self.x_values.tolist(),
167            "f_values": self.f_values.tolist(),
168            "x_label": self.x_label,
169            "category_labels": self.category_labels,
170        }
171
172    def to_json(self, indent: int = 2) -> str:
173        """Serialise to JSON string."""
174        return json.dumps(self.to_dict(), indent=indent)
175
176    def plot(
177        self,
178        ax: Optional[Any] = None,
179        show_monotonicity: bool = True,
180        title: Optional[str] = None,
181        figsize: Tuple[int, int] = (8, 4),
182    ) -> Any:
183        """Plot the shape function.
184
185        Parameters
186        ----------
187        ax:
188            Matplotlib axes object. Creates a new figure if None.
189        show_monotonicity:
190            Annotate the plot with the monotonicity constraint.
191        title:
192            Plot title. Defaults to feature name.
193
194        Returns
195        -------
196        matplotlib.axes.Axes
197        """
198        import matplotlib.pyplot as plt
199
200        if ax is None:
201            _, ax = plt.subplots(figsize=figsize)
202
203        if self.feature_type == "continuous":
204            ax.plot(self.x_values, self.f_values, linewidth=2, color="#2c6fad")
205            ax.axhline(0, color="gray", linewidth=0.8, linestyle="--", alpha=0.5)
206            ax.fill_between(
207                self.x_values, self.f_values, 0,
208                alpha=0.15, color="#2c6fad"
209            )
210            ax.set_xlabel(self.x_label or self.feature_name)
211            ax.set_ylabel("log contribution")
212
213            if show_monotonicity and self.monotonicity != "none":
214                mono_label = f"monotone {self.monotonicity}"
215                ax.annotate(
216                    mono_label,
217                    xy=(0.02, 0.95),
218                    xycoords="axes fraction",
219                    fontsize=9,
220                    color="darkgreen",
221                    va="top",
222                )
223
224        else:
225            # Categorical: bar chart
226            labels = [
227                self.category_labels.get(int(i), str(int(i)))
228                if self.category_labels
229                else str(int(i))
230                for i in self.x_values
231            ]
232            colors = [
233                "#d63031" if v < 0 else "#0984e3" for v in self.f_values
234            ]
235            ax.bar(labels, self.f_values, color=colors, edgecolor="white")
236            ax.axhline(0, color="gray", linewidth=0.8, linestyle="--")
237            ax.set_xlabel(self.x_label or self.feature_name)
238            ax.set_ylabel("log contribution")
239            ax.tick_params(axis="x", rotation=45)
240
241        ax.set_title(title or f"Shape function: {self.feature_name}")
242        ax.spines["top"].set_visible(False)
243        ax.spines["right"].set_visible(False)
244
245        return ax

Extracted shape function for one feature.

Contains the evaluated curve (x_values, f_values) and metadata for reporting. Created by ANAM.shape_functions() after fitting.

Attributes

feature_name: Feature identifier. feature_type: 'continuous' or 'categorical'. x_values: Grid of feature values (continuous) or category indices (categorical). f_values: Corresponding subnetwork outputs f_i(x_i). x_label: Human-readable x-axis label for plots. monotonicity: Monotonicity constraint applied during training. category_labels: Optional mapping from category index to label string.

ShapeFunction( feature_name: str, feature_type: str, x_values: numpy.ndarray, f_values: numpy.ndarray, x_label: str = '', monotonicity: str = 'none', category_labels: Optional[Dict[int, str]] = None)
feature_name: str
feature_type: str
x_values: numpy.ndarray
f_values: numpy.ndarray
x_label: str = ''
monotonicity: str = 'none'
category_labels: Optional[Dict[int, str]] = None
def to_polars(self) -> polars.dataframe.frame.DataFrame:
66    def to_polars(self) -> pl.DataFrame:
67        """Export shape function as a Polars DataFrame.
68
69        For continuous features: columns [x, f_x]
70        For categorical features: columns [category_index, category_label, f_x]
71        """
72        if self.feature_type == "continuous":
73            return pl.DataFrame(
74                {
75                    "x": self.x_values.tolist(),
76                    "f_x": self.f_values.tolist(),
77                    "feature": [self.feature_name] * len(self.x_values),
78                }
79            )
80        else:
81            labels = [
82                self.category_labels.get(int(i), str(int(i)))
83                if self.category_labels
84                else str(int(i))
85                for i in self.x_values
86            ]
87            return pl.DataFrame(
88                {
89                    "category_index": self.x_values.astype(int).tolist(),
90                    "category_label": labels,
91                    "f_x": self.f_values.tolist(),
92                    "feature": [self.feature_name] * len(self.x_values),
93                }
94            )

Export shape function as a Polars DataFrame.

For continuous features: columns [x, f_x] For categorical features: columns [category_index, category_label, f_x]

def to_relativities( self, base_level: Optional[float] = None) -> polars.dataframe.frame.DataFrame:
 96    def to_relativities(self, base_level: Optional[float] = None) -> pl.DataFrame:
 97        """Convert shape function to GLM-style multiplicative relativities.
 98
 99        For log-link models, the shape function is on the log scale.
100        exp(f_i(x_i)) gives the multiplicative factor. This is then
101        normalised by the value at the base level (default: median x).
102
103        Parameters
104        ----------
105        base_level:
106            Feature value to use as base (relativity = 1.0). For
107            continuous features, defaults to median. For categorical,
108            defaults to category with f_x closest to zero.
109
110        Returns
111        -------
112        pl.DataFrame
113            Columns: [x, relativity, log_relativity]
114        """
115        if self.feature_type == "continuous":
116            if base_level is None:
117                mid_idx = len(self.x_values) // 2
118                base_f = self.f_values[mid_idx]
119            else:
120                # Find nearest grid point
121                idx = np.argmin(np.abs(self.x_values - base_level))
122                base_f = self.f_values[idx]
123
124            log_rel = self.f_values - base_f
125            rel = np.exp(log_rel)
126
127            return pl.DataFrame(
128                {
129                    "x": self.x_values.tolist(),
130                    "relativity": rel.tolist(),
131                    "log_relativity": log_rel.tolist(),
132                    "feature": [self.feature_name] * len(self.x_values),
133                }
134            )
135        else:
136            if base_level is None:
137                base_f = self.f_values[np.argmin(np.abs(self.f_values))]
138            else:
139                base_f = self.f_values[int(base_level)]
140
141            log_rel = self.f_values - base_f
142            rel = np.exp(log_rel)
143            labels = [
144                self.category_labels.get(int(i), str(int(i)))
145                if self.category_labels
146                else str(int(i))
147                for i in self.x_values
148            ]
149
150            return pl.DataFrame(
151                {
152                    "category_index": self.x_values.astype(int).tolist(),
153                    "category_label": labels,
154                    "relativity": rel.tolist(),
155                    "log_relativity": log_rel.tolist(),
156                    "feature": [self.feature_name] * len(self.x_values),
157                }
158            )

Convert shape function to GLM-style multiplicative relativities.

For log-link models, the shape function is on the log scale. exp(f_i(x_i)) gives the multiplicative factor. This is then normalised by the value at the base level (default: median x).

Parameters

base_level: Feature value to use as base (relativity = 1.0). For continuous features, defaults to median. For categorical, defaults to category with f_x closest to zero.

Returns

pl.DataFrame Columns: [x, relativity, log_relativity]

def to_dict(self) -> Dict[str, Any]:
160    def to_dict(self) -> Dict[str, Any]:
161        """Serialise to a plain dict (JSON-compatible)."""
162        return {
163            "feature_name": self.feature_name,
164            "feature_type": self.feature_type,
165            "monotonicity": self.monotonicity,
166            "x_values": self.x_values.tolist(),
167            "f_values": self.f_values.tolist(),
168            "x_label": self.x_label,
169            "category_labels": self.category_labels,
170        }

Serialise to a plain dict (JSON-compatible).

def to_json(self, indent: int = 2) -> str:
172    def to_json(self, indent: int = 2) -> str:
173        """Serialise to JSON string."""
174        return json.dumps(self.to_dict(), indent=indent)

Serialise to JSON string.

def plot( self, ax: Optional[Any] = None, show_monotonicity: bool = True, title: Optional[str] = None, figsize: Tuple[int, int] = (8, 4)) -> Any:
176    def plot(
177        self,
178        ax: Optional[Any] = None,
179        show_monotonicity: bool = True,
180        title: Optional[str] = None,
181        figsize: Tuple[int, int] = (8, 4),
182    ) -> Any:
183        """Plot the shape function.
184
185        Parameters
186        ----------
187        ax:
188            Matplotlib axes object. Creates a new figure if None.
189        show_monotonicity:
190            Annotate the plot with the monotonicity constraint.
191        title:
192            Plot title. Defaults to feature name.
193
194        Returns
195        -------
196        matplotlib.axes.Axes
197        """
198        import matplotlib.pyplot as plt
199
200        if ax is None:
201            _, ax = plt.subplots(figsize=figsize)
202
203        if self.feature_type == "continuous":
204            ax.plot(self.x_values, self.f_values, linewidth=2, color="#2c6fad")
205            ax.axhline(0, color="gray", linewidth=0.8, linestyle="--", alpha=0.5)
206            ax.fill_between(
207                self.x_values, self.f_values, 0,
208                alpha=0.15, color="#2c6fad"
209            )
210            ax.set_xlabel(self.x_label or self.feature_name)
211            ax.set_ylabel("log contribution")
212
213            if show_monotonicity and self.monotonicity != "none":
214                mono_label = f"monotone {self.monotonicity}"
215                ax.annotate(
216                    mono_label,
217                    xy=(0.02, 0.95),
218                    xycoords="axes fraction",
219                    fontsize=9,
220                    color="darkgreen",
221                    va="top",
222                )
223
224        else:
225            # Categorical: bar chart
226            labels = [
227                self.category_labels.get(int(i), str(int(i)))
228                if self.category_labels
229                else str(int(i))
230                for i in self.x_values
231            ]
232            colors = [
233                "#d63031" if v < 0 else "#0984e3" for v in self.f_values
234            ]
235            ax.bar(labels, self.f_values, color=colors, edgecolor="white")
236            ax.axhline(0, color="gray", linewidth=0.8, linestyle="--")
237            ax.set_xlabel(self.x_label or self.feature_name)
238            ax.set_ylabel("log contribution")
239            ax.tick_params(axis="x", rotation=45)
240
241        ax.set_title(title or f"Shape function: {self.feature_name}")
242        ax.spines["top"].set_visible(False)
243        ax.spines["right"].set_visible(False)
244
245        return ax

Plot the shape function.

Parameters

ax: Matplotlib axes object. Creates a new figure if None. show_monotonicity: Annotate the plot with the monotonicity constraint. title: Plot title. Defaults to feature name.

Returns

matplotlib.axes.Axes

def extract_shape_functions( model: "'ANAMModel'", X_train: numpy.ndarray, n_points: int = 200, category_labels: Optional[Dict[str, Dict[int, str]]] = None) -> Dict[str, ShapeFunction]:
248def extract_shape_functions(
249    model: "ANAMModel",
250    X_train: np.ndarray,
251    n_points: int = 200,
252    category_labels: Optional[Dict[str, Dict[int, str]]] = None,
253) -> Dict[str, ShapeFunction]:
254    """Extract all shape functions from a trained model.
255
256    Evaluates each subnetwork over the observed range of the training data.
257
258    Parameters
259    ----------
260    model:
261        Trained ANAMModel.
262    X_train:
263        Training feature matrix. Used to determine observed feature ranges
264        and category sets.
265    n_points:
266        Number of grid points for continuous features.
267    category_labels:
268        Optional dict mapping feature_name -> {category_idx -> label string}.
269
270    Returns
271    -------
272    Dict[str, ShapeFunction]
273        Feature name -> ShapeFunction for each feature.
274    """
275    shapes: Dict[str, ShapeFunction] = {}
276
277    model.eval()
278    with torch.no_grad():
279        for i, cfg in enumerate(model.feature_configs):
280            col = X_train[:, i]
281            net = model.feature_nets[cfg.name]
282            cat_labels = (
283                category_labels.get(cfg.name) if category_labels else None
284            )
285
286            if cfg.feature_type == "continuous":
287                x_min, x_max = float(col.min()), float(col.max())
288                x_grid = torch.linspace(x_min, x_max, n_points)
289                f_vals = net(x_grid.unsqueeze(-1)).squeeze(-1).cpu().numpy()
290                x_vals = x_grid.cpu().numpy()
291
292            else:
293                unique_cats = np.unique(col.astype(int))
294                x_cat = torch.tensor(unique_cats, dtype=torch.long)
295                f_vals = net(x_cat).squeeze(-1).cpu().numpy()
296                x_vals = unique_cats.astype(float)
297
298            shapes[cfg.name] = ShapeFunction(
299                feature_name=cfg.name,
300                feature_type=cfg.feature_type,
301                x_values=x_vals,
302                f_values=f_vals,
303                x_label=cfg.name,
304                monotonicity=cfg.monotonicity if hasattr(cfg, "monotonicity") else "none",
305                category_labels=cat_labels,
306            )
307
308    return shapes

Extract all shape functions from a trained model.

Evaluates each subnetwork over the observed range of the training data.

Parameters

model: Trained ANAMModel. X_train: Training feature matrix. Used to determine observed feature ranges and category sets. n_points: Number of grid points for continuous features. category_labels: Optional dict mapping feature_name -> {category_idx -> label string}.

Returns

Dict[str, ShapeFunction] Feature name -> ShapeFunction for each feature.

def plot_all_shapes( shapes: Dict[str, ShapeFunction], n_cols: int = 3, figsize_per_plot: Tuple[int, int] = (5, 3), suptitle: str = 'ANAM Shape Functions') -> Any:
311def plot_all_shapes(
312    shapes: Dict[str, ShapeFunction],
313    n_cols: int = 3,
314    figsize_per_plot: Tuple[int, int] = (5, 3),
315    suptitle: str = "ANAM Shape Functions",
316) -> Any:
317    """Plot all shape functions in a grid layout.
318
319    Parameters
320    ----------
321    shapes:
322        Dict returned by extract_shape_functions().
323    n_cols:
324        Number of columns in the subplot grid.
325    figsize_per_plot:
326        Width and height per subplot panel.
327    suptitle:
328        Overall figure title.
329
330    Returns
331    -------
332    matplotlib.figure.Figure
333    """
334    import matplotlib.pyplot as plt
335
336    n = len(shapes)
337    n_rows = (n + n_cols - 1) // n_cols
338    fig_w = figsize_per_plot[0] * n_cols
339    fig_h = figsize_per_plot[1] * n_rows
340
341    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h))
342    if n == 1:
343        axes = np.array([[axes]])
344    elif n_rows == 1:
345        axes = axes.reshape(1, -1)
346    elif n_cols == 1:
347        axes = axes.reshape(-1, 1)
348
349    axes_flat = axes.flatten()
350
351    for ax, (name, sf) in zip(axes_flat, shapes.items()):
352        sf.plot(ax=ax)
353
354    # Hide unused subplots
355    for ax in axes_flat[n:]:
356        ax.set_visible(False)
357
358    fig.suptitle(suptitle, fontsize=13, y=1.01)
359    fig.tight_layout()
360    return fig

Plot all shape functions in a grid layout.

Parameters

shapes: Dict returned by extract_shape_functions(). n_cols: Number of columns in the subplot grid. figsize_per_plot: Width and height per subplot panel. suptitle: Overall figure title.

Returns

matplotlib.figure.Figure

class StandardScaler:
33class StandardScaler:
34    """Simple StandardScaler that tracks feature names and ranges.
35
36    Stores mean/std for inverse transformation (needed to recover original
37    feature values when plotting shape functions).
38    """
39
40    def __init__(self) -> None:
41        self.means_: Optional[np.ndarray] = None
42        self.stds_: Optional[np.ndarray] = None
43        self.feature_names_: Optional[List[str]] = None
44
45    def fit(
46        self,
47        X: np.ndarray,
48        feature_names: Optional[List[str]] = None,
49    ) -> "StandardScaler":
50        self.means_ = X.mean(axis=0)
51        self.stds_ = X.std(axis=0).clip(min=1e-8)
52        self.feature_names_ = feature_names or [f"f{i}" for i in range(X.shape[1])]
53        return self
54
55    def transform(self, X: np.ndarray) -> np.ndarray:
56        if self.means_ is None:
57            raise RuntimeError("Call fit() before transform().")
58        return (X - self.means_) / self.stds_
59
60    def fit_transform(self, X: np.ndarray, feature_names: Optional[List[str]] = None) -> np.ndarray:
61        return self.fit(X, feature_names).transform(X)
62
63    def inverse_transform(self, X_scaled: np.ndarray) -> np.ndarray:
64        if self.means_ is None:
65            raise RuntimeError("Call fit() before inverse_transform().")
66        return X_scaled * self.stds_ + self.means_
67
68    def inverse_transform_col(self, x_scaled: np.ndarray, col_idx: int) -> np.ndarray:
69        """Inverse transform a single column."""
70        assert self.means_ is not None
71        return x_scaled * self.stds_[col_idx] + self.means_[col_idx]

Simple StandardScaler that tracks feature names and ranges.

Stores mean/std for inverse transformation (needed to recover original feature values when plotting shape functions).

means_: Optional[numpy.ndarray]
stds_: Optional[numpy.ndarray]
feature_names_: Optional[List[str]]
def fit( self, X: numpy.ndarray, feature_names: Optional[List[str]] = None) -> StandardScaler:
45    def fit(
46        self,
47        X: np.ndarray,
48        feature_names: Optional[List[str]] = None,
49    ) -> "StandardScaler":
50        self.means_ = X.mean(axis=0)
51        self.stds_ = X.std(axis=0).clip(min=1e-8)
52        self.feature_names_ = feature_names or [f"f{i}" for i in range(X.shape[1])]
53        return self
def transform(self, X: numpy.ndarray) -> numpy.ndarray:
55    def transform(self, X: np.ndarray) -> np.ndarray:
56        if self.means_ is None:
57            raise RuntimeError("Call fit() before transform().")
58        return (X - self.means_) / self.stds_
def fit_transform( self, X: numpy.ndarray, feature_names: Optional[List[str]] = None) -> numpy.ndarray:
60    def fit_transform(self, X: np.ndarray, feature_names: Optional[List[str]] = None) -> np.ndarray:
61        return self.fit(X, feature_names).transform(X)
def inverse_transform(self, X_scaled: numpy.ndarray) -> numpy.ndarray:
63    def inverse_transform(self, X_scaled: np.ndarray) -> np.ndarray:
64        if self.means_ is None:
65            raise RuntimeError("Call fit() before inverse_transform().")
66        return X_scaled * self.stds_ + self.means_
def inverse_transform_col(self, x_scaled: numpy.ndarray, col_idx: int) -> numpy.ndarray:
68    def inverse_transform_col(self, x_scaled: np.ndarray, col_idx: int) -> np.ndarray:
69        """Inverse transform a single column."""
70        assert self.means_ is not None
71        return x_scaled * self.stds_[col_idx] + self.means_[col_idx]

Inverse transform a single column.

def select_interactions_correlation( X: numpy.ndarray, feature_names: List[str], threshold: float = 0.3, top_k: Optional[int] = 10, exclude_categorical: Optional[List[int]] = None) -> List[Tuple[str, str, float]]:
 79def select_interactions_correlation(
 80    X: np.ndarray,
 81    feature_names: List[str],
 82    threshold: float = 0.3,
 83    top_k: Optional[int] = 10,
 84    exclude_categorical: Optional[List[int]] = None,
 85) -> List[Tuple[str, str, float]]:
 86    """Screen feature pairs for interaction candidates using Pearson correlation.
 87
 88    Pairs with |r| above threshold are candidates for interaction subnetworks.
 89    Lower |r| means the two features are more independent in the feature space
 90    — which paradoxically can mean their interaction is more surprising and
 91    worth modelling. This heuristic selects the *highest* correlation pairs
 92    because they often represent rate structures where the joint effect differs
 93    from the sum of main effects.
 94
 95    Parameters
 96    ----------
 97    X:
 98        Feature matrix (n, p). Should be the continuous features only.
 99    feature_names:
100        Names matching columns of X.
101    threshold:
102        Minimum |correlation| to include pair. Default 0.3.
103    top_k:
104        Maximum number of pairs to return. None = all above threshold.
105    exclude_categorical:
106        Column indices of categorical features to exclude from screening.
107
108    Returns
109    -------
110    List[Tuple[str, str, float]]
111        List of (feature_i, feature_j, correlation) sorted by |r| descending.
112    """
113    exclude = set(exclude_categorical or [])
114    p = X.shape[1]
115
116    candidates: List[Tuple[str, str, float]] = []
117
118    for i in range(p):
119        if i in exclude:
120            continue
121        for j in range(i + 1, p):
122            if j in exclude:
123                continue
124            corr = float(np.corrcoef(X[:, i], X[:, j])[0, 1])
125            if abs(corr) >= threshold:
126                candidates.append((feature_names[i], feature_names[j], corr))
127
128    candidates.sort(key=lambda x: abs(x[2]), reverse=True)
129
130    if top_k is not None:
131        candidates = candidates[:top_k]
132
133    return candidates

Screen feature pairs for interaction candidates using Pearson correlation.

Pairs with |r| above threshold are candidates for interaction subnetworks. Lower |r| means the two features are more independent in the feature space — which paradoxically can mean their interaction is more surprising and worth modelling. This heuristic selects the highest correlation pairs because they often represent rate structures where the joint effect differs from the sum of main effects.

Parameters

X: Feature matrix (n, p). Should be the continuous features only. feature_names: Names matching columns of X. threshold: Minimum |correlation| to include pair. Default 0.3. top_k: Maximum number of pairs to return. None = all above threshold. exclude_categorical: Column indices of categorical features to exclude from screening.

Returns

List[Tuple[str, str, float]] List of (feature_i, feature_j, correlation) sorted by |r| descending.

def select_interactions_residual( X: numpy.ndarray, y_residuals: numpy.ndarray, feature_names: List[str], top_k: int = 5, exclude_categorical: Optional[List[int]] = None) -> List[Tuple[str, str, float]]:
136def select_interactions_residual(
137    X: np.ndarray,
138    y_residuals: np.ndarray,
139    feature_names: List[str],
140    top_k: int = 5,
141    exclude_categorical: Optional[List[int]] = None,
142) -> List[Tuple[str, str, float]]:
143    """Select interaction pairs by pairwise product correlation with residuals.
144
145    After fitting the additive model, residuals contain unexplained variance.
146    This method checks whether x_i * x_j correlates with residuals — if so,
147    the interaction term x_i x_j may be worth adding.
148
149    More principled than pure feature-space correlation for identifying
150    interactions that actually improve model fit.
151
152    Parameters
153    ----------
154    X:
155        Feature matrix.
156    y_residuals:
157        Model residuals (y - predicted).
158    feature_names:
159        Feature names.
160    top_k:
161        Number of top pairs to return.
162
163    Returns
164    -------
165    List[Tuple[str, str, float]]
166        Sorted by |correlation with residuals| descending.
167    """
168    exclude = set(exclude_categorical or [])
169    p = X.shape[1]
170    candidates: List[Tuple[str, str, float]] = []
171
172    for i in range(p):
173        if i in exclude:
174            continue
175        for j in range(i + 1, p):
176            if j in exclude:
177                continue
178            product = X[:, i] * X[:, j]
179            corr = float(np.corrcoef(product, y_residuals)[0, 1])
180            candidates.append((feature_names[i], feature_names[j], abs(corr)))
181
182    candidates.sort(key=lambda x: x[2], reverse=True)
183    return candidates[:top_k]

Select interaction pairs by pairwise product correlation with residuals.

After fitting the additive model, residuals contain unexplained variance. This method checks whether x_i * x_j correlates with residuals — if so, the interaction term x_i x_j may be worth adding.

More principled than pure feature-space correlation for identifying interactions that actually improve model fit.

Parameters

X: Feature matrix. y_residuals: Model residuals (y - predicted). feature_names: Feature names. top_k: Number of top pairs to return.

Returns

List[Tuple[str, str, float]] Sorted by |correlation with residuals| descending.

def shapes_to_relativity_table( shapes: "Dict[str, 'ShapeFunction']", feature_names: Optional[List[str]] = None) -> polars.dataframe.frame.DataFrame:
191def shapes_to_relativity_table(
192    shapes: Dict[str, "ShapeFunction"],
193    feature_names: Optional[List[str]] = None,
194) -> pl.DataFrame:
195    """Aggregate all shape functions into a single relativity table.
196
197    Outputs a long-format DataFrame with one row per (feature, level)
198    combination, suitable for actuarial review in Excel.
199
200    Parameters
201    ----------
202    shapes:
203        Dict returned by extract_shape_functions().
204    feature_names:
205        Subset of feature names to include. None = all features.
206
207    Returns
208    -------
209    pl.DataFrame
210        Columns: [feature, level, f_x, relativity, log_relativity]
211    """
212    names = feature_names or list(shapes.keys())
213    dfs: List[pl.DataFrame] = []
214
215    for name in names:
216        if name not in shapes:
217            continue
218        sf = shapes[name]
219        rel_df = sf.to_relativities()
220
221        if sf.feature_type == "continuous":
222            level_col = pl.Series("level", [str(round(float(x), 4)) for x in sf.x_values])
223        else:
224            level_col = pl.Series(
225                "level",
226                [
227                    sf.category_labels.get(int(i), str(int(i))) if sf.category_labels else str(int(i))
228                    for i in sf.x_values
229                ],
230            )
231
232        feature_col = pl.Series("feature", [name] * len(sf.x_values))
233        f_x_col = pl.Series("f_x", sf.f_values.tolist())
234
235        rel_col = rel_df["relativity"]
236        log_rel_col = rel_df["log_relativity"]
237
238        dfs.append(
239            pl.DataFrame(
240                {
241                    "feature": feature_col,
242                    "level": level_col,
243                    "f_x": f_x_col,
244                    "relativity": rel_col,
245                    "log_relativity": log_rel_col,
246                }
247            )
248        )
249
250    if not dfs:
251        return pl.DataFrame(
252            schema={
253                "feature": pl.Utf8,
254                "level": pl.Utf8,
255                "f_x": pl.Float64,
256                "relativity": pl.Float64,
257                "log_relativity": pl.Float64,
258            }
259        )
260
261    return pl.concat(dfs)

Aggregate all shape functions into a single relativity table.

Outputs a long-format DataFrame with one row per (feature, level) combination, suitable for actuarial review in Excel.

Parameters

shapes: Dict returned by extract_shape_functions(). feature_names: Subset of feature names to include. None = all features.

Returns

pl.DataFrame Columns: [feature, level, f_x, relativity, log_relativity]

def compare_shapes_to_glm( anam_shapes: "Dict[str, 'ShapeFunction']", glm_coefficients: Dict[str, Dict[str, float]]) -> polars.dataframe.frame.DataFrame:
264def compare_shapes_to_glm(
265    anam_shapes: Dict[str, "ShapeFunction"],
266    glm_coefficients: Dict[str, Dict[str, float]],
267) -> pl.DataFrame:
268    """Compare ANAM shape functions to GLM log-relativities.
269
270    For each feature level in the GLM, find the nearest ANAM shape function
271    value and compute the deviation. Useful for model validation against
272    existing production GLMs.
273
274    Parameters
275    ----------
276    anam_shapes:
277        ANAM shape functions (from extract_shape_functions).
278    glm_coefficients:
279        Dict mapping feature_name -> {level_str -> log_relativity}.
280        GLM log-relativities in the same scale as ANAM f_i outputs.
281
282    Returns
283    -------
284    pl.DataFrame
285        Columns: [feature, level, anam_f, glm_log_rel, deviation]
286    """
287    rows: List[Dict] = []
288
289    for feature_name, glm_levels in glm_coefficients.items():
290        if feature_name not in anam_shapes:
291            continue
292
293        sf = anam_shapes[feature_name]
294
295        for level_str, glm_val in glm_levels.items():
296            # Find ANAM value at this level
297            if sf.feature_type == "continuous":
298                try:
299                    x_level = float(level_str)
300                    idx = int(np.argmin(np.abs(sf.x_values - x_level)))
301                    anam_val = float(sf.f_values[idx])
302                except (ValueError, IndexError):
303                    continue
304            else:
305                try:
306                    cat_idx = int(level_str)
307                    if cat_idx < len(sf.f_values):
308                        anam_val = float(sf.f_values[cat_idx])
309                    else:
310                        continue
311                except (ValueError, IndexError):
312                    continue
313
314            rows.append(
315                {
316                    "feature": feature_name,
317                    "level": level_str,
318                    "anam_f": anam_val,
319                    "glm_log_rel": glm_val,
320                    "deviation": anam_val - glm_val,
321                }
322            )
323
324    if not rows:
325        return pl.DataFrame(
326            schema={
327                "feature": pl.Utf8,
328                "level": pl.Utf8,
329                "anam_f": pl.Float64,
330                "glm_log_rel": pl.Float64,
331                "deviation": pl.Float64,
332            }
333        )
334
335    return pl.DataFrame(rows)

Compare ANAM shape functions to GLM log-relativities.

For each feature level in the GLM, find the nearest ANAM shape function value and compute the deviation. Useful for model validation against existing production GLMs.

Parameters

anam_shapes: ANAM shape functions (from extract_shape_functions). glm_coefficients: Dict mapping feature_name -> {level_str -> log_relativity}. GLM log-relativities in the same scale as ANAM f_i outputs.

Returns

pl.DataFrame Columns: [feature, level, anam_f, glm_log_rel, deviation]

def compute_deviance_stat( y_true: numpy.ndarray, y_pred: numpy.ndarray, exposure: Optional[numpy.ndarray] = None, loss: str = 'poisson', tweedie_p: float = 1.5, eps: float = 1e-08) -> float:
338def compute_deviance_stat(
339    y_true: np.ndarray,
340    y_pred: np.ndarray,
341    exposure: Optional[np.ndarray] = None,
342    loss: str = "poisson",
343    tweedie_p: float = 1.5,
344    eps: float = 1e-8,
345) -> float:
346    """Compute weighted deviance statistic for model comparison.
347
348    Returns the mean deviance (lower is better). Useful for comparing
349    ANAM to GLM/EBM baselines with a single number.
350    """
351    import torch
352
353    y_t = torch.tensor(y_true, dtype=torch.float32)
354    yp_t = torch.tensor(y_pred, dtype=torch.float32)
355    w_t = torch.tensor(exposure, dtype=torch.float32) if exposure is not None else None
356
357    from .losses import gamma_deviance, poisson_deviance, tweedie_deviance
358
359    if loss == "poisson":
360        return float(poisson_deviance(yp_t, y_t, w_t, eps).item())
361    elif loss == "tweedie":
362        return float(tweedie_deviance(yp_t, y_t, tweedie_p, w_t, eps).item())
363    elif loss == "gamma":
364        return float(gamma_deviance(yp_t, y_t, w_t, eps).item())
365    else:
366        raise ValueError(f"Unknown loss: {loss!r}")

Compute weighted deviance statistic for model comparison.

Returns the mean deviance (lower is better). Useful for comparing ANAM to GLM/EBM baselines with a single number.